File size: 1,908 Bytes
1945a83
 
8bb18e8
 
1945a83
 
 
7642f85
1945a83
 
 
7642f85
1945a83
7642f85
1945a83
 
 
 
 
 
 
 
7642f85
1945a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7642f85
1945a83
 
 
 
7642f85
1945a83
 
7642f85
1945a83
 
 
 
 
8bb18e8
7642f85
1945a83
8bb18e8
 
7642f85
8bb18e8
7642f85
 
1945a83
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import io
import torch
import librosa
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForCTC, AutoProcessor

# Set cache to writable directory
os.environ["HF_HOME"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)

app = FastAPI(title="Bambara ASR Dedicated API")

# Enable CORS for your frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load ASR components globally
device = "cpu"
model_id = "facebook/mms-1b-all"

print("Loading processor and model...")
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)

# Pre-load Bambara adapter to prevent lag/OOM on first request
processor.tokenizer.set_target_lang("bam")
model.load_adapter("bam")
print("Bambara adapter loaded. System Ready.")

@app.post("/transcribe")
async def transcribe(audio_file: UploadFile = File(...)):
    try:
        # Read file stream
        content = await audio_file.read()
        if not content:
            return {"text": "Error: Empty audio file"}

        # Load & Resample (Critical: Model expects 16,000Hz)
        audio_data, _ = librosa.load(io.BytesIO(content), sr=16000)

        # Prepare inputs
        inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
        
        # Inference (inference_mode is more memory efficient than no_grad)
        with torch.inference_mode():
            logits = model(**inputs).logits

        # Decode output
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

        return {"text": transcription}

    except Exception as e:
        print(f"Server Error: {e}")
        return {"text": f"Error: {str(e)}"}