Grinding's picture
Update app.py
736aad2 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from transformers import pipeline
import numpy as np
import tempfile
import os
import subprocess
import soundfile as sf
app = FastAPI()
# Load the ASR pipeline on startup
try:
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="distil-whisper/distil-large-v3",
torch_dtype=None, # let pipeline pick sensible dtype
device="cpu",
)
print("✅ ASR model loaded successfully")
except Exception as e:
asr_pipeline = None
print(f"❌ Error loading ASR model: {e}")
@app.post("/transcribe")
async def transcribe_audio(audio_file: UploadFile = File(...)):
if not asr_pipeline:
raise HTTPException(status_code=503, detail="ASR model is not available.")
audio_bytes = await audio_file.read()
tmp_in = None
tmp_wav = None
try:
# 1) Save uploaded bytes to a temporary file (preserve extension if available)
suffix = os.path.splitext(audio_file.filename)[1] or ""
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tf:
tf.write(audio_bytes)
tf.flush()
tmp_in = tf.name
# 2) Use ffmpeg to convert to 16kHz mono WAV PCM (stable, avoids librosa/numba)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tfwav:
tmp_wav = tfwav.name
ffmpeg_cmd = [
"ffmpeg",
"-y", # overwrite
"-i", tmp_in, # input file
"-ar", "16000", # sample rate 16k
"-ac", "1", # mono
"-f", "wav",
tmp_wav
]
proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
if proc.returncode != 0:
# include ffmpeg stderr for debugging
raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}")
# 3) Read WAV with soundfile into float32 waveform
speech, sr = sf.read(tmp_wav, dtype="float32")
if sr != 16000:
# should not happen because ffmpeg forced 16k, but check anyway
raise RuntimeError(f"Unexpected sample rate {sr}")
# 4) Transcribe using the transformers ASR pipeline
# Provide waveform as a 1-D numpy array
if speech.ndim > 1:
# ensure mono
speech = np.mean(speech, axis=1)
# chunking options to keep memory bounded
result = asr_pipeline(speech, chunk_length_s=30, stride_length_s=5)
text = result.get("text", "") if isinstance(result, dict) else (
result[0].get("text", "") if isinstance(result, list) and result else ""
)
return {"transcription": text}
except Exception as e:
# Return a 400 with a helpful message
raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
finally:
# cleanup temp files
for path in (tmp_in, tmp_wav):
if path and os.path.exists(path):
try:
os.remove(path)
except Exception:
pass