import io import torch import torch.nn.functional as F import librosa from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0") MODEL_ID = "facebook/mms-1b-all" LID_MODEL_ID = "facebook/mms-lid-256" processor = None model = None lid_extractor = None lid_model = None @app.on_event("startup") async def load_model(): global processor, model, lid_extractor, lid_model print("Loading MMS ASR model...") processor = AutoProcessor.from_pretrained(MODEL_ID) model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) model.eval() print("Loading MMS LID model...") lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID) lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID) lid_model.eval() print("All models loaded.") @app.get("/") def root(): return {"message": "MMS Speech-to-Text API", "model": MODEL_ID} @app.get("/health") def health(): return { "status": "ok", "asr_model_loaded": model is not None, "lid_model_loaded": lid_model is not None, } @app.post("/transcribe") async def transcribe(file: UploadFile = File(...)): if model is None or processor is None: raise HTTPException(status_code=503, detail="Model not loaded yet") audio_bytes = await file.read() try: audio, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True) except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}") inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = F.softmax(logits, dim=-1) predicted_ids = torch.argmax(probs, dim=-1)[0] token_probs = torch.max(probs, dim=-1).values[0] transcription = processor.decode(predicted_ids) tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids) words = [] current_word = "" current_confs = [] prev_token = None for tok, conf in zip(tokens, token_probs): if tok == "": continue if tok == prev_token: continue prev_token = tok if tok == "|": if current_word: words.append({ "word": current_word, "confidence": float(sum(current_confs) / len(current_confs)) }) current_word = "" current_confs = [] else: current_word += tok current_confs.append(conf.item()) if current_word: words.append({ "word": current_word, "confidence": float(sum(current_confs) / len(current_confs)) }) global_conf = float(token_probs.mean().item()) entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) uncertainty = float(entropy.mean().item()) return JSONResponse({ "transcription": transcription, "confidence": global_conf, "uncertainty": uncertainty, "words": words }) @app.post("/lid") async def language_identification(file: UploadFile = File(...)): if lid_model is None or lid_extractor is None: raise HTTPException(status_code=503, detail="LID model not loaded yet") audio_bytes = await file.read() try: audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True) except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}") inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): outputs = lid_model(**inputs) logits = outputs.logits predicted_id = torch.argmax(logits, dim=-1).item() predicted_lang = lid_model.config.id2label[predicted_id] return JSONResponse({"language": predicted_lang})