File size: 4,096 Bytes
4845a07
 
 
d0e18c9
4845a07
 
 
da466fe
4845a07
d0e18c9
4845a07
 
da466fe
 
4845a07
 
da466fe
 
4845a07
 
 
 
da466fe
 
4845a07
 
 
da466fe
 
 
 
 
4845a07
 
 
 
 
 
 
 
 
da466fe
 
 
 
 
4845a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0e18c9
da466fe
d0e18c9
 
 
 
 
4845a07
 
d0e18c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da466fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import io

import torch
import torch.nn.functional as F
import librosa
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification

app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")

MODEL_ID = "facebook/mms-1b-all"
LID_MODEL_ID = "facebook/mms-lid-256"

processor = None
model = None
lid_extractor = None
lid_model = None


@app.on_event("startup")
async def load_model():
    global processor, model, lid_extractor, lid_model
    print("Loading MMS ASR model...")
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
    model.eval()
    print("Loading MMS LID model...")
    lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID)
    lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID)
    lid_model.eval()
    print("All models loaded.")


@app.get("/")
def root():
    return {"message": "MMS Speech-to-Text API", "model": MODEL_ID}


@app.get("/health")
def health():
    return {
        "status": "ok",
        "asr_model_loaded": model is not None,
        "lid_model_loaded": lid_model is not None,
    }


@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
    if model is None or processor is None:
        raise HTTPException(status_code=503, detail="Model not loaded yet")

    audio_bytes = await file.read()

    try:
        audio, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")

    inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    probs = F.softmax(logits, dim=-1)

    predicted_ids = torch.argmax(probs, dim=-1)[0]
    token_probs = torch.max(probs, dim=-1).values[0]

    transcription = processor.decode(predicted_ids)
    tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)

    words = []
    current_word = ""
    current_confs = []
    prev_token = None

    for tok, conf in zip(tokens, token_probs):
        if tok == "<pad>":
            continue
        if tok == prev_token:
            continue
        prev_token = tok

        if tok == "|":
            if current_word:
                words.append({
                    "word": current_word,
                    "confidence": float(sum(current_confs) / len(current_confs))
                })
            current_word = ""
            current_confs = []
        else:
            current_word += tok
            current_confs.append(conf.item())

    if current_word:
        words.append({
            "word": current_word,
            "confidence": float(sum(current_confs) / len(current_confs))
        })

    global_conf = float(token_probs.mean().item())
    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
    uncertainty = float(entropy.mean().item())

    return JSONResponse({
        "transcription": transcription,
        "confidence": global_conf,
        "uncertainty": uncertainty,
        "words": words
    })


@app.post("/lid")
async def language_identification(file: UploadFile = File(...)):
    if lid_model is None or lid_extractor is None:
        raise HTTPException(status_code=503, detail="LID model not loaded yet")

    audio_bytes = await file.read()

    try:
        audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")

    inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt")

    with torch.no_grad():
        outputs = lid_model(**inputs)
        logits = outputs.logits

    predicted_id = torch.argmax(logits, dim=-1).item()
    predicted_lang = lid_model.config.id2label[predicted_id]

    return JSONResponse({"language": predicted_lang})