Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Stuttering Detection API | |
| ================================== | |
| FastAPI backend with adaptive, research-based stuttering detection. | |
| No hardcoded thresholds - uses statistical methods (Modified Z-Score/MAD). | |
| Improvements over previous version: | |
| - Adaptive thresholding using Modified Z-Score (Median Absolute Deviation) | |
| - Multi-feature acoustic analysis (RMS, Pitch, MFCCs, Spectral features) | |
| - Speaking-rate normalization for accurate severity assessment | |
| - Detection of 5 dysfluency types with confidence scores | |
| - Research-backed algorithms from recent stuttering detection literature | |
| """ | |
| import os | |
| import torch | |
| import librosa | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| from typing import Optional, Dict, List | |
| import traceback | |
| from scipy import signal | |
| app = FastAPI( | |
| title="SLAQ Enhanced AI Engine", | |
| description="Adaptive stuttering detection with multi-feature analysis", | |
| version="2.0.0" | |
| ) | |
| # --- CONFIGURATION --- | |
| ASR_MODEL_ID = "facebook/mms-1b-all" | |
| SAMPLE_RATE = 16000 | |
| print("π Loading Enhanced AI Models...") | |
| # Load ASR Model for transcription | |
| try: | |
| processor = AutoProcessor.from_pretrained(ASR_MODEL_ID) | |
| model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_ID) | |
| print(f"β ASR Model loaded: {ASR_MODEL_ID}") | |
| except Exception as e: | |
| print(f"β Failed to load ASR model: {e}") | |
| raise e | |
| # Language Mapping | |
| LANG_MAP = { | |
| 'hindi': 'hin', 'tamil': 'tam', 'telugu': 'tel', 'marathi': 'mar', | |
| 'bengali': 'ben', 'gujarati': 'guj', 'kannada': 'kan', 'malayalam': 'mal', | |
| 'punjabi': 'pan', 'urdu': 'urd', 'english': 'eng', 'auto': 'auto' | |
| } | |
| class EnhancedStutterDetector: | |
| """ | |
| Enhanced stuttering detection using adaptive statistical methods. | |
| Based on recent research (2023-2025) in dysfluency detection. | |
| """ | |
| def __init__(self, sample_rate: int = 16000): | |
| self.sr = sample_rate | |
| self.mad_threshold = 3.5 # Modified Z-Score threshold | |
| def analyze(self, y: np.ndarray, sr: int) -> Dict: | |
| """Main analysis pipeline.""" | |
| duration = len(y) / sr | |
| # Extract multi-dimensional acoustic features | |
| features = self._extract_features(y, sr) | |
| # Detect speaking rate | |
| speaking_rate = self._estimate_speaking_rate(y, sr) | |
| # Detect dysfluency events | |
| events = [] | |
| events.extend(self._detect_blocks(y, sr, features)) | |
| events.extend(self._detect_prolongations(y, sr, features)) | |
| events.extend(self._detect_sound_repetitions(y, sr, features)) | |
| events.extend(self._detect_word_repetitions(y, sr, features)) | |
| events.extend(self._detect_interjections(y, sr, features)) | |
| # Sort by time | |
| events.sort(key=lambda x: x['start']) | |
| # Calculate adaptive severity | |
| severity_score = self._calculate_severity(events, duration, speaking_rate) | |
| return { | |
| 'events': events, | |
| 'total_events': len(events), | |
| 'severity_score': severity_score, | |
| 'severity_label': self._get_severity_label(severity_score), | |
| 'speaking_rate': speaking_rate, | |
| 'duration': duration, | |
| 'event_counts': self._count_types(events) | |
| } | |
| def _extract_features(self, y: np.ndarray, sr: int) -> Dict: | |
| """Extract acoustic features.""" | |
| frame_length = int(0.025 * sr) | |
| hop_length = int(0.010 * sr) | |
| features = {} | |
| # Energy (RMS) | |
| rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] | |
| features['rms'] = rms | |
| # Pitch (F0) | |
| f0 = librosa.yin(y, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), sr=sr) | |
| features['f0'] = f0 | |
| # Spectral features | |
| features['spectral_centroid'] = librosa.feature.spectral_centroid(y=y, sr=sr, hop_length=hop_length)[0] | |
| features['spectral_rolloff'] = librosa.feature.spectral_rolloff(y=y, sr=sr, hop_length=hop_length)[0] | |
| features['zcr'] = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0] | |
| # MFCCs | |
| mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13, hop_length=hop_length) | |
| features['mfcc'] = mfcc | |
| # Time mapping | |
| features['hop_length'] = hop_length | |
| features['frame_times'] = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop_length) | |
| return features | |
| def _estimate_speaking_rate(self, y: np.ndarray, sr: int) -> float: | |
| """Estimate speaking rate (syllables/sec).""" | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| onsets = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr, units='time') | |
| duration = len(y) / sr | |
| return len(onsets) / duration if duration > 0 else 0.0 | |
| def _modified_z_score(self, data: np.ndarray) -> np.ndarray: | |
| """Calculate Modified Z-Score using MAD (more robust than standard Z-score).""" | |
| median = np.median(data) | |
| mad = np.median(np.abs(data - median)) | |
| if mad < 1e-10: | |
| mad = np.mean(np.abs(data - median)) | |
| if mad < 1e-10: | |
| return np.zeros_like(data) | |
| return 0.6745 * (data - median) / mad | |
| def _detect_blocks(self, y: np.ndarray, sr: int, features: Dict) -> List[Dict]: | |
| """Detect blocks (abnormal silent pauses).""" | |
| rms = features['rms'] | |
| frame_times = features['frame_times'] | |
| # Adaptive silence threshold using Modified Z-Score | |
| rms_z = self._modified_z_score(rms) | |
| is_silent = rms_z < -self.mad_threshold | |
| blocks = [] | |
| in_block = False | |
| block_start = 0 | |
| for i, silent in enumerate(is_silent): | |
| if silent and not in_block: | |
| block_start = frame_times[i] | |
| in_block = True | |
| elif not silent and in_block: | |
| block_end = frame_times[i] | |
| duration = block_end - block_start | |
| if 0.2 < duration < 2.0: | |
| blocks.append({ | |
| 'type': 'block', | |
| 'start': float(block_start), | |
| 'end': float(block_end), | |
| 'duration': float(duration), | |
| 'confidence': float(np.mean(np.abs(rms_z[max(0, i-10):i]))) | |
| }) | |
| in_block = False | |
| return blocks | |
| def _detect_prolongations(self, y: np.ndarray, sr: int, features: Dict) -> List[Dict]: | |
| """Detect prolongations (stable sound segments).""" | |
| rms = features['rms'] | |
| f0 = features['f0'] | |
| frame_times = features['frame_times'] | |
| prolongations = [] | |
| window = 20 | |
| for i in range(window, len(rms) - window): | |
| win_rms = rms[i-window:i+window] | |
| win_f0 = f0[i-window:i+window] | |
| rms_cv = np.std(win_rms) / (np.mean(win_rms) + 1e-10) | |
| f0_cv = np.std(win_f0) / (np.mean(win_f0) + 1e-10) | |
| if rms_cv < 0.1 and f0_cv < 0.15 and np.mean(win_rms) > np.median(rms) * 0.3: | |
| if prolongations and frame_times[i] - prolongations[-1]['end'] < 0.1: | |
| prolongations[-1]['end'] = float(frame_times[i]) | |
| prolongations[-1]['duration'] = prolongations[-1]['end'] - prolongations[-1]['start'] | |
| else: | |
| start = frame_times[max(0, i-window)] | |
| end = frame_times[min(len(frame_times)-1, i+window)] | |
| prolongations.append({ | |
| 'type': 'prolongation', | |
| 'start': float(start), | |
| 'end': float(end), | |
| 'duration': float(end - start), | |
| 'confidence': float(1.0 - (rms_cv + f0_cv) / 2) | |
| }) | |
| return [p for p in prolongations if 0.3 < p['duration'] < 3.0] | |
| def _detect_sound_repetitions(self, y: np.ndarray, sr: int, features: Dict) -> List[Dict]: | |
| """Detect sound repetitions using spectral similarity.""" | |
| mfcc = features['mfcc'] | |
| frame_times = features['frame_times'] | |
| repetitions = [] | |
| window = 15 | |
| for i in range(window, len(frame_times) - window * 2): | |
| curr = mfcc[:, i:i+window].flatten() | |
| next = mfcc[:, i+window:i+2*window].flatten() | |
| if len(curr) > 0 and len(next) > 0: | |
| similarity = np.dot(curr, next) / (np.linalg.norm(curr) * np.linalg.norm(next) + 1e-10) | |
| if similarity > 0.85: | |
| start = frame_times[i] | |
| end = frame_times[min(len(frame_times)-1, i+2*window)] | |
| repetitions.append({ | |
| 'type': 'sound_repetition', | |
| 'start': float(start), | |
| 'end': float(end), | |
| 'duration': float(end - start), | |
| 'confidence': float(similarity) | |
| }) | |
| return [r for r in repetitions if 0.1 < r['duration'] < 1.5] | |
| def _detect_word_repetitions(self, y: np.ndarray, sr: int, features: Dict) -> List[Dict]: | |
| """Detect word repetitions using autocorrelation.""" | |
| rms = features['rms'] | |
| frame_times = features['frame_times'] | |
| rms_norm = (rms - np.mean(rms)) / (np.std(rms) + 1e-10) | |
| autocorr = np.correlate(rms_norm, rms_norm, mode='full') | |
| autocorr = autocorr[len(autocorr)//2:] | |
| word_window = 30 | |
| peaks, _ = signal.find_peaks( | |
| autocorr[word_window:word_window*3], | |
| height=np.percentile(autocorr, 75), | |
| distance=word_window//2 | |
| ) | |
| repetitions = [] | |
| for peak in peaks: | |
| idx = peak + word_window | |
| if idx < len(frame_times): | |
| start = frame_times[max(0, idx-word_window)] | |
| end = frame_times[min(len(frame_times)-1, idx+word_window)] | |
| repetitions.append({ | |
| 'type': 'word_repetition', | |
| 'start': float(start), | |
| 'end': float(end), | |
| 'duration': float(end - start), | |
| 'confidence': 0.7 | |
| }) | |
| return [r for r in repetitions if 0.3 < r['duration'] < 2.0] | |
| def _detect_interjections(self, y: np.ndarray, sr: int, features: Dict) -> List[Dict]: | |
| """Detect interjections (um, uh, ah).""" | |
| rms = features['rms'] | |
| centroid = features['spectral_centroid'] | |
| frame_times = features['frame_times'] | |
| centroid_z = self._modified_z_score(centroid) | |
| unusual = np.abs(centroid_z) > self.mad_threshold | |
| interjections = [] | |
| in_interj = False | |
| start_idx = 0 | |
| for i, is_unusual in enumerate(unusual): | |
| if is_unusual and rms[i] > np.median(rms) * 0.2: | |
| if not in_interj: | |
| start_idx = i | |
| in_interj = True | |
| elif in_interj: | |
| duration = (i - start_idx) * features['hop_length'] / sr | |
| if 0.1 < duration < 0.8: | |
| interjections.append({ | |
| 'type': 'interjection', | |
| 'start': float(frame_times[start_idx]), | |
| 'end': float(frame_times[i]), | |
| 'duration': float(duration), | |
| 'confidence': float(np.mean(np.abs(centroid_z[start_idx:i]))) | |
| }) | |
| in_interj = False | |
| return interjections | |
| def _calculate_severity(self, events: List[Dict], duration: float, rate: float) -> float: | |
| """Calculate adaptive severity score (0-100).""" | |
| if duration <= 0: | |
| return 0.0 | |
| counts = self._count_types(events) | |
| total_time = sum(e['duration'] for e in events) | |
| # Dysfluency percentage | |
| dysfluency_pct = (total_time / duration) * 100 | |
| # Event frequency (per minute) | |
| event_freq = (len(events) / duration) * 60 | |
| # Weighted count (blocks/prolongations more severe) | |
| weights = {'block': 2.0, 'prolongation': 1.8, 'sound_repetition': 1.5, | |
| 'word_repetition': 1.3, 'interjection': 1.0} | |
| weighted = sum(counts.get(t, 0) * w for t, w in weights.items()) | |
| # Rate normalization | |
| rate_factor = min(rate / 4.0, 2.0) if rate > 0 else 1.0 | |
| severity = ( | |
| dysfluency_pct * 0.4 + | |
| (event_freq / rate_factor) * 0.3 + | |
| (weighted / rate_factor) * 0.3 | |
| ) | |
| return float(np.clip(severity, 0, 100)) | |
| def _count_types(self, events: List[Dict]) -> Dict[str, int]: | |
| """Count events by type.""" | |
| counts = {} | |
| for e in events: | |
| counts[e['type']] = counts.get(e['type'], 0) + 1 | |
| return counts | |
| def _get_severity_label(self, score: float) -> str: | |
| """Convert score to label.""" | |
| if score < 10: return 'none' | |
| elif score < 25: return 'mild' | |
| elif score < 50: return 'moderate' | |
| elif score < 75: return 'severe' | |
| else: return 'very_severe' | |
| # Initialize detector | |
| stutter_detector = EnhancedStutterDetector(sample_rate=SAMPLE_RATE) | |
| print("β Enhanced Stutter Detector initialized") | |
| def home(): | |
| return { | |
| "status": "running", | |
| "service": "SLAQ Enhanced AI Engine", | |
| "version": "2.0.0", | |
| "features": [ | |
| "Adaptive thresholding (Modified Z-Score/MAD)", | |
| "Multi-feature acoustic analysis", | |
| "Speaking-rate normalization", | |
| "5 dysfluency types detection", | |
| "Multilingual support (MMS-1B)" | |
| ], | |
| "model": ASR_MODEL_ID | |
| } | |
| def health(): | |
| return {"status": "healthy", "model_loaded": True} | |
| async def analyze_audio( | |
| audio: UploadFile = File(...), | |
| transcript: Optional[str] = Form(""), | |
| language: Optional[str] = Form("auto") | |
| ): | |
| """ | |
| Analyze audio for stuttering events with adaptive detection. | |
| Args: | |
| audio: Audio file (WAV, MP3, etc.) | |
| transcript: Optional reference transcript for comparison | |
| language: Language code or 'auto' for detection | |
| Returns: | |
| Comprehensive stuttering analysis with adaptive thresholds | |
| """ | |
| temp_filename = f"temp_{audio.filename}" | |
| try: | |
| # Save uploaded file | |
| with open(temp_filename, "wb") as buffer: | |
| buffer.write(await audio.read()) | |
| # Load audio | |
| speech, sr = librosa.load(temp_filename, sr=SAMPLE_RATE) | |
| # --- LANGUAGE HANDLING --- | |
| lang_code = LANG_MAP.get(str(language).lower(), 'eng') | |
| if lang_code != 'auto': | |
| try: | |
| processor.tokenizer.set_target_lang(lang_code) | |
| model.load_adapter(lang_code) | |
| except: | |
| print(f"β οΈ Adapter not found for {lang_code}, using eng") | |
| lang_code = 'eng' | |
| processor.tokenizer.set_target_lang('eng') | |
| model.load_adapter('eng') | |
| else: | |
| # For auto mode, default to English | |
| lang_code = 'eng' | |
| processor.tokenizer.set_target_lang('eng') | |
| model.load_adapter('eng') | |
| # --- TRANSCRIPTION --- | |
| inputs = processor(speech, sampling_rate=SAMPLE_RATE, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| actual_transcript = processor.batch_decode(predicted_ids)[0] | |
| confidence = float(torch.mean(torch.nn.functional.softmax(logits, dim=-1).max(dim=-1).values)) | |
| # --- ENHANCED ACOUSTIC ANALYSIS --- | |
| analysis = stutter_detector.analyze(speech, sr) | |
| # --- TRANSCRIPT COMPARISON (if provided) --- | |
| mismatch_pct = 0.0 | |
| if transcript: | |
| import Levenshtein | |
| dist = Levenshtein.distance(actual_transcript, transcript) | |
| mismatch_pct = (dist / max(len(transcript), 1)) * 100 | |
| # Format timestamps | |
| timestamps = [ | |
| { | |
| 'type': evt['type'], | |
| 'start': evt['start'], | |
| 'end': evt['end'], | |
| 'duration': evt['duration'], | |
| 'confidence': evt.get('confidence', 0.5) | |
| } | |
| for evt in analysis['events'] | |
| ] | |
| # Calculate total stutter duration | |
| total_stutter_duration = sum(evt['duration'] for evt in analysis['events']) | |
| return { | |
| "actual_transcript": actual_transcript, | |
| "target_transcript": transcript or "", | |
| "mismatch_percentage": round(mismatch_pct, 2), | |
| "stutter_timestamps": timestamps, | |
| "total_stutter_duration": round(total_stutter_duration, 2), | |
| "stutter_frequency": analysis['total_events'], | |
| "severity": analysis['severity_label'], | |
| "severity_score": round(analysis['severity_score'], 2), | |
| "confidence_score": round(confidence, 2), | |
| "model_version": f"enhanced-adaptive-v2 ({lang_code})", | |
| "language_detected": lang_code, | |
| "speaking_rate": round(analysis['speaking_rate'], 2), | |
| "event_breakdown": analysis['event_counts'], | |
| "dysfluency_rate": round(analysis['total_events'] / (analysis['duration'] / 60), 2) if analysis['duration'] > 0 else 0 | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| if os.path.exists(temp_filename): | |
| os.remove(temp_filename) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("\nπ Starting Enhanced SLAQ AI Engine...") | |
| print("π Features: Adaptive thresholds, MAD-based detection, Multi-feature analysis") | |
| print("π Access at: http://localhost:8000") | |
| print("π Docs at: http://localhost:8000/docs\n") | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |