# diagnosis/ai_engine/detect_stuttering.py import os import librosa import torch import logging import numpy as np from transformers import Wav2Vec2ForCTC, AutoProcessor import time from dataclasses import dataclass, field from typing import List, Dict, Any, Tuple, Optional from difflib import SequenceMatcher import re # Advanced similarity and distance metrics from scipy.spatial.distance import cosine, euclidean from scipy.stats import pearsonr logger = logging.getLogger(__name__) # === CONFIGURATION === MODEL_ID = "ai4bharat/indicwav2vec-hindi" # Only model used - IndicWav2Vec Hindi for ASR DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token for authenticated model access INDIAN_LANGUAGES = { 'hindi': 'hin', 'english': 'eng', 'tamil': 'tam', 'telugu': 'tel', 'bengali': 'ben', 'marathi': 'mar', 'gujarati': 'guj', 'kannada': 'kan', 'malayalam': 'mal', 'punjabi': 'pan', 'urdu': 'urd', 'assamese': 'asm', 'odia': 'ory', 'bhojpuri': 'bho', 'maithili': 'mai' } # === DEVANAGARI PHONETIC MAPPINGS (Research-Based) === # Consonants grouped by phonetic similarity for stutter detection DEVANAGARI_CONSONANT_GROUPS = { # Plosives (stops) 'velar': ['क', 'ख', 'ग', 'घ', 'ङ'], 'palatal': ['च', 'छ', 'ज', 'झ', 'ञ'], 'retroflex': ['ट', 'ठ', 'ड', 'ढ', 'ण'], 'dental': ['त', 'थ', 'द', 'ध', 'न'], 'labial': ['प', 'फ', 'ब', 'भ', 'म'], # Fricatives & Approximants 'sibilants': ['श', 'ष', 'स', 'ह'], 'liquids': ['र', 'ल', 'ळ'], 'semivowels': ['य', 'व'], } # Vowels grouped by phonetic features DEVANAGARI_VOWEL_GROUPS = { 'short': ['अ', 'इ', 'उ', 'ऋ'], 'long': ['आ', 'ई', 'ऊ', 'ॠ'], 'diphthongs': ['ए', 'ऐ', 'ओ', 'औ'], } # Common Hindi stutter patterns (research-based) HINDI_STUTTER_PATTERNS = { 'repetition': [r'(.)\1{2,}', r'(\w+)\s+\1', r'(\w)\s+\1'], # Character/word repetition 'prolongation': [r'(.)\1{3,}', r'[आईऊएओ]{2,}'], # Extended vowels 'filled_pause': ['अ', 'उ', 'ए', 'म', 'उम', 'आ'], # Hesitation sounds } # === RESEARCH-BASED THRESHOLDS (2024-2025 Literature) === # Prolongation Detection (Spectral Correlation + Duration) PROLONGATION_CORRELATION_THRESHOLD = 0.90 # >0.9 spectral similarity PROLONGATION_MIN_DURATION = 0.25 # >250ms (Revisiting Rule-Based, 2025) # Block Detection (Silence Analysis) BLOCK_SILENCE_THRESHOLD = 0.35 # >350ms silence mid-utterance BLOCK_ENERGY_PERCENTILE = 10 # Bottom 10% energy = silence # Repetition Detection (DTW + Text Matching) REPETITION_DTW_THRESHOLD = 0.15 # Normalized DTW distance REPETITION_MIN_SIMILARITY = 0.85 # Text-based similarity # Speaking Rate Norms (syllables/second) SPEECH_RATE_MIN = 2.0 SPEECH_RATE_MAX = 6.0 SPEECH_RATE_TYPICAL = 4.0 # Formant Analysis (Vowel Centralization - Research Finding) # People who stutter show reduced vowel space area VOWEL_SPACE_REDUCTION_THRESHOLD = 0.70 # 70% of typical area # Voice Quality (Jitter, Shimmer, HNR) JITTER_THRESHOLD = 0.01 # >1% jitter indicates instability SHIMMER_THRESHOLD = 0.03 # >3% shimmer HNR_THRESHOLD = 15.0 # <15 dB Harmonics-to-Noise Ratio # Zero-Crossing Rate (Voiced/Unvoiced Discrimination) ZCR_VOICED_THRESHOLD = 0.1 # Low ZCR = voiced ZCR_UNVOICED_THRESHOLD = 0.3 # High ZCR = unvoiced # Entropy-Based Uncertainty ENTROPY_HIGH_THRESHOLD = 3.5 # High confusion in model predictions CONFIDENCE_LOW_THRESHOLD = 0.40 # Low confidence frame threshold @dataclass class StutterEvent: """Enhanced stutter event with multi-modal features""" type: str # 'repetition', 'prolongation', 'block', 'dysfluency', 'mismatch' start: float end: float text: str confidence: float acoustic_features: Dict[str, float] = field(default_factory=dict) voice_quality: Dict[str, float] = field(default_factory=dict) formant_data: Dict[str, Any] = field(default_factory=dict) phonetic_similarity: float = 0.0 # For comparing expected vs actual sounds class AdvancedStutterDetector: """ 🎤 IndicWav2Vec Hindi ASR Engine Simplified engine using ONLY ai4bharat/indicwav2vec-hindi for Automatic Speech Recognition. Features: - Speech-to-text transcription using IndicWav2Vec Hindi model - Text-based stutter analysis from transcription - Confidence scoring from model predictions - Basic dysfluency detection from transcript patterns Model: ai4bharat/indicwav2vec-hindi (Wav2Vec2ForCTC) Purpose: Automatic Speech Recognition (ASR) for Hindi and Indian languages """ def __init__(self): logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...") if HF_TOKEN: logger.info("✅ HF_TOKEN found - using authenticated model access") else: logger.warning("⚠️ HF_TOKEN not found - model access may fail if authentication is required") try: # Wav2Vec2 Model Loading - IndicWav2Vec Hindi Model self.processor = AutoProcessor.from_pretrained( MODEL_ID, token=HF_TOKEN ) self.model = Wav2Vec2ForCTC.from_pretrained( MODEL_ID, token=HF_TOKEN, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 ).to(DEVICE) self.model.eval() # Initialize feature extractor (clean architecture pattern) from .features import ASRFeatureExtractor self.feature_extractor = ASRFeatureExtractor( model=self.model, processor=self.processor, device=DEVICE ) # Debug: Log processor structure logger.info(f"📋 Processor type: {type(self.processor)}") if hasattr(self.processor, 'tokenizer'): logger.info(f"📋 Tokenizer type: {type(self.processor.tokenizer)}") if hasattr(self.processor, 'feature_extractor'): logger.info(f"📋 Feature extractor type: {type(self.processor.feature_extractor)}") logger.info("✅ IndicWav2Vec Hindi ASR Engine Loaded with Feature Extractor") except Exception as e: logger.error(f"🔥 Engine Failure: {e}") raise def _init_common_adapters(self): """Not applicable - IndicWav2Vec Hindi doesn't use adapters""" pass def _activate_adapter(self, lang_code: str): """Not applicable - IndicWav2Vec Hindi doesn't use adapters""" logger.info(f"Using IndicWav2Vec Hindi model (optimized for Hindi)") pass # ===== LEGACY METHODS (NOT USED IN ASR-ONLY MODE) ===== # These methods are kept for reference but not called in the simplified ASR pipeline # They require additional libraries (parselmouth, fastdtw, sklearn) that are not needed for ASR-only mode def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]: """Extract multi-modal acoustic features""" features = {} # MFCC (20 coefficients) mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=20, hop_length=512) features['mfcc'] = mfcc.T # Transpose for time x features # Zero-Crossing Rate zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0] features['zcr'] = zcr # RMS Energy rms_energy = librosa.feature.rms(y=audio, hop_length=512)[0] features['rms_energy'] = rms_energy # Spectral Flux stft = librosa.stft(audio, hop_length=512) magnitude = np.abs(stft) spectral_flux = np.sum(np.diff(magnitude, axis=1) * (np.diff(magnitude, axis=1) > 0), axis=0) features['spectral_flux'] = spectral_flux # Energy Entropy frame_energy = np.sum(magnitude ** 2, axis=0) frame_energy = frame_energy + 1e-10 # Avoid log(0) energy_entropy = -np.sum((magnitude ** 2 / frame_energy) * np.log(magnitude ** 2 / frame_energy + 1e-10), axis=0) features['energy_entropy'] = energy_entropy # Formant Analysis using Parselmouth try: sound = parselmouth.Sound(audio_path) formant = sound.to_formant_burg(time_step=0.01) times = np.arange(0, sound.duration, 0.01) f1, f2, f3, f4 = [], [], [], [] for t in times: try: f1.append(formant.get_value_at_time(1, t) if formant.get_value_at_time(1, t) > 0 else np.nan) f2.append(formant.get_value_at_time(2, t) if formant.get_value_at_time(2, t) > 0 else np.nan) f3.append(formant.get_value_at_time(3, t) if formant.get_value_at_time(3, t) > 0 else np.nan) f4.append(formant.get_value_at_time(4, t) if formant.get_value_at_time(4, t) > 0 else np.nan) except: f1.append(np.nan) f2.append(np.nan) f3.append(np.nan) f4.append(np.nan) formants = np.array([f1, f2, f3, f4]).T features['formants'] = formants # Calculate vowel space area (F1-F2 plane) valid_f1f2 = formants[~np.isnan(formants[:, 0]) & ~np.isnan(formants[:, 1]), :2] if len(valid_f1f2) > 0: # Convex hull area approximation try: hull = ConvexHull(valid_f1f2) vowel_space_area = hull.volume except: vowel_space_area = np.nan else: vowel_space_area = np.nan features['formant_summary'] = { 'vowel_space_area': float(vowel_space_area) if not np.isnan(vowel_space_area) else 0.0, 'f1_mean': float(np.nanmean(f1)) if len(f1) > 0 else 0.0, 'f2_mean': float(np.nanmean(f2)) if len(f2) > 0 else 0.0, 'f1_std': float(np.nanstd(f1)) if len(f1) > 0 else 0.0, 'f2_std': float(np.nanstd(f2)) if len(f2) > 0 else 0.0 } except Exception as e: logger.warning(f"Formant analysis failed: {e}") features['formants'] = np.zeros((len(audio) // 100, 4)) features['formant_summary'] = { 'vowel_space_area': 0.0, 'f1_mean': 0.0, 'f2_mean': 0.0, 'f1_std': 0.0, 'f2_std': 0.0 } # Voice Quality Metrics (Jitter, Shimmer, HNR) try: sound = parselmouth.Sound(audio_path) pitch = sound.to_pitch() point_process = parselmouth.praat.call([sound, pitch], "To PointProcess") jitter = parselmouth.praat.call(point_process, "Get jitter (local)", 0.0, 0.0, 1.1, 1.6, 1.3, 1.6) shimmer = parselmouth.praat.call([sound, point_process], "Get shimmer (local)", 0.0, 0.0, 0.0001, 0.02, 1.3, 1.6) hnr = parselmouth.praat.call(sound, "Get harmonicity (cc)", 0.0, 0.0, 0.01, 1.5, 1.0, 0.1, 1.0) features['voice_quality'] = { 'jitter': float(jitter) if jitter is not None else 0.0, 'shimmer': float(shimmer) if shimmer is not None else 0.0, 'hnr_db': float(hnr) if hnr is not None else 20.0 } except Exception as e: logger.warning(f"Voice quality analysis failed: {e}") features['voice_quality'] = { 'jitter': 0.0, 'shimmer': 0.0, 'hnr_db': 20.0 } return features def _transcribe_with_timestamps(self, audio: np.ndarray) -> Tuple[str, List[Dict], torch.Tensor]: """ Transcribe audio and return word timestamps and logits. Uses the feature extractor for clean separation of concerns. """ try: # Use feature extractor for transcription (clean architecture) features = self.feature_extractor.get_transcription_features(audio, sample_rate=16000) transcript = features['transcript'] logits = torch.from_numpy(features['logits']) # Get word-level features for timestamps word_features = self.feature_extractor.get_word_level_features(audio, sample_rate=16000) word_timestamps = word_features['word_timestamps'] logger.info(f"📝 Transcription via feature extractor: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})") return transcript, word_timestamps, logits except Exception as e: logger.error(f"❌ Transcription failed: {e}", exc_info=True) return "", [], torch.zeros((1, 100, 32)) # Dummy return def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]: """Calculate entropy-based uncertainty and low-confidence regions""" try: probs = torch.softmax(logits, dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) entropy_mean = float(torch.mean(entropy).item()) # Find low-confidence regions frame_duration = 0.02 low_conf_regions = [] confidence = torch.max(probs, dim=-1)[0] for i in range(confidence.shape[1]): conf = float(confidence[0, i].item()) if conf < CONFIDENCE_LOW_THRESHOLD: low_conf_regions.append({ 'time': i * frame_duration, 'confidence': conf }) return entropy_mean, low_conf_regions except Exception as e: logger.warning(f"Uncertainty calculation failed: {e}") return 0.0, [] def _estimate_speaking_rate(self, audio: np.ndarray, sr: int) -> float: """Estimate speaking rate in syllables per second""" try: # Simple syllable estimation using energy peaks rms = librosa.feature.rms(y=audio, hop_length=512)[0] peaks, _ = librosa.util.peak_pick(rms, pre_max=3, post_max=3, pre_avg=3, post_avg=5, delta=0.1, wait=10) duration = len(audio) / sr num_syllables = len(peaks) speaking_rate = num_syllables / duration if duration > 0 else SPEECH_RATE_TYPICAL return max(SPEECH_RATE_MIN, min(SPEECH_RATE_MAX, speaking_rate)) except Exception as e: logger.warning(f"Speaking rate estimation failed: {e}") return SPEECH_RATE_TYPICAL def _detect_prolongations_advanced(self, mfcc: np.ndarray, spectral_flux: np.ndarray, speaking_rate: float, word_timestamps: List[Dict]) -> List[StutterEvent]: """Detect prolongations using spectral correlation""" events = [] frame_duration = 0.02 # Adaptive threshold based on speaking rate min_duration = PROLONGATION_MIN_DURATION * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1)) window_size = int(min_duration / frame_duration) if window_size < 2: return events for i in range(len(mfcc) - window_size): window = mfcc[i:i+window_size] # Calculate spectral correlation if len(window) > 1: corr_matrix = np.corrcoef(window.T) avg_correlation = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)]) if avg_correlation > PROLONGATION_CORRELATION_THRESHOLD: start_time = i * frame_duration end_time = (i + window_size) * frame_duration # Check if within a word boundary for word_ts in word_timestamps: if word_ts['start'] <= start_time <= word_ts['end']: events.append(StutterEvent( type='prolongation', start=start_time, end=end_time, text=word_ts.get('word', ''), confidence=float(avg_correlation), acoustic_features={ 'spectral_correlation': float(avg_correlation), 'duration': end_time - start_time } )) break return events def _detect_blocks_enhanced(self, audio: np.ndarray, sr: int, rms_energy: np.ndarray, zcr: np.ndarray, word_timestamps: List[Dict], speaking_rate: float) -> List[StutterEvent]: """Detect blocks using silence analysis""" events = [] frame_duration = 0.02 # Adaptive threshold silence_threshold = BLOCK_SILENCE_THRESHOLD * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1)) energy_threshold = np.percentile(rms_energy, BLOCK_ENERGY_PERCENTILE) in_silence = False silence_start = 0 for i, energy in enumerate(rms_energy): is_silent = energy < energy_threshold and zcr[i] < ZCR_VOICED_THRESHOLD if is_silent and not in_silence: silence_start = i * frame_duration in_silence = True elif not is_silent and in_silence: silence_duration = (i * frame_duration) - silence_start if silence_duration > silence_threshold: # Check if mid-utterance (not at start/end) audio_duration = len(audio) / sr if silence_start > 0.1 and silence_start < audio_duration - 0.1: events.append(StutterEvent( type='block', start=silence_start, end=i * frame_duration, text="", confidence=0.8, acoustic_features={ 'silence_duration': silence_duration, 'energy_level': float(energy) } )) in_silence = False return events def _detect_repetitions_advanced(self, mfcc: np.ndarray, formants: np.ndarray, word_timestamps: List[Dict], transcript: str, speaking_rate: float) -> List[StutterEvent]: """Detect repetitions using DTW and text matching""" events = [] if len(word_timestamps) < 2: return events # Text-based repetition detection words = transcript.lower().split() for i in range(len(words) - 1): if words[i] == words[i+1]: # Find corresponding timestamps if i < len(word_timestamps) and i+1 < len(word_timestamps): start = word_timestamps[i]['start'] end = word_timestamps[i+1]['end'] # DTW verification on MFCC start_frame = int(start / 0.02) mid_frame = int((start + end) / 2 / 0.02) end_frame = int(end / 0.02) if start_frame < len(mfcc) and end_frame < len(mfcc): segment1 = mfcc[start_frame:mid_frame] segment2 = mfcc[mid_frame:end_frame] if len(segment1) > 0 and len(segment2) > 0: try: distance, _ = fastdtw(segment1, segment2) normalized_distance = distance / max(len(segment1), len(segment2)) if normalized_distance < REPETITION_DTW_THRESHOLD: events.append(StutterEvent( type='repetition', start=start, end=end, text=words[i], confidence=1.0 - normalized_distance, acoustic_features={ 'dtw_distance': float(normalized_distance), 'repetition_count': 2 } )) except: pass return events def _detect_voice_quality_issues(self, audio_path: str, word_timestamps: List[Dict], voice_quality: Dict[str, float]) -> List[StutterEvent]: """Detect dysfluencies based on voice quality metrics""" events = [] # Global voice quality issues if voice_quality.get('jitter', 0) > JITTER_THRESHOLD or \ voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD or \ voice_quality.get('hnr_db', 20) < HNR_THRESHOLD: # Mark regions with poor voice quality for word_ts in word_timestamps: if word_ts.get('start', 0) > 0: # Skip first word events.append(StutterEvent( type='dysfluency', start=word_ts['start'], end=word_ts['end'], text=word_ts.get('word', ''), confidence=0.6, voice_quality=voice_quality.copy() )) break # Only mark first occurrence return events def _is_overlapping(self, time: float, events: List[StutterEvent], threshold: float = 0.1) -> bool: """Check if time overlaps with existing events""" for event in events: if event.start - threshold <= time <= event.end + threshold: return True return False def _detect_anomalies(self, events: List[StutterEvent], features: Dict[str, Any]) -> List[StutterEvent]: """Use Isolation Forest to filter anomalous events""" if len(events) == 0: return events try: # Extract features for anomaly detection X = [] for event in events: feat_vec = [ event.end - event.start, # Duration event.confidence, features.get('voice_quality', {}).get('jitter', 0), features.get('voice_quality', {}).get('shimmer', 0) ] X.append(feat_vec) X = np.array(X) if len(X) > 1: self.anomaly_detector.fit(X) predictions = self.anomaly_detector.predict(X) # Keep only non-anomalous events (predictions == 1) filtered_events = [events[i] for i, pred in enumerate(predictions) if pred == 1] return filtered_events except Exception as e: logger.warning(f"Anomaly detection failed: {e}") return events def _deduplicate_events_cascade(self, events: List[StutterEvent]) -> List[StutterEvent]: """Remove overlapping events with priority: Block > Repetition > Prolongation > Dysfluency""" if len(events) == 0: return events # Sort by priority and start time priority = {'block': 4, 'repetition': 3, 'prolongation': 2, 'dysfluency': 1} events.sort(key=lambda e: (priority.get(e.type, 0), e.start), reverse=True) cleaned = [] for event in events: overlap = False for existing in cleaned: # Check overlap if not (event.end < existing.start or event.start > existing.end): overlap = True break if not overlap: cleaned.append(event) # Sort by start time cleaned.sort(key=lambda e: e.start) return cleaned def _calculate_clinical_metrics(self, events: List[StutterEvent], duration: float, speaking_rate: float, features: Dict[str, Any]) -> Dict[str, Any]: """Calculate comprehensive clinical metrics""" total_duration = sum(e.end - e.start for e in events) frequency = (len(events) / duration * 60) if duration > 0 else 0 # Calculate severity score (0-100) stutter_percentage = (total_duration / duration * 100) if duration > 0 else 0 frequency_score = min(frequency / 10 * 100, 100) # Normalize to 100 severity_score = (stutter_percentage * 0.6 + frequency_score * 0.4) # Determine severity label if severity_score < 10: severity_label = 'none' elif severity_score < 25: severity_label = 'mild' elif severity_score < 50: severity_label = 'moderate' else: severity_label = 'severe' # Calculate confidence based on multiple factors voice_quality = features.get('voice_quality', {}) confidence = 0.8 # Base confidence # Adjust based on voice quality metrics if voice_quality.get('jitter', 0) > JITTER_THRESHOLD: confidence -= 0.1 if voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD: confidence -= 0.1 if voice_quality.get('hnr_db', 20) < HNR_THRESHOLD: confidence -= 0.1 confidence = max(0.3, min(1.0, confidence)) return { 'total_duration': round(total_duration, 2), 'frequency': round(frequency, 2), 'severity_score': round(severity_score, 2), 'severity_label': severity_label, 'confidence': round(confidence, 2) } def _event_to_dict(self, event: StutterEvent) -> Dict[str, Any]: """Convert StutterEvent to dictionary""" return { 'type': event.type, 'start': round(event.start, 2), 'end': round(event.end, 2), 'text': event.text, 'confidence': round(event.confidence, 2), 'acoustic_features': event.acoustic_features, 'voice_quality': event.voice_quality, 'formant_data': event.formant_data, 'phonetic_similarity': round(event.phonetic_similarity, 2) } # ========== ADVANCED TRANSCRIPT COMPARISON METHODS ========== def _get_phonetic_group(self, char: str) -> Optional[str]: """Get phonetic group for a Devanagari character""" for group_name, chars in DEVANAGARI_CONSONANT_GROUPS.items(): if char in chars: return f'consonant_{group_name}' for group_name, chars in DEVANAGARI_VOWEL_GROUPS.items(): if char in chars: return f'vowel_{group_name}' return None def _calculate_phonetic_similarity(self, char1: str, char2: str) -> float: """ Calculate phonetic similarity between two characters (0-1) Based on articulatory phonetics research """ if char1 == char2: return 1.0 # Get phonetic groups group1 = self._get_phonetic_group(char1) group2 = self._get_phonetic_group(char2) if group1 is None or group2 is None: # Non-Devanagari characters - use simple comparison return 1.0 if char1.lower() == char2.lower() else 0.0 # Same phonetic group = high similarity (common in stuttering) if group1 == group2: return 0.85 # e.g., क vs ख (both velar) # Same major category (both consonants or both vowels) if group1.split('_')[0] == group2.split('_')[0]: return 0.5 # e.g., क (velar) vs च (palatal) # Different categories return 0.2 def _longest_common_subsequence(self, text1: str, text2: str) -> str: """ Find longest common subsequence (LCS) using dynamic programming Critical for identifying core message vs stuttered additions """ m, n = len(text1), len(text2) dp = [[0] * (n + 1) for _ in range(m + 1)] # Build DP table for i in range(1, m + 1): for j in range(1, n + 1): if text1[i-1] == text2[j-1]: dp[i][j] = dp[i-1][j-1] + 1 else: dp[i][j] = max(dp[i-1][j], dp[i][j-1]) # Backtrack to construct LCS lcs = [] i, j = m, n while i > 0 and j > 0: if text1[i-1] == text2[j-1]: lcs.append(text1[i-1]) i -= 1 j -= 1 elif dp[i-1][j] > dp[i][j-1]: i -= 1 else: j -= 1 return ''.join(reversed(lcs)) def _calculate_edit_distance(self, text1: str, text2: str, phonetic_aware: bool = True) -> Tuple[int, List[Dict]]: """ Calculate Levenshtein edit distance with phonetic awareness Returns: (distance, list of edit operations) """ m, n = len(text1), len(text2) dp = [[0] * (n + 1) for _ in range(m + 1)] ops = [[[] for _ in range(n + 1)] for _ in range(m + 1)] # Initialize for i in range(m + 1): dp[i][0] = i if i > 0: ops[i][0] = ops[i-1][0] + [{'op': 'delete', 'pos': i-1, 'char': text1[i-1]}] for j in range(n + 1): dp[0][j] = j if j > 0: ops[0][j] = ops[0][j-1] + [{'op': 'insert', 'pos': j-1, 'char': text2[j-1]}] # Fill DP table with phonetic costs for i in range(1, m + 1): for j in range(1, n + 1): if text1[i-1] == text2[j-1]: # Exact match - no cost dp[i][j] = dp[i-1][j-1] ops[i][j] = ops[i-1][j-1] else: # Calculate phonetic substitution cost if phonetic_aware: phon_sim = self._calculate_phonetic_similarity(text1[i-1], text2[j-1]) sub_cost = 1.0 - (phon_sim * 0.5) # 0.5-1.0 range else: sub_cost = 1.0 # Choose minimum cost operation costs = [ dp[i-1][j] + 1, # Delete dp[i][j-1] + 1, # Insert dp[i-1][j-1] + sub_cost # Substitute ] min_cost_idx = costs.index(min(costs)) dp[i][j] = costs[min_cost_idx] if min_cost_idx == 0: ops[i][j] = ops[i-1][j] + [{'op': 'delete', 'pos': i-1, 'char': text1[i-1]}] elif min_cost_idx == 1: ops[i][j] = ops[i][j-1] + [{'op': 'insert', 'pos': j-1, 'char': text2[j-1]}] else: ops[i][j] = ops[i-1][j-1] + [{'op': 'substitute', 'pos': i-1, 'from': text1[i-1], 'to': text2[j-1], 'phonetic_sim': phon_sim if phonetic_aware else 0}] return int(dp[m][n]), ops[m][n] def _find_mismatched_segments(self, actual: str, target: str) -> List[str]: """ Find character sequences in actual that don't appear in target Uses LCS to identify core message, then extracts mismatches """ if not actual or not target: return [actual] if actual else [] lcs = self._longest_common_subsequence(actual, target) # Extract segments not in LCS mismatched_segments = [] segment = "" lcs_idx = 0 for char in actual: if lcs_idx < len(lcs) and char == lcs[lcs_idx]: if segment: mismatched_segments.append(segment) segment = "" lcs_idx += 1 else: segment += char if segment: mismatched_segments.append(segment) return mismatched_segments def _detect_stutter_patterns_in_text(self, text: str) -> List[Dict[str, Any]]: """ Detect common Hindi stutter patterns in text Based on linguistic research on Hindi dysfluencies """ patterns_found = [] # Detect repetitions for pattern in HINDI_STUTTER_PATTERNS['repetition']: matches = re.finditer(pattern, text) for match in matches: patterns_found.append({ 'type': 'repetition', 'text': match.group(0), 'position': match.start(), 'pattern': pattern }) # Detect prolongations for pattern in HINDI_STUTTER_PATTERNS['prolongation']: matches = re.finditer(pattern, text) for match in matches: patterns_found.append({ 'type': 'prolongation', 'text': match.group(0), 'position': match.start(), 'pattern': pattern }) # Detect filled pauses words = text.split() for i, word in enumerate(words): if word in HINDI_STUTTER_PATTERNS['filled_pause']: patterns_found.append({ 'type': 'filled_pause', 'text': word, 'position': i, 'pattern': 'hesitation' }) return patterns_found def _compare_transcripts_comprehensive(self, actual: str, target: str) -> Dict[str, Any]: """ Comprehensive transcript comparison with multiple metrics Returns detailed analysis including phonetic, structural, and acoustic mismatches """ if not target: # No target provided - only analyze actual for stutter patterns stutter_patterns = self._detect_stutter_patterns_in_text(actual) return { 'has_target': False, 'mismatched_chars': [], 'mismatch_percentage': 0, 'edit_distance': 0, 'lcs_ratio': 1.0, 'phonetic_similarity': 1.0, 'stutter_patterns': stutter_patterns, 'edit_operations': [] } # Normalize whitespace actual = ' '.join(actual.split()) target = ' '.join(target.split()) # 1. Find mismatched character segments mismatched_segments = self._find_mismatched_segments(actual, target) # 2. Calculate edit distance with phonetic awareness edit_dist, edit_ops = self._calculate_edit_distance(actual, target, phonetic_aware=True) # 3. Calculate LCS ratio (similarity measure) lcs = self._longest_common_subsequence(actual, target) lcs_ratio = len(lcs) / max(len(target), 1) # 4. Calculate overall phonetic similarity phonetic_scores = [] matcher = SequenceMatcher(None, actual, target) for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag == 'equal': phonetic_scores.append(1.0) elif tag == 'replace': # Calculate phonetic similarity for replacements for a_char, t_char in zip(actual[i1:i2], target[j1:j2]): phonetic_scores.append(self._calculate_phonetic_similarity(a_char, t_char)) avg_phonetic_sim = np.mean(phonetic_scores) if phonetic_scores else 0.0 # 5. Calculate mismatch percentage (characters not in target) total_mismatched = sum(len(seg) for seg in mismatched_segments) mismatch_percentage = (total_mismatched / max(len(target), 1)) * 100 mismatch_percentage = min(round(mismatch_percentage), 100) # 6. Detect stutter patterns in actual transcript stutter_patterns = self._detect_stutter_patterns_in_text(actual) # 7. Word-level analysis actual_words = actual.split() target_words = target.split() word_matcher = SequenceMatcher(None, actual_words, target_words) word_accuracy = word_matcher.ratio() return { 'has_target': True, 'mismatched_chars': mismatched_segments, 'mismatch_percentage': mismatch_percentage, 'edit_distance': edit_dist, 'normalized_edit_distance': edit_dist / max(len(target), 1), 'lcs': lcs, 'lcs_ratio': round(lcs_ratio, 3), 'phonetic_similarity': round(float(avg_phonetic_sim), 3), 'word_accuracy': round(word_accuracy, 3), 'stutter_patterns': stutter_patterns, 'edit_operations': edit_ops[:20], # Limit for performance 'actual_length': len(actual), 'target_length': len(target), 'actual_words': len(actual_words), 'target_words': len(target_words) } # ========== ACOUSTIC SIMILARITY METHODS (SOUND-BASED MATCHING) ========== def _extract_mfcc_features(self, audio: np.ndarray, sr: int, n_mfcc: int = 13) -> np.ndarray: """Extract MFCC features for acoustic comparison""" mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=n_mfcc, hop_length=512) # Normalize mfcc = (mfcc - np.mean(mfcc, axis=1, keepdims=True)) / (np.std(mfcc, axis=1, keepdims=True) + 1e-8) return mfcc.T # Time x Features def _calculate_dtw_distance(self, seq1: np.ndarray, seq2: np.ndarray) -> float: """ Dynamic Time Warping distance for comparing audio segments Critical for detecting phonetic stutters where timing differs """ n, m = len(seq1), len(seq2) dtw_matrix = np.full((n + 1, m + 1), np.inf) dtw_matrix[0, 0] = 0 for i in range(1, n + 1): for j in range(1, m + 1): cost = euclidean(seq1[i-1], seq2[j-1]) dtw_matrix[i, j] = cost + min( dtw_matrix[i-1, j], # Insertion dtw_matrix[i, j-1], # Deletion dtw_matrix[i-1, j-1] # Match ) # Normalize by path length return dtw_matrix[n, m] / (n + m) def _compare_audio_segments_acoustic(self, segment1: np.ndarray, segment2: np.ndarray, sr: int = 16000) -> Dict[str, float]: """ Compare two audio segments acoustically using multiple metrics Used to detect when sounds are similar but transcripts differ (phonetic stutters) """ # Extract MFCC features mfcc1 = self._extract_mfcc_features(segment1, sr) mfcc2 = self._extract_mfcc_features(segment2, sr) # 1. DTW distance dtw_dist = self._calculate_dtw_distance(mfcc1, mfcc2) dtw_similarity = max(0, 1.0 - (dtw_dist / 10)) # Normalize to 0-1 # 2. Spectral features comparison spec1 = np.abs(librosa.stft(segment1)) spec2 = np.abs(librosa.stft(segment2)) # Resize to same shape for comparison min_frames = min(spec1.shape[1], spec2.shape[1]) spec1 = spec1[:, :min_frames] spec2 = spec2[:, :min_frames] # Spectral correlation spec_corr = np.mean([pearsonr(spec1[:, i], spec2[:, i])[0] for i in range(min_frames) if not np.all(spec1[:, i] == 0) and not np.all(spec2[:, i] == 0)]) spec_corr = max(0, spec_corr) # Handle NaN/negative # 3. Energy comparison energy1 = np.sum(segment1 ** 2) energy2 = np.sum(segment2 ** 2) energy_ratio = min(energy1, energy2) / (max(energy1, energy2) + 1e-8) # 4. Zero-crossing rate comparison zcr1 = np.mean(librosa.feature.zero_crossing_rate(segment1)[0]) zcr2 = np.mean(librosa.feature.zero_crossing_rate(segment2)[0]) zcr_similarity = 1.0 - min(abs(zcr1 - zcr2) / (max(zcr1, zcr2) + 1e-8), 1.0) # Overall acoustic similarity (weighted average) overall_similarity = ( dtw_similarity * 0.4 + spec_corr * 0.3 + energy_ratio * 0.15 + zcr_similarity * 0.15 ) return { 'dtw_similarity': round(float(dtw_similarity), 3), 'spectral_correlation': round(float(spec_corr), 3), 'energy_ratio': round(float(energy_ratio), 3), 'zcr_similarity': round(float(zcr_similarity), 3), 'overall_acoustic_similarity': round(float(overall_similarity), 3) } def _detect_acoustic_repetitions(self, audio: np.ndarray, sr: int, word_timestamps: List[Dict]) -> List[StutterEvent]: """ Detect repetitions by comparing acoustic similarity between word segments Catches stutters even when ASR transcribes them differently """ events = [] if len(word_timestamps) < 2: return events # Compare consecutive words acoustically for i in range(len(word_timestamps) - 1): try: # Extract audio segments start1 = int(word_timestamps[i]['start'] * sr) end1 = int(word_timestamps[i]['end'] * sr) start2 = int(word_timestamps[i+1]['start'] * sr) end2 = int(word_timestamps[i+1]['end'] * sr) if end1 > len(audio) or end2 > len(audio): continue segment1 = audio[start1:end1] segment2 = audio[start2:end2] if len(segment1) < 100 or len(segment2) < 100: # Skip very short segments continue # Calculate acoustic similarity acoustic_sim = self._compare_audio_segments_acoustic(segment1, segment2, sr) # High acoustic similarity suggests repetition (even if transcripts differ) if acoustic_sim['overall_acoustic_similarity'] > 0.75: events.append(StutterEvent( type='repetition', start=word_timestamps[i]['start'], end=word_timestamps[i+1]['end'], text=f"{word_timestamps[i].get('word', '')} → {word_timestamps[i+1].get('word', '')}", confidence=acoustic_sim['overall_acoustic_similarity'], acoustic_features=acoustic_sim, phonetic_similarity=acoustic_sim['overall_acoustic_similarity'] )) except Exception as e: logger.warning(f"Acoustic comparison failed for words {i}-{i+1}: {e}") continue return events def _detect_prolongations_by_sound(self, audio: np.ndarray, sr: int, word_timestamps: List[Dict]) -> List[StutterEvent]: """ Detect prolongations by analyzing spectral stability within words High spectral correlation over time = prolonged sound """ events = [] for word_info in word_timestamps: try: start = int(word_info['start'] * sr) end = int(word_info['end'] * sr) if end > len(audio) or end - start < sr * 0.3: # Skip if < 300ms continue segment = audio[start:end] # Extract MFCC mfcc = self._extract_mfcc_features(segment, sr) if len(mfcc) < 10: # Need sufficient frames continue # Calculate frame-to-frame correlation correlations = [] window_size = 5 for i in range(len(mfcc) - window_size): corr_matrix = np.corrcoef(mfcc[i:i+window_size].T) avg_corr = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)]) correlations.append(avg_corr) avg_correlation = np.mean(correlations) if correlations else 0 # High correlation = prolongation (same sound repeated) if avg_correlation > PROLONGATION_CORRELATION_THRESHOLD: duration = (end - start) / sr events.append(StutterEvent( type='prolongation', start=word_info['start'], end=word_info['end'], text=word_info.get('word', ''), confidence=float(avg_correlation), acoustic_features={ 'spectral_correlation': float(avg_correlation), 'duration': duration }, phonetic_similarity=float(avg_correlation) )) except Exception as e: logger.warning(f"Prolongation detection failed for word: {e}") continue return events def analyze_audio(self, audio_path: str, proper_transcript: str = "", language: str = 'hindi') -> dict: """ 🎯 ADVANCED Multi-Modal Stutter Detection Pipeline Combines: 1. ASR Transcription (IndicWav2Vec Hindi) 2. Phonetic-Aware Transcript Comparison 3. Acoustic Similarity Matching (Sound-Based) 4. Linguistic Pattern Detection This detects stutters that ASR might miss by comparing: - What was said (actual) vs what should be said (target) - How it sounds (acoustic features) - Common Hindi stutter patterns """ start_time = time.time() logger.info(f"🚀 Starting advanced analysis: {audio_path}") # === STEP 1: Audio Loading & Preprocessing === audio, sr = librosa.load(audio_path, sr=16000) duration = librosa.get_duration(y=audio, sr=sr) logger.info(f"🎵 Audio loaded: {duration:.2f}s duration") # === STEP 2: ASR Transcription using IndicWav2Vec Hindi === transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio) logger.info(f"📝 ASR Transcription: '{transcript}' ({len(transcript)} chars, {len(word_timestamps)} words)") # === STEP 3: Comprehensive Transcript Comparison === comparison_result = self._compare_transcripts_comprehensive(transcript, proper_transcript) logger.info(f"🔍 Transcript comparison: {comparison_result['mismatch_percentage']}% mismatch, " f"phonetic similarity: {comparison_result['phonetic_similarity']:.2f}") # === STEP 4: Multi-Modal Stutter Detection === events = [] # 4a. Text-based stutters from transcript comparison if comparison_result['has_target'] and comparison_result['mismatched_chars']: for i, segment in enumerate(comparison_result['mismatched_chars'][:10]): # Limit to top 10 events.append(StutterEvent( type='mismatch', start=i * 0.5, # Approximate timing end=(i + 1) * 0.5, text=segment, confidence=0.8, acoustic_features={'source': 'transcript_comparison'}, phonetic_similarity=comparison_result['phonetic_similarity'] )) # 4b. Detected linguistic patterns (repetitions, prolongations, filled pauses) for pattern in comparison_result.get('stutter_patterns', []): events.append(StutterEvent( type=pattern['type'], start=pattern.get('position', 0) * 0.5, end=(pattern.get('position', 0) + 1) * 0.5, text=pattern['text'], confidence=0.75, acoustic_features={'pattern': pattern['pattern']} )) # 4c. Acoustic-based detection (sound similarity) logger.info("🎤 Running acoustic similarity analysis...") acoustic_repetitions = self._detect_acoustic_repetitions(audio, sr, word_timestamps) events.extend(acoustic_repetitions) logger.info(f"✅ Found {len(acoustic_repetitions)} acoustic repetitions") acoustic_prolongations = self._detect_prolongations_by_sound(audio, sr, word_timestamps) events.extend(acoustic_prolongations) logger.info(f"✅ Found {len(acoustic_prolongations)} acoustic prolongations") # 4d. Model uncertainty regions (low confidence) entropy_score, low_conf_regions = self._calculate_uncertainty(logits) for region in low_conf_regions[:5]: # Limit to 5 most uncertain events.append(StutterEvent( type='dysfluency', start=region['time'], end=region['time'] + 0.3, text="", confidence=region['confidence'], acoustic_features={'entropy': entropy_score, 'model_uncertainty': True} )) # === STEP 5: Deduplicate and Rank Events === # Remove overlapping events, keeping highest confidence events.sort(key=lambda e: (e.start, -e.confidence)) deduplicated_events = [] for event in events: # Check if overlaps with existing events overlaps = False for existing in deduplicated_events: if not (event.end < existing.start or event.start > existing.end): overlaps = True break if not overlaps: deduplicated_events.append(event) events = deduplicated_events logger.info(f"📊 Total events after deduplication: {len(events)}") # === STEP 6: Calculate Comprehensive Metrics === total_duration = sum(e.end - e.start for e in events) frequency = (len(events) / duration * 60) if duration > 0 else 0 # Mismatch percentage from transcript comparison (more accurate) mismatch_percentage = comparison_result['mismatch_percentage'] # Severity assessment (multi-factor) severity_score = ( mismatch_percentage * 0.4 + (total_duration / duration * 100) * 0.3 + (frequency / 10 * 100) * 0.3 ) if duration > 0 else 0 if severity_score < 10: severity = 'none' elif severity_score < 25: severity = 'mild' elif severity_score < 50: severity = 'moderate' else: severity = 'severe' # Confidence score (multi-factor) model_confidence = 1.0 - (entropy_score / 10.0) if entropy_score > 0 else 0.8 phonetic_confidence = comparison_result.get('phonetic_similarity', 1.0) acoustic_confidence = np.mean([e.confidence for e in events if e.type in ['repetition', 'prolongation']]) if events else 0.7 overall_confidence = ( model_confidence * 0.4 + phonetic_confidence * 0.3 + acoustic_confidence * 0.3 ) overall_confidence = max(0.0, min(1.0, overall_confidence)) # === STEP 7: Return Comprehensive Results === actual_transcript = transcript if transcript else "" target_transcript = proper_transcript if proper_transcript else "" analysis_time = time.time() - start_time result = { # Core transcripts 'actual_transcript': actual_transcript, 'target_transcript': target_transcript, # Mismatch analysis 'mismatched_chars': comparison_result.get('mismatched_chars', []), 'mismatch_percentage': round(mismatch_percentage, 2), # Advanced comparison metrics 'edit_distance': comparison_result.get('edit_distance', 0), 'lcs_ratio': comparison_result.get('lcs_ratio', 1.0), 'phonetic_similarity': comparison_result.get('phonetic_similarity', 1.0), 'word_accuracy': comparison_result.get('word_accuracy', 1.0), # Model metrics 'ctc_loss_score': round(entropy_score, 4), # Stutter events with acoustic features 'stutter_timestamps': [self._event_to_dict(e) for e in events], 'total_stutter_duration': round(total_duration, 2), 'stutter_frequency': round(frequency, 2), # Assessment 'severity': severity, 'severity_score': round(severity_score, 2), 'confidence_score': round(overall_confidence, 2), # Speaking metrics 'speaking_rate_sps': round(len(word_timestamps) / duration if duration > 0 else 0, 2), # Metadata 'analysis_duration_seconds': round(analysis_time, 2), 'model_version': 'indicwav2vec-hindi-advanced-v2', 'features_used': ['asr', 'phonetic_comparison', 'acoustic_similarity', 'pattern_detection'], # Debug info 'debug': { 'total_events_detected': len(events), 'acoustic_repetitions': len(acoustic_repetitions), 'acoustic_prolongations': len(acoustic_prolongations), 'text_patterns': len(comparison_result.get('stutter_patterns', [])), 'has_target_transcript': comparison_result['has_target'] } } logger.info(f"✅ Analysis complete in {analysis_time:.2f}s - Severity: {severity}, " f"Mismatch: {mismatch_percentage}%, Confidence: {overall_confidence:.2f}") return result # Model loader is now in a separate module: model_loader.py # This follows clean architecture principles - separation of concerns # Import using: from diagnosis.ai_engine.model_loader import get_stutter_detector