import os import tempfile import whisperx from pyannote.audio import Pipeline import pandas as pd import librosa import soundfile as sf import numpy as np from scipy.signal import butter, filtfilt from typing import Optional, Dict, List, Any import torch from dataclasses import dataclass, field from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import time import shutil try: import noisereduce as nr HAVE_NOISEREDUCE = True except ImportError: HAVE_NOISEREDUCE = False Annotation: Any = None Segment: Any = None DiarizationErrorRate: Any = None pyannote_available = False try: from pyannote.core import Annotation, Segment from pyannote.metrics.diarization import DiarizationErrorRate pyannote_available = True except Exception: pyannote_available = False device = "cuda" if torch.cuda.is_available() else "cpu" # HF_TOKEN is read from environment automatically, but explicitly check token = os.environ.get("HF_TOKEN") if not token: print("Warning: HF_TOKEN not set. Diarization will be skipped.") perform_diarization = True if token and pyannote_available else False model_name = "large-v2" class TimelineItem(BaseModel): start: float end: float speaker: str | None = None text: str class AnalysisResult(BaseModel): duration: float language: str der: float | None = None speaker_error: float | None = None missed_speech: float | None = None false_alarm: float | None = None timeline_data: list[TimelineItem] app = FastAPI(title="Audio Analyzer Backend") app.add_middleware( CORSMiddleware, allow_origins=["https://frontend-audio-analyzer.vercel.app"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @dataclass class AnalysisResults: timelineData: List[Dict[str, Any]] = field(default_factory=list) duration: float = 0.0 languageCode: str = "unknown" diarizationErrorRate: Optional[float] = None speakerError: Optional[float] = None missedSpeech: Optional[float] = None falseAlarm: Optional[float] = None warnings: List[str] = field(default_factory=list) success: bool = False message: str = "Analysis initiated." def warn(results: AnalysisResults, code: str, detail: str) -> None: msg = f"{code}: {detail}" if msg not in results.warnings: results.warnings.append(msg) def set_message(results: AnalysisResults, msg: str) -> None: initial_message = "Analysis initiated." if results.message and results.message != initial_message: results.message += f" | {msg}" else: results.message = msg def normalize_speaker(lbl: str) -> str: lbl_str = str(lbl) return lbl_str.replace("SPEAKER_", "Speaker_").replace("speaker_", "Speaker_") def temp_wav_path() -> str: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: return f.name # GLOBAL HELPER: Robust helper function to handle NumPy types, NaN, and Infinity def force_float(value: Optional[Any]) -> Optional[float]: """Ensures value is a native Python float or None. Returns None for NaN/Inf.""" if value is None: return None try: f_val = float(value) if np.isnan(f_val) or np.isinf(f_val): return None return f_val except (TypeError, ValueError, AttributeError): return None # GLOBAL HELPER: Robust helper function to safely extract time fields def _get_time_field(d: Dict[str, Any], keys: List[str]) -> Optional[float]: """Try multiple possible keys and coerce to native float, returning None if not possible.""" for k in keys: if k in d: f_val = force_float(d[k]) if f_val is not None: return f_val return None def butter_filter(y, sr, lowpass=None, highpass=None, order=4): nyq = 0.5 * sr if highpass and highpass > 0 and highpass < nyq: b, a = butter(order, highpass / nyq, btype="highpass", analog=False) y = filtfilt(b, a, y) if lowpass and lowpass > 0 and lowpass < nyq: b, a = butter(order, lowpass / nyq, btype="lowpass", analog=False) y = filtfilt(b, a, y) return y def rms_normalize(y, target_rms=0.8, eps=1e-6): rms = (y**2).mean() ** 0.5 if rms < eps: return y gain = target_rms / (rms + eps) return y * gain def preprocess_audio(input_path, target_sr=16000, normalize_rms=True, target_rms=0.08, denoise=False, highpass=None, lowpass=None, output_subtype="PCM_16", verbose=False): if not os.path.exists(input_path): raise FileNotFoundError(f"Input audio not found: {input_path}") output_path = temp_wav_path() y_stereo, sr = sf.read(input_path, dtype='float64') if y_stereo.ndim > 1: y = librosa.to_mono(y_stereo.T) else: y = y_stereo if sr != target_sr: y = librosa.resample(y, orig_sr=sr, target_sr=target_sr) sr = target_sr if highpass or lowpass: y = butter_filter(y, sr, highpass=highpass, lowpass=lowpass) if denoise and HAVE_NOISEREDUCE: try: noise_len = int(min(len(y), int(0.5 * sr))) noise_clip = y[:noise_len] y = nr.reduce_noise(y=y, sr=sr, y_noise=noise_clip, prop_decrease=0.9, verbose=False) except Exception: pass if normalize_rms: y = rms_normalize(y, target_rms=target_rms) sf.write(output_path, y, sr, subtype=output_subtype) return output_path def load_rttm(path: str) -> Optional["Annotation"]: if not pyannote_available or not os.path.exists(path): return None ann = Annotation() try: with open(path, "r", encoding="utf-8") as f: for line in f: if line.startswith(";;"): continue parts = line.strip().split() if len(parts) >= 8 and parts[0] == "SPEAKER": start = force_float(parts[3]) dur = force_float(parts[4]) if start is None or dur is None: continue # Skip invalid time entries spk = normalize_speaker(parts[7]) ann[Segment(start, start + dur)] = spk return ann except Exception as e: print(f"Error loading RTTM: {e}") return None # Assuming normalize_speaker, force_float, Segment, Annotation, pyannote_available are defined globally def create_fallback_hypothesis(diarize_output: Any) -> Optional["Annotation"]: """Create hypothesis annotation from diarization output with robust error handling.""" if not pyannote_available: return None ann = Annotation() try: # Handle pyannote's native format (itertracks) if hasattr(diarize_output, "itertracks"): for segment, _, label in diarize_output.itertracks(yield_label=True): try: start = force_float(segment.start) end = force_float(segment.end) if start is None or end is None: continue # Ensure valid time range if start >= 0 and end > start and end - start > 0.01: ann[Segment(start, end)] = normalize_speaker(label) except Exception: # Catch any internal conversion errors continue return ann if len(ann) > 0 else None except Exception: pass # Try to handle as list/dict format (e.g., from an intermediate whisperx output) try: if isinstance(diarize_output, list) or hasattr(diarize_output, '__iter__'): for seg in diarize_output: try: start, end, speaker = None, None, None if isinstance(seg, dict): start = force_float(seg.get('start')) end = force_float(seg.get('end')) speaker = seg.get('speaker') or seg.get('label') elif hasattr(seg, 'start') and hasattr(seg, 'end'): start = force_float(seg.start) end = force_float(seg.end) speaker = getattr(seg, 'speaker', getattr(seg, 'label', None)) else: continue if start is not None and end is not None and speaker is not None: if start >= 0 and end > start and end - start > 0.01: ann[Segment(start, end)] = normalize_speaker(speaker) except Exception: continue return ann if len(ann) > 0 else None except Exception as e: print(f"Error creating fallback hypothesis: {e}") return None def analyze_audio(audio_file: str, reference_rttm_file: Optional[str] = None, preprocess: bool = True, preprocess_params: Optional[Dict[str, Any]] = None, language_code_input: Optional[str] = None) -> AnalysisResults: results = AnalysisResults() if not os.path.exists(audio_file): results.message = f"Error: Input audio file '{audio_file}' not found." return results audio_for_model = audio_file temp_preproc = None if preprocess: params = { "target_sr": 16000, "normalize_rms": True, "target_rms": 0.08, "denoise": False, "highpass": None, "lowpass": None, "output_subtype": "PCM_16", "verbose": False } if isinstance(preprocess_params, dict): params.update(preprocess_params) if params.get("denoise") and not HAVE_NOISEREDUCE: warn(results, "DENOISE_SKIP", "Denoise requested but noisereduce not installed; skipping denoise.") params["denoise"] = False try: temp_preproc = preprocess_audio(audio_file, **params) audio_for_model = temp_preproc except Exception as e: warn(results, "PREP_FAIL", f"Preprocessing failed: {e}. Falling back to original audio.") audio_for_model = audio_file temp_preproc = None multilingual = False # Set to True if your audio might be non-English max_new_tokens = 448 # Default used by whisperx clip_timestamps = None hallucination_silence_threshold = 0.1 hotwords = None start_ml_time = time.time() try: # --- Inside analyze_audio function, after loading the model and audio --- print(f"Loading Whisper model '{model_name}' on {device}...") model = whisperx.load_model(model_name, device, compute_type="float32") audio_loaded = whisperx.load_audio(audio_for_model) print("Running robust automatic language detection...") language_code, language_prob = model.detect_language(audio_loaded) language_to_use = language_code if language_code else "en" results.languageCode = language_to_use print(f"Detected language: {language_to_use} (Prob: {language_prob:.2f}). Starting transcription...") result = model.transcribe( audio_loaded, batch_size=4, language=language_to_use, # <-- CRITICAL: Uses the detected code multilingual=True, # Must be True since the model is multilingual max_new_tokens=448, # Required positional argument clip_timestamps=None, hallucination_silence_threshold=0.1, hotwords=None ) print(f"Transcription successful in {language_to_use}. Proceeding to alignment...") try: align_model, metadata = whisperx.load_align_model(language_code=language_to_use, device=device) aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device) print("Alignment successful.") except Exception as e: warn(results, "ALIGN_SKIP", f"Alignment unavailable ({type(e).__name__}: {e}); using raw Whisper segments.") if language_to_use != 'en': try: print("Trying English alignment model as fallback...") align_model, metadata = whisperx.load_align_model(language_code="en", device=device) aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device) except Exception: warn(results, "ALIGN_FAIL", "English fallback alignment also failed. Proceeding with raw unaligned segments.") pass else: warn(results, "ALIGN_FAIL", "Alignment failed. Proceeding with raw unaligned segments.") # --- Speaker Diarization --- diarize_output = None if perform_diarization: print("Performing speaker diarization (Requires HF_TOKEN)...") try: diarize_model = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=token) diarize_output = diarize_model(audio_for_model) for segment, _, label in diarize_output.itertracks(yield_label=True): print(f"start={segment.start:.1f}s stop={segment.end:.1f}s {label}") except Exception as e: warn(results, "DIAR_SKIP", f"Error during diarization (likely token/model failure): {type(e).__name__}: {e}. Skipping diarization.") diarize_output = None else: warn(results, "DIAR_SKIP", "HF_TOKEN not set or pyannote not available. Skipping speaker diarization.") # --- Speaker Assignment --- final = aligned # Default final to aligned output diarize_segments_for_assignment = [] if diarize_output is not None: # 1. Convert pyannote output (Annotation) to list of dicts for whisperx if hasattr(diarize_output, "itertracks"): for segment, _, label in diarize_output.itertracks(yield_label=True): diarize_segments_for_assignment.append({ "start": force_float(segment.start) or 0.0, "end": force_float(segment.end) or 0.0, "speaker": normalize_speaker(label) }) # 2. Perform speaker assignment if diarize_segments_for_assignment: print("Assigning speakers to words...") try: final = whisperx.assign_word_speakers(diarize_segments_for_assignment, aligned) except Exception as e: warn(results, "ASSIGN_SPEAKERS_ERROR", f"Error assigning speakers: {e}. Falling back to unassigned segments.") else: warn(results, "ASSIGN_SKIP", "Diarization conversion failed; using raw aligned segments.") # Ensure segments have a default speaker label if assignment failed/skipped for seg in final.get("segments", []): if "speaker" not in seg and "speaker_label" not in seg: seg["speaker"] = "Speaker_1" # --- Robust Word Processing and Timeline Creation --- rows: List[Dict[str, Any]] = [] for seg in final.get("segments", []): seg_speaker = normalize_speaker(seg.get("speaker") or seg.get("speaker_label") or "Speaker_1") word_list = seg.get("words") or seg.get("tokens") or seg.get("items") or [] # If no word list, use segment as a whole if not word_list: word_start = _get_time_field(seg, ["start", "s", "timestamp", "t0"]) word_end = _get_time_field(seg, ["end", "e", "t1"]) if word_start is None: continue if word_end is None: word_end = word_start rows.append({ "start": float(word_start), "end": float(word_end), "text": str(seg.get("text", "")).strip(), "speaker": str(seg_speaker), }) continue # Process each word (uses robust global _get_time_field) for w in word_list: if not isinstance(w, dict): continue word_start = _get_time_field(w, ["start", "s", "timestamp", "t0"]) word_end = _get_time_field(w, ["end", "e", "t1"]) if word_start is None: word_start = _get_time_field(seg, ["start", "s"]) if word_end is None: word_end = _get_time_field(seg, ["end", "e"]) if word_start is None: continue if word_end is None: word_end = word_start word_speaker = normalize_speaker(w.get("speaker") or seg_speaker) word_text = (w.get("text") or w.get("word") or w.get("label") or "").strip() rows.append({ "start": force_float(word_start) or 0.0, "end": force_float(word_end) or 0.0, "text": str(word_text), "speaker": str(word_speaker), }) # Sort and assign to results rows = sorted(rows, key=lambda r: r.get("start", 0.0)) results.timelineData = rows # Calculate duration ends = [force_float(w.get("end")) for w in rows] results.duration = force_float(max(ends) if ends else 0.0) or 0.0 # --- Compute DER if reference RTTM provided (CRITICAL VALIDATION ADDED) --- if pyannote_available and reference_rttm_file and diarize_output is not None: print(f"Computing DER using reference RTTM: {reference_rttm_file}...") reference = load_rttm(reference_rttm_file) hypothesis = create_fallback_hypothesis(diarize_output) if hypothesis is None: print("DEBUG DER: Primary hypothesis failed. Falling back to using final timeline data...") hypothesis = Annotation() for item in results.timelineData: try: start = force_float(item.get('start')) end = force_float(item.get('end')) speaker = item.get('speaker') # Use a smaller minimum duration check for timeline data if start is not None and end is not None and speaker is not None and start >= 0 and end > start: hypothesis[Segment(start, end)] = normalize_speaker(speaker) except Exception: continue # --- CRITICAL DEBUGGING LOGS --- ref_len = len(reference) if reference else 0 hyp_len = len(hypothesis) if hypothesis else 0 print(f"DEBUG DER: Reference segments loaded: {ref_len}. Hypothesis segments loaded: {hyp_len}.") if ref_len == 0: warn(results, "DER_FAIL", f"Reference RTTM loaded but contains 0 valid segments.") elif hyp_len == 0: warn(results, "DER_FAIL", "Diarization hypothesis conversion failed and timeline fallback failed (0 segments).") else: try: metric = DiarizationErrorRate(collar=0.25, skip_overlap=False) der_report = metric.compute_components(reference, hypothesis) results.diarizationErrorRate = force_float(der_report.get('diarization error rate')) results.speakerError = force_float(der_report.get('speaker error')) results.missedSpeech = force_float(der_report.get('missed speech')) results.falseAlarm = force_float(der_report.get('false alarm')) print(f"DER computed: {results.diarizationErrorRate:.4f}") except Exception as e: warn(results, "DER_ERROR", f"Error computing DER (pyannote metrics failure): {type(e).__name__}: {e}. Check compatibility.") elif reference_rttm_file: warn(results, "DER_MISSING", "Reference RTTM provided but DER calculation skipped (check pyannote availability or diarization output).") results.success = True set_message(results, "Analysis complete.") except Exception as e: results.message = f"Core analysis failed: {type(e).__name__}: {e}" results.success = False warn(results, "FATAL_ERROR", f"Fatal analysis error: {type(e).__name__}: {e}") finally: # ... (cleanup and final return) ... pass return results @app.post("/upload", response_model=AnalysisResult) async def upload_file(audio_file: UploadFile = File(...), rttm_file: UploadFile = File(None)): start_time = time.time() audio_path: Optional[str] = None rttm_path: Optional[str] = None try: print("Incoming upload:", getattr(audio_file, "filename", None), getattr(rttm_file, "filename", None)) suffix = audio_file.filename.split(".")[-1] if audio_file.filename else "tmp" with tempfile.NamedTemporaryFile(suffix=f".{suffix}", delete=False) as tmp_audio: # Ensure to rewind the file-like object before copying (Best Practice) audio_file.file.seek(0) shutil.copyfileobj(audio_file.file, tmp_audio) audio_path = tmp_audio.name print(f"Received audio file: {audio_file.filename} (saved to {audio_path}), size: {os.path.getsize(audio_path)} bytes") if rttm_file and rttm_file.filename: with tempfile.NamedTemporaryFile(suffix=".rttm", delete=False) as tmp_rttm: # Ensure to rewind the file-like object before copying (Best Practice) rttm_file.file.seek(0) shutil.copyfileobj(rttm_file.file, tmp_rttm) rttm_path = tmp_rttm.name print(f"Received rttm file: {rttm_file.filename} (saved to {rttm_path}), size: {os.path.getsize(rttm_path)} bytes") else: print("No RTTM file received in this request.") preprocessing_config = {"denoise": False} print(f"Starting ML processing with audio: {audio_path}, rttm: {rttm_path}") analysis_result = analyze_audio( audio_file=audio_path, reference_rttm_file=rttm_path, preprocess_params=preprocessing_config, ) print("FAILURE MESSAGE:", analysis_result.message) if not analysis_result.success: raise HTTPException(status_code=500, detail=analysis_result.message) print("DURATION BEFORE RETURN:", analysis_result.duration) if analysis_result.duration is None: analysis_result.duration = 0.0 return AnalysisResult( duration=force_float(analysis_result.duration) or 0.0, language=analysis_result.languageCode, der=analysis_result.diarizationErrorRate, speaker_error=analysis_result.speakerError, missed_speech=analysis_result.missedSpeech, false_alarm=analysis_result.falseAlarm, timeline_data=[ TimelineItem( start=force_float(item.get('start')) or 0.0, end=force_float(item.get('end')) or 0.0, speaker=str(item.get('speaker')) if item.get('speaker') else None, text=str(item.get('text', "")) ) for item in analysis_result.timelineData ] ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error during upload process: {type(e).__name__}: {e}") finally: if audio_path and os.path.exists(audio_path): os.remove(audio_path) if rttm_path and os.path.exists(rttm_path): os.remove(rttm_path) end_time = time.time() print(f"API Request processed in {end_time - start_time:.2f} seconds.") @app.get("/") def root(): return {"message": "Audio Analyzer Backend is running."}