Spaces:
Running
Running
Commit ·
d0e18c9
1
Parent(s): 4845a07
Add per-word confidence, global confidence and uncertainty
Browse files
app.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
import io
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 4 |
import librosa
|
| 5 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 6 |
from fastapi.responses import JSONResponse
|
| 7 |
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
| 8 |
|
| 9 |
-
app = FastAPI(title="MMS Speech-to-Text API", version="
|
| 10 |
|
| 11 |
MODEL_ID = "facebook/mms-1b-all"
|
| 12 |
processor = None
|
|
@@ -48,9 +49,56 @@ async def transcribe(file: UploadFile = File(...)):
|
|
| 48 |
inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
| 49 |
|
| 50 |
with torch.no_grad():
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
predicted_ids = torch.argmax(logits, dim=-1)[0]
|
| 54 |
transcription = processor.decode(predicted_ids)
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import io
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
import librosa
|
| 6 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 7 |
from fastapi.responses import JSONResponse
|
| 8 |
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
| 9 |
|
| 10 |
+
app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
|
| 11 |
|
| 12 |
MODEL_ID = "facebook/mms-1b-all"
|
| 13 |
processor = None
|
|
|
|
| 49 |
inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
+
outputs = model(**inputs)
|
| 53 |
+
logits = outputs.logits # (batch, time, vocab)
|
| 54 |
+
|
| 55 |
+
probs = F.softmax(logits, dim=-1)
|
| 56 |
+
|
| 57 |
+
predicted_ids = torch.argmax(probs, dim=-1)[0]
|
| 58 |
+
token_probs = torch.max(probs, dim=-1).values[0]
|
| 59 |
|
|
|
|
| 60 |
transcription = processor.decode(predicted_ids)
|
| 61 |
|
| 62 |
+
tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
|
| 63 |
+
|
| 64 |
+
words = []
|
| 65 |
+
current_word = ""
|
| 66 |
+
current_confs = []
|
| 67 |
+
prev_token = None
|
| 68 |
+
|
| 69 |
+
for tok, conf in zip(tokens, token_probs):
|
| 70 |
+
if tok == "<pad>":
|
| 71 |
+
continue
|
| 72 |
+
if tok == prev_token:
|
| 73 |
+
continue
|
| 74 |
+
prev_token = tok
|
| 75 |
+
|
| 76 |
+
if tok == "|":
|
| 77 |
+
if current_word:
|
| 78 |
+
words.append({
|
| 79 |
+
"word": current_word,
|
| 80 |
+
"confidence": float(sum(current_confs) / len(current_confs))
|
| 81 |
+
})
|
| 82 |
+
current_word = ""
|
| 83 |
+
current_confs = []
|
| 84 |
+
else:
|
| 85 |
+
current_word += tok
|
| 86 |
+
current_confs.append(conf.item())
|
| 87 |
+
|
| 88 |
+
if current_word:
|
| 89 |
+
words.append({
|
| 90 |
+
"word": current_word,
|
| 91 |
+
"confidence": float(sum(current_confs) / len(current_confs))
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
global_conf = float(token_probs.mean().item())
|
| 95 |
+
|
| 96 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
|
| 97 |
+
uncertainty = float(entropy.mean().item())
|
| 98 |
+
|
| 99 |
+
return JSONResponse({
|
| 100 |
+
"transcription": transcription,
|
| 101 |
+
"confidence": global_conf,
|
| 102 |
+
"uncertainty": uncertainty,
|
| 103 |
+
"words": words
|
| 104 |
+
})
|