Spaces:
Paused
Paused
Replaced librosa with torchaudio for audio loading and resampling. Added speech detection (energy-based or webrtcvad for accuracy). Improved /translate-audio endpoint to handle silent audio gracefully.
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ import uuid
|
|
| 9 |
import torch
|
| 10 |
import numpy as np
|
| 11 |
import soundfile as sf
|
| 12 |
-
import
|
| 13 |
import wave
|
| 14 |
import time
|
| 15 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
|
@@ -80,6 +80,30 @@ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
|
|
| 80 |
# Write the 16-bit PCM data as bytes (little-endian)
|
| 81 |
wav_file.writeframes(pcm_array.tobytes())
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# Function to clean up old audio files
|
| 84 |
def cleanup_old_audio_files():
|
| 85 |
logger.info("Starting cleanup of old audio files...")
|
|
@@ -417,17 +441,33 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 417 |
output_audio_url = None
|
| 418 |
|
| 419 |
try:
|
| 420 |
-
# Step 1:
|
| 421 |
logger.info(f"Reading audio file: {temp_path}")
|
| 422 |
-
waveform, sample_rate =
|
| 423 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
|
|
|
|
|
|
| 424 |
if sample_rate != 16000:
|
| 425 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 429 |
logger.info(f"Using device: {device}")
|
| 430 |
-
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
| 431 |
logger.info("Audio processed, generating transcription...")
|
| 432 |
|
| 433 |
with torch.no_grad():
|
|
@@ -442,7 +482,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 442 |
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
| 443 |
logger.info(f"Transcription completed: {transcription}")
|
| 444 |
|
| 445 |
-
# Step
|
| 446 |
source_code = LANGUAGE_MAPPING[source_lang]
|
| 447 |
target_code = LANGUAGE_MAPPING[target_lang]
|
| 448 |
|
|
@@ -466,7 +506,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 466 |
else:
|
| 467 |
logger.warning("MT model not loaded, skipping translation")
|
| 468 |
|
| 469 |
-
# Step
|
| 470 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
| 471 |
try:
|
| 472 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|
|
|
|
| 9 |
import torch
|
| 10 |
import numpy as np
|
| 11 |
import soundfile as sf
|
| 12 |
+
import torchaudio
|
| 13 |
import wave
|
| 14 |
import time
|
| 15 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
|
|
|
| 80 |
# Write the 16-bit PCM data as bytes (little-endian)
|
| 81 |
wav_file.writeframes(pcm_array.tobytes())
|
| 82 |
|
| 83 |
+
# Function to detect speech using an energy-based approach
|
| 84 |
+
def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
|
| 85 |
+
"""
|
| 86 |
+
Detects if the audio contains speech using an energy-based approach.
|
| 87 |
+
Returns True if speech is detected, False otherwise.
|
| 88 |
+
"""
|
| 89 |
+
# Convert waveform to numpy array
|
| 90 |
+
waveform_np = waveform.numpy()
|
| 91 |
+
if waveform_np.ndim > 1:
|
| 92 |
+
waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono
|
| 93 |
+
|
| 94 |
+
# Compute RMS energy
|
| 95 |
+
rms = np.sqrt(np.mean(waveform_np**2))
|
| 96 |
+
logger.info(f"RMS energy: {rms}")
|
| 97 |
+
|
| 98 |
+
# Check if RMS energy exceeds the threshold
|
| 99 |
+
if rms < threshold:
|
| 100 |
+
logger.info("No speech detected: RMS energy below threshold")
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
# Optionally, check for minimum speech duration (requires more sophisticated VAD)
|
| 104 |
+
# For now, we assume if RMS is above threshold, there is speech
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
# Function to clean up old audio files
|
| 108 |
def cleanup_old_audio_files():
|
| 109 |
logger.info("Starting cleanup of old audio files...")
|
|
|
|
| 441 |
output_audio_url = None
|
| 442 |
|
| 443 |
try:
|
| 444 |
+
# Step 1: Load and resample the audio using torchaudio
|
| 445 |
logger.info(f"Reading audio file: {temp_path}")
|
| 446 |
+
waveform, sample_rate = torchaudio.load(temp_path)
|
| 447 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
| 448 |
+
|
| 449 |
+
# Resample to 16 kHz if needed (required by Whisper and MMS models)
|
| 450 |
if sample_rate != 16000:
|
| 451 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
| 452 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
| 453 |
+
waveform = resampler(waveform)
|
| 454 |
+
sample_rate = 16000
|
| 455 |
+
|
| 456 |
+
# Step 2: Detect speech
|
| 457 |
+
if not detect_speech(waveform, sample_rate):
|
| 458 |
+
return {
|
| 459 |
+
"request_id": request_id,
|
| 460 |
+
"status": "failed",
|
| 461 |
+
"message": "No speech detected in the audio.",
|
| 462 |
+
"source_text": "No speech detected",
|
| 463 |
+
"translated_text": "No translation available",
|
| 464 |
+
"output_audio": None
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
# Step 3: Transcribe the audio (STT)
|
| 468 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 469 |
logger.info(f"Using device: {device}")
|
| 470 |
+
inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
|
| 471 |
logger.info("Audio processed, generating transcription...")
|
| 472 |
|
| 473 |
with torch.no_grad():
|
|
|
|
| 482 |
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
| 483 |
logger.info(f"Transcription completed: {transcription}")
|
| 484 |
|
| 485 |
+
# Step 4: Translate the transcribed text (MT)
|
| 486 |
source_code = LANGUAGE_MAPPING[source_lang]
|
| 487 |
target_code = LANGUAGE_MAPPING[target_lang]
|
| 488 |
|
|
|
|
| 506 |
else:
|
| 507 |
logger.warning("MT model not loaded, skipping translation")
|
| 508 |
|
| 509 |
+
# Step 5: Convert translated text to speech (TTS)
|
| 510 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
| 511 |
try:
|
| 512 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|