import os import io import torch import librosa from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import Wav2Vec2ForCTC, AutoProcessor # Set cache to writable directory os.environ["HF_HOME"] = "/tmp/hf" os.makedirs("/tmp/hf", exist_ok=True) app = FastAPI(title="Bambara ASR Dedicated API") # Enable CORS for your frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load ASR components globally device = "cpu" model_id = "facebook/mms-1b-all" print("Loading processor and model...") processor = AutoProcessor.from_pretrained(model_id) model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) # Pre-load Bambara adapter to prevent lag/OOM on first request processor.tokenizer.set_target_lang("bam") model.load_adapter("bam") print("Bambara adapter loaded. System Ready.") @app.post("/transcribe") async def transcribe(audio_file: UploadFile = File(...)): try: # Read file stream content = await audio_file.read() if not content: return {"text": "Error: Empty audio file"} # Load & Resample (Critical: Model expects 16,000Hz) audio_data, _ = librosa.load(io.BytesIO(content), sr=16000) # Prepare inputs inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device) # Inference (inference_mode is more memory efficient than no_grad) with torch.inference_mode(): logits = model(**inputs).logits # Decode output predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return {"text": transcription} except Exception as e: print(f"Server Error: {e}") return {"text": f"Error: {str(e)}"}