indic-conformer / app.py
sohamchitimali's picture
update
b4b6b98
import os
import tempfile
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
# --- Device ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Load ASR Model ---
ASR_MODEL_ID = "ai4bharat/indic-conformer-600m-multilingual"
asr_model = AutoModel.from_pretrained(ASR_MODEL_ID, trust_remote_code=True).to(DEVICE)
asr_model.eval()
# --- Load Language ID Model ---
LID_MODEL_ID = "facebook/mms-lid-1024"
lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(LID_MODEL_ID)
lid_model = AutoModelForAudioClassification.from_pretrained(LID_MODEL_ID).to(DEVICE)
lid_model.eval()
# --- Language mappings ---
LID_TO_ASR_LANG_MAP = {
"asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
"guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
"kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
"mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
"pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
"tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
}
# --- FastAPI app ---
app = FastAPI(title="Indic ASR API")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# --- Utility ---
def _save_temp(upload: UploadFile) -> str:
suffix = os.path.splitext(upload.filename or "audio")[1] or ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(upload.file.read())
return tmp.name
# --- Endpoint ---
@app.post("/api/stt")
async def transcribe(audio: UploadFile = File(...)):
if not audio:
raise HTTPException(status_code=400, detail="No audio provided")
path = _save_temp(audio)
try:
waveform, sr = torchaudio.load(path)
waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
# Language Detection
inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = lid_model(**inputs)
predicted_id = outputs.logits.argmax(-1).item()
lid_code = lid_model.config.id2label[predicted_id]
asr_lang = LID_TO_ASR_LANG_MAP.get(lid_code, "hi") # default to Hindi if unknown
# ASR Transcription
with torch.no_grad():
transcription = asr_model(waveform_16k.to(DEVICE), asr_lang, "rnnt").strip()
return {"text": transcription, "language": asr_lang}
finally:
os.remove(path)
# --- Test root endpoint ---
@app.get("/")
async def root():
return {"message": "Indic ASR API running. POST audio to /api/transcribe."}
# --- Run locally ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)