hafsaabd82 commited on
Commit
8d49f81
·
verified ·
1 Parent(s): 1e73f01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -13
app.py CHANGED
@@ -158,20 +158,26 @@ def preprocess_audio(input_path,
158
  y = rms_normalize(y, target_rms=target_rms)
159
  sf.write(output_path, y, sr, subtype=output_subtype)
160
  return output_path
161
- def analyze_audio(audio_file: str,
162
- reference_rttm_file: Optional[str] = None,
163
- preprocess: bool = True,
164
  preprocess_params: Optional[Dict[str, Any]] = None) -> AnalysisResults:
 
165
  results = AnalysisResults()
166
- global global_align_model_cache, ALIGN_MODEL_MAP
167
- ends: List[float] = []
 
 
168
  rows: List[Dict[str, Any]] = []
 
169
  if not os.path.exists(audio_file):
170
  results.message = f"Error: Input audio file '{audio_file}' not found."
171
  return results
172
 
173
  audio_for_model = audio_file
174
  temp_preproc = None
 
 
175
  if preprocess:
176
  params = {
177
  "target_sr": 16000, "normalize_rms": True, "target_rms": 0.08,
@@ -190,22 +196,28 @@ def analyze_audio(audio_file: str,
190
  warn(results, "PREP_FAIL", f"Preprocessing failed: {e}. Falling back to original audio.")
191
  audio_for_model = audio_file
192
  temp_preproc = None
 
193
  start_ml_time = time.time()
194
  try:
 
195
  print(f"Loading Whisper model '{model_name}' on {device}...")
196
  model = whisperx.load_model(model_name, device, compute_type="float32")
197
  audio_loaded = whisperx.load_audio(audio_for_model)
198
  print("Transcribing audio...")
199
  result = model.transcribe(audio_loaded, batch_size=4 )
 
200
  language_code = result.get("language") or result.get("detected_language") or "en"
201
  results.languageCode = language_code
202
- global global_align_model_cache
 
 
203
  print(f"Detected language: {language_code}. Aligning transcription...")
204
- aligned = {"segments": result["segments"]}
205
  align_model = None
206
  metadata = None
 
207
  if language_code not in global_align_model_cache:
208
  align_model_name = ALIGN_MODEL_MAP.get(language_code)
 
209
  try:
210
  if align_model_name:
211
  print(f"Loading custom alignment model for {language_code}: {align_model_name}...")
@@ -216,16 +228,18 @@ def analyze_audio(audio_file: str,
216
  )
217
  global_align_model_cache[language_code] = (align_model, metadata)
218
  print(f"Alignment model loaded/cached for language: {language_code}")
219
-
220
  except Exception as e:
221
  warn(results, "ALIGN_LOAD_FAIL", f"Failed to load alignment model for {language_code}: {e}. Alignment skipped.")
222
- global_align_model_cache[language_code] = (None, None) # Cache the failure/skip
223
  else:
224
  align_model, metadata = global_align_model_cache[language_code]
225
  if align_model:
226
  print(f"Alignment model loaded from cache for language: {language_code}")
 
227
  if align_model:
228
  try:
 
229
  aligned = whisperx.align(
230
  result["segments"],
231
  align_model,
@@ -236,7 +250,26 @@ def analyze_audio(audio_file: str,
236
  except Exception as e:
237
  warn(results, "ALIGN_RUN_FAIL", f"Alignment execution failed: {type(e).__name__}: {e}. Using raw segments.")
238
  else:
239
- warn(results, "ALIGN_SKIP", "Alignment model unavailable; using raw Whisper segments.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  try:
241
  diarize_segments_for_assignment = []
242
  if diarize_output is not None and hasattr(diarize_output, "itertracks"):
@@ -247,6 +280,7 @@ def analyze_audio(audio_file: str,
247
  "speaker": normalize_speaker(label)
248
  })
249
  print(f"DEBUG: Converted {len(diarize_segments_for_assignment)} diarization segments.")
 
250
  if diarize_segments_for_assignment:
251
  diarize_df = pd.DataFrame(diarize_segments_for_assignment)
252
  final = whisperx.assign_word_speakers(diarize_df, aligned)
@@ -254,12 +288,15 @@ def analyze_audio(audio_file: str,
254
  warn(results, "ASSIGN_FAIL", "Diarization segments were empty or unavailable. Defaulting all to Speaker_1.")
255
  final = aligned
256
  for seg in final.get("segments", []):
257
- seg["speaker"] = "Speaker_1"
 
258
  except Exception as e:
259
  warn(results, "ASSIGN_SPEAKERS_ERROR", f"Error assigning speakers: {type(e).__name__}: {e}. Falling back to unassigned segments.")
260
  final = aligned
261
  for seg in final.get("segments", []):
262
  seg["speaker"] = "Speaker_1"
 
 
263
  def _get_time_field(d: Dict[str, Any], keys: List[str]) -> Optional[float]:
264
  """Try multiple possible keys and coerce to native float, returning None if not possible."""
265
  for k in keys:
@@ -275,6 +312,7 @@ def analyze_audio(audio_file: str,
275
  except (TypeError, ValueError):
276
  continue
277
  return None
 
278
  for seg in final.get("segments", []):
279
  seg_speaker = normalize_speaker(seg.get("speaker") or seg.get("speaker_label") or "Speaker_1")
280
  word_list = seg.get("words") or seg.get("tokens") or seg.get("items") or []
@@ -301,20 +339,21 @@ def analyze_audio(audio_file: str,
301
  word_start = _get_time_field(seg, ["start", "s"])
302
  if word_end is None:
303
  word_end = _get_time_field(seg, ["end", "e"])
304
-
305
  if word_start is None:
306
  continue
307
  if word_end is None:
308
  word_end = word_start
309
  word_speaker = normalize_speaker(w.get("speaker") or seg_speaker)
310
  word_text = (w.get("text") or w.get("word") or w.get("label") or "").strip()
311
-
312
  rows.append({
313
  "start": float(word_start),
314
  "end": float(word_end),
315
  "text": str(word_text),
316
  "speaker": str(word_speaker),
317
  })
 
318
  rows = sorted(rows, key=lambda r: r.get("start", 0.0))
319
  results.timelineData = rows
320
  for w in rows:
@@ -322,9 +361,21 @@ def analyze_audio(audio_file: str,
322
  f_e = force_float(e)
323
  if f_e is not None:
324
  ends.append(f_e)
 
325
  except Exception as e:
326
  results.message = f"Error during ML processing: {type(e).__name__}: {e}"
327
  return results
 
 
 
 
 
 
 
 
 
 
 
328
  finally:
329
  if temp_preproc and os.path.exists(temp_preproc):
330
  os.remove(temp_preproc)
 
158
  y = rms_normalize(y, target_rms=target_rms)
159
  sf.write(output_path, y, sr, subtype=output_subtype)
160
  return output_path
161
+ def analyze_audio(audio_file: str,
162
+ reference_rttm_file: Optional[str] = None,
163
+ preprocess: bool = True,
164
  preprocess_params: Optional[Dict[str, Any]] = None) -> AnalysisResults:
165
+
166
  results = AnalysisResults()
167
+ # Ensure access to global variables for reading/writing
168
+ global global_align_model_cache, ALIGN_MODEL_MAP
169
+
170
+ ends: List[float] = []
171
  rows: List[Dict[str, Any]] = []
172
+
173
  if not os.path.exists(audio_file):
174
  results.message = f"Error: Input audio file '{audio_file}' not found."
175
  return results
176
 
177
  audio_for_model = audio_file
178
  temp_preproc = None
179
+
180
+ # --- Preprocessing ---
181
  if preprocess:
182
  params = {
183
  "target_sr": 16000, "normalize_rms": True, "target_rms": 0.08,
 
196
  warn(results, "PREP_FAIL", f"Preprocessing failed: {e}. Falling back to original audio.")
197
  audio_for_model = audio_file
198
  temp_preproc = None
199
+
200
  start_ml_time = time.time()
201
  try:
202
+ # --- Transcription ---
203
  print(f"Loading Whisper model '{model_name}' on {device}...")
204
  model = whisperx.load_model(model_name, device, compute_type="float32")
205
  audio_loaded = whisperx.load_audio(audio_for_model)
206
  print("Transcribing audio...")
207
  result = model.transcribe(audio_loaded, batch_size=4 )
208
+
209
  language_code = result.get("language") or result.get("detected_language") or "en"
210
  results.languageCode = language_code
211
+ aligned = {"segments": result["segments"]} # Default fallback
212
+
213
+ # --- Alignment Loading and Execution (Language-Specific) ---
214
  print(f"Detected language: {language_code}. Aligning transcription...")
 
215
  align_model = None
216
  metadata = None
217
+
218
  if language_code not in global_align_model_cache:
219
  align_model_name = ALIGN_MODEL_MAP.get(language_code)
220
+
221
  try:
222
  if align_model_name:
223
  print(f"Loading custom alignment model for {language_code}: {align_model_name}...")
 
228
  )
229
  global_align_model_cache[language_code] = (align_model, metadata)
230
  print(f"Alignment model loaded/cached for language: {language_code}")
231
+
232
  except Exception as e:
233
  warn(results, "ALIGN_LOAD_FAIL", f"Failed to load alignment model for {language_code}: {e}. Alignment skipped.")
234
+ global_align_model_cache[language_code] = (None, None)
235
  else:
236
  align_model, metadata = global_align_model_cache[language_code]
237
  if align_model:
238
  print(f"Alignment model loaded from cache for language: {language_code}")
239
+
240
  if align_model:
241
  try:
242
+ print("Performing word-level alignment...")
243
  aligned = whisperx.align(
244
  result["segments"],
245
  align_model,
 
250
  except Exception as e:
251
  warn(results, "ALIGN_RUN_FAIL", f"Alignment execution failed: {type(e).__name__}: {e}. Using raw segments.")
252
  else:
253
+ warn(results, "ALIGN_SKIP", "Alignment model unavailable; using raw Whisper segments.")
254
+
255
+
256
+ # --- DIARIZATION EXECUTION (The missing block, now re-inserted) ---
257
+ diarize_output = None
258
+ if global_diarizer is not None:
259
+ print("Performing speaker diarization (Requires HF_TOKEN)...")
260
+ try:
261
+ diarize_output = global_diarizer(audio_for_model)
262
+ for segment, _, label in diarize_output.itertracks(yield_label=True):
263
+ print(f"start={segment.start:.1f}s stop={segment.end:.1f}s {label}")
264
+ except Exception as e:
265
+ warn(results, "DIAR_SKIP", f"Error during diarization (likely token/model failure): {type(e).__name__}: {e}. Skipping diarization.")
266
+ diarize_output = None
267
+ else:
268
+ warn(results, "DIAR_SKIP", "HF_TOKEN not set or Diarization Pipeline failed to load globally. Skipping speaker diarization.")
269
+
270
+
271
+ # --- Speaker Assignment ---
272
+ print("Assigning speakers to words...")
273
  try:
274
  diarize_segments_for_assignment = []
275
  if diarize_output is not None and hasattr(diarize_output, "itertracks"):
 
280
  "speaker": normalize_speaker(label)
281
  })
282
  print(f"DEBUG: Converted {len(diarize_segments_for_assignment)} diarization segments.")
283
+
284
  if diarize_segments_for_assignment:
285
  diarize_df = pd.DataFrame(diarize_segments_for_assignment)
286
  final = whisperx.assign_word_speakers(diarize_df, aligned)
 
288
  warn(results, "ASSIGN_FAIL", "Diarization segments were empty or unavailable. Defaulting all to Speaker_1.")
289
  final = aligned
290
  for seg in final.get("segments", []):
291
+ seg["speaker"] = "Speaker_1"
292
+
293
  except Exception as e:
294
  warn(results, "ASSIGN_SPEAKERS_ERROR", f"Error assigning speakers: {type(e).__name__}: {e}. Falling back to unassigned segments.")
295
  final = aligned
296
  for seg in final.get("segments", []):
297
  seg["speaker"] = "Speaker_1"
298
+
299
+ # ... (rest of the timeline generation logic) ...
300
  def _get_time_field(d: Dict[str, Any], keys: List[str]) -> Optional[float]:
301
  """Try multiple possible keys and coerce to native float, returning None if not possible."""
302
  for k in keys:
 
312
  except (TypeError, ValueError):
313
  continue
314
  return None
315
+
316
  for seg in final.get("segments", []):
317
  seg_speaker = normalize_speaker(seg.get("speaker") or seg.get("speaker_label") or "Speaker_1")
318
  word_list = seg.get("words") or seg.get("tokens") or seg.get("items") or []
 
339
  word_start = _get_time_field(seg, ["start", "s"])
340
  if word_end is None:
341
  word_end = _get_time_field(seg, ["end", "e"])
342
+
343
  if word_start is None:
344
  continue
345
  if word_end is None:
346
  word_end = word_start
347
  word_speaker = normalize_speaker(w.get("speaker") or seg_speaker)
348
  word_text = (w.get("text") or w.get("word") or w.get("label") or "").strip()
349
+
350
  rows.append({
351
  "start": float(word_start),
352
  "end": float(word_end),
353
  "text": str(word_text),
354
  "speaker": str(word_speaker),
355
  })
356
+
357
  rows = sorted(rows, key=lambda r: r.get("start", 0.0))
358
  results.timelineData = rows
359
  for w in rows:
 
361
  f_e = force_float(e)
362
  if f_e is not None:
363
  ends.append(f_e)
364
+
365
  except Exception as e:
366
  results.message = f"Error during ML processing: {type(e).__name__}: {e}"
367
  return results
368
+
369
+ finally:
370
+ if temp_preproc and os.path.exists(temp_preproc):
371
+ os.remove(temp_preproc)
372
+
373
+ results.duration = force_float(max(ends) if ends else 0.0) or 0.0
374
+ end_ml_time = time.time()
375
+ print(f"ML Processing finished in {end_ml_time - start_ml_time:.2f} seconds.")
376
+
377
+ results.success = True
378
+ return results
379
  finally:
380
  if temp_preproc and os.path.exists(temp_preproc):
381
  os.remove(temp_preproc)