Spaces:
Running
Running
File size: 4,096 Bytes
4845a07 d0e18c9 4845a07 da466fe 4845a07 d0e18c9 4845a07 da466fe 4845a07 da466fe 4845a07 da466fe 4845a07 da466fe 4845a07 da466fe 4845a07 d0e18c9 da466fe d0e18c9 4845a07 d0e18c9 da466fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import io
import torch
import torch.nn.functional as F
import librosa
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification
app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
MODEL_ID = "facebook/mms-1b-all"
LID_MODEL_ID = "facebook/mms-lid-256"
processor = None
model = None
lid_extractor = None
lid_model = None
@app.on_event("startup")
async def load_model():
global processor, model, lid_extractor, lid_model
print("Loading MMS ASR model...")
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval()
print("Loading MMS LID model...")
lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID)
lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID)
lid_model.eval()
print("All models loaded.")
@app.get("/")
def root():
return {"message": "MMS Speech-to-Text API", "model": MODEL_ID}
@app.get("/health")
def health():
return {
"status": "ok",
"asr_model_loaded": model is not None,
"lid_model_loaded": lid_model is not None,
}
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
audio_bytes = await file.read()
try:
audio, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")
inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=-1)
predicted_ids = torch.argmax(probs, dim=-1)[0]
token_probs = torch.max(probs, dim=-1).values[0]
transcription = processor.decode(predicted_ids)
tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
words = []
current_word = ""
current_confs = []
prev_token = None
for tok, conf in zip(tokens, token_probs):
if tok == "<pad>":
continue
if tok == prev_token:
continue
prev_token = tok
if tok == "|":
if current_word:
words.append({
"word": current_word,
"confidence": float(sum(current_confs) / len(current_confs))
})
current_word = ""
current_confs = []
else:
current_word += tok
current_confs.append(conf.item())
if current_word:
words.append({
"word": current_word,
"confidence": float(sum(current_confs) / len(current_confs))
})
global_conf = float(token_probs.mean().item())
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
uncertainty = float(entropy.mean().item())
return JSONResponse({
"transcription": transcription,
"confidence": global_conf,
"uncertainty": uncertainty,
"words": words
})
@app.post("/lid")
async def language_identification(file: UploadFile = File(...)):
if lid_model is None or lid_extractor is None:
raise HTTPException(status_code=503, detail="LID model not loaded yet")
audio_bytes = await file.read()
try:
audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")
inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = lid_model(**inputs)
logits = outputs.logits
predicted_id = torch.argmax(logits, dim=-1).item()
predicted_lang = lid_model.config.id2label[predicted_id]
return JSONResponse({"language": predicted_lang})
|