STT_Api / app.py
Somalitts's picture
Update app.py
64d5cde verified
import os
import io
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from huggingface_hub import snapshot_download
# ---- Robust HF cache setup (writable in Docker/Spaces) ----
HF_HOME = os.environ.get("HF_HOME", "/tmp/hf")
os.environ["HF_HOME"] = HF_HOME
os.environ["TRANSFORMERS_CACHE"] = os.path.join(HF_HOME, "transformers")
os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
MODEL_ID = os.environ.get("MODEL_ID", "Mustafaa4a/ASR-Somali")
HF_TOKEN = os.environ.get("HF_TOKEN") # only needed for private repos
app = FastAPI(title="Somali ASR API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
processor = None
model = None
@app.on_event("startup")
def _load_model():
global processor, model
try:
# Download the repo snapshot to a local, writable dir
local_dir = snapshot_download(
repo_id=MODEL_ID,
token=HF_TOKEN,
cache_dir=HF_HOME,
)
processor = Wav2Vec2Processor.from_pretrained(local_dir)
model = Wav2Vec2ForCTC.from_pretrained(local_dir)
model.eval()
except Exception as e:
# Surface a clear error instead of crashing Uvicorn silently
raise RuntimeError(f"Failed to load model '{MODEL_ID}': {e}")
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": model is not None, "model_id": MODEL_ID}
@app.get("/")
def root():
return {"message": "Somali Speech-to-Text API is running."}
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded yet. Try again shortly.")
# Read bytes
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Empty file")
# Load audio from bytes
try:
audio_stream = io.BytesIO(audio_bytes)
# torchaudio can auto-detect many formats if system codecs are present
waveform, sample_rate = torchaudio.load(audio_stream)
except Exception:
# As a fallback, try forcing WAV (in case the client always sends WAV)
try:
audio_stream = io.BytesIO(audio_bytes)
waveform, sample_rate = torchaudio.load(audio_stream, format="wav")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Could not read audio: {e}")
# Mono + 16k resample for Wav2Vec2
if waveform.dim() == 2 and waveform.size(0) > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True) # convert to mono
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
return {"transcription": transcription}