Spaces:
Running
Running
File size: 2,541 Bytes
1945a83 8bb18e8 315d5df 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 7642f85 1945a83 2955b20 315d5df 1945a83 7642f85 1945a83 315d5df 7642f85 315d5df 2955b20 315d5df 7642f85 1945a83 8bb18e8 7642f85 8bb18e8 315d5df 7642f85 315d5df 7642f85 315d5df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import io
import torch
import librosa
import subprocess
import tempfile
import soundfile as sf
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForCTC, AutoProcessor
# Set cache to writable directory
os.environ["HF_HOME"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
app = FastAPI(title="Bambara ASR Dedicated API")
# Enable CORS for your frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load ASR components globally
device = "cpu"
model_id = "facebook/mms-1b-all"
print("Loading processor and model...")
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)
# Pre-load Bambara adapter to prevent lag/OOM on first request
processor.tokenizer.set_target_lang("bam")
model.load_adapter("bam")
print("Bambara adapter loaded. System Ready.")
@app.post("/transcribe")
async def transcribe(audio_file: UploadFile = File(...)):
try:
content = await audio_file.read()
if not content:
return {"text": "Empty audio"}
# Write WebM to temp file
with tempfile.NamedTemporaryFile(suffix=".webm") as f_webm, \
tempfile.NamedTemporaryFile(suffix=".wav") as f_wav:
f_webm.write(content)
f_webm.flush()
# Convert WebM → WAV (mono, 16kHz)
subprocess.run(
[
"ffmpeg", "-y",
"-i", f_webm.name,
"-ac", "1",
"-ar", "16000",
f_wav.name
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True
)
# Read WAV
audio_data, sr = sf.read(f_wav.name)
# ASR inference
inputs = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt"
).to(device)
with torch.inference_mode():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(predicted_ids)[0]
return {"text": text}
except subprocess.CalledProcessError:
return {"text": "FFmpeg conversion failed"}
except Exception as e:
print("Server Error:", e)
return {"text": str(e)}
|