File size: 2,541 Bytes
1945a83
 
8bb18e8
 
315d5df
 
 
 
1945a83
 
 
7642f85
1945a83
 
 
7642f85
1945a83
7642f85
1945a83
 
 
 
 
 
 
 
7642f85
1945a83
 
 
 
 
 
 
 
 
 
 
 
 
2955b20
315d5df
 
1945a83
 
7642f85
1945a83
 
315d5df
 
 
 
 
7642f85
315d5df
 
2955b20
315d5df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7642f85
1945a83
8bb18e8
7642f85
8bb18e8
315d5df
 
 
7642f85
315d5df
 
7642f85
 
315d5df
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import io
import torch
import librosa
import subprocess
import tempfile
import soundfile as sf
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForCTC, AutoProcessor

# Set cache to writable directory
os.environ["HF_HOME"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)

app = FastAPI(title="Bambara ASR Dedicated API")

# Enable CORS for your frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load ASR components globally
device = "cpu"
model_id = "facebook/mms-1b-all"

print("Loading processor and model...")
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)

# Pre-load Bambara adapter to prevent lag/OOM on first request
processor.tokenizer.set_target_lang("bam")
model.load_adapter("bam")
print("Bambara adapter loaded. System Ready.")




@app.post("/transcribe")
async def transcribe(audio_file: UploadFile = File(...)):
    try:
        content = await audio_file.read()
        if not content:
            return {"text": "Empty audio"}

        # Write WebM to temp file
        with tempfile.NamedTemporaryFile(suffix=".webm") as f_webm, \
             tempfile.NamedTemporaryFile(suffix=".wav") as f_wav:

            f_webm.write(content)
            f_webm.flush()

            # Convert WebM → WAV (mono, 16kHz)
            subprocess.run(
                [
                    "ffmpeg", "-y",
                    "-i", f_webm.name,
                    "-ac", "1",
                    "-ar", "16000",
                    f_wav.name
                ],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                check=True
            )

            # Read WAV
            audio_data, sr = sf.read(f_wav.name)

        # ASR inference
        inputs = processor(
            audio_data,
            sampling_rate=16000,
            return_tensors="pt"
        ).to(device)

        with torch.inference_mode():
            logits = model(**inputs).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        text = processor.batch_decode(predicted_ids)[0]

        return {"text": text}

    except subprocess.CalledProcessError:
        return {"text": "FFmpeg conversion failed"}

    except Exception as e:
        print("Server Error:", e)
        return {"text": str(e)}