FredyHoundayi commited on
Commit
d0e18c9
·
1 Parent(s): 4845a07

Add per-word confidence, global confidence and uncertainty

Browse files
Files changed (1) hide show
  1. app.py +52 -4
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import io
2
 
3
  import torch
 
4
  import librosa
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
6
  from fastapi.responses import JSONResponse
7
  from transformers import Wav2Vec2ForCTC, AutoProcessor
8
 
9
- app = FastAPI(title="MMS Speech-to-Text API", version="1.0.0")
10
 
11
  MODEL_ID = "facebook/mms-1b-all"
12
  processor = None
@@ -48,9 +49,56 @@ async def transcribe(file: UploadFile = File(...)):
48
  inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
49
 
50
  with torch.no_grad():
51
- logits = model(**inputs).logits
 
 
 
 
 
 
52
 
53
- predicted_ids = torch.argmax(logits, dim=-1)[0]
54
  transcription = processor.decode(predicted_ids)
55
 
56
- return JSONResponse({"transcription": transcription})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
 
3
  import torch
4
+ import torch.nn.functional as F
5
  import librosa
6
  from fastapi import FastAPI, File, UploadFile, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from transformers import Wav2Vec2ForCTC, AutoProcessor
9
 
10
+ app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
11
 
12
  MODEL_ID = "facebook/mms-1b-all"
13
  processor = None
 
49
  inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
50
 
51
  with torch.no_grad():
52
+ outputs = model(**inputs)
53
+ logits = outputs.logits # (batch, time, vocab)
54
+
55
+ probs = F.softmax(logits, dim=-1)
56
+
57
+ predicted_ids = torch.argmax(probs, dim=-1)[0]
58
+ token_probs = torch.max(probs, dim=-1).values[0]
59
 
 
60
  transcription = processor.decode(predicted_ids)
61
 
62
+ tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
63
+
64
+ words = []
65
+ current_word = ""
66
+ current_confs = []
67
+ prev_token = None
68
+
69
+ for tok, conf in zip(tokens, token_probs):
70
+ if tok == "<pad>":
71
+ continue
72
+ if tok == prev_token:
73
+ continue
74
+ prev_token = tok
75
+
76
+ if tok == "|":
77
+ if current_word:
78
+ words.append({
79
+ "word": current_word,
80
+ "confidence": float(sum(current_confs) / len(current_confs))
81
+ })
82
+ current_word = ""
83
+ current_confs = []
84
+ else:
85
+ current_word += tok
86
+ current_confs.append(conf.item())
87
+
88
+ if current_word:
89
+ words.append({
90
+ "word": current_word,
91
+ "confidence": float(sum(current_confs) / len(current_confs))
92
+ })
93
+
94
+ global_conf = float(token_probs.mean().item())
95
+
96
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
97
+ uncertainty = float(entropy.mean().item())
98
+
99
+ return JSONResponse({
100
+ "transcription": transcription,
101
+ "confidence": global_conf,
102
+ "uncertainty": uncertainty,
103
+ "words": words
104
+ })