Spaces:
Running
Running
| import os | |
| import io | |
| import torch | |
| import librosa | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| # Set cache to writable directory | |
| os.environ["HF_HOME"] = "/tmp/hf" | |
| os.makedirs("/tmp/hf", exist_ok=True) | |
| app = FastAPI(title="Bambara ASR Dedicated API") | |
| # Enable CORS for your frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load ASR components globally | |
| device = "cpu" | |
| model_id = "facebook/mms-1b-all" | |
| print("Loading processor and model...") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) | |
| # Pre-load Bambara adapter to prevent lag/OOM on first request | |
| processor.tokenizer.set_target_lang("bam") | |
| model.load_adapter("bam") | |
| print("Bambara adapter loaded. System Ready.") | |
| async def transcribe(audio_file: UploadFile = File(...)): | |
| try: | |
| # Read file stream | |
| content = await audio_file.read() | |
| if not content: | |
| return {"text": "Error: Empty audio file"} | |
| # Load & Resample (Critical: Model expects 16,000Hz) | |
| audio_data, _ = librosa.load(io.BytesIO(content), sr=16000) | |
| # Prepare inputs | |
| inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device) | |
| # Inference (inference_mode is more memory efficient than no_grad) | |
| with torch.inference_mode(): | |
| logits = model(**inputs).logits | |
| # Decode output | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| return {"text": transcription} | |
| except Exception as e: | |
| print(f"Server Error: {e}") | |
| return {"text": f"Error: {str(e)}"} | |