import os import tempfile import torch import torchaudio from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor # --- Device --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # --- Load ASR Model --- ASR_MODEL_ID = "ai4bharat/indic-conformer-600m-multilingual" asr_model = AutoModel.from_pretrained(ASR_MODEL_ID, trust_remote_code=True).to(DEVICE) asr_model.eval() # --- Load Language ID Model --- LID_MODEL_ID = "facebook/mms-lid-1024" lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(LID_MODEL_ID) lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID).to(DEVICE) lid_model.eval() # --- Language mappings --- LID_TO_ASR_LANG_MAP = { "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi", "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks", "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml", "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or", "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd", "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur" } # --- FastAPI app --- app = FastAPI(title="Indic ASR API") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) # --- Utility --- def _save_temp(upload: UploadFile) -> str: suffix = os.path.splitext(upload.filename or "audio")[1] or ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(upload.file.read()) return tmp.name # --- Endpoint --- @app.post("/api/stt") async def transcribe(audio: UploadFile = File(...)): if not audio: raise HTTPException(status_code=400, detail="No audio provided") path = _save_temp(audio) try: waveform, sr = torchaudio.load(path) waveform_16k = torchaudio.functional.resample(waveform, sr, 16000) # Language Detection inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = lid_model(**inputs) predicted_id = outputs.logits.argmax(-1).item() lid_code = lid_model.config.id2label[predicted_id] asr_lang = LID_TO_ASR_LANG_MAP.get(lid_code, "hi") # default to Hindi if unknown # ASR Transcription with torch.no_grad(): transcription = asr_model(waveform_16k.to(DEVICE), asr_lang, "rnnt").strip() return {"text": transcription, "language": asr_lang} finally: os.remove(path) # --- Test root endpoint --- @app.get("/") async def root(): return {"message": "Indic ASR API running. POST audio to /api/transcribe."} # --- Run locally --- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)