Audio-Analyzer / app.py
hafsaabd82's picture
Update app.py
949a22a verified
raw
history blame
24.8 kB
import os
import tempfile
import whisperx
from pyannote.audio import Pipeline
import pandas as pd
import librosa
import soundfile as sf
import numpy as np
from scipy.signal import butter, filtfilt
from typing import Optional, Dict, List, Any
import torch
from dataclasses import dataclass, field
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import time
import shutil
try:
import noisereduce as nr
HAVE_NOISEREDUCE = True
except ImportError:
HAVE_NOISEREDUCE = False
Annotation: Any = None
Segment: Any = None
DiarizationErrorRate: Any = None
pyannote_available = False
try:
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate
pyannote_available = True
except Exception:
pyannote_available = False
device = "cuda" if torch.cuda.is_available() else "cpu"
# HF_TOKEN is read from environment automatically, but explicitly check
token = os.environ.get("HF_TOKEN")
if not token:
print("Warning: HF_TOKEN not set. Diarization will be skipped.")
perform_diarization = True if token and pyannote_available else False
model_name = "large-v2"
class TimelineItem(BaseModel):
start: float
end: float
speaker: str | None = None
text: str
class AnalysisResult(BaseModel):
duration: float
language: str
der: float | None = None
speaker_error: float | None = None
missed_speech: float | None = None
false_alarm: float | None = None
timeline_data: list[TimelineItem]
app = FastAPI(title="Audio Analyzer Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=["https://frontend-audio-analyzer.vercel.app"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@dataclass
class AnalysisResults:
timelineData: List[Dict[str, Any]] = field(default_factory=list)
duration: float = 0.0
languageCode: str = "unknown"
diarizationErrorRate: Optional[float] = None
speakerError: Optional[float] = None
missedSpeech: Optional[float] = None
falseAlarm: Optional[float] = None
warnings: List[str] = field(default_factory=list)
success: bool = False
message: str = "Analysis initiated."
def warn(results: AnalysisResults, code: str, detail: str) -> None:
msg = f"{code}: {detail}"
if msg not in results.warnings:
results.warnings.append(msg)
def set_message(results: AnalysisResults, msg: str) -> None:
initial_message = "Analysis initiated."
if results.message and results.message != initial_message:
results.message += f" | {msg}"
else:
results.message = msg
def normalize_speaker(lbl: str) -> str:
lbl_str = str(lbl)
return lbl_str.replace("SPEAKER_", "Speaker_").replace("speaker_", "Speaker_")
def temp_wav_path() -> str:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
return f.name
# GLOBAL HELPER: Robust helper function to handle NumPy types, NaN, and Infinity
def force_float(value: Optional[Any]) -> Optional[float]:
"""Ensures value is a native Python float or None. Returns None for NaN/Inf."""
if value is None:
return None
try:
f_val = float(value)
if np.isnan(f_val) or np.isinf(f_val):
return None
return f_val
except (TypeError, ValueError, AttributeError):
return None
# GLOBAL HELPER: Robust helper function to safely extract time fields
def _get_time_field(d: Dict[str, Any], keys: List[str]) -> Optional[float]:
"""Try multiple possible keys and coerce to native float, returning None if not possible."""
for k in keys:
if k in d:
f_val = force_float(d[k])
if f_val is not None:
return f_val
return None
def butter_filter(y, sr, lowpass=None, highpass=None, order=4):
nyq = 0.5 * sr
if highpass and highpass > 0 and highpass < nyq:
b, a = butter(order, highpass / nyq, btype="highpass", analog=False)
y = filtfilt(b, a, y)
if lowpass and lowpass > 0 and lowpass < nyq:
b, a = butter(order, lowpass / nyq, btype="lowpass", analog=False)
y = filtfilt(b, a, y)
return y
def rms_normalize(y, target_rms=0.8, eps=1e-6):
rms = (y**2).mean() ** 0.5
if rms < eps:
return y
gain = target_rms / (rms + eps)
return y * gain
def preprocess_audio(input_path,
target_sr=16000,
normalize_rms=True,
target_rms=0.08,
denoise=False,
highpass=None,
lowpass=None,
output_subtype="PCM_16",
verbose=False):
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input audio not found: {input_path}")
output_path = temp_wav_path()
y_stereo, sr = sf.read(input_path, dtype='float64')
if y_stereo.ndim > 1:
y = librosa.to_mono(y_stereo.T)
else:
y = y_stereo
if sr != target_sr:
y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
sr = target_sr
if highpass or lowpass:
y = butter_filter(y, sr, highpass=highpass, lowpass=lowpass)
if denoise and HAVE_NOISEREDUCE:
try:
noise_len = int(min(len(y), int(0.5 * sr)))
noise_clip = y[:noise_len]
y = nr.reduce_noise(y=y, sr=sr, y_noise=noise_clip, prop_decrease=0.9, verbose=False)
except Exception:
pass
if normalize_rms:
y = rms_normalize(y, target_rms=target_rms)
sf.write(output_path, y, sr, subtype=output_subtype)
return output_path
def load_rttm(path: str) -> Optional["Annotation"]:
if not pyannote_available or not os.path.exists(path):
return None
ann = Annotation()
try:
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.startswith(";;"):
continue
parts = line.strip().split()
if len(parts) >= 8 and parts[0] == "SPEAKER":
start = force_float(parts[3])
dur = force_float(parts[4])
if start is None or dur is None:
continue # Skip invalid time entries
spk = normalize_speaker(parts[7])
ann[Segment(start, start + dur)] = spk
return ann
except Exception as e:
print(f"Error loading RTTM: {e}")
return None
# Assuming normalize_speaker, force_float, Segment, Annotation, pyannote_available are defined globally
def create_fallback_hypothesis(diarize_output: Any) -> Optional["Annotation"]:
"""Create hypothesis annotation from diarization output with robust error handling."""
if not pyannote_available:
return None
ann = Annotation()
try:
# Handle pyannote's native format (itertracks)
if hasattr(diarize_output, "itertracks"):
for segment, _, label in diarize_output.itertracks(yield_label=True):
try:
start = force_float(segment.start)
end = force_float(segment.end)
if start is None or end is None: continue
# Ensure valid time range
if start >= 0 and end > start and end - start > 0.01:
ann[Segment(start, end)] = normalize_speaker(label)
except Exception: # Catch any internal conversion errors
continue
return ann if len(ann) > 0 else None
except Exception:
pass
# Try to handle as list/dict format (e.g., from an intermediate whisperx output)
try:
if isinstance(diarize_output, list) or hasattr(diarize_output, '__iter__'):
for seg in diarize_output:
try:
start, end, speaker = None, None, None
if isinstance(seg, dict):
start = force_float(seg.get('start'))
end = force_float(seg.get('end'))
speaker = seg.get('speaker') or seg.get('label')
elif hasattr(seg, 'start') and hasattr(seg, 'end'):
start = force_float(seg.start)
end = force_float(seg.end)
speaker = getattr(seg, 'speaker', getattr(seg, 'label', None))
else:
continue
if start is not None and end is not None and speaker is not None:
if start >= 0 and end > start and end - start > 0.01:
ann[Segment(start, end)] = normalize_speaker(speaker)
except Exception:
continue
return ann if len(ann) > 0 else None
except Exception as e:
print(f"Error creating fallback hypothesis: {e}")
return None
def analyze_audio(audio_file: str,
reference_rttm_file: Optional[str] = None,
preprocess: bool = True,
preprocess_params: Optional[Dict[str, Any]] = None,
language_code_input: Optional[str] = None) -> AnalysisResults:
results = AnalysisResults()
if not os.path.exists(audio_file):
results.message = f"Error: Input audio file '{audio_file}' not found."
return results
audio_for_model = audio_file
temp_preproc = None
if preprocess:
params = {
"target_sr": 16000, "normalize_rms": True, "target_rms": 0.08,
"denoise": False, "highpass": None, "lowpass": None,
"output_subtype": "PCM_16", "verbose": False
}
if isinstance(preprocess_params, dict):
params.update(preprocess_params)
if params.get("denoise") and not HAVE_NOISEREDUCE:
warn(results, "DENOISE_SKIP", "Denoise requested but noisereduce not installed; skipping denoise.")
params["denoise"] = False
try:
temp_preproc = preprocess_audio(audio_file, **params)
audio_for_model = temp_preproc
except Exception as e:
warn(results, "PREP_FAIL", f"Preprocessing failed: {e}. Falling back to original audio.")
audio_for_model = audio_file
temp_preproc = None
multilingual = False # Set to True if your audio might be non-English
max_new_tokens = 448 # Default used by whisperx
clip_timestamps = None
hallucination_silence_threshold = 0.1
hotwords = None
start_ml_time = time.time()
try:
# --- Inside analyze_audio function, after loading the model and audio ---
print(f"Loading Whisper model '{model_name}' on {device}...")
model = whisperx.load_model(model_name, device, compute_type="float32")
audio_loaded = whisperx.load_audio(audio_for_model)
print("Running robust automatic language detection...")
language_code, language_prob = model.detect_language(audio_loaded)
language_to_use = language_code if language_code else "en"
results.languageCode = language_to_use
print(f"Detected language: {language_to_use} (Prob: {language_prob:.2f}). Starting transcription...")
result = model.transcribe(
audio_loaded,
batch_size=4,
language=language_to_use, # <-- CRITICAL: Uses the detected code
multilingual=True, # Must be True since the model is multilingual
max_new_tokens=448, # Required positional argument
clip_timestamps=None,
hallucination_silence_threshold=0.1,
hotwords=None
)
print(f"Transcription successful in {language_to_use}. Proceeding to alignment...")
try:
align_model, metadata = whisperx.load_align_model(language_code=language_to_use, device=device)
aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device)
print("Alignment successful.")
except Exception as e:
warn(results, "ALIGN_SKIP", f"Alignment unavailable ({type(e).__name__}: {e}); using raw Whisper segments.")
if language_to_use != 'en':
try:
print("Trying English alignment model as fallback...")
align_model, metadata = whisperx.load_align_model(language_code="en", device=device)
aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device)
except Exception:
warn(results, "ALIGN_FAIL", "English fallback alignment also failed. Proceeding with raw unaligned segments.")
pass
else:
warn(results, "ALIGN_FAIL", "Alignment failed. Proceeding with raw unaligned segments.")
# --- Speaker Diarization ---
diarize_output = None
if perform_diarization:
print("Performing speaker diarization (Requires HF_TOKEN)...")
try:
diarize_model = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=token)
diarize_output = diarize_model(audio_for_model)
for segment, _, label in diarize_output.itertracks(yield_label=True):
print(f"start={segment.start:.1f}s stop={segment.end:.1f}s {label}")
except Exception as e:
warn(results, "DIAR_SKIP", f"Error during diarization (likely token/model failure): {type(e).__name__}: {e}. Skipping diarization.")
diarize_output = None
else:
warn(results, "DIAR_SKIP", "HF_TOKEN not set or pyannote not available. Skipping speaker diarization.")
# --- Speaker Assignment ---
final = aligned # Default final to aligned output
diarize_segments_for_assignment = []
if diarize_output is not None:
# 1. Convert pyannote output (Annotation) to list of dicts for whisperx
if hasattr(diarize_output, "itertracks"):
for segment, _, label in diarize_output.itertracks(yield_label=True):
diarize_segments_for_assignment.append({
"start": force_float(segment.start) or 0.0,
"end": force_float(segment.end) or 0.0,
"speaker": normalize_speaker(label)
})
# 2. Perform speaker assignment
if diarize_segments_for_assignment:
print("Assigning speakers to words...")
try:
final = whisperx.assign_word_speakers(diarize_segments_for_assignment, aligned)
except Exception as e:
warn(results, "ASSIGN_SPEAKERS_ERROR", f"Error assigning speakers: {e}. Falling back to unassigned segments.")
else:
warn(results, "ASSIGN_SKIP", "Diarization conversion failed; using raw aligned segments.")
# Ensure segments have a default speaker label if assignment failed/skipped
for seg in final.get("segments", []):
if "speaker" not in seg and "speaker_label" not in seg:
seg["speaker"] = "Speaker_1"
# --- Robust Word Processing and Timeline Creation ---
rows: List[Dict[str, Any]] = []
for seg in final.get("segments", []):
seg_speaker = normalize_speaker(seg.get("speaker") or seg.get("speaker_label") or "Speaker_1")
word_list = seg.get("words") or seg.get("tokens") or seg.get("items") or []
# If no word list, use segment as a whole
if not word_list:
word_start = _get_time_field(seg, ["start", "s", "timestamp", "t0"])
word_end = _get_time_field(seg, ["end", "e", "t1"])
if word_start is None: continue
if word_end is None: word_end = word_start
rows.append({
"start": float(word_start),
"end": float(word_end),
"text": str(seg.get("text", "")).strip(),
"speaker": str(seg_speaker),
})
continue
# Process each word (uses robust global _get_time_field)
for w in word_list:
if not isinstance(w, dict): continue
word_start = _get_time_field(w, ["start", "s", "timestamp", "t0"])
word_end = _get_time_field(w, ["end", "e", "t1"])
if word_start is None: word_start = _get_time_field(seg, ["start", "s"])
if word_end is None: word_end = _get_time_field(seg, ["end", "e"])
if word_start is None: continue
if word_end is None: word_end = word_start
word_speaker = normalize_speaker(w.get("speaker") or seg_speaker)
word_text = (w.get("text") or w.get("word") or w.get("label") or "").strip()
rows.append({
"start": force_float(word_start) or 0.0,
"end": force_float(word_end) or 0.0,
"text": str(word_text),
"speaker": str(word_speaker),
})
# Sort and assign to results
rows = sorted(rows, key=lambda r: r.get("start", 0.0))
results.timelineData = rows
# Calculate duration
ends = [force_float(w.get("end")) for w in rows]
results.duration = force_float(max(ends) if ends else 0.0) or 0.0
# --- Compute DER if reference RTTM provided (CRITICAL VALIDATION ADDED) ---
if pyannote_available and reference_rttm_file and diarize_output is not None:
print(f"Computing DER using reference RTTM: {reference_rttm_file}...")
reference = load_rttm(reference_rttm_file)
hypothesis = create_fallback_hypothesis(diarize_output)
if hypothesis is None:
print("DEBUG DER: Primary hypothesis failed. Falling back to using final timeline data...")
hypothesis = Annotation()
for item in results.timelineData:
try:
start = force_float(item.get('start'))
end = force_float(item.get('end'))
speaker = item.get('speaker')
# Use a smaller minimum duration check for timeline data
if start is not None and end is not None and speaker is not None and start >= 0 and end > start:
hypothesis[Segment(start, end)] = normalize_speaker(speaker)
except Exception:
continue
# --- CRITICAL DEBUGGING LOGS ---
ref_len = len(reference) if reference else 0
hyp_len = len(hypothesis) if hypothesis else 0
print(f"DEBUG DER: Reference segments loaded: {ref_len}. Hypothesis segments loaded: {hyp_len}.")
if ref_len == 0:
warn(results, "DER_FAIL", f"Reference RTTM loaded but contains 0 valid segments.")
elif hyp_len == 0:
warn(results, "DER_FAIL", "Diarization hypothesis conversion failed and timeline fallback failed (0 segments).")
else:
try:
metric = DiarizationErrorRate(collar=0.25, skip_overlap=False)
der_report = metric.compute_components(reference, hypothesis)
results.diarizationErrorRate = force_float(der_report.get('diarization error rate'))
results.speakerError = force_float(der_report.get('speaker error'))
results.missedSpeech = force_float(der_report.get('missed speech'))
results.falseAlarm = force_float(der_report.get('false alarm'))
print(f"DER computed: {results.diarizationErrorRate:.4f}")
except Exception as e:
warn(results, "DER_ERROR", f"Error computing DER (pyannote metrics failure): {type(e).__name__}: {e}. Check compatibility.")
elif reference_rttm_file:
warn(results, "DER_MISSING", "Reference RTTM provided but DER calculation skipped (check pyannote availability or diarization output).")
results.success = True
set_message(results, "Analysis complete.")
except Exception as e:
results.message = f"Core analysis failed: {type(e).__name__}: {e}"
results.success = False
warn(results, "FATAL_ERROR", f"Fatal analysis error: {type(e).__name__}: {e}")
finally:
# ... (cleanup and final return) ...
pass
return results
@app.post("/upload", response_model=AnalysisResult)
async def upload_file(audio_file: UploadFile = File(...), rttm_file: UploadFile = File(None)):
start_time = time.time()
audio_path: Optional[str] = None
rttm_path: Optional[str] = None
try:
print("Incoming upload:", getattr(audio_file, "filename", None), getattr(rttm_file, "filename", None))
suffix = audio_file.filename.split(".")[-1] if audio_file.filename else "tmp"
with tempfile.NamedTemporaryFile(suffix=f".{suffix}", delete=False) as tmp_audio:
# Ensure to rewind the file-like object before copying (Best Practice)
audio_file.file.seek(0)
shutil.copyfileobj(audio_file.file, tmp_audio)
audio_path = tmp_audio.name
print(f"Received audio file: {audio_file.filename} (saved to {audio_path}), size: {os.path.getsize(audio_path)} bytes")
if rttm_file and rttm_file.filename:
with tempfile.NamedTemporaryFile(suffix=".rttm", delete=False) as tmp_rttm:
# Ensure to rewind the file-like object before copying (Best Practice)
rttm_file.file.seek(0)
shutil.copyfileobj(rttm_file.file, tmp_rttm)
rttm_path = tmp_rttm.name
print(f"Received rttm file: {rttm_file.filename} (saved to {rttm_path}), size: {os.path.getsize(rttm_path)} bytes")
else:
print("No RTTM file received in this request.")
preprocessing_config = {"denoise": False}
print(f"Starting ML processing with audio: {audio_path}, rttm: {rttm_path}")
analysis_result = analyze_audio(
audio_file=audio_path,
reference_rttm_file=rttm_path,
preprocess_params=preprocessing_config,
)
print("FAILURE MESSAGE:", analysis_result.message)
if not analysis_result.success:
raise HTTPException(status_code=500, detail=analysis_result.message)
print("DURATION BEFORE RETURN:", analysis_result.duration)
if analysis_result.duration is None:
analysis_result.duration = 0.0
return AnalysisResult(
duration=force_float(analysis_result.duration) or 0.0,
language=analysis_result.languageCode,
der=analysis_result.diarizationErrorRate,
speaker_error=analysis_result.speakerError,
missed_speech=analysis_result.missedSpeech,
false_alarm=analysis_result.falseAlarm,
timeline_data=[
TimelineItem(
start=force_float(item.get('start')) or 0.0,
end=force_float(item.get('end')) or 0.0,
speaker=str(item.get('speaker')) if item.get('speaker') else None,
text=str(item.get('text', ""))
) for item in analysis_result.timelineData
]
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error during upload process: {type(e).__name__}: {e}")
finally:
if audio_path and os.path.exists(audio_path):
os.remove(audio_path)
if rttm_path and os.path.exists(rttm_path):
os.remove(rttm_path)
end_time = time.time()
print(f"API Request processed in {end_time - start_time:.2f} seconds.")
@app.get("/")
def root():
return {"message": "Audio Analyzer Backend is running."}