Spaces:
Paused
Paused
Fix torch import error in translate-audio endpoint
Browse files
app.py
CHANGED
|
@@ -6,8 +6,10 @@ import logging
|
|
| 6 |
import threading
|
| 7 |
import tempfile
|
| 8 |
import uuid
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import soundfile as sf
|
|
|
|
| 11 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
from typing import Dict, Any, Optional
|
|
@@ -50,12 +52,10 @@ def load_models_task():
|
|
| 50 |
try:
|
| 51 |
loading_in_progress = True
|
| 52 |
|
| 53 |
-
#
|
| 54 |
logger.info("Starting to load STT model...")
|
| 55 |
-
import torch
|
| 56 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 57 |
|
| 58 |
-
# Load STT model
|
| 59 |
try:
|
| 60 |
logger.info("Loading Whisper model...")
|
| 61 |
model_status["stt"] = "loading"
|
|
@@ -177,15 +177,18 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 177 |
|
| 178 |
try:
|
| 179 |
# Read and preprocess the audio
|
|
|
|
| 180 |
waveform, sample_rate = sf.read(temp_path)
|
|
|
|
| 181 |
if sample_rate != 16000:
|
| 182 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
| 183 |
-
import librosa
|
| 184 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
| 185 |
|
| 186 |
# Process the audio with Whisper
|
| 187 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 188 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
|
|
|
| 189 |
with torch.no_grad():
|
| 190 |
generated_ids = stt_model.generate(**inputs)
|
| 191 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
@@ -210,6 +213,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 210 |
"output_audio": None
|
| 211 |
}
|
| 212 |
finally:
|
|
|
|
| 213 |
os.unlink(temp_path)
|
| 214 |
|
| 215 |
if __name__ == "__main__":
|
|
|
|
| 6 |
import threading
|
| 7 |
import tempfile
|
| 8 |
import uuid
|
| 9 |
+
import torch
|
| 10 |
import numpy as np
|
| 11 |
import soundfile as sf
|
| 12 |
+
import librosa
|
| 13 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 14 |
from fastapi.responses import JSONResponse
|
| 15 |
from typing import Dict, Any, Optional
|
|
|
|
| 52 |
try:
|
| 53 |
loading_in_progress = True
|
| 54 |
|
| 55 |
+
# Load STT model
|
| 56 |
logger.info("Starting to load STT model...")
|
|
|
|
| 57 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 58 |
|
|
|
|
| 59 |
try:
|
| 60 |
logger.info("Loading Whisper model...")
|
| 61 |
model_status["stt"] = "loading"
|
|
|
|
| 177 |
|
| 178 |
try:
|
| 179 |
# Read and preprocess the audio
|
| 180 |
+
logger.info(f"Reading audio file: {temp_path}")
|
| 181 |
waveform, sample_rate = sf.read(temp_path)
|
| 182 |
+
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
| 183 |
if sample_rate != 16000:
|
| 184 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
|
|
|
| 185 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
| 186 |
|
| 187 |
# Process the audio with Whisper
|
| 188 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 189 |
+
logger.info(f"Using device: {device}")
|
| 190 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
| 191 |
+
logger.info("Audio processed, generating transcription...")
|
| 192 |
with torch.no_grad():
|
| 193 |
generated_ids = stt_model.generate(**inputs)
|
| 194 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
| 213 |
"output_audio": None
|
| 214 |
}
|
| 215 |
finally:
|
| 216 |
+
logger.info(f"Cleaning up temporary file: {temp_path}")
|
| 217 |
os.unlink(temp_path)
|
| 218 |
|
| 219 |
if __name__ == "__main__":
|