Spaces:
Running
Running
| import os | |
| import io | |
| import torch | |
| import librosa | |
| import subprocess | |
| import tempfile | |
| import soundfile as sf | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| # Set cache to writable directory | |
| os.environ["HF_HOME"] = "/tmp/hf" | |
| os.makedirs("/tmp/hf", exist_ok=True) | |
| app = FastAPI(title="Bambara ASR Dedicated API") | |
| # Enable CORS for your frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load ASR components globally | |
| device = "cpu" | |
| model_id = "facebook/mms-1b-all" | |
| print("Loading processor and model...") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device) | |
| # Pre-load Bambara adapter to prevent lag/OOM on first request | |
| processor.tokenizer.set_target_lang("bam") | |
| model.load_adapter("bam") | |
| print("Bambara adapter loaded. System Ready.") | |
| async def transcribe(audio_file: UploadFile = File(...)): | |
| try: | |
| content = await audio_file.read() | |
| if not content: | |
| return {"text": "Empty audio"} | |
| # Write WebM to temp file | |
| with tempfile.NamedTemporaryFile(suffix=".webm") as f_webm, \ | |
| tempfile.NamedTemporaryFile(suffix=".wav") as f_wav: | |
| f_webm.write(content) | |
| f_webm.flush() | |
| # Convert WebM → WAV (mono, 16kHz) | |
| subprocess.run( | |
| [ | |
| "ffmpeg", "-y", | |
| "-i", f_webm.name, | |
| "-ac", "1", | |
| "-ar", "16000", | |
| f_wav.name | |
| ], | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| check=True | |
| ) | |
| # Read WAV | |
| audio_data, sr = sf.read(f_wav.name) | |
| # ASR inference | |
| inputs = processor( | |
| audio_data, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.inference_mode(): | |
| logits = model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| text = processor.batch_decode(predicted_ids)[0] | |
| return {"text": text} | |
| except subprocess.CalledProcessError: | |
| return {"text": "FFmpeg conversion failed"} | |
| except Exception as e: | |
| print("Server Error:", e) | |
| return {"text": str(e)} | |