Spaces:
Running
Running
Commit ·
da466fe
1
Parent(s): d0e18c9
Add LID endpoint using facebook/mms-lid-256
Browse files
app.py
CHANGED
|
@@ -5,23 +5,31 @@ import torch.nn.functional as F
|
|
| 5 |
import librosa
|
| 6 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 7 |
from fastapi.responses import JSONResponse
|
| 8 |
-
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
| 9 |
|
| 10 |
app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
|
| 11 |
|
| 12 |
MODEL_ID = "facebook/mms-1b-all"
|
|
|
|
|
|
|
| 13 |
processor = None
|
| 14 |
model = None
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@app.on_event("startup")
|
| 18 |
async def load_model():
|
| 19 |
-
global processor, model
|
| 20 |
-
print("Loading MMS model...")
|
| 21 |
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 22 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
| 23 |
model.eval()
|
| 24 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
@app.get("/")
|
|
@@ -31,7 +39,11 @@ def root():
|
|
| 31 |
|
| 32 |
@app.get("/health")
|
| 33 |
def health():
|
| 34 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
@app.post("/transcribe")
|
|
@@ -50,7 +62,7 @@ async def transcribe(file: UploadFile = File(...)):
|
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
outputs = model(**inputs)
|
| 53 |
-
logits = outputs.logits
|
| 54 |
|
| 55 |
probs = F.softmax(logits, dim=-1)
|
| 56 |
|
|
@@ -58,7 +70,6 @@ async def transcribe(file: UploadFile = File(...)):
|
|
| 58 |
token_probs = torch.max(probs, dim=-1).values[0]
|
| 59 |
|
| 60 |
transcription = processor.decode(predicted_ids)
|
| 61 |
-
|
| 62 |
tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
|
| 63 |
|
| 64 |
words = []
|
|
@@ -92,7 +103,6 @@ async def transcribe(file: UploadFile = File(...)):
|
|
| 92 |
})
|
| 93 |
|
| 94 |
global_conf = float(token_probs.mean().item())
|
| 95 |
-
|
| 96 |
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
|
| 97 |
uncertainty = float(entropy.mean().item())
|
| 98 |
|
|
@@ -102,3 +112,27 @@ async def transcribe(file: UploadFile = File(...)):
|
|
| 102 |
"uncertainty": uncertainty,
|
| 103 |
"words": words
|
| 104 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import librosa
|
| 6 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 7 |
from fastapi.responses import JSONResponse
|
| 8 |
+
from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification
|
| 9 |
|
| 10 |
app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
|
| 11 |
|
| 12 |
MODEL_ID = "facebook/mms-1b-all"
|
| 13 |
+
LID_MODEL_ID = "facebook/mms-lid-256"
|
| 14 |
+
|
| 15 |
processor = None
|
| 16 |
model = None
|
| 17 |
+
lid_extractor = None
|
| 18 |
+
lid_model = None
|
| 19 |
|
| 20 |
|
| 21 |
@app.on_event("startup")
|
| 22 |
async def load_model():
|
| 23 |
+
global processor, model, lid_extractor, lid_model
|
| 24 |
+
print("Loading MMS ASR model...")
|
| 25 |
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 26 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
| 27 |
model.eval()
|
| 28 |
+
print("Loading MMS LID model...")
|
| 29 |
+
lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID)
|
| 30 |
+
lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID)
|
| 31 |
+
lid_model.eval()
|
| 32 |
+
print("All models loaded.")
|
| 33 |
|
| 34 |
|
| 35 |
@app.get("/")
|
|
|
|
| 39 |
|
| 40 |
@app.get("/health")
|
| 41 |
def health():
|
| 42 |
+
return {
|
| 43 |
+
"status": "ok",
|
| 44 |
+
"asr_model_loaded": model is not None,
|
| 45 |
+
"lid_model_loaded": lid_model is not None,
|
| 46 |
+
}
|
| 47 |
|
| 48 |
|
| 49 |
@app.post("/transcribe")
|
|
|
|
| 62 |
|
| 63 |
with torch.no_grad():
|
| 64 |
outputs = model(**inputs)
|
| 65 |
+
logits = outputs.logits
|
| 66 |
|
| 67 |
probs = F.softmax(logits, dim=-1)
|
| 68 |
|
|
|
|
| 70 |
token_probs = torch.max(probs, dim=-1).values[0]
|
| 71 |
|
| 72 |
transcription = processor.decode(predicted_ids)
|
|
|
|
| 73 |
tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
|
| 74 |
|
| 75 |
words = []
|
|
|
|
| 103 |
})
|
| 104 |
|
| 105 |
global_conf = float(token_probs.mean().item())
|
|
|
|
| 106 |
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
|
| 107 |
uncertainty = float(entropy.mean().item())
|
| 108 |
|
|
|
|
| 112 |
"uncertainty": uncertainty,
|
| 113 |
"words": words
|
| 114 |
})
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@app.post("/lid")
|
| 118 |
+
async def language_identification(file: UploadFile = File(...)):
|
| 119 |
+
if lid_model is None or lid_extractor is None:
|
| 120 |
+
raise HTTPException(status_code=503, detail="LID model not loaded yet")
|
| 121 |
+
|
| 122 |
+
audio_bytes = await file.read()
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")
|
| 128 |
+
|
| 129 |
+
inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt")
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
outputs = lid_model(**inputs)
|
| 133 |
+
logits = outputs.logits
|
| 134 |
+
|
| 135 |
+
predicted_id = torch.argmax(logits, dim=-1).item()
|
| 136 |
+
predicted_lang = lid_model.config.id2label[predicted_id]
|
| 137 |
+
|
| 138 |
+
return JSONResponse({"language": predicted_lang})
|