hafsaabd82 commited on
Commit
eb8b754
·
verified ·
1 Parent(s): a1ffca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +361 -96
app.py CHANGED
@@ -1,126 +1,391 @@
1
- import os
2
  import tempfile
3
  import whisperx
4
- from pyannote.audio import Pipeline
5
  import pandas as pd
6
  import librosa
 
 
 
 
 
 
7
  from fastapi import FastAPI, UploadFile, File, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
- import torch
10
- if not hasattr(torch.utils._pytree, "register_pytree_node"):
11
- torch.utils._pytree.register_pytree_node = torch.utils._pytree._register_pytree_node
12
- import traceback
 
 
 
 
 
 
13
 
14
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["https://frontend-audio-analyzer.vercel.app/"],
 
19
  allow_methods=["*"],
20
- allow_headers=["*"]
21
  )
22
- device = "cpu"
23
- compute_type = "float16" if device == "cuda" else "float32"
24
- hf_token = os.environ.get("HF_TOKEN")
25
- try:
26
- whisper_model = whisperx.load_model("large-v2", device=device, compute_type=compute_type)
27
- except Exception as e:
28
- print(f"Error loading WhisperX model: {e}")
29
- whisper_model = None
30
- try:
31
- diarize_pipeline = Pipeline.from_pretrained(
32
- "pyannote/speaker-diarization-2.1",
33
- use_auth_token=hf_token
34
- )
35
- except Exception as e:
36
- print(f"Error loading Pyannote pipeline. Check HF_TOKEN: {e}")
37
- diarize_pipeline = None
38
- try:
39
- align_model, metadata = whisperx.load_align_model(
40
- language_code=None,
41
- device=device
42
- )
43
- except Exception as e:
44
- print(f"Error loading WhisperX alignment model: {e}")
45
- align_model, metadata = None, None
46
- @app.post("/process-audio")
47
- async def process_audio(file: UploadFile = File(...)):
48
- if not whisper_model or not diarize_pipeline or not align_model:
49
- raise HTTPException(status_code=503, detail="Model loading failed on server.")
50
- if not file.filename.endswith((".wav", ".mp3", ".m4a", ".flac")):
51
- raise HTTPException(status_code=400, detail="Invalid audio format")
52
- audio_path = None
 
53
  try:
