anfastech commited on
Commit
a62077e
·
0 Parent(s):

fix: resolve torch security error by pinning torch 2.6.0 and updating requirements

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ hello.wav
3
+ venv/
4
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ libffi-dev \
8
+ libsndfile1 \
9
+ libasound2 \
10
+ libxt6 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first to leverage Docker cache
14
+ COPY requirements.txt .
15
+
16
+ # Install PyTorch CPU first (2.6.0 available for CPU)
17
+ RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
18
+
19
+ # Install the rest of requirements
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application files
23
+ COPY . .
24
+
25
+ EXPOSE 7860
26
+
27
+ ENV PYTHONUNBUFFERED=1
28
+
29
+ # Run the application
30
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
Docs/ARCHITECTURE.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Engine Architecture
2
+
3
+ ## Clean Architecture Implementation
4
+
5
+ This AI engine follows clean architecture principles with proper separation of concerns.
6
+
7
+ ---
8
+
9
+ ## Module Structure
10
+
11
+ ```
12
+ diagnosis/ai_engine/
13
+ ├── detect_stuttering.py # Main detector class (business logic)
14
+ ├── model_loader.py # Singleton pattern for model loading
15
+ └── features.py # Feature extraction (ASR features)
16
+ ```
17
+
18
+ ---
19
+
20
+ ## Architecture Pattern
21
+
22
+ ### 1. Model Loader (`model_loader.py`)
23
+ **Responsibility**: Singleton pattern for model instance management
24
+
25
+ - Ensures models are loaded only once
26
+ - Provides clean interface: `get_stutter_detector()`
27
+ - Handles initialization and error handling
28
+ - Used by API layer (`app.py`)
29
+
30
+ **Usage:**
31
+ ```python
32
+ from diagnosis.ai_engine.model_loader import get_stutter_detector
33
+
34
+ detector = get_stutter_detector() # Singleton instance
35
+ ```
36
+
37
+ ---
38
+
39
+ ### 2. Feature Extractor (`features.py`)
40
+ **Responsibility**: Feature extraction from audio using IndicWav2Vec Hindi
41
+
42
+ **Class**: `ASRFeatureExtractor`
43
+
44
+ **Methods:**
45
+ - `extract_audio_features()` - Raw audio feature extraction
46
+ - `get_transcription_features()` - Transcription with confidence scores
47
+ - `get_word_level_features()` - Word-level timestamps and confidence
48
+
49
+ **Design Pattern**:
50
+ - Takes pre-loaded model and processor as dependencies
51
+ - Single responsibility: feature extraction only
52
+ - Reusable across different use cases
53
+
54
+ **Usage:**
55
+ ```python
56
+ from .features import ASRFeatureExtractor
57
+
58
+ extractor = ASRFeatureExtractor(model, processor, device)
59
+ features = extractor.get_transcription_features(audio)
60
+ ```
61
+
62
+ ---
63
+
64
+ ### 3. Detector (`detect_stuttering.py`)
65
+ **Responsibility**: High-level stutter detection orchestration
66
+
67
+ **Class**: `AdvancedStutterDetector`
68
+
69
+ **Design:**
70
+ - Uses feature extractor for transcription (composition)
71
+ - Orchestrates the analysis pipeline
72
+ - Returns structured results
73
+
74
+ **Flow:**
75
+ ```
76
+ Audio Input
77
+
78
+ Feature Extractor (ASR)
79
+
80
+ Text Analysis
81
+
82
+ Results
83
+ ```
84
+
85
+ ---
86
+
87
+ ## Benefits of This Architecture
88
+
89
+ ### ✅ Separation of Concerns
90
+ - **Model Loading**: Isolated in `model_loader.py`
91
+ - **Feature Extraction**: Isolated in `features.py`
92
+ - **Business Logic**: In `detect_stuttering.py`
93
+
94
+ ### ✅ Single Responsibility Principle
95
+ - Each module has one clear purpose
96
+ - Easy to test and maintain
97
+ - Easy to extend or replace components
98
+
99
+ ### ✅ Dependency Injection
100
+ - Feature extractor receives model/processor as dependencies
101
+ - No tight coupling
102
+ - Easy to mock for testing
103
+
104
+ ### ✅ Reusability
105
+ - Feature extractor can be used independently
106
+ - Model loader can be used by other modules
107
+ - Clean interfaces between layers
108
+
109
+ ---
110
+
111
+ ## Data Flow
112
+
113
+ ```
114
+ API Request (app.py)
115
+
116
+ get_stutter_detector() [model_loader.py]
117
+
118
+ AdvancedStutterDetector [detect_stuttering.py]
119
+
120
+ ASRFeatureExtractor [features.py]
121
+
122
+ IndicWav2Vec Hindi Model
123
+
124
+ Results back through layers
125
+ ```
126
+
127
+ ---
128
+
129
+ ## Comparison with Django App
130
+
131
+ **Before (Django App):**
132
+ - Model loading logic in Django app
133
+ - Feature extraction in Django app
134
+ - Tight coupling between web app and ML logic
135
+
136
+ **After (AI Engine Service):**
137
+ - ✅ Model loading in AI engine service
138
+ - ✅ Feature extraction in AI engine service
139
+ - ✅ Django app only calls API (loose coupling)
140
+ - ✅ ML logic isolated in dedicated service
141
+
142
+ ---
143
+
144
+ ## Extension Points
145
+
146
+ ### Adding New Features
147
+ 1. Add method to `ASRFeatureExtractor` in `features.py`
148
+ 2. Use in `AdvancedStutterDetector` via composition
149
+ 3. No changes needed to model loader
150
+
151
+ ### Adding New Models
152
+ 1. Update `detect_stuttering.py` to load new model
153
+ 2. Create new feature extractor if needed
154
+ 3. Model loader remains unchanged
155
+
156
+ ### Testing
157
+ - Mock `ASRFeatureExtractor` in tests
158
+ - Mock model loader for integration tests
159
+ - Each component can be tested independently
160
+
161
+ ---
162
+
163
+ ## Key Principles Applied
164
+
165
+ 1. **Dependency Inversion**: High-level modules don't depend on low-level modules
166
+ 2. **Open/Closed**: Open for extension, closed for modification
167
+ 3. **Interface Segregation**: Clean, focused interfaces
168
+ 4. **Don't Repeat Yourself (DRY)**: Feature extraction logic centralized
169
+ 5. **Single Source of Truth**: Model instance managed by singleton
170
+
171
+ ---
172
+
173
+ ## File Responsibilities
174
+
175
+ | File | Responsibility | Depends On |
176
+ |------|---------------|------------|
177
+ | `model_loader.py` | Singleton model management | `detect_stuttering.py` |
178
+ | `features.py` | Feature extraction | `transformers`, `torch` |
179
+ | `detect_stuttering.py` | Business logic orchestration | `features.py`, `model_loader.py` |
180
+ | `app.py` | API layer | `model_loader.py` |
181
+
182
+ ---
183
+
184
+ This architecture ensures the ML/AI logic stays in the AI engine service, not in the Django web application, following microservices best practices.
185
+
Docs/MODEL_SUMMARY.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Engine Model Summary
2
+
3
+ ## Simplified ASR-Only Configuration
4
+
5
+ This engine has been simplified to use **ONLY** the IndicWav2Vec Hindi model for Automatic Speech Recognition (ASR).
6
+
7
+ ---
8
+
9
+ ## Active Model
10
+
11
+ ### 1. IndicWav2Vec Hindi (Primary & Only Model)
12
+ - **Model ID**: `ai4bharat/indicwav2vec-hindi`
13
+ - **Type**: `Wav2Vec2ForCTC`
14
+ - **Purpose**: Automatic Speech Recognition (ASR) for Hindi and Indian languages
15
+ - **Status**: ✅ Active - Loaded at startup
16
+ - **Location**: `detect_stuttering.py` lines 26, 148-156
17
+ - **Authentication**: Requires `HF_TOKEN` environment variable
18
+
19
+ **Features:**
20
+ - Speech-to-text transcription
21
+ - Confidence scoring from model predictions
22
+ - Text-based stutter analysis (simple repetition detection)
23
+
24
+ ---
25
+
26
+ ## Removed Models
27
+
28
+ The following models have been **removed** to simplify the engine:
29
+
30
+ 1. ❌ **MMS Language Identification (LID)** - `facebook/mms-lid-126`
31
+ - Previously used for language detection
32
+ - No longer needed - IndicWav2Vec handles Hindi natively
33
+
34
+ 2. ❌ **Isolation Forest** (sklearn)
35
+ - Previously used for anomaly detection
36
+ - Removed - using simple text-based analysis instead
37
+
38
+ ---
39
+
40
+ ## Removed Libraries
41
+
42
+ The following signal processing libraries are no longer used:
43
+
44
+ - ❌ `parselmouth` (Praat) - Voice quality analysis
45
+ - ❌ `fastdtw` - Repetition detection via DTW
46
+ - ❌ `sklearn` - Machine learning algorithms
47
+ - ❌ Complex acoustic feature extraction (MFCC, formants, etc.)
48
+
49
+ ---
50
+
51
+ ## Current Pipeline
52
+
53
+ ```
54
+ Audio Input
55
+
56
+ IndicWav2Vec Hindi ASR
57
+
58
+ Text Transcription
59
+
60
+ Basic Text Analysis
61
+
62
+ Results (transcript + simple stutter detection)
63
+ ```
64
+
65
+ ---
66
+
67
+ ## API Response Format
68
+
69
+ The simplified engine returns:
70
+
71
+ ```json
72
+ {
73
+ "actual_transcript": "transcribed text",
74
+ "target_transcript": "expected text (if provided)",
75
+ "mismatched_chars": ["timestamps of low confidence regions"],
76
+ "mismatch_percentage": 0.0,
77
+ "ctc_loss_score": 0.0,
78
+ "stutter_timestamps": [{"type": "repetition", "start": 0.0, "end": 0.5, ...}],
79
+ "total_stutter_duration": 0.0,
80
+ "stutter_frequency": 0.0,
81
+ "severity": "none|mild|moderate|severe",
82
+ "confidence_score": 0.8,
83
+ "speaking_rate_sps": 0.0,
84
+ "analysis_duration_seconds": 0.0,
85
+ "model_version": "indicwav2vec-hindi-asr-v1"
86
+ }
87
+ ```
88
+
89
+ ---
90
+
91
+ ## Dependencies
92
+
93
+ **Required:**
94
+ - `transformers` 4.35.0 - For IndicWav2Vec model
95
+ - `torch` 2.0.1 - PyTorch backend
96
+ - `librosa` ≥0.10.0 - Audio loading (16kHz resampling)
97
+ - `numpy` - Array operations
98
+
99
+ **Optional (for legacy methods, not used in ASR mode):**
100
+ - `parselmouth` - Voice quality (not used)
101
+ - `fastdtw` - DTW algorithm (not used)
102
+ - `sklearn` - ML algorithms (not used)
103
+
104
+ ---
105
+
106
+ ## Usage
107
+
108
+ ```python
109
+ from diagnosis.ai_engine.detect_stuttering import get_stutter_detector
110
+
111
+ detector = get_stutter_detector()
112
+ result = detector.analyze_audio(
113
+ audio_path="path/to/audio.wav",
114
+ proper_transcript="expected text", # optional
115
+ language="hindi" # default: hindi
116
+ )
117
+
118
+ print(result['actual_transcript']) # ASR transcription
119
+ ```
120
+
121
+ ---
122
+
123
+ ## Notes
124
+
125
+ - The engine focuses **only** on ASR transcription
126
+ - Stutter detection is simplified to text-based repetition analysis
127
+ - No complex acoustic feature extraction
128
+ - Faster and lighter than the previous multi-model approach
129
+ - Optimized for Hindi but can handle other Indian languages
130
+
Docs/TRANSCRIPT_DEBUG.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Transcript Debugging Guide
2
+
3
+ ## Issue: Empty Transcripts ("No transcript available")
4
+
5
+ ## Complete Flow Analysis
6
+
7
+ ### 1. Django App → API Request (`slaq-version-c/diagnosis/ai_engine/detect_stuttering.py`)
8
+
9
+ **Location:** Line 269-274
10
+ ```python
11
+ response = requests.post(
12
+ self.api_url,
13
+ files=files,
14
+ data={
15
+ "transcript": proper_transcript if proper_transcript else "",
16
+ "language": lang_code,
17
+ },
18
+ timeout=self.api_timeout
19
+ )
20
+ ```
21
+
22
+ **Status:** ✅ Sending transcript parameter correctly
23
+
24
+ ---
25
+
26
+ ### 2. API Receives Request (`slaq-version-c-ai-enginee/app.py`)
27
+
28
+ **Location:** Line 70-73
29
+ ```python
30
+ @app.post("/analyze")
31
+ async def analyze_audio(
32
+ audio: UploadFile = File(...),
33
+ transcript: str = Form("") # ✅ Fixed: Now uses Form() for multipart
34
+ ):
35
+ ```
36
+
37
+ **Status:** ✅ Fixed - Now correctly receives transcript via Form()
38
+
39
+ ---
40
+
41
+ ### 3. API Calls Model (`slaq-version-c-ai-enginee/app.py`)
42
+
43
+ **Location:** Line 106
44
+ ```python
45
+ result = detector.analyze_audio(temp_file, transcript)
46
+ ```
47
+
48
+ **Status:** ✅ Passing transcript correctly
49
+
50
+ ---
51
+
52
+ ### 4. Model Transcribes Audio (`slaq-version-c-ai-enginee/diagnosis/ai_engine/detect_stuttering.py`)
53
+
54
+ **Location:** Line 313-369 (`_transcribe_with_timestamps`)
55
+
56
+ **Potential Issues:**
57
+ - ❓ IndicWav2Vec decoding might not work with `processor.batch_decode()`
58
+ - ❓ Need to use tokenizer directly
59
+ - ❓ Model might not be producing valid predictions
60
+
61
+ **Status:** ⚠️ **LIKELY ISSUE HERE** - Decoding method may be incorrect
62
+
63
+ ---
64
+
65
+ ### 5. Model Returns Result (`slaq-version-c-ai-enginee/diagnosis/ai_engine/detect_stuttering.py`)
66
+
67
+ **Location:** Line 787-794
68
+ ```python
69
+ actual_transcript = transcript if transcript else ""
70
+ target_transcript = proper_transcript if proper_transcript else transcript if transcript else ""
71
+
72
+ return {
73
+ 'actual_transcript': actual_transcript,
74
+ 'target_transcript': target_transcript,
75
+ ...
76
+ }
77
+ ```
78
+
79
+ **Status:** ✅ Returns transcripts correctly (if transcript is not empty)
80
+
81
+ ---
82
+
83
+ ### 6. API Returns Response (`slaq-version-c-ai-enginee/app.py`)
84
+
85
+ **Location:** Line 109-113
86
+ ```python
87
+ actual = result.get('actual_transcript', '')
88
+ target = result.get('target_transcript', '')
89
+ logger.info(f"📝 Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})")
90
+ return result
91
+ ```
92
+
93
+ **Status:** ✅ Returns JSON with transcripts
94
+
95
+ ---
96
+
97
+ ### 7. Django Receives Response (`slaq-version-c/diagnosis/ai_engine/detect_stuttering.py`)
98
+
99
+ **Location:** Line 279-410
100
+ ```python
101
+ result = response.json()
102
+ # ... formatting ...
103
+ actual_transcript = str(api_result.get('actual_transcript', '')).strip()
104
+ target_transcript = str(api_result.get('target_transcript', '')).strip()
105
+ ```
106
+
107
+ **Status:** ✅ Extracts transcripts correctly
108
+
109
+ ---
110
+
111
+ ### 8. Django Saves to Database (`slaq-version-c/diagnosis/tasks.py`)
112
+
113
+ **Location:** Line 141-142
114
+ ```python
115
+ actual_transcript=actual_transcript,
116
+ target_transcript=target_transcript,
117
+ ```
118
+
119
+ **Status:** ✅ Saves correctly
120
+
121
+ ---
122
+
123
+ ## Root Cause Analysis
124
+
125
+ ### Most Likely Issue: Transcription Decoding
126
+
127
+ The IndicWav2Vec model (`ai4bharat/indicwav2vec-hindi`) may require:
128
+ 1. **Direct tokenizer access** instead of `processor.batch_decode()`
129
+ 2. **CTC decoding** with proper tokenizer
130
+ 3. **Special handling** for Indic scripts
131
+
132
+ ### Fix Applied
133
+
134
+ Updated `_transcribe_with_timestamps()` to:
135
+ 1. Try multiple decoding methods
136
+ 2. Use tokenizer directly if available
137
+ 3. Add comprehensive error logging
138
+ 4. Log predicted IDs for debugging
139
+
140
+ ---
141
+
142
+ ## Debugging Steps
143
+
144
+ ### 1. Check API Logs
145
+
146
+ When processing audio, look for:
147
+ ```
148
+ 📝 Transcribed text: '...' (length: X)
149
+ 📝 Final return - Actual: '...' (len: X), Target: '...' (len: Y)
150
+ 📝 Result transcripts - Actual: '...' (len: X), Target: '...' (len: Y)
151
+ ```
152
+
153
+ ### 2. Check Django Logs
154
+
155
+ Look for:
156
+ ```
157
+ 📝 Final transcripts - Actual: X chars, Target: Y chars
158
+ 📝 Saving transcripts - Actual: X chars, Target: Y chars
159
+ ```
160
+
161
+ ### 3. Check Database
162
+
163
+ Query the `AnalysisResult` table:
164
+ ```sql
165
+ SELECT actual_transcript, target_transcript, LENGTH(actual_transcript) as actual_len, LENGTH(target_transcript) as target_len
166
+ FROM diagnosis_analysisresult
167
+ ORDER BY created_at DESC LIMIT 5;
168
+ ```
169
+
170
+ ### 4. Test API Directly
171
+
172
+ ```bash
173
+ curl -X POST "http://localhost:7860/analyze" \
174
+ -F "audio=@test.wav" \
175
+ -F "transcript=test transcript" \
176
+ -F "language=hin"
177
+ ```
178
+
179
+ Check the response JSON for `actual_transcript` and `target_transcript`.
180
+
181
+ ---
182
+
183
+ ## Next Steps
184
+
185
+ 1. **Rebuild Docker image** with latest changes
186
+ 2. **Check logs** during audio processing
187
+ 3. **Verify processor structure** - logs will show processor attributes
188
+ 4. **Test with Hindi audio** - model is optimized for Hindi
189
+ 5. **Check if model is loaded correctly** - verify HF_TOKEN is working
190
+
191
+ ---
192
+
193
+ ## Expected Log Output (Success)
194
+
195
+ ```
196
+ 🚀 Initializing Advanced AI Engine on cpu...
197
+ ✅ HF_TOKEN found - using authenticated model access
198
+ 📋 Processor type: <class 'transformers.models.wav2vec2.processing_wav2vec2.Wav2Vec2Processor'>
199
+ 📋 Processor attributes: ['batch_decode', 'decode', 'feature_extractor', 'tokenizer', ...]
200
+ 📋 Tokenizer type: <class 'transformers.models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizer'>
201
+ 📝 Transcribed text: 'नमस्ते मैं हिंदी बोल रहा हूं' (length: 25)
202
+ 📝 Final return - Actual: 'नमस्ते मैं हिंदी बोल रहा हूं' (len: 25), Target: '...' (len: X)
203
+ ```
204
+
205
+ ---
206
+
207
+ ## If Still Empty
208
+
209
+ 1. **Model may not be loaded correctly** - check HF_TOKEN
210
+ 2. **Audio format issue** - ensure 16kHz mono WAV
211
+ 3. **Model not producing predictions** - check predicted_ids in logs
212
+ 4. **Tokenizer mismatch** - IndicWav2Vec may need special tokenizer initialization
213
+
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Zlaqa Version B Ai Enginee
3
+ emoji: ⚡
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import logging
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ import gradio as gr
10
+
11
+ # Configure logging FIRST
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
15
+ stream=sys.stdout
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Add project root to path
20
+ sys.path.insert(0, str(Path(__file__).parent))
21
+
22
+ # Import detector using model loader (clean architecture)
23
+ try:
24
+ from diagnosis.ai_engine.model_loader import get_stutter_detector
25
+ logger.info("✅ Successfully imported model loader")
26
+ except ImportError as e:
27
+ logger.error(f"❌ Failed to import model loader: {e}")
28
+ raise
29
+
30
+ # Initialize FastAPI
31
+ app = FastAPI(
32
+ title="Stutter Detector API",
33
+ description="Speech analysis using Wav2Vec2 models for stutter detection",
34
+ version="1.0.0"
35
+ )
36
+
37
+ # Add CORS middleware
38
+ app.add_middleware(
39
+ CORSMiddleware,
40
+ allow_origins=["*"],
41
+ allow_credentials=True,
42
+ allow_methods=["*"],
43
+ allow_headers=["*"],
44
+ )
45
+
46
+ # Global detector instance
47
+ detector = None
48
+
49
+ @app.on_event("startup")
50
+ async def startup_event():
51
+ """Load models on startup"""
52
+ global detector
53
+ try:
54
+ logger.info("🚀 Startup event: Loading AI models...")
55
+ detector = get_stutter_detector()
56
+ logger.info("✅ Models loaded successfully!")
57
+ except Exception as e:
58
+ logger.error(f"❌ Failed to load models: {e}", exc_info=True)
59
+ raise
60
+
61
+ def gradio_analyze(audio_path, transcript=""):
62
+ """
63
+ Analyze audio for stuttering using Gradio interface
64
+ """
65
+ if not detector:
66
+ return {"error": "Models not loaded yet. Please try again later."}
67
+ try:
68
+ result = detector.analyze_audio(audio_path, transcript)
69
+ return result
70
+ except Exception as e:
71
+ return {"error": f"Analysis failed: {str(e)}"}
72
+
73
+ # Create Gradio interface
74
+ gradio_app = gr.Interface(
75
+ fn=gradio_analyze,
76
+ inputs=[
77
+ gr.Audio(type="filepath", label="Upload Audio File"),
78
+ gr.Textbox(label="Optional Transcript", placeholder="Enter expected transcript here...", lines=2)
79
+ ],
80
+ outputs=gr.JSON(label="Analysis Results"),
81
+ title="Stutter Detection",
82
+ description="Upload an audio file and optionally provide a transcript to analyze for stuttering."
83
+ )
84
+
85
+ # Mount Gradio app to FastAPI
86
+ gr.mount_gradio_app(app, gradio_app, path="/gradio")
87
+
88
+ @app.get("/health")
89
+ async def health_check():
90
+ """Health check endpoint"""
91
+ from datetime import datetime
92
+ return {
93
+ "status": "healthy",
94
+ "models_loaded": detector is not None,
95
+ "timestamp": datetime.utcnow().isoformat() + "Z"
96
+ }
97
+
98
+ @app.post("/analyze")
99
+ async def analyze_audio(
100
+ audio: UploadFile = File(...),
101
+ transcript: str = Form("")
102
+ ):
103
+ """
104
+ Analyze audio file for stuttering
105
+
106
+ Parameters:
107
+ - audio: WAV or MP3 audio file
108
+ - transcript: Optional expected transcript
109
+
110
+ Returns: Complete stutter analysis results
111
+ """
112
+ temp_file = None
113
+ try:
114
+ if not detector:
115
+ raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.")
116
+
117
+ logger.info(f"📥 Processing: {audio.filename}")
118
+
119
+ # Create temp directory if needed
120
+ temp_dir = "/tmp/stutter_analysis"
121
+ os.makedirs(temp_dir, exist_ok=True)
122
+
123
+ # Save uploaded file
124
+ temp_file = os.path.join(temp_dir, audio.filename)
125
+ content = await audio.read()
126
+
127
+ with open(temp_file, "wb") as f:
128
+ f.write(content)
129
+
130
+ logger.info(f"📂 Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
131
+
132
+ # Analyze
133
+ logger.info(f"🔄 Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
134
+ result = detector.analyze_audio(temp_file, transcript)
135
+
136
+ # Log transcript values from result
137
+ actual = result.get('actual_transcript', '')
138
+ target = result.get('target_transcript', '')
139
+ logger.info(f"✅ Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
140
+ logger.info(f"📝 Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})")
141
+ return result
142
+
143
+ except HTTPException:
144
+ raise
145
+ except Exception as e:
146
+ logger.error(f"❌ Error during analysis: {str(e)}", exc_info=True)
147
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
148
+
149
+ finally:
150
+ # Cleanup
151
+ if temp_file and os.path.exists(temp_file):
152
+ try:
153
+ os.remove(temp_file)
154
+ logger.info(f"🧹 Cleaned up: {temp_file}")
155
+ except Exception as e:
156
+ logger.warning(f"Could not clean up {temp_file}: {e}")
157
+
158
+ @app.get("/")
159
+ async def root():
160
+ """API documentation"""
161
+ return {
162
+ "name": "SLAQ Stutter Detector API",
163
+ "version": "1.0.0",
164
+ "status": "running",
165
+ "endpoints": {
166
+ "health": "GET /health",
167
+ "analyze": "POST /analyze (multipart: audio file + optional transcript field)",
168
+ "docs": "GET /docs (interactive API docs)",
169
+ "gradio": "GET /gradio (web UI for stutter detection)"
170
+ },
171
+ "models": {
172
+ "base": "facebook/wav2vec2-base-960h",
173
+ "large": "facebook/wav2vec2-large-960h-lv60-self",
174
+ "xlsr": "jonatasgrosman/wav2vec2-large-xlsr-53-english"
175
+ }
176
+ }
177
+
178
+ if __name__ == "__main__":
179
+ import uvicorn
180
+ logger.info("🚀 Starting SLAQ Stutter Detector API...")
181
+ uvicorn.run(
182
+ app,
183
+ host="0.0.0.0",
184
+ port=7860,
185
+ log_level="info"
186
+ )
diagnosis/ai_engine/detect_stuttering.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diagnosis/ai_engine/detect_stuttering.py
2
+ import os
3
+ import librosa
4
+ import torch
5
+ import logging
6
+ import numpy as np
7
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
8
+ import time
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Dict, Any, Tuple
11
+ # Simplified: Only using ASR transcription, removed complex signal processing libraries
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # === CONFIGURATION ===
16
+ MODEL_ID = "ai4bharat/indicwav2vec-hindi" # Only model used - IndicWav2Vec Hindi for ASR
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token for authenticated model access
19
+
20
+ INDIAN_LANGUAGES = {
21
+ 'hindi': 'hin', 'english': 'eng', 'tamil': 'tam', 'telugu': 'tel',
22
+ 'bengali': 'ben', 'marathi': 'mar', 'gujarati': 'guj', 'kannada': 'kan',
23
+ 'malayalam': 'mal', 'punjabi': 'pan', 'urdu': 'urd', 'assamese': 'asm',
24
+ 'odia': 'ory', 'bhojpuri': 'bho', 'maithili': 'mai'
25
+ }
26
+
27
+ # === RESEARCH-BASED THRESHOLDS (2024-2025 Literature) ===
28
+ # Prolongation Detection (Spectral Correlation + Duration)
29
+ PROLONGATION_CORRELATION_THRESHOLD = 0.90 # >0.9 spectral similarity
30
+ PROLONGATION_MIN_DURATION = 0.25 # >250ms (Revisiting Rule-Based, 2025)
31
+
32
+ # Block Detection (Silence Analysis)
33
+ BLOCK_SILENCE_THRESHOLD = 0.35 # >350ms silence mid-utterance
34
+ BLOCK_ENERGY_PERCENTILE = 10 # Bottom 10% energy = silence
35
+
36
+ # Repetition Detection (DTW + Text Matching)
37
+ REPETITION_DTW_THRESHOLD = 0.15 # Normalized DTW distance
38
+ REPETITION_MIN_SIMILARITY = 0.85 # Text-based similarity
39
+
40
+ # Speaking Rate Norms (syllables/second)
41
+ SPEECH_RATE_MIN = 2.0
42
+ SPEECH_RATE_MAX = 6.0
43
+ SPEECH_RATE_TYPICAL = 4.0
44
+
45
+ # Formant Analysis (Vowel Centralization - Research Finding)
46
+ # People who stutter show reduced vowel space area
47
+ VOWEL_SPACE_REDUCTION_THRESHOLD = 0.70 # 70% of typical area
48
+
49
+ # Voice Quality (Jitter, Shimmer, HNR)
50
+ JITTER_THRESHOLD = 0.01 # >1% jitter indicates instability
51
+ SHIMMER_THRESHOLD = 0.03 # >3% shimmer
52
+ HNR_THRESHOLD = 15.0 # <15 dB Harmonics-to-Noise Ratio
53
+
54
+ # Zero-Crossing Rate (Voiced/Unvoiced Discrimination)
55
+ ZCR_VOICED_THRESHOLD = 0.1 # Low ZCR = voiced
56
+ ZCR_UNVOICED_THRESHOLD = 0.3 # High ZCR = unvoiced
57
+
58
+ # Entropy-Based Uncertainty
59
+ ENTROPY_HIGH_THRESHOLD = 3.5 # High confusion in model predictions
60
+ CONFIDENCE_LOW_THRESHOLD = 0.40 # Low confidence frame threshold
61
+
62
+ @dataclass
63
+ class StutterEvent:
64
+ """Enhanced stutter event with multi-modal features"""
65
+ type: str # 'repetition', 'prolongation', 'block', 'dysfluency'
66
+ start: float
67
+ end: float
68
+ text: str
69
+ confidence: float
70
+ acoustic_features: Dict[str, float] = field(default_factory=dict)
71
+ voice_quality: Dict[str, float] = field(default_factory=dict)
72
+ formant_data: Dict[str, Any] = field(default_factory=dict)
73
+
74
+
75
+ class AdvancedStutterDetector:
76
+ """
77
+ 🎤 IndicWav2Vec Hindi ASR Engine
78
+
79
+ Simplified engine using ONLY ai4bharat/indicwav2vec-hindi for Automatic Speech Recognition.
80
+
81
+ Features:
82
+ - Speech-to-text transcription using IndicWav2Vec Hindi model
83
+ - Text-based stutter analysis from transcription
84
+ - Confidence scoring from model predictions
85
+ - Basic dysfluency detection from transcript patterns
86
+
87
+ Model: ai4bharat/indicwav2vec-hindi (Wav2Vec2ForCTC)
88
+ Purpose: Automatic Speech Recognition (ASR) for Hindi and Indian languages
89
+ """
90
+
91
+ def __init__(self):
92
+ logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...")
93
+ if HF_TOKEN:
94
+ logger.info("✅ HF_TOKEN found - using authenticated model access")
95
+ else:
96
+ logger.warning("⚠️ HF_TOKEN not found - model access may fail if authentication is required")
97
+ try:
98
+ # Wav2Vec2 Model Loading - IndicWav2Vec Hindi Model
99
+ self.processor = AutoProcessor.from_pretrained(
100
+ MODEL_ID,
101
+ token=HF_TOKEN
102
+ )
103
+ self.model = Wav2Vec2ForCTC.from_pretrained(
104
+ MODEL_ID,
105
+ token=HF_TOKEN,
106
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
107
+ ).to(DEVICE)
108
+ self.model.eval()
109
+
110
+ # Initialize feature extractor (clean architecture pattern)
111
+ from .features import ASRFeatureExtractor
112
+ self.feature_extractor = ASRFeatureExtractor(
113
+ model=self.model,
114
+ processor=self.processor,
115
+ device=DEVICE
116
+ )
117
+
118
+ # Debug: Log processor structure
119
+ logger.info(f"📋 Processor type: {type(self.processor)}")
120
+ if hasattr(self.processor, 'tokenizer'):
121
+ logger.info(f"📋 Tokenizer type: {type(self.processor.tokenizer)}")
122
+ if hasattr(self.processor, 'feature_extractor'):
123
+ logger.info(f"📋 Feature extractor type: {type(self.processor.feature_extractor)}")
124
+
125
+ logger.info("✅ IndicWav2Vec Hindi ASR Engine Loaded with Feature Extractor")
126
+ except Exception as e:
127
+ logger.error(f"🔥 Engine Failure: {e}")
128
+ raise
129
+
130
+ def _init_common_adapters(self):
131
+ """Not applicable - IndicWav2Vec Hindi doesn't use adapters"""
132
+ pass
133
+
134
+ def _activate_adapter(self, lang_code: str):
135
+ """Not applicable - IndicWav2Vec Hindi doesn't use adapters"""
136
+ logger.info(f"Using IndicWav2Vec Hindi model (optimized for Hindi)")
137
+ pass
138
+
139
+ # ===== LEGACY METHODS (NOT USED IN ASR-ONLY MODE) =====
140
+ # These methods are kept for reference but not called in the simplified ASR pipeline
141
+ # They require additional libraries (parselmouth, fastdtw, sklearn) that are not needed for ASR-only mode
142
+
143
+ def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]:
144
+ """Extract multi-modal acoustic features"""
145
+ features = {}
146
+
147
+ # MFCC (20 coefficients)
148
+ mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=20, hop_length=512)
149
+ features['mfcc'] = mfcc.T # Transpose for time x features
150
+
151
+ # Zero-Crossing Rate
152
+ zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0]
153
+ features['zcr'] = zcr
154
+
155
+ # RMS Energy
156
+ rms_energy = librosa.feature.rms(y=audio, hop_length=512)[0]
157
+ features['rms_energy'] = rms_energy
158
+
159
+ # Spectral Flux
160
+ stft = librosa.stft(audio, hop_length=512)
161
+ magnitude = np.abs(stft)
162
+ spectral_flux = np.sum(np.diff(magnitude, axis=1) * (np.diff(magnitude, axis=1) > 0), axis=0)
163
+ features['spectral_flux'] = spectral_flux
164
+
165
+ # Energy Entropy
166
+ frame_energy = np.sum(magnitude ** 2, axis=0)
167
+ frame_energy = frame_energy + 1e-10 # Avoid log(0)
168
+ energy_entropy = -np.sum((magnitude ** 2 / frame_energy) * np.log(magnitude ** 2 / frame_energy + 1e-10), axis=0)
169
+ features['energy_entropy'] = energy_entropy
170
+
171
+ # Formant Analysis using Parselmouth
172
+ try:
173
+ sound = parselmouth.Sound(audio_path)
174
+ formant = sound.to_formant_burg(time_step=0.01)
175
+ times = np.arange(0, sound.duration, 0.01)
176
+ f1, f2, f3, f4 = [], [], [], []
177
+
178
+ for t in times:
179
+ try:
180
+ f1.append(formant.get_value_at_time(1, t) if formant.get_value_at_time(1, t) > 0 else np.nan)
181
+ f2.append(formant.get_value_at_time(2, t) if formant.get_value_at_time(2, t) > 0 else np.nan)
182
+ f3.append(formant.get_value_at_time(3, t) if formant.get_value_at_time(3, t) > 0 else np.nan)
183
+ f4.append(formant.get_value_at_time(4, t) if formant.get_value_at_time(4, t) > 0 else np.nan)
184
+ except:
185
+ f1.append(np.nan)
186
+ f2.append(np.nan)
187
+ f3.append(np.nan)
188
+ f4.append(np.nan)
189
+
190
+ formants = np.array([f1, f2, f3, f4]).T
191
+ features['formants'] = formants
192
+
193
+ # Calculate vowel space area (F1-F2 plane)
194
+ valid_f1f2 = formants[~np.isnan(formants[:, 0]) & ~np.isnan(formants[:, 1]), :2]
195
+ if len(valid_f1f2) > 0:
196
+ # Convex hull area approximation
197
+ try:
198
+ hull = ConvexHull(valid_f1f2)
199
+ vowel_space_area = hull.volume
200
+ except:
201
+ vowel_space_area = np.nan
202
+ else:
203
+ vowel_space_area = np.nan
204
+
205
+ features['formant_summary'] = {
206
+ 'vowel_space_area': float(vowel_space_area) if not np.isnan(vowel_space_area) else 0.0,
207
+ 'f1_mean': float(np.nanmean(f1)) if len(f1) > 0 else 0.0,
208
+ 'f2_mean': float(np.nanmean(f2)) if len(f2) > 0 else 0.0,
209
+ 'f1_std': float(np.nanstd(f1)) if len(f1) > 0 else 0.0,
210
+ 'f2_std': float(np.nanstd(f2)) if len(f2) > 0 else 0.0
211
+ }
212
+ except Exception as e:
213
+ logger.warning(f"Formant analysis failed: {e}")
214
+ features['formants'] = np.zeros((len(audio) // 100, 4))
215
+ features['formant_summary'] = {
216
+ 'vowel_space_area': 0.0,
217
+ 'f1_mean': 0.0, 'f2_mean': 0.0,
218
+ 'f1_std': 0.0, 'f2_std': 0.0
219
+ }
220
+
221
+ # Voice Quality Metrics (Jitter, Shimmer, HNR)
222
+ try:
223
+ sound = parselmouth.Sound(audio_path)
224
+ pitch = sound.to_pitch()
225
+ point_process = parselmouth.praat.call([sound, pitch], "To PointProcess")
226
+
227
+ jitter = parselmouth.praat.call(point_process, "Get jitter (local)", 0.0, 0.0, 1.1, 1.6, 1.3, 1.6)
228
+ shimmer = parselmouth.praat.call([sound, point_process], "Get shimmer (local)", 0.0, 0.0, 0.0001, 0.02, 1.3, 1.6)
229
+ hnr = parselmouth.praat.call(sound, "Get harmonicity (cc)", 0.0, 0.0, 0.01, 1.5, 1.0, 0.1, 1.0)
230
+
231
+ features['voice_quality'] = {
232
+ 'jitter': float(jitter) if jitter is not None else 0.0,
233
+ 'shimmer': float(shimmer) if shimmer is not None else 0.0,
234
+ 'hnr_db': float(hnr) if hnr is not None else 20.0
235
+ }
236
+ except Exception as e:
237
+ logger.warning(f"Voice quality analysis failed: {e}")
238
+ features['voice_quality'] = {
239
+ 'jitter': 0.0,
240
+ 'shimmer': 0.0,
241
+ 'hnr_db': 20.0
242
+ }
243
+
244
+ return features
245
+
246
+ def _transcribe_with_timestamps(self, audio: np.ndarray) -> Tuple[str, List[Dict], torch.Tensor]:
247
+ """
248
+ Transcribe audio and return word timestamps and logits.
249
+
250
+ Uses the feature extractor for clean separation of concerns.
251
+ """
252
+ try:
253
+ # Use feature extractor for transcription (clean architecture)
254
+ features = self.feature_extractor.get_transcription_features(audio, sample_rate=16000)
255
+ transcript = features['transcript']
256
+ logits = torch.from_numpy(features['logits'])
257
+
258
+ # Get word-level features for timestamps
259
+ word_features = self.feature_extractor.get_word_level_features(audio, sample_rate=16000)
260
+ word_timestamps = word_features['word_timestamps']
261
+
262
+ logger.info(f"📝 Transcription via feature extractor: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})")
263
+
264
+ return transcript, word_timestamps, logits
265
+ except Exception as e:
266
+ logger.error(f"❌ Transcription failed: {e}", exc_info=True)
267
+ return "", [], torch.zeros((1, 100, 32)) # Dummy return
268
+
269
+ def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]:
270
+ """Calculate entropy-based uncertainty and low-confidence regions"""
271
+ try:
272
+ probs = torch.softmax(logits, dim=-1)
273
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
274
+ entropy_mean = float(torch.mean(entropy).item())
275
+
276
+ # Find low-confidence regions
277
+ frame_duration = 0.02
278
+ low_conf_regions = []
279
+ confidence = torch.max(probs, dim=-1)[0]
280
+
281
+ for i in range(confidence.shape[1]):
282
+ conf = float(confidence[0, i].item())
283
+ if conf < CONFIDENCE_LOW_THRESHOLD:
284
+ low_conf_regions.append({
285
+ 'time': i * frame_duration,
286
+ 'confidence': conf
287
+ })
288
+
289
+ return entropy_mean, low_conf_regions
290
+ except Exception as e:
291
+ logger.warning(f"Uncertainty calculation failed: {e}")
292
+ return 0.0, []
293
+
294
+ def _estimate_speaking_rate(self, audio: np.ndarray, sr: int) -> float:
295
+ """Estimate speaking rate in syllables per second"""
296
+ try:
297
+ # Simple syllable estimation using energy peaks
298
+ rms = librosa.feature.rms(y=audio, hop_length=512)[0]
299
+ peaks, _ = librosa.util.peak_pick(rms, pre_max=3, post_max=3, pre_avg=3, post_avg=5, delta=0.1, wait=10)
300
+
301
+ duration = len(audio) / sr
302
+ num_syllables = len(peaks)
303
+ speaking_rate = num_syllables / duration if duration > 0 else SPEECH_RATE_TYPICAL
304
+
305
+ return max(SPEECH_RATE_MIN, min(SPEECH_RATE_MAX, speaking_rate))
306
+ except Exception as e:
307
+ logger.warning(f"Speaking rate estimation failed: {e}")
308
+ return SPEECH_RATE_TYPICAL
309
+
310
+ def _detect_prolongations_advanced(self, mfcc: np.ndarray, spectral_flux: np.ndarray,
311
+ speaking_rate: float, word_timestamps: List[Dict]) -> List[StutterEvent]:
312
+ """Detect prolongations using spectral correlation"""
313
+ events = []
314
+ frame_duration = 0.02
315
+
316
+ # Adaptive threshold based on speaking rate
317
+ min_duration = PROLONGATION_MIN_DURATION * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1))
318
+
319
+ window_size = int(min_duration / frame_duration)
320
+ if window_size < 2:
321
+ return events
322
+
323
+ for i in range(len(mfcc) - window_size):
324
+ window = mfcc[i:i+window_size]
325
+
326
+ # Calculate spectral correlation
327
+ if len(window) > 1:
328
+ corr_matrix = np.corrcoef(window.T)
329
+ avg_correlation = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)])
330
+
331
+ if avg_correlation > PROLONGATION_CORRELATION_THRESHOLD:
332
+ start_time = i * frame_duration
333
+ end_time = (i + window_size) * frame_duration
334
+
335
+ # Check if within a word boundary
336
+ for word_ts in word_timestamps:
337
+ if word_ts['start'] <= start_time <= word_ts['end']:
338
+ events.append(StutterEvent(
339
+ type='prolongation',
340
+ start=start_time,
341
+ end=end_time,
342
+ text=word_ts.get('word', ''),
343
+ confidence=float(avg_correlation),
344
+ acoustic_features={
345
+ 'spectral_correlation': float(avg_correlation),
346
+ 'duration': end_time - start_time
347
+ }
348
+ ))
349
+ break
350
+
351
+ return events
352
+
353
+ def _detect_blocks_enhanced(self, audio: np.ndarray, sr: int, rms_energy: np.ndarray,
354
+ zcr: np.ndarray, word_timestamps: List[Dict],
355
+ speaking_rate: float) -> List[StutterEvent]:
356
+ """Detect blocks using silence analysis"""
357
+ events = []
358
+ frame_duration = 0.02
359
+
360
+ # Adaptive threshold
361
+ silence_threshold = BLOCK_SILENCE_THRESHOLD * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1))
362
+ energy_threshold = np.percentile(rms_energy, BLOCK_ENERGY_PERCENTILE)
363
+
364
+ in_silence = False
365
+ silence_start = 0
366
+
367
+ for i, energy in enumerate(rms_energy):
368
+ is_silent = energy < energy_threshold and zcr[i] < ZCR_VOICED_THRESHOLD
369
+
370
+ if is_silent and not in_silence:
371
+ silence_start = i * frame_duration
372
+ in_silence = True
373
+ elif not is_silent and in_silence:
374
+ silence_duration = (i * frame_duration) - silence_start
375
+ if silence_duration > silence_threshold:
376
+ # Check if mid-utterance (not at start/end)
377
+ audio_duration = len(audio) / sr
378
+ if silence_start > 0.1 and silence_start < audio_duration - 0.1:
379
+ events.append(StutterEvent(
380
+ type='block',
381
+ start=silence_start,
382
+ end=i * frame_duration,
383
+ text="<silence>",
384
+ confidence=0.8,
385
+ acoustic_features={
386
+ 'silence_duration': silence_duration,
387
+ 'energy_level': float(energy)
388
+ }
389
+ ))
390
+ in_silence = False
391
+
392
+ return events
393
+
394
+ def _detect_repetitions_advanced(self, mfcc: np.ndarray, formants: np.ndarray,
395
+ word_timestamps: List[Dict], transcript: str,
396
+ speaking_rate: float) -> List[StutterEvent]:
397
+ """Detect repetitions using DTW and text matching"""
398
+ events = []
399
+
400
+ if len(word_timestamps) < 2:
401
+ return events
402
+
403
+ # Text-based repetition detection
404
+ words = transcript.lower().split()
405
+ for i in range(len(words) - 1):
406
+ if words[i] == words[i+1]:
407
+ # Find corresponding timestamps
408
+ if i < len(word_timestamps) and i+1 < len(word_timestamps):
409
+ start = word_timestamps[i]['start']
410
+ end = word_timestamps[i+1]['end']
411
+
412
+ # DTW verification on MFCC
413
+ start_frame = int(start / 0.02)
414
+ mid_frame = int((start + end) / 2 / 0.02)
415
+ end_frame = int(end / 0.02)
416
+
417
+ if start_frame < len(mfcc) and end_frame < len(mfcc):
418
+ segment1 = mfcc[start_frame:mid_frame]
419
+ segment2 = mfcc[mid_frame:end_frame]
420
+
421
+ if len(segment1) > 0 and len(segment2) > 0:
422
+ try:
423
+ distance, _ = fastdtw(segment1, segment2)
424
+ normalized_distance = distance / max(len(segment1), len(segment2))
425
+
426
+ if normalized_distance < REPETITION_DTW_THRESHOLD:
427
+ events.append(StutterEvent(
428
+ type='repetition',
429
+ start=start,
430
+ end=end,
431
+ text=words[i],
432
+ confidence=1.0 - normalized_distance,
433
+ acoustic_features={
434
+ 'dtw_distance': float(normalized_distance),
435
+ 'repetition_count': 2
436
+ }
437
+ ))
438
+ except:
439
+ pass
440
+
441
+ return events
442
+
443
+ def _detect_voice_quality_issues(self, audio_path: str, word_timestamps: List[Dict],
444
+ voice_quality: Dict[str, float]) -> List[StutterEvent]:
445
+ """Detect dysfluencies based on voice quality metrics"""
446
+ events = []
447
+
448
+ # Global voice quality issues
449
+ if voice_quality.get('jitter', 0) > JITTER_THRESHOLD or \
450
+ voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD or \
451
+ voice_quality.get('hnr_db', 20) < HNR_THRESHOLD:
452
+
453
+ # Mark regions with poor voice quality
454
+ for word_ts in word_timestamps:
455
+ if word_ts.get('start', 0) > 0: # Skip first word
456
+ events.append(StutterEvent(
457
+ type='dysfluency',
458
+ start=word_ts['start'],
459
+ end=word_ts['end'],
460
+ text=word_ts.get('word', ''),
461
+ confidence=0.6,
462
+ voice_quality=voice_quality.copy()
463
+ ))
464
+ break # Only mark first occurrence
465
+
466
+ return events
467
+
468
+ def _is_overlapping(self, time: float, events: List[StutterEvent], threshold: float = 0.1) -> bool:
469
+ """Check if time overlaps with existing events"""
470
+ for event in events:
471
+ if event.start - threshold <= time <= event.end + threshold:
472
+ return True
473
+ return False
474
+
475
+ def _detect_anomalies(self, events: List[StutterEvent], features: Dict[str, Any]) -> List[StutterEvent]:
476
+ """Use Isolation Forest to filter anomalous events"""
477
+ if len(events) == 0:
478
+ return events
479
+
480
+ try:
481
+ # Extract features for anomaly detection
482
+ X = []
483
+ for event in events:
484
+ feat_vec = [
485
+ event.end - event.start, # Duration
486
+ event.confidence,
487
+ features.get('voice_quality', {}).get('jitter', 0),
488
+ features.get('voice_quality', {}).get('shimmer', 0)
489
+ ]
490
+ X.append(feat_vec)
491
+
492
+ X = np.array(X)
493
+ if len(X) > 1:
494
+ self.anomaly_detector.fit(X)
495
+ predictions = self.anomaly_detector.predict(X)
496
+
497
+ # Keep only non-anomalous events (predictions == 1)
498
+ filtered_events = [events[i] for i, pred in enumerate(predictions) if pred == 1]
499
+ return filtered_events
500
+ except Exception as e:
501
+ logger.warning(f"Anomaly detection failed: {e}")
502
+
503
+ return events
504
+
505
+ def _deduplicate_events_cascade(self, events: List[StutterEvent]) -> List[StutterEvent]:
506
+ """Remove overlapping events with priority: Block > Repetition > Prolongation > Dysfluency"""
507
+ if len(events) == 0:
508
+ return events
509
+
510
+ # Sort by priority and start time
511
+ priority = {'block': 4, 'repetition': 3, 'prolongation': 2, 'dysfluency': 1}
512
+ events.sort(key=lambda e: (priority.get(e.type, 0), e.start), reverse=True)
513
+
514
+ cleaned = []
515
+ for event in events:
516
+ overlap = False
517
+ for existing in cleaned:
518
+ # Check overlap
519
+ if not (event.end < existing.start or event.start > existing.end):
520
+ overlap = True
521
+ break
522
+
523
+ if not overlap:
524
+ cleaned.append(event)
525
+
526
+ # Sort by start time
527
+ cleaned.sort(key=lambda e: e.start)
528
+ return cleaned
529
+
530
+ def _calculate_clinical_metrics(self, events: List[StutterEvent], duration: float,
531
+ speaking_rate: float, features: Dict[str, Any]) -> Dict[str, Any]:
532
+ """Calculate comprehensive clinical metrics"""
533
+ total_duration = sum(e.end - e.start for e in events)
534
+ frequency = (len(events) / duration * 60) if duration > 0 else 0
535
+
536
+ # Calculate severity score (0-100)
537
+ stutter_percentage = (total_duration / duration * 100) if duration > 0 else 0
538
+ frequency_score = min(frequency / 10 * 100, 100) # Normalize to 100
539
+ severity_score = (stutter_percentage * 0.6 + frequency_score * 0.4)
540
+
541
+ # Determine severity label
542
+ if severity_score < 10:
543
+ severity_label = 'none'
544
+ elif severity_score < 25:
545
+ severity_label = 'mild'
546
+ elif severity_score < 50:
547
+ severity_label = 'moderate'
548
+ else:
549
+ severity_label = 'severe'
550
+
551
+ # Calculate confidence based on multiple factors
552
+ voice_quality = features.get('voice_quality', {})
553
+ confidence = 0.8 # Base confidence
554
+
555
+ # Adjust based on voice quality metrics
556
+ if voice_quality.get('jitter', 0) > JITTER_THRESHOLD:
557
+ confidence -= 0.1
558
+ if voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD:
559
+ confidence -= 0.1
560
+ if voice_quality.get('hnr_db', 20) < HNR_THRESHOLD:
561
+ confidence -= 0.1
562
+
563
+ confidence = max(0.3, min(1.0, confidence))
564
+
565
+ return {
566
+ 'total_duration': round(total_duration, 2),
567
+ 'frequency': round(frequency, 2),
568
+ 'severity_score': round(severity_score, 2),
569
+ 'severity_label': severity_label,
570
+ 'confidence': round(confidence, 2)
571
+ }
572
+
573
+ def _event_to_dict(self, event: StutterEvent) -> Dict[str, Any]:
574
+ """Convert StutterEvent to dictionary"""
575
+ return {
576
+ 'type': event.type,
577
+ 'start': round(event.start, 2),
578
+ 'end': round(event.end, 2),
579
+ 'text': event.text,
580
+ 'confidence': round(event.confidence, 2),
581
+ 'acoustic_features': event.acoustic_features,
582
+ 'voice_quality': event.voice_quality,
583
+ 'formant_data': event.formant_data
584
+ }
585
+
586
+
587
+ def analyze_audio(self, audio_path: str, proper_transcript: str = "", language: str = 'hindi') -> dict:
588
+ """
589
+ Main ASR analysis pipeline using IndicWav2Vec Hindi model
590
+
591
+ Focus: Automatic Speech Recognition (ASR) transcription only
592
+ """
593
+ start_time = time.time()
594
+
595
+ # === STEP 1: Audio Loading & Preprocessing ===
596
+ audio, sr = librosa.load(audio_path, sr=16000)
597
+ duration = librosa.get_duration(y=audio, sr=sr)
598
+
599
+ # === STEP 2: ASR Transcription using IndicWav2Vec Hindi ===
600
+ transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio)
601
+ logger.info(f"📝 ASR Transcription: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})")
602
+
603
+ # === STEP 3: Calculate Confidence from Model Predictions ===
604
+ entropy_score, low_conf_regions = self._calculate_uncertainty(logits)
605
+ avg_confidence = 1.0 - (entropy_score / 10.0) if entropy_score > 0 else 0.8
606
+ avg_confidence = max(0.0, min(1.0, avg_confidence))
607
+
608
+ # === STEP 4: Basic Text-based Analysis ===
609
+ # Simple text-based stutter detection (repetitions, hesitations)
610
+ events = []
611
+ if transcript:
612
+ words = transcript.split()
613
+ # Detect word repetitions
614
+ for i in range(len(words) - 1):
615
+ if words[i] == words[i+1] and i < len(word_timestamps) - 1:
616
+ events.append(StutterEvent(
617
+ type='repetition',
618
+ start=word_timestamps[i]['start'] if i < len(word_timestamps) else 0,
619
+ end=word_timestamps[i+1]['end'] if i+1 < len(word_timestamps) else 0,
620
+ text=words[i],
621
+ confidence=0.7
622
+ ))
623
+
624
+ # Add low confidence regions as potential dysfluencies
625
+ for region in low_conf_regions[:5]: # Limit to first 5
626
+ events.append(StutterEvent(
627
+ type='dysfluency',
628
+ start=region['time'],
629
+ end=region['time'] + 0.3,
630
+ text="<uncertainty>",
631
+ confidence=0.4,
632
+ acoustic_features={'entropy': entropy_score}
633
+ ))
634
+
635
+ # === STEP 5: Calculate Basic Metrics ===
636
+ total_duration = sum(e.end - e.start for e in events)
637
+ frequency = (len(events) / duration * 60) if duration > 0 else 0
638
+ stutter_percentage = (total_duration / duration * 100) if duration > 0 else 0
639
+
640
+ # Simple severity assessment
641
+ if stutter_percentage < 5:
642
+ severity = 'none'
643
+ elif stutter_percentage < 15:
644
+ severity = 'mild'
645
+ elif stutter_percentage < 30:
646
+ severity = 'moderate'
647
+ else:
648
+ severity = 'severe'
649
+
650
+ # === STEP 6: Return ASR Results ===
651
+ actual_transcript = transcript if transcript else ""
652
+ target_transcript = proper_transcript if proper_transcript else ""
653
+
654
+ logger.info(f"📝 Final ASR result - Actual: '{actual_transcript}' (len: {len(actual_transcript)}), Target: '{target_transcript}' (len: {len(target_transcript)})")
655
+
656
+ return {
657
+ 'actual_transcript': actual_transcript,
658
+ 'target_transcript': target_transcript,
659
+ 'mismatched_chars': [f"{r['time']:.2f}s" for r in low_conf_regions[:10]],
660
+ 'mismatch_percentage': round(stutter_percentage, 2),
661
+ 'ctc_loss_score': round(entropy_score, 4),
662
+ 'stutter_timestamps': [self._event_to_dict(e) for e in events],
663
+ 'total_stutter_duration': round(total_duration, 2),
664
+ 'stutter_frequency': round(frequency, 2),
665
+ 'severity': severity,
666
+ 'confidence_score': round(avg_confidence, 2),
667
+ 'speaking_rate_sps': round(len(word_timestamps) / duration if duration > 0 else 0, 2),
668
+ 'analysis_duration_seconds': round(time.time() - start_time, 2),
669
+ 'model_version': 'indicwav2vec-hindi-asr-v1'
670
+ }
671
+
672
+
673
+ # Legacy methods - kept for backward compatibility but may not work without additional model initialization
674
+ # These methods reference models (xlsr, base, large) that are not initialized in __init__
675
+ # The main analyze_audio() method uses the IndicWav2Vec Hindi model instead
676
+
677
+ def generate_target_transcript(self, audio_file: str) -> str:
678
+ """Generate expected transcript - Legacy method (uses IndicWav2Vec Hindi model)"""
679
+ try:
680
+ audio, sr = librosa.load(audio_file, sr=16000)
681
+ transcript, _, _ = self._transcribe_with_timestamps(audio)
682
+ return transcript
683
+ except Exception as e:
684
+ logger.error(f"Target transcript generation failed: {e}")
685
+ return ""
686
+
687
+ def transcribe_and_detect(self, audio_file: str, proper_transcript: str) -> Dict:
688
+ """Transcribe audio and detect stuttering patterns - Legacy method"""
689
+ try:
690
+ audio, _ = librosa.load(audio_file, sr=16000)
691
+ transcript, _, _ = self._transcribe_with_timestamps(audio)
692
+
693
+ # Find stuttered sequences
694
+ stuttered_chars = self.find_sequences_not_in_common(transcript, proper_transcript)
695
+
696
+ # Calculate mismatch percentage
697
+ total_mismatched = sum(len(segment) for segment in stuttered_chars)
698
+ mismatch_percentage = (total_mismatched / len(proper_transcript)) * 100 if len(proper_transcript) > 0 else 0
699
+ mismatch_percentage = min(round(mismatch_percentage), 100)
700
+
701
+ return {
702
+ 'transcription': transcript,
703
+ 'stuttered_chars': stuttered_chars,
704
+ 'mismatch_percentage': mismatch_percentage
705
+ }
706
+ except Exception as e:
707
+ logger.error(f"Transcription failed: {e}")
708
+ return {
709
+ 'transcription': '',
710
+ 'stuttered_chars': [],
711
+ 'mismatch_percentage': 0
712
+ }
713
+
714
+ def calculate_stutter_timestamps(self, audio_file: str, proper_transcript: str) -> Tuple[float, List[Tuple[float, float]]]:
715
+ """Calculate stutter timestamps - Legacy method (uses analyze_audio instead)"""
716
+ try:
717
+ # Use main analyze_audio method
718
+ result = self.analyze_audio(audio_file, proper_transcript)
719
+
720
+ # Extract timestamps from result
721
+ timestamps = []
722
+ for event in result.get('stutter_timestamps', []):
723
+ timestamps.append((event['start'], event['end']))
724
+
725
+ ctc_score = result.get('ctc_loss_score', 0.0)
726
+ return float(ctc_score), timestamps
727
+ except Exception as e:
728
+ logger.error(f"Timestamp calculation failed: {e}")
729
+ return 0.0, []
730
+
731
+
732
+ def find_max_common_characters(self, transcription1: str, transcript2: str) -> str:
733
+ """Longest Common Subsequence algorithm"""
734
+ m, n = len(transcription1), len(transcript2)
735
+ lcs_matrix = [[0] * (n + 1) for _ in range(m + 1)]
736
+
737
+ for i in range(1, m + 1):
738
+ for j in range(1, n + 1):
739
+ if transcription1[i - 1] == transcript2[j - 1]:
740
+ lcs_matrix[i][j] = lcs_matrix[i - 1][j - 1] + 1
741
+ else:
742
+ lcs_matrix[i][j] = max(lcs_matrix[i - 1][j], lcs_matrix[i][j - 1])
743
+
744
+ # Backtrack to find LCS
745
+ lcs_characters = []
746
+ i, j = m, n
747
+ while i > 0 and j > 0:
748
+ if transcription1[i - 1] == transcript2[j - 1]:
749
+ lcs_characters.append(transcription1[i - 1])
750
+ i -= 1
751
+ j -= 1
752
+ elif lcs_matrix[i - 1][j] > lcs_matrix[i][j - 1]:
753
+ i -= 1
754
+ else:
755
+ j -= 1
756
+
757
+ lcs_characters.reverse()
758
+ return ''.join(lcs_characters)
759
+
760
+
761
+ def find_sequences_not_in_common(self, transcription1: str, proper_transcript: str) -> List[str]:
762
+ """Find stuttered character sequences"""
763
+ common_characters = self.find_max_common_characters(transcription1, proper_transcript)
764
+ sequences = []
765
+ sequence = ""
766
+ i, j = 0, 0
767
+
768
+ while i < len(transcription1) and j < len(common_characters):
769
+ if transcription1[i] == common_characters[j]:
770
+ if sequence:
771
+ sequences.append(sequence)
772
+ sequence = ""
773
+ i += 1
774
+ j += 1
775
+ else:
776
+ sequence += transcription1[i]
777
+ i += 1
778
+
779
+ if sequence:
780
+ sequences.append(sequence)
781
+
782
+ return sequences
783
+
784
+
785
+ def _calculate_total_duration(self, timestamps: List[Tuple[float, float]]) -> float:
786
+ """Calculate total stuttering duration"""
787
+ return sum(end - start for start, end in timestamps)
788
+
789
+
790
+ def _calculate_frequency(self, timestamps: List[Tuple[float, float]], audio_file: str) -> float:
791
+ """Calculate stutters per minute"""
792
+ try:
793
+ audio_duration = librosa.get_duration(path=audio_file)
794
+ if audio_duration > 0:
795
+ return (len(timestamps) / audio_duration) * 60
796
+ return 0.0
797
+ except:
798
+ return 0.0
799
+
800
+
801
+ def _determine_severity(self, mismatch_percentage: float) -> str:
802
+ """Determine severity level"""
803
+ if mismatch_percentage < 10:
804
+ return 'none'
805
+ elif mismatch_percentage < 25:
806
+ return 'mild'
807
+ elif mismatch_percentage < 50:
808
+ return 'moderate'
809
+ else:
810
+ return 'severe'
811
+
812
+
813
+ def _calculate_confidence(self, transcription_result: Dict, ctc_loss: float) -> float:
814
+ """Calculate confidence score for the analysis"""
815
+ # Lower mismatch and lower CTC loss = higher confidence
816
+ mismatch_factor = 1 - (transcription_result['mismatch_percentage'] / 100)
817
+ loss_factor = max(0, 1 - (ctc_loss / 10)) # Normalize loss
818
+ confidence = (mismatch_factor + loss_factor) / 2
819
+ return round(min(max(confidence, 0.0), 1.0), 2)
820
+
821
+
822
+ # Model loader is now in a separate module: model_loader.py
823
+ # This follows clean architecture principles - separation of concerns
824
+ # Import using: from diagnosis.ai_engine.model_loader import get_stutter_detector
diagnosis/ai_engine/features.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diagnosis/ai_engine/features.py
2
+ """
3
+ Feature extraction for IndicWav2Vec Hindi ASR
4
+
5
+ This module provides feature extraction capabilities using the IndicWav2Vec Hindi model.
6
+ Focused on ASR transcription features rather than hybrid acoustic+linguistic features.
7
+ """
8
+ import torch
9
+ import numpy as np
10
+ import logging
11
+ from typing import Dict, Any, Tuple, Optional
12
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ASRFeatureExtractor:
18
+ """
19
+ Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition.
20
+
21
+ This extractor focuses on:
22
+ - Audio feature extraction via IndicWav2Vec
23
+ - Transcription confidence scores
24
+ - Frame-level predictions and logits
25
+ - Word-level alignments (estimated)
26
+
27
+ Model: ai4bharat/indicwav2vec-hindi
28
+ """
29
+
30
+ def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"):
31
+ """
32
+ Initialize the ASR feature extractor.
33
+
34
+ Args:
35
+ model: Pre-loaded IndicWav2Vec Hindi model
36
+ processor: Pre-loaded processor for the model
37
+ device: Device to run inference on ('cpu' or 'cuda')
38
+ """
39
+ self.model = model
40
+ self.processor = processor
41
+ self.device = device
42
+ self.model.eval()
43
+ logger.info(f"✅ ASRFeatureExtractor initialized on {device}")
44
+
45
+ def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]:
46
+ """
47
+ Extract features from audio using IndicWav2Vec Hindi.
48
+
49
+ Args:
50
+ audio: Audio waveform as numpy array
51
+ sample_rate: Sample rate of the audio (default: 16000)
52
+
53
+ Returns:
54
+ Dictionary containing:
55
+ - input_values: Processed audio features
56
+ - attention_mask: Attention mask (if available)
57
+ """
58
+ try:
59
+ # Process audio through the processor
60
+ inputs = self.processor(
61
+ audio,
62
+ sampling_rate=sample_rate,
63
+ return_tensors="pt"
64
+ ).to(self.device)
65
+
66
+ return {
67
+ 'input_values': inputs.input_values,
68
+ 'attention_mask': inputs.get('attention_mask', None)
69
+ }
70
+ except Exception as e:
71
+ logger.error(f"❌ Error extracting audio features: {e}")
72
+ raise
73
+
74
+ def get_transcription_features(
75
+ self,
76
+ audio: np.ndarray,
77
+ sample_rate: int = 16000
78
+ ) -> Dict[str, Any]:
79
+ """
80
+ Get transcription features including logits, predictions, and confidence.
81
+
82
+ Args:
83
+ audio: Audio waveform as numpy array
84
+ sample_rate: Sample rate of the audio (default: 16000)
85
+
86
+ Returns:
87
+ Dictionary containing:
88
+ - transcript: Transcribed text
89
+ - logits: Model logits (raw predictions)
90
+ - predicted_ids: Predicted token IDs
91
+ - probabilities: Softmax probabilities
92
+ - confidence: Average confidence score
93
+ - frame_confidence: Per-frame confidence scores
94
+ """
95
+ try:
96
+ # Process audio
97
+ inputs = self.processor(
98
+ audio,
99
+ sampling_rate=sample_rate,
100
+ return_tensors="pt"
101
+ ).to(self.device)
102
+
103
+ # Get model predictions
104
+ with torch.no_grad():
105
+ outputs = self.model(**inputs)
106
+ logits = outputs.logits
107
+ predicted_ids = torch.argmax(logits, dim=-1)
108
+
109
+ # Calculate probabilities and confidence
110
+ probs = torch.softmax(logits, dim=-1)
111
+ max_probs = torch.max(probs, dim=-1)[0] # Get max probability per frame
112
+ frame_confidence = max_probs[0].cpu().numpy()
113
+ avg_confidence = float(torch.mean(max_probs).item())
114
+
115
+ # Decode transcript
116
+ transcript = ""
117
+ try:
118
+ if hasattr(self.processor, 'tokenizer'):
119
+ transcript = self.processor.tokenizer.decode(
120
+ predicted_ids[0],
121
+ skip_special_tokens=True
122
+ )
123
+ elif hasattr(self.processor, 'batch_decode'):
124
+ transcript = self.processor.batch_decode(predicted_ids)[0]
125
+
126
+ # Clean up transcript
127
+ if transcript:
128
+ transcript = transcript.strip()
129
+ transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip()
130
+ transcript = ' '.join(transcript.split())
131
+ except Exception as e:
132
+ logger.warning(f"⚠️ Decode error: {e}")
133
+ transcript = ""
134
+
135
+ return {
136
+ 'transcript': transcript,
137
+ 'logits': logits.cpu().numpy(),
138
+ 'predicted_ids': predicted_ids.cpu().numpy(),
139
+ 'probabilities': probs.cpu().numpy(),
140
+ 'confidence': avg_confidence,
141
+ 'frame_confidence': frame_confidence,
142
+ 'num_frames': logits.shape[1]
143
+ }
144
+ except Exception as e:
145
+ logger.error(f"❌ Error getting transcription features: {e}")
146
+ raise
147
+
148
+ def get_word_level_features(
149
+ self,
150
+ audio: np.ndarray,
151
+ sample_rate: int = 16000
152
+ ) -> Dict[str, Any]:
153
+ """
154
+ Get word-level features including timestamps and confidence.
155
+
156
+ Args:
157
+ audio: Audio waveform as numpy array
158
+ sample_rate: Sample rate of the audio (default: 16000)
159
+
160
+ Returns:
161
+ Dictionary containing:
162
+ - words: List of words
163
+ - word_timestamps: List of (start, end) timestamps for each word
164
+ - word_confidence: Confidence score for each word
165
+ """
166
+ try:
167
+ # Get transcription features
168
+ features = self.get_transcription_features(audio, sample_rate)
169
+ transcript = features['transcript']
170
+ frame_confidence = features['frame_confidence']
171
+ num_frames = features['num_frames']
172
+
173
+ # Estimate word-level timestamps (simplified)
174
+ words = transcript.split() if transcript else []
175
+ audio_duration = len(audio) / sample_rate
176
+ time_per_word = audio_duration / max(len(words), 1) if words else 0
177
+
178
+ word_timestamps = []
179
+ word_confidence = []
180
+
181
+ for i, word in enumerate(words):
182
+ start_time = i * time_per_word
183
+ end_time = (i + 1) * time_per_word
184
+
185
+ # Estimate confidence for this word (average of corresponding frames)
186
+ start_frame = int((start_time / audio_duration) * num_frames)
187
+ end_frame = int((end_time / audio_duration) * num_frames)
188
+ word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5
189
+
190
+ word_timestamps.append({
191
+ 'word': word,
192
+ 'start': start_time,
193
+ 'end': end_time
194
+ })
195
+ word_confidence.append(word_conf)
196
+
197
+ return {
198
+ 'words': words,
199
+ 'word_timestamps': word_timestamps,
200
+ 'word_confidence': word_confidence,
201
+ 'transcript': transcript
202
+ }
203
+ except Exception as e:
204
+ logger.error(f"❌ Error getting word-level features: {e}")
205
+ raise
206
+
diagnosis/ai_engine/model_loader.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diagnosis/ai_engine/model_loader.py
2
+ """Singleton pattern for model loading
3
+
4
+ This loader provides a clean interface for getting the detector instance.
5
+ Uses singleton pattern to ensure models are loaded only once.
6
+ """
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _detector_instance = None
12
+
13
+ def get_stutter_detector():
14
+ """
15
+ Get or create singleton AdvancedStutterDetector instance.
16
+
17
+ This ensures models are loaded only once and reused across requests.
18
+
19
+ Returns:
20
+ AdvancedStutterDetector: The singleton detector instance
21
+
22
+ Raises:
23
+ ImportError: If the detector class cannot be imported
24
+ """
25
+ global _detector_instance
26
+
27
+ if _detector_instance is None:
28
+ try:
29
+ from .detect_stuttering import AdvancedStutterDetector
30
+ logger.info("🔄 Initializing detector instance (first call)...")
31
+ _detector_instance = AdvancedStutterDetector()
32
+ logger.info("✅ Detector instance created successfully")
33
+ except ImportError as e:
34
+ logger.error(f"❌ Failed to import AdvancedStutterDetector: {e}")
35
+ raise ImportError("No StutterDetector implementation available in detect_stuttering.py") from e
36
+ except Exception as e:
37
+ logger.error(f"❌ Failed to create detector instance: {e}")
38
+ raise
39
+
40
+ return _detector_instance
41
+
42
+ def reset_detector():
43
+ """
44
+ Reset the singleton instance (useful for testing or reloading models).
45
+
46
+ Note: This will force reloading of models on next get_stutter_detector() call.
47
+ """
48
+ global _detector_instance
49
+ _detector_instance = None
50
+ logger.info("🔄 Detector instance reset")
51
+
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML
2
+ numpy>=1.24.0,<2.0.0
3
+ librosa>=0.10.0
4
+ transformers>=4.38.0,<5.0
5
+
6
+ # Audio
7
+ soundfile>=0.12.1
8
+ scipy>=1.11.0
9
+ praat-parselmouth>=0.4.3
10
+ fastdtw>=0.3.4
11
+ pyctcdecode==0.5.0
12
+
13
+ # API
14
+ fastapi>=0.115.2,<1.0
15
+ uvicorn>=0.24.0
16
+ python-multipart>=0.0.18
17
+
18
+ # Logging
19
+ python-json-logger>=2.0.0
20
+
21
+ # Web UI
22
+ gradio==6.1.0
23
+
24
+ # Explicitly pin torch to 2.6+ for transformers compatibility
25
+ torch>=2.6.0
26
+ torchvision>=0.21.0
27
+ torchaudio>=2.6.0