inayatarshad commited on
Commit
bc14aa5
·
1 Parent(s): a1de44a

Convert browser audio before inference

Browse files
Files changed (1) hide show
  1. app.py +31 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import base64
2
  import re
3
  import shutil
 
4
  import tempfile
5
  import zipfile
6
  from pathlib import Path
@@ -332,6 +333,26 @@ def decode_audio_to_tempfile(audio_payload: str) -> str:
332
  return temp_file.name
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  def transcribe_audio(temp_path: str) -> str:
336
  asr = load_asr_pipeline()
337
  result = asr(
@@ -347,9 +368,11 @@ def transcribe_audio(temp_path: str) -> str:
347
  def predict_audio(audio_payload: str) -> dict:
348
  processor, wav2vec_model, audio_classifier, labels = load_audio_model()
349
  temp_path = decode_audio_to_tempfile(audio_payload)
 
350
 
351
  try:
352
- transcript = transcribe_audio(temp_path)
 
353
  span_result = predict_text(transcript) if transcript else {
354
  "isToxic": False,
355
  "confidence": 0.0,
@@ -365,7 +388,7 @@ def predict_audio(audio_payload: str) -> dict:
365
  },
366
  }
367
 
368
- waveform, sample_rate = torchaudio.load(temp_path)
369
  if waveform.shape[0] > 1:
370
  waveform = waveform.mean(dim=0, keepdim=True)
371
 
@@ -423,6 +446,8 @@ def predict_audio(audio_payload: str) -> dict:
423
  }
424
  finally:
425
  Path(temp_path).unlink(missing_ok=True)
 
 
426
 
427
 
428
  def audio_fallback_prediction(message: str = "Audio inference could not run.") -> dict:
@@ -461,7 +486,10 @@ def detect(payload: DetectRequest):
461
  if payload.mode == "audio":
462
  if not payload.audio:
463
  return audio_fallback_prediction("No audio payload was provided.")
464
- return predict_audio(payload.audio)
 
 
 
465
 
466
  text = payload.text or "yeh toxic span detection result hai"
467
  return predict_text(text)
 
1
  import base64
2
  import re
3
  import shutil
4
+ import subprocess
5
  import tempfile
6
  import zipfile
7
  from pathlib import Path
 
333
  return temp_file.name
334
 
335
 
336
+ def convert_audio_to_wav(input_path: str) -> str:
337
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
338
+ output_file.close()
339
+ command = [
340
+ "ffmpeg",
341
+ "-y",
342
+ "-i",
343
+ input_path,
344
+ "-ac",
345
+ "1",
346
+ "-ar",
347
+ "16000",
348
+ "-t",
349
+ "10",
350
+ output_file.name,
351
+ ]
352
+ subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
353
+ return output_file.name
354
+
355
+
356
  def transcribe_audio(temp_path: str) -> str:
357
  asr = load_asr_pipeline()
358
  result = asr(
 
368
  def predict_audio(audio_payload: str) -> dict:
369
  processor, wav2vec_model, audio_classifier, labels = load_audio_model()
370
  temp_path = decode_audio_to_tempfile(audio_payload)
371
+ wav_path = None
372
 
373
  try:
374
+ wav_path = convert_audio_to_wav(temp_path)
375
+ transcript = transcribe_audio(wav_path)
376
  span_result = predict_text(transcript) if transcript else {
377
  "isToxic": False,
378
  "confidence": 0.0,
 
388
  },
389
  }
390
 
391
+ waveform, sample_rate = torchaudio.load(wav_path)
392
  if waveform.shape[0] > 1:
393
  waveform = waveform.mean(dim=0, keepdim=True)
394
 
 
446
  }
447
  finally:
448
  Path(temp_path).unlink(missing_ok=True)
449
+ if wav_path:
450
+ Path(wav_path).unlink(missing_ok=True)
451
 
452
 
453
  def audio_fallback_prediction(message: str = "Audio inference could not run.") -> dict:
 
486
  if payload.mode == "audio":
487
  if not payload.audio:
488
  return audio_fallback_prediction("No audio payload was provided.")
489
+ try:
490
+ return predict_audio(payload.audio)
491
+ except Exception as exc:
492
+ return audio_fallback_prediction(f"Audio inference failed: {exc}")
493
 
494
  text = payload.text or "yeh toxic span detection result hai"
495
  return predict_text(text)