Spaces:
Running
Running
File size: 1,908 Bytes
1945a83 8bb18e8 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 8bb18e8 7642f85 1945a83 8bb18e8 7642f85 8bb18e8 7642f85 1945a83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
import io
import torch
import librosa
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForCTC, AutoProcessor
# Set cache to writable directory
os.environ["HF_HOME"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
app = FastAPI(title="Bambara ASR Dedicated API")
# Enable CORS for your frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load ASR components globally
device = "cpu"
model_id = "facebook/mms-1b-all"
print("Loading processor and model...")
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)
# Pre-load Bambara adapter to prevent lag/OOM on first request
processor.tokenizer.set_target_lang("bam")
model.load_adapter("bam")
print("Bambara adapter loaded. System Ready.")
@app.post("/transcribe")
async def transcribe(audio_file: UploadFile = File(...)):
try:
# Read file stream
content = await audio_file.read()
if not content:
return {"text": "Error: Empty audio file"}
# Load & Resample (Critical: Model expects 16,000Hz)
audio_data, _ = librosa.load(io.BytesIO(content), sr=16000)
# Prepare inputs
inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
# Inference (inference_mode is more memory efficient than no_grad)
with torch.inference_mode():
logits = model(**inputs).logits
# Decode output
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return {"text": transcription}
except Exception as e:
print(f"Server Error: {e}")
return {"text": f"Error: {str(e)}"}
|