Spaces:
Sleeping
Sleeping
Commit ·
bc14aa5
1
Parent(s): a1de44a
Convert browser audio before inference
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import base64
|
| 2 |
import re
|
| 3 |
import shutil
|
|
|
|
| 4 |
import tempfile
|
| 5 |
import zipfile
|
| 6 |
from pathlib import Path
|
|
@@ -332,6 +333,26 @@ def decode_audio_to_tempfile(audio_payload: str) -> str:
|
|
| 332 |
return temp_file.name
|
| 333 |
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
def transcribe_audio(temp_path: str) -> str:
|
| 336 |
asr = load_asr_pipeline()
|
| 337 |
result = asr(
|
|
@@ -347,9 +368,11 @@ def transcribe_audio(temp_path: str) -> str:
|
|
| 347 |
def predict_audio(audio_payload: str) -> dict:
|
| 348 |
processor, wav2vec_model, audio_classifier, labels = load_audio_model()
|
| 349 |
temp_path = decode_audio_to_tempfile(audio_payload)
|
|
|
|
| 350 |
|
| 351 |
try:
|
| 352 |
-
|
|
|
|
| 353 |
span_result = predict_text(transcript) if transcript else {
|
| 354 |
"isToxic": False,
|
| 355 |
"confidence": 0.0,
|
|
@@ -365,7 +388,7 @@ def predict_audio(audio_payload: str) -> dict:
|
|
| 365 |
},
|
| 366 |
}
|
| 367 |
|
| 368 |
-
waveform, sample_rate = torchaudio.load(
|
| 369 |
if waveform.shape[0] > 1:
|
| 370 |
waveform = waveform.mean(dim=0, keepdim=True)
|
| 371 |
|
|
@@ -423,6 +446,8 @@ def predict_audio(audio_payload: str) -> dict:
|
|
| 423 |
}
|
| 424 |
finally:
|
| 425 |
Path(temp_path).unlink(missing_ok=True)
|
|
|
|
|
|
|
| 426 |
|
| 427 |
|
| 428 |
def audio_fallback_prediction(message: str = "Audio inference could not run.") -> dict:
|
|
@@ -461,7 +486,10 @@ def detect(payload: DetectRequest):
|
|
| 461 |
if payload.mode == "audio":
|
| 462 |
if not payload.audio:
|
| 463 |
return audio_fallback_prediction("No audio payload was provided.")
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
text = payload.text or "yeh toxic span detection result hai"
|
| 467 |
return predict_text(text)
|
|
|
|
| 1 |
import base64
|
| 2 |
import re
|
| 3 |
import shutil
|
| 4 |
+
import subprocess
|
| 5 |
import tempfile
|
| 6 |
import zipfile
|
| 7 |
from pathlib import Path
|
|
|
|
| 333 |
return temp_file.name
|
| 334 |
|
| 335 |
|
| 336 |
+
def convert_audio_to_wav(input_path: str) -> str:
|
| 337 |
+
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 338 |
+
output_file.close()
|
| 339 |
+
command = [
|
| 340 |
+
"ffmpeg",
|
| 341 |
+
"-y",
|
| 342 |
+
"-i",
|
| 343 |
+
input_path,
|
| 344 |
+
"-ac",
|
| 345 |
+
"1",
|
| 346 |
+
"-ar",
|
| 347 |
+
"16000",
|
| 348 |
+
"-t",
|
| 349 |
+
"10",
|
| 350 |
+
output_file.name,
|
| 351 |
+
]
|
| 352 |
+
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 353 |
+
return output_file.name
|
| 354 |
+
|
| 355 |
+
|
| 356 |
def transcribe_audio(temp_path: str) -> str:
|
| 357 |
asr = load_asr_pipeline()
|
| 358 |
result = asr(
|
|
|
|
| 368 |
def predict_audio(audio_payload: str) -> dict:
|
| 369 |
processor, wav2vec_model, audio_classifier, labels = load_audio_model()
|
| 370 |
temp_path = decode_audio_to_tempfile(audio_payload)
|
| 371 |
+
wav_path = None
|
| 372 |
|
| 373 |
try:
|
| 374 |
+
wav_path = convert_audio_to_wav(temp_path)
|
| 375 |
+
transcript = transcribe_audio(wav_path)
|
| 376 |
span_result = predict_text(transcript) if transcript else {
|
| 377 |
"isToxic": False,
|
| 378 |
"confidence": 0.0,
|
|
|
|
| 388 |
},
|
| 389 |
}
|
| 390 |
|
| 391 |
+
waveform, sample_rate = torchaudio.load(wav_path)
|
| 392 |
if waveform.shape[0] > 1:
|
| 393 |
waveform = waveform.mean(dim=0, keepdim=True)
|
| 394 |
|
|
|
|
| 446 |
}
|
| 447 |
finally:
|
| 448 |
Path(temp_path).unlink(missing_ok=True)
|
| 449 |
+
if wav_path:
|
| 450 |
+
Path(wav_path).unlink(missing_ok=True)
|
| 451 |
|
| 452 |
|
| 453 |
def audio_fallback_prediction(message: str = "Audio inference could not run.") -> dict:
|
|
|
|
| 486 |
if payload.mode == "audio":
|
| 487 |
if not payload.audio:
|
| 488 |
return audio_fallback_prediction("No audio payload was provided.")
|
| 489 |
+
try:
|
| 490 |
+
return predict_audio(payload.audio)
|
| 491 |
+
except Exception as exc:
|
| 492 |
+
return audio_fallback_prediction(f"Audio inference failed: {exc}")
|
| 493 |
|
| 494 |
text = payload.text or "yeh toxic span detection result hai"
|
| 495 |
return predict_text(text)
|