54
- with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
55
- tmp.write(await file.read())
56
- audio_path = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
- audio = whisperx.load_audio(audio_path)
59
- duration = librosa.get_duration(path=audio_path)
60
- except Exception as e:
61
- raise HTTPException(status_code=400, detail=f"Failed to load or process audio file: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
- result = whisper_model.transcribe(audio, batch_size=8)
 
64
  except Exception as e:
65
- raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
66
-
67
- language = result.get("language", "unknown")
 
 
 
 
 
 
 
 
 
 
 
68
  try:
69
- aligned_result = whisperx.align(
70
- result["segments"],
71
- align_model,
72
- metadata,
73
- audio,
74
- device
75
- )
76
- except Exception as e:
77
- raise HTTPException(status_code=500, detail=f"Alignment failed: {e}")
 
 
 
 
 
 
 
 
 
78
  try:
79
- diarization = diarize_pipeline(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  except Exception as e:
81
- raise HTTPException(status_code=500, detail=f"Diarization failed: {e}")
82
- diar_map = []
83
- for turn in diarization.itertracks(yield_label=True):
84
- segment, _, speaker_label = turn
85
- diar_map.append({
86
- "start": segment.start,
87
- "end": segment.end,
88
- "speaker": speaker_label
89
- })
90
-
91
- diar_df = pd.DataFrame(diar_map)
92
- timeline = []
93
- for seg in aligned_result["segments"]:
94
- if "words" not in seg:
95
- continue
96
- for word in seg["words"]:
97
- if word["start"] is None or word["end"] is None:
 
 
 
 
 
 
 
 
 
 
98
  continue
99
- match = diar_df[
100
- (diar_df.start <= word["start"]) &
101
- (diar_df.end >= word["end"])
102
- ]
103
- speaker = match.iloc[0].speaker if not match.empty else "Unknown"
104
- timeline.append({
105
- "start": round(word["start"], 3),
106
- "end": round(word["end"], 3),
107
- "text": word["word"],
108
- "speaker": speaker
109
  })
110
- timeline = sorted(timeline, key=lambda x: x["start"])
111
- return {
112
- "duration": duration,
113
- "language": language,
114
- "timeline_data": timeline
115
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  except HTTPException:
117
  raise
118
  except Exception as e:
119
- traceback.print_exc()
120
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during processing: {e}")
121
  finally:
122
  if audio_path and os.path.exists(audio_path):
123
  os.remove(audio_path)
 
 
 
124
  @app.get("/")
125
  def root():
126
  return {"message": "Audio Analyzer Backend is running."}
 
 
1
  import tempfile
2
  import whisperx
3
+ from whisperx import diarize
4
  import pandas as pd
5
  import librosa
6
+ import soundfile as sf
7
+ import numpy as np
8
+ from scipy.signal import butter, filtfilt
9
+ from typing import Optional, Dict, List, Any, Union
10
+ import torch
11
+ from dataclasses import dataclass, field
12
  from fastapi import FastAPI, UploadFile, File, HTTPException
13
  from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+ import time
16
+ import shutil
17
+ try:
18
+ import noisereduce as nr
19
+ HAVE_NOISEREDUCE = True
20
+ except ImportError:
21
+ HAVE_NOISEREDUCE = False
22
+ Annotation: Any = None
23
+ Segment: Any = None
24
 
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ token = os.environ.get("HF_TOKEN")
27
+ if not token:
28
+ print("Warning: HF_TOKEN not set. Diarization will be skipped.")
29
+ perform_diarization = True if token else False
30
+ model_name = "medium"
31
+ class TimelineItem(BaseModel):
32
+ start: float
33
+ end: float
34
+ speaker: Union[str, None] = None
35
+ text: str
36
+ class AnalysisResult(BaseModel):
37
+ duration: float
38
+ language: str
39
+ der: Union[float, None] = None
40
+ speaker_error: Union[float, None] = None
41
+ missed_speech: Union[float, None] = None
42
+ false_alarm: Union[float, None] = None
43
+ timeline_data: List[TimelineItem]
44
 
45
+ app = FastAPI(title="Audio Analyzer Backend")
46
  app.add_middleware(
47
  CORSMiddleware,
48
+ allow_origins=["https://frontend-audio-analyzer.vercel.app"],
49
+ allow_credentials=True,
50
  allow_methods=["*"],
51
+ allow_headers=["*"],
52
  )
53
+ @dataclass
54
+ class AnalysisResults:
55
+ timelineData: List[Dict[str, Any]] = field(default_factory=list)
56
+ duration: float = 0.0
57
+ languageCode: str = "unknown"
58
+ diarizationErrorRate: Optional[float] = None
59
+ speakerError: Optional[float] = None
60
+ missedSpeech: Optional[float] = None
61
+ falseAlarm: Optional[float] = None
62
+ warnings: List[str] = field(default_factory=list)
63
+ success: bool = False
64
+ message: str = "Analysis initiated."
65
+ def warn(results: AnalysisResults, code: str, detail: str) -> None:
66
+ msg = f"{code}: {detail}"
67
+ if msg not in results.warnings:
68
+ results.warnings.append(msg)
69
+ def set_message(results: AnalysisResults, msg: str) -> None:
70
+ initial_message = "Analysis initiated."
71
+ if results.message and results.message != initial_message:
72
+ results.message += f" | {msg}"
73
+ else:
74
+ results.message = msg
75
+ def normalize_speaker(lbl: str) -> str:
76
+ lbl_str = str(lbl)
77
+ return lbl_str.replace("SPEAKER_", "Speaker_").replace("speaker_", "Speaker_")
78
+ def temp_wav_path() -> str:
79
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
80
+ return f.name
81
+ def force_float(value: Optional[Any]) -> Optional[float]:
82
+ """Ensures value is a native Python float or None. Returns None for NaN/Inf."""
83
+ if value is None:
84
+ return None
85
  try:
86
+ f_val = float(value)
87
+ if np.isnan(f_val) or np.isinf(f_val):
88
+ return None
89
+ return f_val
90
+ except (TypeError, ValueError, AttributeError):
91
+ return None
92
+
93
+ def butter_filter(y, sr, lowpass=None, highpass=None, order=4):
94
+ nyq = 0.5 * sr
95
+ if highpass and highpass > 0 and highpass < nyq:
96
+ b, a = butter(order, highpass / nyq, btype="highpass", analog=False)
97
+ y = filtfilt(b, a, y)
98
+ if lowpass and lowpass > 0 and lowpass < nyq:
99
+ b, a = butter(order, lowpass / nyq, btype="lowpass", analog=False)
100
+ y = filtfilt(b, a, y)
101
+ return y
102
+
103
+ def rms_normalize(y, target_rms=0.8, eps=1e-6):
104
+ rms = (y**2).mean() ** 0.5
105
+ if rms < eps:
106
+ return y
107
+ gain = target_rms / (rms + eps)
108
+ return y * gain
109
+
110
+ def preprocess_audio(input_path,
111
+ target_sr=16000,
112
+ normalize_rms=True,
113
+ target_rms=0.08,
114
+ denoise=False,
115
+ highpass=None,
116
+ lowpass=None,
117
+ output_subtype="PCM_16",
118
+ verbose=False) -> str:
119
+ if not os.path.exists(input_path):
120
+ raise FileNotFoundError(f"Input audio not found: {input_path}")
121
+ output_path = temp_wav_path()
122
+ y_stereo, sr = sf.read(input_path, dtype='float64')
123
+ if y_stereo.ndim > 1:
124
+ y = librosa.to_mono(y_stereo.T)
125
+ else:
126
+ y = y_stereo
127
+ if sr != target_sr:
128
+ y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
129
+ sr = target_sr
130
+ if highpass or lowpass:
131
+ y = butter_filter(y, sr, highpass=highpass, lowpass=lowpass)
132
+ if denoise and HAVE_NOISEREDUCE:
133
  try:
134
+ noise_len = int(min(len(y), int(0.5 * sr)))
135
+ noise_clip = y[:noise_len]
136
+ y = nr.reduce_noise(y=y, sr=sr, y_noise=noise_clip, prop_decrease=0.9, verbose=False)
137
+ except Exception:
138
+ pass
139
+ if normalize_rms:
140
+ y = rms_normalize(y, target_rms=target_rms)
141
+ sf.write(output_path, y, sr, subtype=output_subtype)
142
+ return output_path
143
+ try:
144
+ if hasattr(diarize_output, '__iter__'):
145
+ for seg in diarize_output:
146
+ if all(k in seg for k in ("start", "end", "speaker")):
147
+ s = float(seg.get("start", 0.0))
148
+ e = float(seg.get("end", s))
149
+ lbl = normalize_speaker(seg.get("speaker", "Speaker_1"))
150
+ ann[Segment(s, e)] = lbl
151
+ if isinstance(seg, dict):
152
+ if 'segment' in seg and 'label' in seg:
153
+ s = float(seg['segment'].start)
154
+ e = float(seg['segment'].end)
155
+ lbl = normalize_speaker(seg['label'])
156
+ ann[Segment(s, e)] = lbl
157
+ elif all(k in seg for k in ("start", "end", "speaker")):
158
+ s = float(seg.get("start", 0.0))
159
+ e = float(seg.get("end", s))
160
+ lbl = normalize_speaker(seg.get("speaker", "Speaker_1"))
161
+ ann[Segment(s, e)] = lbl
162
+ elif hasattr(seg, 'start') and hasattr(seg, 'end'):
163
+ s = float(seg.start)
164
+ e = float(seg.end)
165
+ lbl = normalize_speaker(getattr(seg, 'speaker', getattr(seg, 'label', 'Speaker_1')))
166
+ ann[Segment(s, e)] = lbl
167
+ return ann
168
+ except Exception as e:
169
+ print(f"Error in diarization_to_annotation: {e}")
170
+ return None
171
+ def analyze_audio(audio_file: str,
172
+ reference_rttm_file: Optional[str] = None,
173
+ preprocess: bool = True,
174
+ preprocess_params: Optional[Dict[str, Any]] = None) -> AnalysisResults:
175
+ results = AnalysisResults()
176
+ if not os.path.exists(audio_file):
177
+ results.message = f"Error: Input audio file '{audio_file}' not found."
178
+ return results
179
+
180
+ audio_for_model = audio_file
181
+ temp_preproc = None
182
+ if preprocess:
183
+ params = {
184
+ "target_sr": 16000, "normalize_rms": True, "target_rms": 0.08,
185
+ "denoise": False, "highpass": None, "lowpass": None,
186
+ "output_subtype": "PCM_16", "verbose": False
187
+ }
188
+ if isinstance(preprocess_params, dict):
189
+ params.update(preprocess_params)
190
+ if params.get("denoise") and not HAVE_NOISEREDUCE:
191
+ warn(results, "DENOISE_SKIP", "Denoise requested but noisereduce not installed; skipping denoise.")
192
+ params["denoise"] = False
193
  try:
194
+ temp_preproc = preprocess_audio(audio_file, **params)
195
+ audio_for_model = temp_preproc
196
  except Exception as e:
197
+ warn(results, "PREP_FAIL", f"Preprocessing failed: {e}. Falling back to original audio.")
198
+ audio_for_model = audio_file
199
+ temp_preproc = None
200
+
201
+ start_ml_time = time.time()
202
+ try:
203
+ print(f"Loading Whisper model '{model_name}' on {device}...")
204
+ model = whisperx.load_model(model_name, device, compute_type="float32")
205
+ audio_loaded = whisperx.load_audio(audio_for_model)
206
+ print("Transcribing audio...")
207
+ result = model.transcribe(audio_loaded, batch_size=4)
208
+ language_code = result.get("language") or result.get("detected_language") or "en"
209
+ results.languageCode = language_code
210
+ print(f"Detected language: {language_code}. Aligning transcription...")
211
  try:
212
+ align_model, metadata = whisperx.load_align_model(language_code=language_code, device=device)
213
+ aligned = whisperx.align(result["segments"], align_model, metadata, audio_loaded, device)
214
+ except Exception:
215
+ aligned = {"segments": result["segments"]}
216
+ warn(results, "ALIGN_SKIP", "Alignment unavailable; using raw Whisper segments.")
217
+ diarize_output = None
218
+ if perform_diarization:
219
+ print("Performing speaker diarization (Requires HF_TOKEN)...")
220
+ try:
221
+ diarize_output = diarize(audio_for_model)
222
+ for segment, _, label in diarize_output.itertracks(yield_label=True):
223
+ print(f"start={segment.start:.1f}s stop={segment.end:.1f}s {label}")
224
+ except Exception as e:
225
+ warn(results, "DIAR_SKIP", f"Error during diarization (likely token/model failure): {type(e).__name__}: {e}. Skipping diarization.")
226
+ diarize_output = None
227
+ else:
228
+ warn(results, "DIAR_SKIP", "HF_TOKEN not set. Skipping speaker diarization.")
229
+ print("Assigning speakers to words...")
230
  try:
231
+ diarize_segments_for_assignment = []
232
+ if diarize_output is not None:
233
+ if hasattr(diarize_output, "itertracks"):
234
+ for segment, _, label in diarize_output.itertracks(yield_label=True):
235
+ diarize_segments_for_assignment.append({
236
+ "start": float(segment.start),
237
+ "end": float(segment.end),
238
+ "speaker": normalize_speaker(label)
239
+ })
240
+ else:
241
+ diarize_segments_for_assignment = diarize_output
242
+ else:
243
+ diarize_segments_for_assignment = []
244
+ for seg in aligned.get("segments", []):
245
+ diarize_segments_for_assignment.append({
246
+ "start": seg.get("start", 0),
247
+ "end": seg.get("end", seg.get("start", 0)),
248
+ "speaker": "Speaker_1"
249
+ })
250
+
251
+ if diarize_segments_for_assignment:
252
+ final = whisperx.assign_word_speakers(diarize_segments_for_assignment, aligned)
253
+ else:
254
+ final = aligned
255
+ for seg in final.get("segments", []):
256
+ seg["speaker"] = "Speaker_1"
257
  except Exception as e:
258
+ warn(results, "ASSIGN_SPEAKERS_ERROR", f"Error assigning speakers: {e}. Falling back to unassigned segments.")
259
+ final = aligned
260
+ for seg in final.get("segments", []):
261
+ seg["speaker"] = "Speaker_1"
262
+ def _get_time_field(d: Dict[str, Any], keys: List[str]) -> Optional[float]:
263
+ """Try multiple possible keys and coerce to native float, returning None if not possible."""
264
+ for k in keys:
265
+ if k in d:
266
+ try:
267
+ v = d[k]
268
+ if v is None:
269
+ continue
270
+ f = float(v)
271
+ if np.isnan(f) or np.isinf(f):
272
+ return None
273
+ return f
274
+ except (TypeError, ValueError):
275
+ continue
276
+ return None
277
+ rows: List[Dict[str, Any]] = []
278
+ for seg in final.get("segments", []):
279
+ seg_speaker = normalize_speaker(seg.get("speaker") or seg.get("speaker_label") or "Speaker_1")
280
+ word_list = seg.get("words") or seg.get("tokens") or seg.get("items") or []
281
+ if not word_list:
282
+ word_start = _get_time_field(seg, ["start", "s", "timestamp", "t0"])
283
+ word_end = _get_time_field(seg, ["end", "e", "t1"])
284
+ if word_start is None:
285
  continue
286
+ if word_end is None:
287
+ word_end = word_start
288
+ rows.append({
289
+ "start": float(word_start),
290
+ "end": float(word_end),
291
+ "text": str(seg.get("text", "")).strip(),
292
+ "speaker": str(seg_speaker),
 
 
 
293
  })
294
+ continue
295
+ for w in word_list:
296
+ if not isinstance(w, dict):
297
+ continue
298
+ word_start = _get_time_field(w, ["start", "s", "timestamp", "t0"])
299
+ word_end = _get_time_field(w, ["end", "e", "t1"])
300
+ if word_start is None:
301
+ word_start = _get_time_field(seg, ["start", "s"])
302
+ if word_end is None:
303
+ word_end = _get_time_field(seg, ["end", "e"])
304
+
305
+ if word_start is None:
306
+ continue
307
+ if word_end is None:
308
+ word_end = word_start
309
+ word_speaker = normalize_speaker(w.get("speaker") or seg_speaker)
310
+ word_text = (w.get("text") or w.get("word") or w.get("label") or "").strip()
311
+
312
+ rows.append({
313
+ "start": float(word_start),
314
+ "end": float(word_end),
315
+ "text": str(word_text),
316
+ "speaker": str(word_speaker),
317
+ })
318
+ rows = sorted(rows, key=lambda r: r.get("start", 0.0))
319
+ results.timelineData = rows
320
+ ends = []
321
+ for w in rows:
322
+ e = w.get("end")
323
+ f_e = force_float(e)
324
+ if f_e is not None:
325
+ ends.append(f_e)
326
+ except Exception as e:
327
+ results.message = f"Error during ML processing: {type(e).__name__}: {e}"
328
+ return results
329
+ finally:
330
+ if temp_preproc and os.path.exists(temp_preproc):
331
+ os.remove(temp_preproc)
332
+ results.duration = force_float(max(ends) if ends else 0.0) or 0.0
333
+ end_ml_time = time.time()
334
+ print(f"ML Processing finished in {end_ml_time - start_ml_time:.2f} seconds.")
335
+
336
+ return results
337
+ @app.post("/upload", response_model=AnalysisResult)
338
+ async def upload_file(audio_file: UploadFile = File(...)):
339
+ start_time = time.time()
340
+ audio_path: Optional[str] = None
341
+ try:
342
+ print("Incoming upload:", getattr(audio_file, "filename", None))
343
+
344
+ suffix = audio_file.filename.split(".")[-1] if audio_file.filename else "tmp"
345
+ with tempfile.NamedTemporaryFile(suffix=f".{suffix}", delete=False) as tmp_audio:
346
+ shutil.copyfileobj(audio_file.file, tmp_audio)
347
+ audio_path = tmp_audio.name
348
+ print(f"Received audio file: {audio_file.filename} (saved to {audio_path}), size: {os.path.getsize(audio_path)} bytes")
349
+
350
+ preprocessing_config = {"denoise": False}
351
+ print(f"Starting ML processing with audio: {audio_path}, preprocess_params: {preprocessing_config}")
352
+
353
+ analysis_result = analyze_audio(
354
+ audio_file=audio_path,
355
+ preprocess_params=preprocessing_config
356
+ )
357
+ print("FAILURE MESSAGE:", analysis_result.message)
358
+ if not analysis_result.success:
359
+ raise HTTPException(status_code=500, detail=analysis_result.message)
360
+
361
+ print("DURATION BEFORE RETURN:", analysis_result.duration)
362
+ if analysis_result.duration is None:
363
+ analysis_result.duration = 0.0
364
+ return AnalysisResult(
365
+ duration=force_float(analysis_result.duration) or 0.0,
366
+ language=analysis_result.languageCode,
367
+ timeline_data=[
368
+ TimelineItem(
369
+ start=force_float(item.get('start')) or 0.0,
370
+ end=force_float(item.get('end')) or 0.0,
371
+ speaker=str(item.get('speaker')) if item.get('speaker') else None,
372
+ text=str(item.get('text', ""))
373
+
374
+ ) for item in analysis_result.timelineData
375
+ ]
376
+ )
377
+
378
  except HTTPException:
379
  raise
380
  except Exception as e:
381
+ raise HTTPException(status_code=500, detail=f"Unexpected error during upload process: {type(e).__name__}: {e}")
382
+
383
  finally:
384
  if audio_path and os.path.exists(audio_path):
385
  os.remove(audio_path)
386
+ end_time = time.time()
387
+ print(f"API Request processed in {end_time - start_time:.2f} seconds.")
388
+
389
  @app.get("/")
390
  def root():
391
  return {"message": "Audio Analyzer Backend is running."}