FredyHoundayi commited on
Commit
da466fe
·
1 Parent(s): d0e18c9

Add LID endpoint using facebook/mms-lid-256

Browse files
Files changed (1) hide show
  1. app.py +42 -8
app.py CHANGED
@@ -5,23 +5,31 @@ import torch.nn.functional as F
5
  import librosa
6
  from fastapi import FastAPI, File, UploadFile, HTTPException
7
  from fastapi.responses import JSONResponse
8
- from transformers import Wav2Vec2ForCTC, AutoProcessor
9
 
10
  app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
11
 
12
  MODEL_ID = "facebook/mms-1b-all"
 
 
13
  processor = None
14
  model = None
 
 
15
 
16
 
17
  @app.on_event("startup")
18
  async def load_model():
19
- global processor, model
20
- print("Loading MMS model...")
21
  processor = AutoProcessor.from_pretrained(MODEL_ID)
22
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
23
  model.eval()
24
- print("Model loaded.")
 
 
 
 
25
 
26
 
27
  @app.get("/")
@@ -31,7 +39,11 @@ def root():
31
 
32
  @app.get("/health")
33
  def health():
34
- return {"status": "ok", "model_loaded": model is not None}
 
 
 
 
35
 
36
 
37
  @app.post("/transcribe")
@@ -50,7 +62,7 @@ async def transcribe(file: UploadFile = File(...)):
50
 
51
  with torch.no_grad():
52
  outputs = model(**inputs)
53
- logits = outputs.logits # (batch, time, vocab)
54
 
55
  probs = F.softmax(logits, dim=-1)
56
 
@@ -58,7 +70,6 @@ async def transcribe(file: UploadFile = File(...)):
58
  token_probs = torch.max(probs, dim=-1).values[0]
59
 
60
  transcription = processor.decode(predicted_ids)
61
-
62
  tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
63
 
64
  words = []
@@ -92,7 +103,6 @@ async def transcribe(file: UploadFile = File(...)):
92
  })
93
 
94
  global_conf = float(token_probs.mean().item())
95
-
96
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
97
  uncertainty = float(entropy.mean().item())
98
 
@@ -102,3 +112,27 @@ async def transcribe(file: UploadFile = File(...)):
102
  "uncertainty": uncertainty,
103
  "words": words
104
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import librosa
6
  from fastapi import FastAPI, File, UploadFile, HTTPException
7
  from fastapi.responses import JSONResponse
8
+ from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification
9
 
10
  app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0")
11
 
12
  MODEL_ID = "facebook/mms-1b-all"
13
+ LID_MODEL_ID = "facebook/mms-lid-256"
14
+
15
  processor = None
16
  model = None
17
+ lid_extractor = None
18
+ lid_model = None
19
 
20
 
21
  @app.on_event("startup")
22
  async def load_model():
23
+ global processor, model, lid_extractor, lid_model
24
+ print("Loading MMS ASR model...")
25
  processor = AutoProcessor.from_pretrained(MODEL_ID)
26
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
27
  model.eval()
28
+ print("Loading MMS LID model...")
29
+ lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID)
30
+ lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID)
31
+ lid_model.eval()
32
+ print("All models loaded.")
33
 
34
 
35
  @app.get("/")
 
39
 
40
  @app.get("/health")
41
  def health():
42
+ return {
43
+ "status": "ok",
44
+ "asr_model_loaded": model is not None,
45
+ "lid_model_loaded": lid_model is not None,
46
+ }
47
 
48
 
49
  @app.post("/transcribe")
 
62
 
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
+ logits = outputs.logits
66
 
67
  probs = F.softmax(logits, dim=-1)
68
 
 
70
  token_probs = torch.max(probs, dim=-1).values[0]
71
 
72
  transcription = processor.decode(predicted_ids)
 
73
  tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
74
 
75
  words = []
 
103
  })
104
 
105
  global_conf = float(token_probs.mean().item())
 
106
  entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
107
  uncertainty = float(entropy.mean().item())
108
 
 
112
  "uncertainty": uncertainty,
113
  "words": words
114
  })
115
+
116
+
117
+ @app.post("/lid")
118
+ async def language_identification(file: UploadFile = File(...)):
119
+ if lid_model is None or lid_extractor is None:
120
+ raise HTTPException(status_code=503, detail="LID model not loaded yet")
121
+
122
+ audio_bytes = await file.read()
123
+
124
+ try:
125
+ audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
126
+ except Exception as e:
127
+ raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}")
128
+
129
+ inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt")
130
+
131
+ with torch.no_grad():
132
+ outputs = lid_model(**inputs)
133
+ logits = outputs.logits
134
+
135
+ predicted_id = torch.argmax(logits, dim=-1).item()
136
+ predicted_lang = lid_model.config.id2label[predicted_id]
137
+
138
+ return JSONResponse({"language": predicted_lang})