Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import whisperx | |
| from pyannote.audio import Pipeline | |
| import pandas as pd | |
| import librosa | |
| import soundfile as sf | |
| import numpy as np | |
| from scipy.signal import butter, filtfilt | |
| from typing import Optional, Dict, List, Any | |
| import torch | |
| from dataclasses import dataclass, field | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import time | |
| import shutil | |
| try: | |
| import noisereduce as nr | |
| HAVE_NOISEREDUCE = True | |
| except ImportError: | |
| HAVE_NOISEREDUCE = False | |
| Annotation: Any = None | |
| Segment: Any = None | |
| DiarizationErrorRate: Any = None | |
| pyannote_available = False | |
| try: | |
| from pyannote.core import Annotation, Segment | |
| from pyannote.metrics.diarization import DiarizationErrorRate | |
| pyannote_available = True | |
| except Exception: | |
| pyannote_available = False | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # HF_TOKEN is read from environment automatically, but explicitly check | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("Warning: HF_TOKEN not set. Diarization will be skipped.") | |
| perform_diarization = True if token and pyannote_available else False | |
| model_name = "large-v2" | |
| class TimelineItem(BaseModel): | |
| start: float | |
| end: float | |
| speaker: str | None = None | |
| text: str | |
| class AnalysisResult(BaseModel): | |
| duration: float | |
| language: str | |
| der: float | None = None | |
| speaker_error: float | None = None | |
| missed_speech: float | None = None | |
| false_alarm: float | None = None | |
| timeline_data: list[TimelineItem] | |
| app = FastAPI(title="Audio Analyzer Backend") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://frontend-audio-analyzer.vercel.app"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class AnalysisResults: | |
| timelineData: List[Dict[str, Any]] = field(default_factory=list) | |
| duration: float = 0.0 | |
| languageCode: str = "unknown" | |
| diarizationErrorRate: Optional[float] = None | |
| speakerError: Optional[float] = None | |
| missedSpeech: Optional[float] = None | |
| falseAlarm: Optional[float] = None | |
| warnings: List[str] = field(default_factory=list) | |
| success: bool = False | |
| message: str = "Analysis initiated." | |
| def warn(results: AnalysisResults, code: str, detail: str) -> None: | |
| msg = f"{code}: {detail}" | |
| if msg not in results.warnings: | |
| results.warnings.append(msg) | |
| def set_message(results: AnalysisResults, msg: str) -> None: | |
| initial_message = "Analysis initiated." | |
| if results.message and results.message != initial_message: | |
| results.message += f" | {msg}" | |
| else: | |
| results.message = msg | |
| def normalize_speaker(lbl: str) -> str: | |
| lbl_str = str(lbl) | |
| return lbl_str.replace("SPEAKER_", "Speaker_").replace("speaker_", "Speaker_") | |
| def temp_wav_path() -> str: | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| return f.name | |
| # GLOBAL HELPER: Robust helper function to handle NumPy types, NaN, and Infinity | |
| def force_float(value: Optional[Any]) -> Optional[float]: | |
| """Ensures value is a native Python float or None. Returns None for NaN/Inf.""" | |
| if value is None: | |
| return None | |
| try: | |
| f_val = float(value) | |
| if np.isnan(f_val) or np.isinf(f_val): | |
| return None | |
| return f_val | |
| except (TypeError, ValueError, AttributeError): | |
| return None | |
| # GLOBAL HELPER: Robust helper function to safely extract time fields | |
| def _get_time_field(d: Dict[str, Any], keys: List[str]) -> Optional[float]: | |
| """Try multiple possible keys and coerce to native float, returning None if not possible.""" | |
| for k in keys: | |
| if k in d: | |
| f_val = force_float(d[k]) | |
| if f_val is not None: | |
| return f_val | |
| return None | |
| def butter_filter(y, sr, lowpass=None, highpass=None, order=4): | |
| nyq = 0.5 * sr | |
| if highpass and highpass > 0 and highpass < nyq: | |
| b, a = butter(order, highpass / nyq, btype="highpass", analog=False) | |
| y = filtfilt(b, a, y) | |
| if lowpass and lowpass > 0 and lowpass < nyq: | |
| b, a = butter(order, lowpass / nyq, btype="lowpass", analog=False) | |
| y = filtfilt(b, a, y) | |
| return y | |
| def rms_normalize(y, target_rms=0.8, eps=1e-6): | |
| rms = (y**2).mean() ** 0.5 | |
| if rms < eps: | |
| return y | |
| gain = target_rms / (rms + eps) | |
| return y * gain | |
| def preprocess_audio(input_path, | |
| target_sr=16000, | |
| normalize_rms=True, | |
| target_rms=0.08, | |
| denoise=False, | |
| highpass=None, | |
| lowpass=None, | |
| output_subtype="PCM_16", | |
| verbose=False): | |
| if not os.path.exists(input_path): | |
| raise FileNotFoundError(f"Input audio not found: {input_path}") | |
| output_path = temp_wav_path() | |
| y_stereo, sr = sf.read(input_path, dtype='float64') | |
| if y_stereo.ndim > 1: | |
| y = librosa.to_mono(y_stereo.T) | |
| else: | |
| y = y_stereo | |
| if sr != target_sr: | |
| y = librosa.resample(y, orig_sr=sr, target_sr=target_sr) | |
| sr = target_sr | |
| if highpass or lowpass: | |
| y = butter_filter(y, sr, highpass=highpass, lowpass=lowpass) | |
| if denoise and HAVE_NOISEREDUCE: | |
| try: | |
| noise_len = int(min(len(y), int(0.5 * sr))) | |
| noise_clip = y[:noise_len] | |
| y = nr.reduce_noise(y=y, sr=sr, y_noise=noise_clip, prop_decrease=0.9, verbose=False) | |
| except Exception: | |
| pass | |
| if normalize_rms: | |
| y = rms_normalize(y, target_rms=target_rms) | |
| sf.write(output_path, y, sr, subtype=output_subtype) | |
| return output_path | |
| def load_rttm(path: str) -> Optional["Annotation"]: | |
| if not pyannote_available or not os.path.exists(path): | |
| return None | |
| ann = Annotation() | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if line.startswith(";;"): | |
| continue | |
| parts = line.strip().split() | |
| if len(parts) >= 8 and parts[0] == "SPEAKER": | |
| start = force_float(parts[3]) | |
| dur = force_float(parts[4]) | |
| if start is None or dur is None: | |
| continue # Skip invalid time entries | |
| spk = normalize_speaker(parts[7]) | |
| ann[Segment(start, start + dur)] = spk | |
| return ann | |
| except Exception as e: | |
| print(f"Error loading RTTM: {e}") | |
| return None | |
| # Assuming normalize_speaker, force_float, Segment, Annotation, pyannote_available are defined globally | |
| def create_fallback_hypothesis(diarize_output: Any) -> Optional["Annotation"]: | |
| """Create hypothesis annotation from diarization output with robust error handling.""" | |
| if not pyannote_available: | |
| return None | |
| ann = Annotation() | |
| try: | |
| # Handle pyannote's native format (itertracks) | |
| if hasattr(diarize_output, "itertracks"): | |
| for segment, _, label in diarize_output.itertracks(yield_label=True): | |
| try: | |
| start = force_float(segment.start) | |
| end = force_float(segment.end) | |
| if start is None or end is None: continue | |
| # Ensure valid time range | |
| if start >= 0 and end > start and end - start > 0.01: | |
| ann[Segment(start, end)] = normalize_speaker(label) | |
| except Exception: # Catch any internal conversion errors | |
| continue | |
| return ann if len(ann) > 0 else None | |
| except Exception: | |
| pass | |
| # Try to handle as list/dict format (e.g., from an intermediate whisperx output) | |
| try: | |
| if isinstance(diarize_output, list) or hasattr(diarize_output, '__iter__'): | |
| for seg in diarize_output: | |
| try: | |
| start, end, speaker = None, None, None | |
| if isinstance(seg, dict): | |
| start = force_float(seg.get('start')) | |
| end = force_float(seg.get('end')) | |
| speaker = seg.get('speaker') or seg.get('label') | |
| elif hasattr(seg, 'start') and hasattr(seg, 'end'): | |
| start = force_float(seg.start) | |
| end = force_float(seg.end) | |
| speaker = getattr(seg, 'speaker', getattr(seg, 'label', None)) | |
| else: | |
| continue | |
| if start is not None and end is not None and speaker is not None: | |
| if start >= 0 and end > start and end - start > 0.01: | |
| ann[Segment(start, end)] = normalize_speaker(speaker) | |
| except Exception: | |
| continue | |
| return ann if len(ann) > 0 else None | |
| except Exception as e: | |
| print(f"Error creating fallback hypothesis: {e}") | |
| return None | |
| def analyze_audio(audio_file: str, | |
| reference_rttm_file: Optional[str] = None, | |
| preprocess: bool = True, | |
| preprocess_params: Optional[Dict[str, Any]] = None, | |
| language_code_input: Optional[str] = None) -> AnalysisResults: | |
| results = AnalysisResults() | |
| if not os.path.exists(audio_file): | |
| results.message = f"Error: Input audio file '{audio_file}' not found." | |
| return results | |
| audio_for_model = audio_file | |
| temp_preproc = None | |
| if preprocess: | |
| params = { | |
| "target_sr": 16000, "normalize_rms": True, "target_rms": 0.08, | |
| "denoise": False, "highpass": None, "lowpass": None, | |
| "output_subtype": "PCM_16", "verbose": False | |
| } | |
| if isinstance(preprocess_params, dict): | |
| params.update(preprocess_params) | |
| if params.get("denoise") and not HAVE_NOISEREDUCE: | |
| warn(results, "DENOISE_SKIP", "Denoise requested but noisereduce not installed; skipping denoise.") | |
| params["denoise"] = False | |
| try: | |
| temp_preproc = preprocess_audio(audio_file, **params) | |
| audio_for_model = temp_preproc | |
| except Exception as e: | |
| warn(results, "PREP_FAIL", f"Preprocessing failed: {e}. Falling back to original audio.") | |
| audio_for_model = audio_file | |
| temp_preproc = None | |
| multilingual = False # Set to True if your audio might be non-English | |
| max_new_tokens = 448 # Default used by whisperx | |
| clip_timestamps = None | |
| hallucination_silence_threshold = 0.1 | |
| hotwords = None | |
| start_ml_time = time.time() | |
| try: | |
| # --- Inside analyze_audio function, after loading the model and audio --- | |
| print(f"Loading Whisper model '{model_name}' on {device}...") | |
| model = whisperx.load_model(model_name, device, compute_type="float32") | |
| audio_loaded = whisperx.load_audio(audio_for_model) | |
| print("Running robust automatic language detection...") | |
| language_code, language_prob = model.detect_language(audio_loaded) | |
| language_to_use = language_code if language_code else "en" | |
| results.languageCode = language_to_use | |
| print(f"Detected language: {language_to_use} (Prob: {language_prob:.2f}). Starting transcription...") | |
| result = model.transcribe( | |
| audio_loaded, | |
| batch_size=4, | |
| language=language_to_use, # <-- CRITICAL: Uses the detected code | |
| multilingual=True, # Must be True since the model is multilingual | |
| max_new_tokens=448, # Required positional argument | |
| clip_timestamps=None, | |
| hallucination_silence_threshold=0.1, | |
| hotwords=None | |
| ) | |
| print(f"Transcription successful in {language_to_use}. Proceeding to alignment...") | |
| try: | |
| align_model, metadata = whisperx.load_align_model(language_code=language_to_use, device=device) | |
| aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device) | |
| print("Alignment successful.") | |
| except Exception as e: | |
| warn(results, "ALIGN_SKIP", f"Alignment unavailable ({type(e).__name__}: {e}); using raw Whisper segments.") | |
| if language_to_use != 'en': | |
| try: | |
| print("Trying English alignment model as fallback...") | |
| align_model, metadata = whisperx.load_align_model(language_code="en", device=device) | |
| aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device) | |
| except Exception: | |
| warn(results, "ALIGN_FAIL", "English fallback alignment also failed. Proceeding with raw unaligned segments.") | |
| pass | |
| else: | |
| warn(results, "ALIGN_FAIL", "Alignment failed. Proceeding with raw unaligned segments.") | |
| # --- Speaker Diarization --- | |
| diarize_output = None | |
| if perform_diarization: | |
| print("Performing speaker diarization (Requires HF_TOKEN)...") | |
| try: | |
| diarize_model = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=token) | |
| diarize_output = diarize_model(audio_for_model) | |
| for segment, _, label in diarize_output.itertracks(yield_label=True): | |
| print(f"start={segment.start:.1f}s stop={segment.end:.1f}s {label}") | |
| except Exception as e: | |
| warn(results, "DIAR_SKIP", f"Error during diarization (likely token/model failure): {type(e).__name__}: {e}. Skipping diarization.") | |
| diarize_output = None | |
| else: | |
| warn(results, "DIAR_SKIP", "HF_TOKEN not set or pyannote not available. Skipping speaker diarization.") | |
| # --- Speaker Assignment --- | |
| final = aligned # Default final to aligned output | |
| diarize_segments_for_assignment = [] | |
| if diarize_output is not None: | |
| # 1. Convert pyannote output (Annotation) to list of dicts for whisperx | |
| if hasattr(diarize_output, "itertracks"): | |
| for segment, _, label in diarize_output.itertracks(yield_label=True): | |
| diarize_segments_for_assignment.append({ | |
| "start": force_float(segment.start) or 0.0, | |
| "end": force_float(segment.end) or 0.0, | |
| "speaker": normalize_speaker(label) | |
| }) | |
| # 2. Perform speaker assignment | |
| if diarize_segments_for_assignment: | |
| print("Assigning speakers to words...") | |
| try: | |
| final = whisperx.assign_word_speakers(diarize_segments_for_assignment, aligned) | |
| except Exception as e: | |
| warn(results, "ASSIGN_SPEAKERS_ERROR", f"Error assigning speakers: {e}. Falling back to unassigned segments.") | |
| else: | |
| warn(results, "ASSIGN_SKIP", "Diarization conversion failed; using raw aligned segments.") | |
| # Ensure segments have a default speaker label if assignment failed/skipped | |
| for seg in final.get("segments", []): | |
| if "speaker" not in seg and "speaker_label" not in seg: | |
| seg["speaker"] = "Speaker_1" | |
| # --- Robust Word Processing and Timeline Creation --- | |
| rows: List[Dict[str, Any]] = [] | |
| for seg in final.get("segments", []): | |
| seg_speaker = normalize_speaker(seg.get("speaker") or seg.get("speaker_label") or "Speaker_1") | |
| word_list = seg.get("words") or seg.get("tokens") or seg.get("items") or [] | |
| # If no word list, use segment as a whole | |
| if not word_list: | |
| word_start = _get_time_field(seg, ["start", "s", "timestamp", "t0"]) | |
| word_end = _get_time_field(seg, ["end", "e", "t1"]) | |
| if word_start is None: continue | |
| if word_end is None: word_end = word_start | |
| rows.append({ | |
| "start": float(word_start), | |
| "end": float(word_end), | |
| "text": str(seg.get("text", "")).strip(), | |
| "speaker": str(seg_speaker), | |
| }) | |
| continue | |
| # Process each word (uses robust global _get_time_field) | |
| for w in word_list: | |
| if not isinstance(w, dict): continue | |
| word_start = _get_time_field(w, ["start", "s", "timestamp", "t0"]) | |
| word_end = _get_time_field(w, ["end", "e", "t1"]) | |
| if word_start is None: word_start = _get_time_field(seg, ["start", "s"]) | |
| if word_end is None: word_end = _get_time_field(seg, ["end", "e"]) | |
| if word_start is None: continue | |
| if word_end is None: word_end = word_start | |
| word_speaker = normalize_speaker(w.get("speaker") or seg_speaker) | |
| word_text = (w.get("text") or w.get("word") or w.get("label") or "").strip() | |
| rows.append({ | |
| "start": force_float(word_start) or 0.0, | |
| "end": force_float(word_end) or 0.0, | |
| "text": str(word_text), | |
| "speaker": str(word_speaker), | |
| }) | |
| # Sort and assign to results | |
| rows = sorted(rows, key=lambda r: r.get("start", 0.0)) | |
| results.timelineData = rows | |
| # Calculate duration | |
| ends = [force_float(w.get("end")) for w in rows] | |
| results.duration = force_float(max(ends) if ends else 0.0) or 0.0 | |
| # --- Compute DER if reference RTTM provided (CRITICAL VALIDATION ADDED) --- | |
| if pyannote_available and reference_rttm_file and diarize_output is not None: | |
| print(f"Computing DER using reference RTTM: {reference_rttm_file}...") | |
| reference = load_rttm(reference_rttm_file) | |
| hypothesis = create_fallback_hypothesis(diarize_output) | |
| if hypothesis is None: | |
| print("DEBUG DER: Primary hypothesis failed. Falling back to using final timeline data...") | |
| hypothesis = Annotation() | |
| for item in results.timelineData: | |
| try: | |
| start = force_float(item.get('start')) | |
| end = force_float(item.get('end')) | |
| speaker = item.get('speaker') | |
| # Use a smaller minimum duration check for timeline data | |
| if start is not None and end is not None and speaker is not None and start >= 0 and end > start: | |
| hypothesis[Segment(start, end)] = normalize_speaker(speaker) | |
| except Exception: | |
| continue | |
| # --- CRITICAL DEBUGGING LOGS --- | |
| ref_len = len(reference) if reference else 0 | |
| hyp_len = len(hypothesis) if hypothesis else 0 | |
| print(f"DEBUG DER: Reference segments loaded: {ref_len}. Hypothesis segments loaded: {hyp_len}.") | |
| if ref_len == 0: | |
| warn(results, "DER_FAIL", f"Reference RTTM loaded but contains 0 valid segments.") | |
| elif hyp_len == 0: | |
| warn(results, "DER_FAIL", "Diarization hypothesis conversion failed and timeline fallback failed (0 segments).") | |
| else: | |
| try: | |
| metric = DiarizationErrorRate(collar=0.25, skip_overlap=False) | |
| der_report = metric.compute_components(reference, hypothesis) | |
| results.diarizationErrorRate = force_float(der_report.get('diarization error rate')) | |
| results.speakerError = force_float(der_report.get('speaker error')) | |
| results.missedSpeech = force_float(der_report.get('missed speech')) | |
| results.falseAlarm = force_float(der_report.get('false alarm')) | |
| print(f"DER computed: {results.diarizationErrorRate:.4f}") | |
| except Exception as e: | |
| warn(results, "DER_ERROR", f"Error computing DER (pyannote metrics failure): {type(e).__name__}: {e}. Check compatibility.") | |
| elif reference_rttm_file: | |
| warn(results, "DER_MISSING", "Reference RTTM provided but DER calculation skipped (check pyannote availability or diarization output).") | |
| results.success = True | |
| set_message(results, "Analysis complete.") | |
| except Exception as e: | |
| results.message = f"Core analysis failed: {type(e).__name__}: {e}" | |
| results.success = False | |
| warn(results, "FATAL_ERROR", f"Fatal analysis error: {type(e).__name__}: {e}") | |
| finally: | |
| # ... (cleanup and final return) ... | |
| pass | |
| return results | |
| async def upload_file(audio_file: UploadFile = File(...), rttm_file: UploadFile = File(None)): | |
| start_time = time.time() | |
| audio_path: Optional[str] = None | |
| rttm_path: Optional[str] = None | |
| try: | |
| print("Incoming upload:", getattr(audio_file, "filename", None), getattr(rttm_file, "filename", None)) | |
| suffix = audio_file.filename.split(".")[-1] if audio_file.filename else "tmp" | |
| with tempfile.NamedTemporaryFile(suffix=f".{suffix}", delete=False) as tmp_audio: | |
| # Ensure to rewind the file-like object before copying (Best Practice) | |
| audio_file.file.seek(0) | |
| shutil.copyfileobj(audio_file.file, tmp_audio) | |
| audio_path = tmp_audio.name | |
| print(f"Received audio file: {audio_file.filename} (saved to {audio_path}), size: {os.path.getsize(audio_path)} bytes") | |
| if rttm_file and rttm_file.filename: | |
| with tempfile.NamedTemporaryFile(suffix=".rttm", delete=False) as tmp_rttm: | |
| # Ensure to rewind the file-like object before copying (Best Practice) | |
| rttm_file.file.seek(0) | |
| shutil.copyfileobj(rttm_file.file, tmp_rttm) | |
| rttm_path = tmp_rttm.name | |
| print(f"Received rttm file: {rttm_file.filename} (saved to {rttm_path}), size: {os.path.getsize(rttm_path)} bytes") | |
| else: | |
| print("No RTTM file received in this request.") | |
| preprocessing_config = {"denoise": False} | |
| print(f"Starting ML processing with audio: {audio_path}, rttm: {rttm_path}") | |
| analysis_result = analyze_audio( | |
| audio_file=audio_path, | |
| reference_rttm_file=rttm_path, | |
| preprocess_params=preprocessing_config, | |
| ) | |
| print("FAILURE MESSAGE:", analysis_result.message) | |
| if not analysis_result.success: | |
| raise HTTPException(status_code=500, detail=analysis_result.message) | |
| print("DURATION BEFORE RETURN:", analysis_result.duration) | |
| if analysis_result.duration is None: | |
| analysis_result.duration = 0.0 | |
| return AnalysisResult( | |
| duration=force_float(analysis_result.duration) or 0.0, | |
| language=analysis_result.languageCode, | |
| der=analysis_result.diarizationErrorRate, | |
| speaker_error=analysis_result.speakerError, | |
| missed_speech=analysis_result.missedSpeech, | |
| false_alarm=analysis_result.falseAlarm, | |
| timeline_data=[ | |
| TimelineItem( | |
| start=force_float(item.get('start')) or 0.0, | |
| end=force_float(item.get('end')) or 0.0, | |
| speaker=str(item.get('speaker')) if item.get('speaker') else None, | |
| text=str(item.get('text', "")) | |
| ) for item in analysis_result.timelineData | |
| ] | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Unexpected error during upload process: {type(e).__name__}: {e}") | |
| finally: | |
| if audio_path and os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| if rttm_path and os.path.exists(rttm_path): | |
| os.remove(rttm_path) | |
| end_time = time.time() | |
| print(f"API Request processed in {end_time - start_time:.2f} seconds.") | |
| def root(): | |
| return {"message": "Audio Analyzer Backend is running."} |