Spaces:
Running
Running
| import io | |
| import torch | |
| import torch.nn.functional as F | |
| import librosa | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor, AutoFeatureExtractor, AutoModelForAudioClassification | |
| app = FastAPI(title="MMS Speech-to-Text API", version="2.0.0") | |
| MODEL_ID = "facebook/mms-1b-all" | |
| LID_MODEL_ID = "facebook/mms-lid-256" | |
| processor = None | |
| model = None | |
| lid_extractor = None | |
| lid_model = None | |
| async def load_model(): | |
| global processor, model, lid_extractor, lid_model | |
| print("Loading MMS ASR model...") | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) | |
| model.eval() | |
| print("Loading MMS LID model...") | |
| lid_extractor = AutoFeatureExtractor.from_pretrained(LID_MODEL_ID) | |
| lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID) | |
| lid_model.eval() | |
| print("All models loaded.") | |
| def root(): | |
| return {"message": "MMS Speech-to-Text API", "model": MODEL_ID} | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "asr_model_loaded": model is not None, | |
| "lid_model_loaded": lid_model is not None, | |
| } | |
| async def transcribe(file: UploadFile = File(...)): | |
| if model is None or processor is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| audio_bytes = await file.read() | |
| try: | |
| audio, sampling_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}") | |
| inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=-1) | |
| predicted_ids = torch.argmax(probs, dim=-1)[0] | |
| token_probs = torch.max(probs, dim=-1).values[0] | |
| transcription = processor.decode(predicted_ids) | |
| tokens = processor.tokenizer.convert_ids_to_tokens(predicted_ids) | |
| words = [] | |
| current_word = "" | |
| current_confs = [] | |
| prev_token = None | |
| for tok, conf in zip(tokens, token_probs): | |
| if tok == "<pad>": | |
| continue | |
| if tok == prev_token: | |
| continue | |
| prev_token = tok | |
| if tok == "|": | |
| if current_word: | |
| words.append({ | |
| "word": current_word, | |
| "confidence": float(sum(current_confs) / len(current_confs)) | |
| }) | |
| current_word = "" | |
| current_confs = [] | |
| else: | |
| current_word += tok | |
| current_confs.append(conf.item()) | |
| if current_word: | |
| words.append({ | |
| "word": current_word, | |
| "confidence": float(sum(current_confs) / len(current_confs)) | |
| }) | |
| global_conf = float(token_probs.mean().item()) | |
| entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) | |
| uncertainty = float(entropy.mean().item()) | |
| return JSONResponse({ | |
| "transcription": transcription, | |
| "confidence": global_conf, | |
| "uncertainty": uncertainty, | |
| "words": words | |
| }) | |
| async def language_identification(file: UploadFile = File(...)): | |
| if lid_model is None or lid_extractor is None: | |
| raise HTTPException(status_code=503, detail="LID model not loaded yet") | |
| audio_bytes = await file.read() | |
| try: | |
| audio_input, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to load audio: {e}") | |
| inputs = lid_extractor(audio_input, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = lid_model(**inputs) | |
| logits = outputs.logits | |
| predicted_id = torch.argmax(logits, dim=-1).item() | |
| predicted_lang = lid_model.config.id2label[predicted_id] | |
| return JSONResponse({"language": predicted_lang}) | |