Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor | |
| # --- Device --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Load ASR Model --- | |
| ASR_MODEL_ID = "ai4bharat/indic-conformer-600m-multilingual" | |
| asr_model = AutoModel.from_pretrained(ASR_MODEL_ID, trust_remote_code=True).to(DEVICE) | |
| asr_model.eval() | |
| # --- Load Language ID Model --- | |
| LID_MODEL_ID = "facebook/mms-lid-1024" | |
| lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(LID_MODEL_ID) | |
| lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID).to(DEVICE) | |
| lid_model.eval() | |
| # --- Language mappings --- | |
| LID_TO_ASR_LANG_MAP = { | |
| "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi", | |
| "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks", | |
| "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml", | |
| "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or", | |
| "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd", | |
| "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur" | |
| } | |
| # --- FastAPI app --- | |
| app = FastAPI(title="Indic ASR API") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # --- Utility --- | |
| def _save_temp(upload: UploadFile) -> str: | |
| suffix = os.path.splitext(upload.filename or "audio")[1] or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| tmp.write(upload.file.read()) | |
| return tmp.name | |
| # --- Endpoint --- | |
| async def transcribe(audio: UploadFile = File(...)): | |
| if not audio: | |
| raise HTTPException(status_code=400, detail="No audio provided") | |
| path = _save_temp(audio) | |
| try: | |
| waveform, sr = torchaudio.load(path) | |
| waveform_16k = torchaudio.functional.resample(waveform, sr, 16000) | |
| # Language Detection | |
| inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = lid_model(**inputs) | |
| predicted_id = outputs.logits.argmax(-1).item() | |
| lid_code = lid_model.config.id2label[predicted_id] | |
| asr_lang = LID_TO_ASR_LANG_MAP.get(lid_code, "hi") # default to Hindi if unknown | |
| # ASR Transcription | |
| with torch.no_grad(): | |
| transcription = asr_model(waveform_16k.to(DEVICE), asr_lang, "rnnt").strip() | |
| return {"text": transcription, "language": asr_lang} | |
| finally: | |
| os.remove(path) | |
| # --- Test root endpoint --- | |
| async def root(): | |
| return {"message": "Indic ASR API running. POST audio to /api/transcribe."} | |
| # --- Run locally --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |