import os import io import torch import librosa import subprocess import tempfile import soundfile as sf import numpy as np from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import Wav2Vec2ForCTC, AutoProcessor # Set cache to writable directory os.environ["HF_HOME"] = "/tmp/hf" os.makedirs("/tmp/hf", exist_ok=True) app = FastAPI(title="Bambara ASR Dedicated API") # Enable CORS for your frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load ASR components globally device = "cpu" model_id = "facebook/mms-1b-all" print("Loading processor and model...") processor = AutoProcessor.from_pretrained(model_id) model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) # Pre-load Bambara adapter to prevent lag/OOM on first request processor.tokenizer.set_target_lang("bam") model.load_adapter("bam") print("Bambara adapter loaded. System Ready.") @app.post("/transcribe") async def transcribe(audio_file: UploadFile = File(...)): try: content = await audio_file.read() if not content: return {"text": "Empty audio"} # Write WebM to temp file with tempfile.NamedTemporaryFile(suffix=".webm") as f_webm, \ tempfile.NamedTemporaryFile(suffix=".wav") as f_wav: f_webm.write(content) f_webm.flush() # Convert WebM → WAV (mono, 16kHz) subprocess.run( [ "ffmpeg", "-y", "-i", f_webm.name, "-ac", "1", "-ar", "16000", f_wav.name ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True ) # Read WAV audio_data, sr = sf.read(f_wav.name) # ASR inference inputs = processor( audio_data, sampling_rate=16000, return_tensors="pt" ).to(device) with torch.inference_mode(): logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) text = processor.batch_decode(predicted_ids)[0] return {"text": text} except subprocess.CalledProcessError: return {"text": "FFmpeg conversion failed"} except Exception as e: print("Server Error:", e) return {"text": str(e)}