bm_speech / app.py
Gaoussin's picture
Update app.py
1945a83 verified
raw
history blame
1.91 kB
import os
import io
import torch
import librosa
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForCTC, AutoProcessor
# Set cache to writable directory
os.environ["HF_HOME"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
app = FastAPI(title="Bambara ASR Dedicated API")
# Enable CORS for your frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load ASR components globally
device = "cpu"
model_id = "facebook/mms-1b-all"
print("Loading processor and model...")
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)
# Pre-load Bambara adapter to prevent lag/OOM on first request
processor.tokenizer.set_target_lang("bam")
model.load_adapter("bam")
print("Bambara adapter loaded. System Ready.")
@app.post("/transcribe")
async def transcribe(audio_file: UploadFile = File(...)):
try:
# Read file stream
content = await audio_file.read()
if not content:
return {"text": "Error: Empty audio file"}
# Load & Resample (Critical: Model expects 16,000Hz)
audio_data, _ = librosa.load(io.BytesIO(content), sr=16000)
# Prepare inputs
inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
# Inference (inference_mode is more memory efficient than no_grad)
with torch.inference_mode():
logits = model(**inputs).logits
# Decode output
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return {"text": transcription}
except Exception as e:
print(f"Server Error: {e}")
return {"text": f"Error: {str(e)}"}