bm_speech / app.py
Gaoussin's picture
Update app.py
315d5df verified
import os
import io
import torch
import librosa
import subprocess
import tempfile
import soundfile as sf
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForCTC, AutoProcessor
# Set cache to writable directory
os.environ["HF_HOME"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
app = FastAPI(title="Bambara ASR Dedicated API")
# Enable CORS for your frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load ASR components globally
device = "cpu"
model_id = "facebook/mms-1b-all"
print("Loading processor and model...")
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)
# Pre-load Bambara adapter to prevent lag/OOM on first request
processor.tokenizer.set_target_lang("bam")
model.load_adapter("bam")
print("Bambara adapter loaded. System Ready.")
@app.post("/transcribe")
async def transcribe(audio_file: UploadFile = File(...)):
try:
content = await audio_file.read()
if not content:
return {"text": "Empty audio"}
# Write WebM to temp file
with tempfile.NamedTemporaryFile(suffix=".webm") as f_webm, \
tempfile.NamedTemporaryFile(suffix=".wav") as f_wav:
f_webm.write(content)
f_webm.flush()
# Convert WebM → WAV (mono, 16kHz)
subprocess.run(
[
"ffmpeg", "-y",
"-i", f_webm.name,
"-ac", "1",
"-ar", "16000",
f_wav.name
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True
)
# Read WAV
audio_data, sr = sf.read(f_wav.name)
# ASR inference
inputs = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt"
).to(device)
with torch.inference_mode():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(predicted_ids)[0]
return {"text": text}
except subprocess.CalledProcessError:
return {"text": "FFmpeg conversion failed"}
except Exception as e:
print("Server Error:", e)
return {"text": str(e)}