zlaqa-version-c-ai-enginee / detect_stuttering.py
HackerMOne's picture
Upload 6 files
e08baf1 verified
raw
history blame
55.6 kB
# diagnosis/ai_engine/detect_stuttering.py
import os
import librosa
import torch
import logging
import numpy as np
from transformers import Wav2Vec2ForCTC, AutoProcessor
import time
from dataclasses import dataclass, field
from typing import List, Dict, Any, Tuple, Optional
from difflib import SequenceMatcher
import re
# Advanced similarity and distance metrics
from scipy.spatial.distance import cosine, euclidean
from scipy.stats import pearsonr
logger = logging.getLogger(__name__)
# === CONFIGURATION ===
MODEL_ID = "ai4bharat/indicwav2vec-hindi" # Only model used - IndicWav2Vec Hindi for ASR
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token for authenticated model access
INDIAN_LANGUAGES = {
'hindi': 'hin', 'english': 'eng', 'tamil': 'tam', 'telugu': 'tel',
'bengali': 'ben', 'marathi': 'mar', 'gujarati': 'guj', 'kannada': 'kan',
'malayalam': 'mal', 'punjabi': 'pan', 'urdu': 'urd', 'assamese': 'asm',
'odia': 'ory', 'bhojpuri': 'bho', 'maithili': 'mai'
}
# === DEVANAGARI PHONETIC MAPPINGS (Research-Based) ===
# Consonants grouped by phonetic similarity for stutter detection
DEVANAGARI_CONSONANT_GROUPS = {
# Plosives (stops)
'velar': ['क', 'ख', 'ग', 'घ', 'ङ'],
'palatal': ['च', 'छ', 'ज', 'झ', 'ञ'],
'retroflex': ['ट', 'ठ', 'ड', 'ढ', 'ण'],
'dental': ['त', 'थ', 'द', 'ध', 'न'],
'labial': ['प', 'फ', 'ब', 'भ', 'म'],
# Fricatives & Approximants
'sibilants': ['श', 'ष', 'स', 'ह'],
'liquids': ['र', 'ल', 'ळ'],
'semivowels': ['य', 'व'],
}
# Vowels grouped by phonetic features
DEVANAGARI_VOWEL_GROUPS = {
'short': ['अ', 'इ', 'उ', 'ऋ'],
'long': ['आ', 'ई', 'ऊ', 'ॠ'],
'diphthongs': ['ए', 'ऐ', 'ओ', 'औ'],
}
# Common Hindi stutter patterns (research-based)
HINDI_STUTTER_PATTERNS = {
'repetition': [r'(.)\1{2,}', r'(\w+)\s+\1', r'(\w)\s+\1'], # Character/word repetition
'prolongation': [r'(.)\1{3,}', r'[आईऊएओ]{2,}'], # Extended vowels
'filled_pause': ['अ', 'उ', 'ए', 'म', 'उम', 'आ'], # Hesitation sounds
}
# === RESEARCH-BASED THRESHOLDS (2024-2025 Literature) ===
# Prolongation Detection (Spectral Correlation + Duration)
PROLONGATION_CORRELATION_THRESHOLD = 0.90 # >0.9 spectral similarity
PROLONGATION_MIN_DURATION = 0.25 # >250ms (Revisiting Rule-Based, 2025)
# Block Detection (Silence Analysis)
BLOCK_SILENCE_THRESHOLD = 0.35 # >350ms silence mid-utterance
BLOCK_ENERGY_PERCENTILE = 10 # Bottom 10% energy = silence
# Repetition Detection (DTW + Text Matching)
REPETITION_DTW_THRESHOLD = 0.15 # Normalized DTW distance
REPETITION_MIN_SIMILARITY = 0.85 # Text-based similarity
# Speaking Rate Norms (syllables/second)
SPEECH_RATE_MIN = 2.0
SPEECH_RATE_MAX = 6.0
SPEECH_RATE_TYPICAL = 4.0
# Formant Analysis (Vowel Centralization - Research Finding)
# People who stutter show reduced vowel space area
VOWEL_SPACE_REDUCTION_THRESHOLD = 0.70 # 70% of typical area
# Voice Quality (Jitter, Shimmer, HNR)
JITTER_THRESHOLD = 0.01 # >1% jitter indicates instability
SHIMMER_THRESHOLD = 0.03 # >3% shimmer
HNR_THRESHOLD = 15.0 # <15 dB Harmonics-to-Noise Ratio
# Zero-Crossing Rate (Voiced/Unvoiced Discrimination)
ZCR_VOICED_THRESHOLD = 0.1 # Low ZCR = voiced
ZCR_UNVOICED_THRESHOLD = 0.3 # High ZCR = unvoiced
# Entropy-Based Uncertainty
ENTROPY_HIGH_THRESHOLD = 3.5 # High confusion in model predictions
CONFIDENCE_LOW_THRESHOLD = 0.40 # Low confidence frame threshold
@dataclass
class StutterEvent:
"""Enhanced stutter event with multi-modal features"""
type: str # 'repetition', 'prolongation', 'block', 'dysfluency', 'mismatch'
start: float
end: float
text: str
confidence: float
acoustic_features: Dict[str, float] = field(default_factory=dict)
voice_quality: Dict[str, float] = field(default_factory=dict)
formant_data: Dict[str, Any] = field(default_factory=dict)
phonetic_similarity: float = 0.0 # For comparing expected vs actual sounds
class AdvancedStutterDetector:
"""
🎤 IndicWav2Vec Hindi ASR Engine
Simplified engine using ONLY ai4bharat/indicwav2vec-hindi for Automatic Speech Recognition.
Features:
- Speech-to-text transcription using IndicWav2Vec Hindi model
- Text-based stutter analysis from transcription
- Confidence scoring from model predictions
- Basic dysfluency detection from transcript patterns
Model: ai4bharat/indicwav2vec-hindi (Wav2Vec2ForCTC)
Purpose: Automatic Speech Recognition (ASR) for Hindi and Indian languages
"""
def __init__(self):
logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...")
if HF_TOKEN:
logger.info("✅ HF_TOKEN found - using authenticated model access")
else:
logger.warning("⚠️ HF_TOKEN not found - model access may fail if authentication is required")
try:
# Wav2Vec2 Model Loading - IndicWav2Vec Hindi Model
self.processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN
)
self.model = Wav2Vec2ForCTC.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
self.model.eval()
# Initialize feature extractor (clean architecture pattern)
from .features import ASRFeatureExtractor
self.feature_extractor = ASRFeatureExtractor(
model=self.model,
processor=self.processor,
device=DEVICE
)
# Debug: Log processor structure
logger.info(f"📋 Processor type: {type(self.processor)}")
if hasattr(self.processor, 'tokenizer'):
logger.info(f"📋 Tokenizer type: {type(self.processor.tokenizer)}")
if hasattr(self.processor, 'feature_extractor'):
logger.info(f"📋 Feature extractor type: {type(self.processor.feature_extractor)}")
logger.info("✅ IndicWav2Vec Hindi ASR Engine Loaded with Feature Extractor")
except Exception as e:
logger.error(f"🔥 Engine Failure: {e}")
raise
def _init_common_adapters(self):
"""Not applicable - IndicWav2Vec Hindi doesn't use adapters"""
pass
def _activate_adapter(self, lang_code: str):
"""Not applicable - IndicWav2Vec Hindi doesn't use adapters"""
logger.info(f"Using IndicWav2Vec Hindi model (optimized for Hindi)")
pass
# ===== LEGACY METHODS (NOT USED IN ASR-ONLY MODE) =====
# These methods are kept for reference but not called in the simplified ASR pipeline
# They require additional libraries (parselmouth, fastdtw, sklearn) that are not needed for ASR-only mode
def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]:
"""Extract multi-modal acoustic features"""
features = {}
# MFCC (20 coefficients)
mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=20, hop_length=512)
features['mfcc'] = mfcc.T # Transpose for time x features
# Zero-Crossing Rate
zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0]
features['zcr'] = zcr
# RMS Energy
rms_energy = librosa.feature.rms(y=audio, hop_length=512)[0]
features['rms_energy'] = rms_energy
# Spectral Flux
stft = librosa.stft(audio, hop_length=512)
magnitude = np.abs(stft)
spectral_flux = np.sum(np.diff(magnitude, axis=1) * (np.diff(magnitude, axis=1) > 0), axis=0)
features['spectral_flux'] = spectral_flux
# Energy Entropy
frame_energy = np.sum(magnitude ** 2, axis=0)
frame_energy = frame_energy + 1e-10 # Avoid log(0)
energy_entropy = -np.sum((magnitude ** 2 / frame_energy) * np.log(magnitude ** 2 / frame_energy + 1e-10), axis=0)
features['energy_entropy'] = energy_entropy
# Formant Analysis using Parselmouth
try:
sound = parselmouth.Sound(audio_path)
formant = sound.to_formant_burg(time_step=0.01)
times = np.arange(0, sound.duration, 0.01)
f1, f2, f3, f4 = [], [], [], []
for t in times:
try:
f1.append(formant.get_value_at_time(1, t) if formant.get_value_at_time(1, t) > 0 else np.nan)
f2.append(formant.get_value_at_time(2, t) if formant.get_value_at_time(2, t) > 0 else np.nan)
f3.append(formant.get_value_at_time(3, t) if formant.get_value_at_time(3, t) > 0 else np.nan)
f4.append(formant.get_value_at_time(4, t) if formant.get_value_at_time(4, t) > 0 else np.nan)
except:
f1.append(np.nan)
f2.append(np.nan)
f3.append(np.nan)
f4.append(np.nan)
formants = np.array([f1, f2, f3, f4]).T
features['formants'] = formants
# Calculate vowel space area (F1-F2 plane)
valid_f1f2 = formants[~np.isnan(formants[:, 0]) & ~np.isnan(formants[:, 1]), :2]
if len(valid_f1f2) > 0:
# Convex hull area approximation
try:
hull = ConvexHull(valid_f1f2)
vowel_space_area = hull.volume
except:
vowel_space_area = np.nan
else:
vowel_space_area = np.nan
features['formant_summary'] = {
'vowel_space_area': float(vowel_space_area) if not np.isnan(vowel_space_area) else 0.0,
'f1_mean': float(np.nanmean(f1)) if len(f1) > 0 else 0.0,
'f2_mean': float(np.nanmean(f2)) if len(f2) > 0 else 0.0,
'f1_std': float(np.nanstd(f1)) if len(f1) > 0 else 0.0,
'f2_std': float(np.nanstd(f2)) if len(f2) > 0 else 0.0
}
except Exception as e:
logger.warning(f"Formant analysis failed: {e}")
features['formants'] = np.zeros((len(audio) // 100, 4))
features['formant_summary'] = {
'vowel_space_area': 0.0,
'f1_mean': 0.0, 'f2_mean': 0.0,
'f1_std': 0.0, 'f2_std': 0.0
}
# Voice Quality Metrics (Jitter, Shimmer, HNR)
try:
sound = parselmouth.Sound(audio_path)
pitch = sound.to_pitch()
point_process = parselmouth.praat.call([sound, pitch], "To PointProcess")
jitter = parselmouth.praat.call(point_process, "Get jitter (local)", 0.0, 0.0, 1.1, 1.6, 1.3, 1.6)
shimmer = parselmouth.praat.call([sound, point_process], "Get shimmer (local)", 0.0, 0.0, 0.0001, 0.02, 1.3, 1.6)
hnr = parselmouth.praat.call(sound, "Get harmonicity (cc)", 0.0, 0.0, 0.01, 1.5, 1.0, 0.1, 1.0)
features['voice_quality'] = {
'jitter': float(jitter) if jitter is not None else 0.0,
'shimmer': float(shimmer) if shimmer is not None else 0.0,
'hnr_db': float(hnr) if hnr is not None else 20.0
}
except Exception as e:
logger.warning(f"Voice quality analysis failed: {e}")
features['voice_quality'] = {
'jitter': 0.0,
'shimmer': 0.0,
'hnr_db': 20.0
}
return features
def _transcribe_with_timestamps(self, audio: np.ndarray) -> Tuple[str, List[Dict], torch.Tensor]:
"""
Transcribe audio and return word timestamps and logits.
Uses the feature extractor for clean separation of concerns.
"""
try:
# Use feature extractor for transcription (clean architecture)
features = self.feature_extractor.get_transcription_features(audio, sample_rate=16000)
transcript = features['transcript']
logits = torch.from_numpy(features['logits'])
# Get word-level features for timestamps
word_features = self.feature_extractor.get_word_level_features(audio, sample_rate=16000)
word_timestamps = word_features['word_timestamps']
logger.info(f"📝 Transcription via feature extractor: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})")
return transcript, word_timestamps, logits
except Exception as e:
logger.error(f"❌ Transcription failed: {e}", exc_info=True)
return "", [], torch.zeros((1, 100, 32)) # Dummy return
def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]:
"""Calculate entropy-based uncertainty and low-confidence regions"""
try:
probs = torch.softmax(logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
entropy_mean = float(torch.mean(entropy).item())
# Find low-confidence regions
frame_duration = 0.02
low_conf_regions = []
confidence = torch.max(probs, dim=-1)[0]
for i in range(confidence.shape[1]):
conf = float(confidence[0, i].item())
if conf < CONFIDENCE_LOW_THRESHOLD:
low_conf_regions.append({
'time': i * frame_duration,
'confidence': conf
})
return entropy_mean, low_conf_regions
except Exception as e:
logger.warning(f"Uncertainty calculation failed: {e}")
return 0.0, []
def _estimate_speaking_rate(self, audio: np.ndarray, sr: int) -> float:
"""Estimate speaking rate in syllables per second"""
try:
# Simple syllable estimation using energy peaks
rms = librosa.feature.rms(y=audio, hop_length=512)[0]
peaks, _ = librosa.util.peak_pick(rms, pre_max=3, post_max=3, pre_avg=3, post_avg=5, delta=0.1, wait=10)
duration = len(audio) / sr
num_syllables = len(peaks)
speaking_rate = num_syllables / duration if duration > 0 else SPEECH_RATE_TYPICAL
return max(SPEECH_RATE_MIN, min(SPEECH_RATE_MAX, speaking_rate))
except Exception as e:
logger.warning(f"Speaking rate estimation failed: {e}")
return SPEECH_RATE_TYPICAL
def _detect_prolongations_advanced(self, mfcc: np.ndarray, spectral_flux: np.ndarray,
speaking_rate: float, word_timestamps: List[Dict]) -> List[StutterEvent]:
"""Detect prolongations using spectral correlation"""
events = []
frame_duration = 0.02
# Adaptive threshold based on speaking rate
min_duration = PROLONGATION_MIN_DURATION * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1))
window_size = int(min_duration / frame_duration)
if window_size < 2:
return events
for i in range(len(mfcc) - window_size):
window = mfcc[i:i+window_size]
# Calculate spectral correlation
if len(window) > 1:
corr_matrix = np.corrcoef(window.T)
avg_correlation = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)])
if avg_correlation > PROLONGATION_CORRELATION_THRESHOLD:
start_time = i * frame_duration
end_time = (i + window_size) * frame_duration
# Check if within a word boundary
for word_ts in word_timestamps:
if word_ts['start'] <= start_time <= word_ts['end']:
events.append(StutterEvent(
type='prolongation',
start=start_time,
end=end_time,
text=word_ts.get('word', ''),
confidence=float(avg_correlation),
acoustic_features={
'spectral_correlation': float(avg_correlation),
'duration': end_time - start_time
}
))
break
return events
def _detect_blocks_enhanced(self, audio: np.ndarray, sr: int, rms_energy: np.ndarray,
zcr: np.ndarray, word_timestamps: List[Dict],
speaking_rate: float) -> List[StutterEvent]:
"""Detect blocks using silence analysis"""
events = []
frame_duration = 0.02
# Adaptive threshold
silence_threshold = BLOCK_SILENCE_THRESHOLD * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1))
energy_threshold = np.percentile(rms_energy, BLOCK_ENERGY_PERCENTILE)
in_silence = False
silence_start = 0
for i, energy in enumerate(rms_energy):
is_silent = energy < energy_threshold and zcr[i] < ZCR_VOICED_THRESHOLD
if is_silent and not in_silence:
silence_start = i * frame_duration
in_silence = True
elif not is_silent and in_silence:
silence_duration = (i * frame_duration) - silence_start
if silence_duration > silence_threshold:
# Check if mid-utterance (not at start/end)
audio_duration = len(audio) / sr
if silence_start > 0.1 and silence_start < audio_duration - 0.1:
events.append(StutterEvent(
type='block',
start=silence_start,
end=i * frame_duration,
text="<silence>",
confidence=0.8,
acoustic_features={
'silence_duration': silence_duration,
'energy_level': float(energy)
}
))
in_silence = False
return events
def _detect_repetitions_advanced(self, mfcc: np.ndarray, formants: np.ndarray,
word_timestamps: List[Dict], transcript: str,
speaking_rate: float) -> List[StutterEvent]:
"""Detect repetitions using DTW and text matching"""
events = []
if len(word_timestamps) < 2:
return events
# Text-based repetition detection
words = transcript.lower().split()
for i in range(len(words) - 1):
if words[i] == words[i+1]:
# Find corresponding timestamps
if i < len(word_timestamps) and i+1 < len(word_timestamps):
start = word_timestamps[i]['start']
end = word_timestamps[i+1]['end']
# DTW verification on MFCC
start_frame = int(start / 0.02)
mid_frame = int((start + end) / 2 / 0.02)
end_frame = int(end / 0.02)
if start_frame < len(mfcc) and end_frame < len(mfcc):
segment1 = mfcc[start_frame:mid_frame]
segment2 = mfcc[mid_frame:end_frame]
if len(segment1) > 0 and len(segment2) > 0:
try:
distance, _ = fastdtw(segment1, segment2)
normalized_distance = distance / max(len(segment1), len(segment2))
if normalized_distance < REPETITION_DTW_THRESHOLD:
events.append(StutterEvent(
type='repetition',
start=start,
end=end,
text=words[i],
confidence=1.0 - normalized_distance,
acoustic_features={
'dtw_distance': float(normalized_distance),
'repetition_count': 2
}
))
except:
pass
return events
def _detect_voice_quality_issues(self, audio_path: str, word_timestamps: List[Dict],
voice_quality: Dict[str, float]) -> List[StutterEvent]:
"""Detect dysfluencies based on voice quality metrics"""
events = []
# Global voice quality issues
if voice_quality.get('jitter', 0) > JITTER_THRESHOLD or \
voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD or \
voice_quality.get('hnr_db', 20) < HNR_THRESHOLD:
# Mark regions with poor voice quality
for word_ts in word_timestamps:
if word_ts.get('start', 0) > 0: # Skip first word
events.append(StutterEvent(
type='dysfluency',
start=word_ts['start'],
end=word_ts['end'],
text=word_ts.get('word', ''),
confidence=0.6,
voice_quality=voice_quality.copy()
))
break # Only mark first occurrence
return events
def _is_overlapping(self, time: float, events: List[StutterEvent], threshold: float = 0.1) -> bool:
"""Check if time overlaps with existing events"""
for event in events:
if event.start - threshold <= time <= event.end + threshold:
return True
return False
def _detect_anomalies(self, events: List[StutterEvent], features: Dict[str, Any]) -> List[StutterEvent]:
"""Use Isolation Forest to filter anomalous events"""
if len(events) == 0:
return events
try:
# Extract features for anomaly detection
X = []
for event in events:
feat_vec = [
event.end - event.start, # Duration
event.confidence,
features.get('voice_quality', {}).get('jitter', 0),
features.get('voice_quality', {}).get('shimmer', 0)
]
X.append(feat_vec)
X = np.array(X)
if len(X) > 1:
self.anomaly_detector.fit(X)
predictions = self.anomaly_detector.predict(X)
# Keep only non-anomalous events (predictions == 1)
filtered_events = [events[i] for i, pred in enumerate(predictions) if pred == 1]
return filtered_events
except Exception as e:
logger.warning(f"Anomaly detection failed: {e}")
return events
def _deduplicate_events_cascade(self, events: List[StutterEvent]) -> List[StutterEvent]:
"""Remove overlapping events with priority: Block > Repetition > Prolongation > Dysfluency"""
if len(events) == 0:
return events
# Sort by priority and start time
priority = {'block': 4, 'repetition': 3, 'prolongation': 2, 'dysfluency': 1}
events.sort(key=lambda e: (priority.get(e.type, 0), e.start), reverse=True)
cleaned = []
for event in events:
overlap = False
for existing in cleaned:
# Check overlap
if not (event.end < existing.start or event.start > existing.end):
overlap = True
break
if not overlap:
cleaned.append(event)
# Sort by start time
cleaned.sort(key=lambda e: e.start)
return cleaned
def _calculate_clinical_metrics(self, events: List[StutterEvent], duration: float,
speaking_rate: float, features: Dict[str, Any]) -> Dict[str, Any]:
"""Calculate comprehensive clinical metrics"""
total_duration = sum(e.end - e.start for e in events)
frequency = (len(events) / duration * 60) if duration > 0 else 0
# Calculate severity score (0-100)
stutter_percentage = (total_duration / duration * 100) if duration > 0 else 0
frequency_score = min(frequency / 10 * 100, 100) # Normalize to 100
severity_score = (stutter_percentage * 0.6 + frequency_score * 0.4)
# Determine severity label
if severity_score < 10:
severity_label = 'none'
elif severity_score < 25:
severity_label = 'mild'
elif severity_score < 50:
severity_label = 'moderate'
else:
severity_label = 'severe'
# Calculate confidence based on multiple factors
voice_quality = features.get('voice_quality', {})
confidence = 0.8 # Base confidence
# Adjust based on voice quality metrics
if voice_quality.get('jitter', 0) > JITTER_THRESHOLD:
confidence -= 0.1
if voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD:
confidence -= 0.1
if voice_quality.get('hnr_db', 20) < HNR_THRESHOLD:
confidence -= 0.1
confidence = max(0.3, min(1.0, confidence))
return {
'total_duration': round(total_duration, 2),
'frequency': round(frequency, 2),
'severity_score': round(severity_score, 2),
'severity_label': severity_label,
'confidence': round(confidence, 2)
}
def _event_to_dict(self, event: StutterEvent) -> Dict[str, Any]:
"""Convert StutterEvent to dictionary"""
return {
'type': event.type,
'start': round(event.start, 2),
'end': round(event.end, 2),
'text': event.text,
'confidence': round(event.confidence, 2),
'acoustic_features': event.acoustic_features,
'voice_quality': event.voice_quality,
'formant_data': event.formant_data,
'phonetic_similarity': round(event.phonetic_similarity, 2)
}
# ========== ADVANCED TRANSCRIPT COMPARISON METHODS ==========
def _get_phonetic_group(self, char: str) -> Optional[str]:
"""Get phonetic group for a Devanagari character"""
for group_name, chars in DEVANAGARI_CONSONANT_GROUPS.items():
if char in chars:
return f'consonant_{group_name}'
for group_name, chars in DEVANAGARI_VOWEL_GROUPS.items():
if char in chars:
return f'vowel_{group_name}'
return None
def _calculate_phonetic_similarity(self, char1: str, char2: str) -> float:
"""
Calculate phonetic similarity between two characters (0-1)
Based on articulatory phonetics research
"""
if char1 == char2:
return 1.0
# Get phonetic groups
group1 = self._get_phonetic_group(char1)
group2 = self._get_phonetic_group(char2)
if group1 is None or group2 is None:
# Non-Devanagari characters - use simple comparison
return 1.0 if char1.lower() == char2.lower() else 0.0
# Same phonetic group = high similarity (common in stuttering)
if group1 == group2:
return 0.85 # e.g., क vs ख (both velar)
# Same major category (both consonants or both vowels)
if group1.split('_')[0] == group2.split('_')[0]:
return 0.5 # e.g., क (velar) vs च (palatal)
# Different categories
return 0.2
def _longest_common_subsequence(self, text1: str, text2: str) -> str:
"""
Find longest common subsequence (LCS) using dynamic programming
Critical for identifying core message vs stuttered additions
"""
m, n = len(text1), len(text2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Build DP table
for i in range(1, m + 1):
for j in range(1, n + 1):
if text1[i-1] == text2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
# Backtrack to construct LCS
lcs = []
i, j = m, n
while i > 0 and j > 0:
if text1[i-1] == text2[j-1]:
lcs.append(text1[i-1])
i -= 1
j -= 1
elif dp[i-1][j] > dp[i][j-1]:
i -= 1
else:
j -= 1
return ''.join(reversed(lcs))
def _calculate_edit_distance(self, text1: str, text2: str, phonetic_aware: bool = True) -> Tuple[int, List[Dict]]:
"""
Calculate Levenshtein edit distance with phonetic awareness
Returns: (distance, list of edit operations)
"""
m, n = len(text1), len(text2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
ops = [[[] for _ in range(n + 1)] for _ in range(m + 1)]
# Initialize
for i in range(m + 1):
dp[i][0] = i
if i > 0:
ops[i][0] = ops[i-1][0] + [{'op': 'delete', 'pos': i-1, 'char': text1[i-1]}]
for j in range(n + 1):
dp[0][j] = j
if j > 0:
ops[0][j] = ops[0][j-1] + [{'op': 'insert', 'pos': j-1, 'char': text2[j-1]}]
# Fill DP table with phonetic costs
for i in range(1, m + 1):
for j in range(1, n + 1):
if text1[i-1] == text2[j-1]:
# Exact match - no cost
dp[i][j] = dp[i-1][j-1]
ops[i][j] = ops[i-1][j-1]
else:
# Calculate phonetic substitution cost
if phonetic_aware:
phon_sim = self._calculate_phonetic_similarity(text1[i-1], text2[j-1])
sub_cost = 1.0 - (phon_sim * 0.5) # 0.5-1.0 range
else:
sub_cost = 1.0
# Choose minimum cost operation
costs = [
dp[i-1][j] + 1, # Delete
dp[i][j-1] + 1, # Insert
dp[i-1][j-1] + sub_cost # Substitute
]
min_cost_idx = costs.index(min(costs))
dp[i][j] = costs[min_cost_idx]
if min_cost_idx == 0:
ops[i][j] = ops[i-1][j] + [{'op': 'delete', 'pos': i-1, 'char': text1[i-1]}]
elif min_cost_idx == 1:
ops[i][j] = ops[i][j-1] + [{'op': 'insert', 'pos': j-1, 'char': text2[j-1]}]
else:
ops[i][j] = ops[i-1][j-1] + [{'op': 'substitute', 'pos': i-1,
'from': text1[i-1], 'to': text2[j-1],
'phonetic_sim': phon_sim if phonetic_aware else 0}]
return int(dp[m][n]), ops[m][n]
def _find_mismatched_segments(self, actual: str, target: str) -> List[str]:
"""
Find character sequences in actual that don't appear in target
Uses LCS to identify core message, then extracts mismatches
"""
if not actual or not target:
return [actual] if actual else []
lcs = self._longest_common_subsequence(actual, target)
# Extract segments not in LCS
mismatched_segments = []
segment = ""
lcs_idx = 0
for char in actual:
if lcs_idx < len(lcs) and char == lcs[lcs_idx]:
if segment:
mismatched_segments.append(segment)
segment = ""
lcs_idx += 1
else:
segment += char
if segment:
mismatched_segments.append(segment)
return mismatched_segments
def _detect_stutter_patterns_in_text(self, text: str) -> List[Dict[str, Any]]:
"""
Detect common Hindi stutter patterns in text
Based on linguistic research on Hindi dysfluencies
"""
patterns_found = []
# Detect repetitions
for pattern in HINDI_STUTTER_PATTERNS['repetition']:
matches = re.finditer(pattern, text)
for match in matches:
patterns_found.append({
'type': 'repetition',
'text': match.group(0),
'position': match.start(),
'pattern': pattern
})
# Detect prolongations
for pattern in HINDI_STUTTER_PATTERNS['prolongation']:
matches = re.finditer(pattern, text)
for match in matches:
patterns_found.append({
'type': 'prolongation',
'text': match.group(0),
'position': match.start(),
'pattern': pattern
})
# Detect filled pauses
words = text.split()
for i, word in enumerate(words):
if word in HINDI_STUTTER_PATTERNS['filled_pause']:
patterns_found.append({
'type': 'filled_pause',
'text': word,
'position': i,
'pattern': 'hesitation'
})
return patterns_found
def _compare_transcripts_comprehensive(self, actual: str, target: str) -> Dict[str, Any]:
"""
Comprehensive transcript comparison with multiple metrics
Returns detailed analysis including phonetic, structural, and acoustic mismatches
"""
if not target:
# No target provided - only analyze actual for stutter patterns
stutter_patterns = self._detect_stutter_patterns_in_text(actual)
return {
'has_target': False,
'mismatched_chars': [],
'mismatch_percentage': 0,
'edit_distance': 0,
'lcs_ratio': 1.0,
'phonetic_similarity': 1.0,
'stutter_patterns': stutter_patterns,
'edit_operations': []
}
# Normalize whitespace
actual = ' '.join(actual.split())
target = ' '.join(target.split())
# 1. Find mismatched character segments
mismatched_segments = self._find_mismatched_segments(actual, target)
# 2. Calculate edit distance with phonetic awareness
edit_dist, edit_ops = self._calculate_edit_distance(actual, target, phonetic_aware=True)
# 3. Calculate LCS ratio (similarity measure)
lcs = self._longest_common_subsequence(actual, target)
lcs_ratio = len(lcs) / max(len(target), 1)
# 4. Calculate overall phonetic similarity
phonetic_scores = []
matcher = SequenceMatcher(None, actual, target)
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
phonetic_scores.append(1.0)
elif tag == 'replace':
# Calculate phonetic similarity for replacements
for a_char, t_char in zip(actual[i1:i2], target[j1:j2]):
phonetic_scores.append(self._calculate_phonetic_similarity(a_char, t_char))
avg_phonetic_sim = np.mean(phonetic_scores) if phonetic_scores else 0.0
# 5. Calculate mismatch percentage (characters not in target)
total_mismatched = sum(len(seg) for seg in mismatched_segments)
mismatch_percentage = (total_mismatched / max(len(target), 1)) * 100
mismatch_percentage = min(round(mismatch_percentage), 100)
# 6. Detect stutter patterns in actual transcript
stutter_patterns = self._detect_stutter_patterns_in_text(actual)
# 7. Word-level analysis
actual_words = actual.split()
target_words = target.split()
word_matcher = SequenceMatcher(None, actual_words, target_words)
word_accuracy = word_matcher.ratio()
return {
'has_target': True,
'mismatched_chars': mismatched_segments,
'mismatch_percentage': mismatch_percentage,
'edit_distance': edit_dist,
'normalized_edit_distance': edit_dist / max(len(target), 1),
'lcs': lcs,
'lcs_ratio': round(lcs_ratio, 3),
'phonetic_similarity': round(float(avg_phonetic_sim), 3),
'word_accuracy': round(word_accuracy, 3),
'stutter_patterns': stutter_patterns,
'edit_operations': edit_ops[:20], # Limit for performance
'actual_length': len(actual),
'target_length': len(target),
'actual_words': len(actual_words),
'target_words': len(target_words)
}
# ========== ACOUSTIC SIMILARITY METHODS (SOUND-BASED MATCHING) ==========
def _extract_mfcc_features(self, audio: np.ndarray, sr: int, n_mfcc: int = 13) -> np.ndarray:
"""Extract MFCC features for acoustic comparison"""
mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=n_mfcc, hop_length=512)
# Normalize
mfcc = (mfcc - np.mean(mfcc, axis=1, keepdims=True)) / (np.std(mfcc, axis=1, keepdims=True) + 1e-8)
return mfcc.T # Time x Features
def _calculate_dtw_distance(self, seq1: np.ndarray, seq2: np.ndarray) -> float:
"""
Dynamic Time Warping distance for comparing audio segments
Critical for detecting phonetic stutters where timing differs
"""
n, m = len(seq1), len(seq2)
dtw_matrix = np.full((n + 1, m + 1), np.inf)
dtw_matrix[0, 0] = 0
for i in range(1, n + 1):
for j in range(1, m + 1):
cost = euclidean(seq1[i-1], seq2[j-1])
dtw_matrix[i, j] = cost + min(
dtw_matrix[i-1, j], # Insertion
dtw_matrix[i, j-1], # Deletion
dtw_matrix[i-1, j-1] # Match
)
# Normalize by path length
return dtw_matrix[n, m] / (n + m)
def _compare_audio_segments_acoustic(self, segment1: np.ndarray, segment2: np.ndarray,
sr: int = 16000) -> Dict[str, float]:
"""
Compare two audio segments acoustically using multiple metrics
Used to detect when sounds are similar but transcripts differ (phonetic stutters)
"""
# Extract MFCC features
mfcc1 = self._extract_mfcc_features(segment1, sr)
mfcc2 = self._extract_mfcc_features(segment2, sr)
# 1. DTW distance
dtw_dist = self._calculate_dtw_distance(mfcc1, mfcc2)
dtw_similarity = max(0, 1.0 - (dtw_dist / 10)) # Normalize to 0-1
# 2. Spectral features comparison
spec1 = np.abs(librosa.stft(segment1))
spec2 = np.abs(librosa.stft(segment2))
# Resize to same shape for comparison
min_frames = min(spec1.shape[1], spec2.shape[1])
spec1 = spec1[:, :min_frames]
spec2 = spec2[:, :min_frames]
# Spectral correlation
spec_corr = np.mean([pearsonr(spec1[:, i], spec2[:, i])[0]
for i in range(min_frames) if not np.all(spec1[:, i] == 0)
and not np.all(spec2[:, i] == 0)])
spec_corr = max(0, spec_corr) # Handle NaN/negative
# 3. Energy comparison
energy1 = np.sum(segment1 ** 2)
energy2 = np.sum(segment2 ** 2)
energy_ratio = min(energy1, energy2) / (max(energy1, energy2) + 1e-8)
# 4. Zero-crossing rate comparison
zcr1 = np.mean(librosa.feature.zero_crossing_rate(segment1)[0])
zcr2 = np.mean(librosa.feature.zero_crossing_rate(segment2)[0])
zcr_similarity = 1.0 - min(abs(zcr1 - zcr2) / (max(zcr1, zcr2) + 1e-8), 1.0)
# Overall acoustic similarity (weighted average)
overall_similarity = (
dtw_similarity * 0.4 +
spec_corr * 0.3 +
energy_ratio * 0.15 +
zcr_similarity * 0.15
)
return {
'dtw_similarity': round(float(dtw_similarity), 3),
'spectral_correlation': round(float(spec_corr), 3),
'energy_ratio': round(float(energy_ratio), 3),
'zcr_similarity': round(float(zcr_similarity), 3),
'overall_acoustic_similarity': round(float(overall_similarity), 3)
}
def _detect_acoustic_repetitions(self, audio: np.ndarray, sr: int,
word_timestamps: List[Dict]) -> List[StutterEvent]:
"""
Detect repetitions by comparing acoustic similarity between word segments
Catches stutters even when ASR transcribes them differently
"""
events = []
if len(word_timestamps) < 2:
return events
# Compare consecutive words acoustically
for i in range(len(word_timestamps) - 1):
try:
# Extract audio segments
start1 = int(word_timestamps[i]['start'] * sr)
end1 = int(word_timestamps[i]['end'] * sr)
start2 = int(word_timestamps[i+1]['start'] * sr)
end2 = int(word_timestamps[i+1]['end'] * sr)
if end1 > len(audio) or end2 > len(audio):
continue
segment1 = audio[start1:end1]
segment2 = audio[start2:end2]
if len(segment1) < 100 or len(segment2) < 100: # Skip very short segments
continue
# Calculate acoustic similarity
acoustic_sim = self._compare_audio_segments_acoustic(segment1, segment2, sr)
# High acoustic similarity suggests repetition (even if transcripts differ)
if acoustic_sim['overall_acoustic_similarity'] > 0.75:
events.append(StutterEvent(
type='repetition',
start=word_timestamps[i]['start'],
end=word_timestamps[i+1]['end'],
text=f"{word_timestamps[i].get('word', '')}{word_timestamps[i+1].get('word', '')}",
confidence=acoustic_sim['overall_acoustic_similarity'],
acoustic_features=acoustic_sim,
phonetic_similarity=acoustic_sim['overall_acoustic_similarity']
))
except Exception as e:
logger.warning(f"Acoustic comparison failed for words {i}-{i+1}: {e}")
continue
return events
def _detect_prolongations_by_sound(self, audio: np.ndarray, sr: int,
word_timestamps: List[Dict]) -> List[StutterEvent]:
"""
Detect prolongations by analyzing spectral stability within words
High spectral correlation over time = prolonged sound
"""
events = []
for word_info in word_timestamps:
try:
start = int(word_info['start'] * sr)
end = int(word_info['end'] * sr)
if end > len(audio) or end - start < sr * 0.3: # Skip if < 300ms
continue
segment = audio[start:end]
# Extract MFCC
mfcc = self._extract_mfcc_features(segment, sr)
if len(mfcc) < 10: # Need sufficient frames
continue
# Calculate frame-to-frame correlation
correlations = []
window_size = 5
for i in range(len(mfcc) - window_size):
corr_matrix = np.corrcoef(mfcc[i:i+window_size].T)
avg_corr = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)])
correlations.append(avg_corr)
avg_correlation = np.mean(correlations) if correlations else 0
# High correlation = prolongation (same sound repeated)
if avg_correlation > PROLONGATION_CORRELATION_THRESHOLD:
duration = (end - start) / sr
events.append(StutterEvent(
type='prolongation',
start=word_info['start'],
end=word_info['end'],
text=word_info.get('word', ''),
confidence=float(avg_correlation),
acoustic_features={
'spectral_correlation': float(avg_correlation),
'duration': duration
},
phonetic_similarity=float(avg_correlation)
))
except Exception as e:
logger.warning(f"Prolongation detection failed for word: {e}")
continue
return events
def analyze_audio(self, audio_path: str, proper_transcript: str = "", language: str = 'hindi') -> dict:
"""
🎯 ADVANCED Multi-Modal Stutter Detection Pipeline
Combines:
1. ASR Transcription (IndicWav2Vec Hindi)
2. Phonetic-Aware Transcript Comparison
3. Acoustic Similarity Matching (Sound-Based)
4. Linguistic Pattern Detection
This detects stutters that ASR might miss by comparing:
- What was said (actual) vs what should be said (target)
- How it sounds (acoustic features)
- Common Hindi stutter patterns
"""
start_time = time.time()
logger.info(f"🚀 Starting advanced analysis: {audio_path}")
# === STEP 1: Audio Loading & Preprocessing ===
audio, sr = librosa.load(audio_path, sr=16000)
duration = librosa.get_duration(y=audio, sr=sr)
logger.info(f"🎵 Audio loaded: {duration:.2f}s duration")
# === STEP 2: ASR Transcription using IndicWav2Vec Hindi ===
transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio)
logger.info(f"📝 ASR Transcription: '{transcript}' ({len(transcript)} chars, {len(word_timestamps)} words)")
# === STEP 3: Comprehensive Transcript Comparison ===
comparison_result = self._compare_transcripts_comprehensive(transcript, proper_transcript)
logger.info(f"🔍 Transcript comparison: {comparison_result['mismatch_percentage']}% mismatch, "
f"phonetic similarity: {comparison_result['phonetic_similarity']:.2f}")
# === STEP 4: Multi-Modal Stutter Detection ===
events = []
# 4a. Text-based stutters from transcript comparison
if comparison_result['has_target'] and comparison_result['mismatched_chars']:
for i, segment in enumerate(comparison_result['mismatched_chars'][:10]): # Limit to top 10
events.append(StutterEvent(
type='mismatch',
start=i * 0.5, # Approximate timing
end=(i + 1) * 0.5,
text=segment,
confidence=0.8,
acoustic_features={'source': 'transcript_comparison'},
phonetic_similarity=comparison_result['phonetic_similarity']
))
# 4b. Detected linguistic patterns (repetitions, prolongations, filled pauses)
for pattern in comparison_result.get('stutter_patterns', []):
events.append(StutterEvent(
type=pattern['type'],
start=pattern.get('position', 0) * 0.5,
end=(pattern.get('position', 0) + 1) * 0.5,
text=pattern['text'],
confidence=0.75,
acoustic_features={'pattern': pattern['pattern']}
))
# 4c. Acoustic-based detection (sound similarity)
logger.info("🎤 Running acoustic similarity analysis...")
acoustic_repetitions = self._detect_acoustic_repetitions(audio, sr, word_timestamps)
events.extend(acoustic_repetitions)
logger.info(f"✅ Found {len(acoustic_repetitions)} acoustic repetitions")
acoustic_prolongations = self._detect_prolongations_by_sound(audio, sr, word_timestamps)
events.extend(acoustic_prolongations)
logger.info(f"✅ Found {len(acoustic_prolongations)} acoustic prolongations")
# 4d. Model uncertainty regions (low confidence)
entropy_score, low_conf_regions = self._calculate_uncertainty(logits)
for region in low_conf_regions[:5]: # Limit to 5 most uncertain
events.append(StutterEvent(
type='dysfluency',
start=region['time'],
end=region['time'] + 0.3,
text="<low_confidence>",
confidence=region['confidence'],
acoustic_features={'entropy': entropy_score, 'model_uncertainty': True}
))
# === STEP 5: Deduplicate and Rank Events ===
# Remove overlapping events, keeping highest confidence
events.sort(key=lambda e: (e.start, -e.confidence))
deduplicated_events = []
for event in events:
# Check if overlaps with existing events
overlaps = False
for existing in deduplicated_events:
if not (event.end < existing.start or event.start > existing.end):
overlaps = True
break
if not overlaps:
deduplicated_events.append(event)
events = deduplicated_events
logger.info(f"📊 Total events after deduplication: {len(events)}")
# === STEP 6: Calculate Comprehensive Metrics ===
total_duration = sum(e.end - e.start for e in events)
frequency = (len(events) / duration * 60) if duration > 0 else 0
# Mismatch percentage from transcript comparison (more accurate)
mismatch_percentage = comparison_result['mismatch_percentage']
# Severity assessment (multi-factor)
severity_score = (
mismatch_percentage * 0.4 +
(total_duration / duration * 100) * 0.3 +
(frequency / 10 * 100) * 0.3
) if duration > 0 else 0
if severity_score < 10:
severity = 'none'
elif severity_score < 25:
severity = 'mild'
elif severity_score < 50:
severity = 'moderate'
else:
severity = 'severe'
# Confidence score (multi-factor)
model_confidence = 1.0 - (entropy_score / 10.0) if entropy_score > 0 else 0.8
phonetic_confidence = comparison_result.get('phonetic_similarity', 1.0)
acoustic_confidence = np.mean([e.confidence for e in events if e.type in ['repetition', 'prolongation']]) if events else 0.7
overall_confidence = (
model_confidence * 0.4 +
phonetic_confidence * 0.3 +
acoustic_confidence * 0.3
)
overall_confidence = max(0.0, min(1.0, overall_confidence))
# === STEP 7: Return Comprehensive Results ===
actual_transcript = transcript if transcript else ""
target_transcript = proper_transcript if proper_transcript else ""
analysis_time = time.time() - start_time
result = {
# Core transcripts
'actual_transcript': actual_transcript,
'target_transcript': target_transcript,
# Mismatch analysis
'mismatched_chars': comparison_result.get('mismatched_chars', []),
'mismatch_percentage': round(mismatch_percentage, 2),
# Advanced comparison metrics
'edit_distance': comparison_result.get('edit_distance', 0),
'lcs_ratio': comparison_result.get('lcs_ratio', 1.0),
'phonetic_similarity': comparison_result.get('phonetic_similarity', 1.0),
'word_accuracy': comparison_result.get('word_accuracy', 1.0),
# Model metrics
'ctc_loss_score': round(entropy_score, 4),
# Stutter events with acoustic features
'stutter_timestamps': [self._event_to_dict(e) for e in events],
'total_stutter_duration': round(total_duration, 2),
'stutter_frequency': round(frequency, 2),
# Assessment
'severity': severity,
'severity_score': round(severity_score, 2),
'confidence_score': round(overall_confidence, 2),
# Speaking metrics
'speaking_rate_sps': round(len(word_timestamps) / duration if duration > 0 else 0, 2),
# Metadata
'analysis_duration_seconds': round(analysis_time, 2),
'model_version': 'indicwav2vec-hindi-advanced-v2',
'features_used': ['asr', 'phonetic_comparison', 'acoustic_similarity', 'pattern_detection'],
# Debug info
'debug': {
'total_events_detected': len(events),
'acoustic_repetitions': len(acoustic_repetitions),
'acoustic_prolongations': len(acoustic_prolongations),
'text_patterns': len(comparison_result.get('stutter_patterns', [])),
'has_target_transcript': comparison_result['has_target']
}
}
logger.info(f"✅ Analysis complete in {analysis_time:.2f}s - Severity: {severity}, "
f"Mismatch: {mismatch_percentage}%, Confidence: {overall_confidence:.2f}")
return result
# Model loader is now in a separate module: model_loader.py
# This follows clean architecture principles - separation of concerns
# Import using: from diagnosis.ai_engine.model_loader import get_stutter_detector