import os import io import torch import torchaudio from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from huggingface_hub import snapshot_download # ---- Robust HF cache setup (writable in Docker/Spaces) ---- HF_HOME = os.environ.get("HF_HOME", "/tmp/hf") os.environ["HF_HOME"] = HF_HOME os.environ["TRANSFORMERS_CACHE"] = os.path.join(HF_HOME, "transformers") os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True) MODEL_ID = os.environ.get("MODEL_ID", "Mustafaa4a/ASR-Somali") HF_TOKEN = os.environ.get("HF_TOKEN") # only needed for private repos app = FastAPI(title="Somali ASR API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) processor = None model = None @app.on_event("startup") def _load_model(): global processor, model try: # Download the repo snapshot to a local, writable dir local_dir = snapshot_download( repo_id=MODEL_ID, token=HF_TOKEN, cache_dir=HF_HOME, ) processor = Wav2Vec2Processor.from_pretrained(local_dir) model = Wav2Vec2ForCTC.from_pretrained(local_dir) model.eval() except Exception as e: # Surface a clear error instead of crashing Uvicorn silently raise RuntimeError(f"Failed to load model '{MODEL_ID}': {e}") @app.get("/health") def health(): return {"status": "ok", "model_loaded": model is not None, "model_id": MODEL_ID} @app.get("/") def root(): return {"message": "Somali Speech-to-Text API is running."} @app.post("/transcribe") async def transcribe(file: UploadFile = File(...)): if model is None or processor is None: raise HTTPException(status_code=503, detail="Model not loaded yet. Try again shortly.") # Read bytes audio_bytes = await file.read() if not audio_bytes: raise HTTPException(status_code=400, detail="Empty file") # Load audio from bytes try: audio_stream = io.BytesIO(audio_bytes) # torchaudio can auto-detect many formats if system codecs are present waveform, sample_rate = torchaudio.load(audio_stream) except Exception: # As a fallback, try forcing WAV (in case the client always sends WAV) try: audio_stream = io.BytesIO(audio_bytes) waveform, sample_rate = torchaudio.load(audio_stream, format="wav") except Exception as e: raise HTTPException(status_code=400, detail=f"Could not read audio: {e}") # Mono + 16k resample for Wav2Vec2 if waveform.dim() == 2 and waveform.size(0) > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # convert to mono if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids[0]) return {"transcription": transcription}