rverma0631's picture
Update app.py
be93ae0 verified
import gradio as gr
import os
import torch
import torchaudio
from transformers import AutoModel
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_google_genai import GoogleGenerativeAI
from langchain.schema import BaseOutputParser
import google.generativeai as genai
load_dotenv()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")
print("πŸ”§ Loading IndicConformer model...")
TOKEN = os.getenv("Token")
try:
ASR_MODEL = AutoModel.from_pretrained(
"ai4bharat/indic-conformer-600m-multilingual",
trust_remote_code=True,
token=TOKEN
)
if device == "cuda":
ASR_MODEL = ASR_MODEL.to(device)
print("βœ… IndicConformer loaded successfully")
except Exception as e:
print(f"❌ Failed to load IndicConformer: {e}")
ASR_MODEL = None
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key:
try:
genai.configure(api_key=gemini_key)
llm = GoogleGenerativeAI(
model="gemini-2.0-flash-exp",
google_api_key=gemini_key,
temperature=0.1
)
print("βœ… Gemini initialized")
except Exception as e:
print(f"❌ Gemini init failed: {e}")
llm = None
else:
llm = None
class SanskritTextParser(BaseOutputParser):
def parse(self, text: str) -> str:
cleaned = text.strip()
if cleaned.startswith('```') and cleaned.endswith('```'):
cleaned = "\n".join(cleaned.split('\n')[1:-1])
return cleaned.strip()
sanskrit_correction_prompt = PromptTemplate(
input_variables=["raw_text"],
template="""You are an expert Sanskrit linguist.
User speech may contain Hindi, English, or other Indian languages,
but the goal is to produce correct, grammatical Sanskrit in Devanagari.
Rules:
- Convert meaning into proper Sanskrit, not transliteration.
- Output ONLY corrected Sanskrit text.
- No explanation, no markdown, no translation.
Raw ASR text: {raw_text}
Corrected Sanskrit:"""
)
def create_langchain_chain():
if llm:
return LLMChain(
llm=llm,
prompt=sanskrit_correction_prompt,
output_parser=SanskritTextParser()
)
return None
CHAIN = create_langchain_chain()
def transcribe_with_indic_conformer(audio_path):
if not ASR_MODEL:
return "❌ ASR model not loaded"
try:
wav, sr = torchaudio.load(audio_path)
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
wav = torchaudio.transforms.Resample(sr, 16000)(wav)
if device == "cuda":
wav = wav.to(device)
print("πŸ”„ Transcribing (CTC)...")
return ASR_MODEL(wav, "sa", "ctc")
except Exception as e:
return f"❌ ASR Error: {e}"
def pipeline(audio_file):
raw = transcribe_with_indic_conformer(audio_file)
if not CHAIN:
return raw, "❌ Gemini not available"
try:
corrected = CHAIN.run(raw_text=raw)
except Exception as e:
corrected = f"❌ Correction failed: {e}"
return raw, corrected
def create_gradio_interface():
with gr.Blocks(title="Sanskrit Speech Recognition") as interface:
gr.Markdown("# πŸ•‰οΈ Sanskrit Speech Recognition (CTC + Gemini)")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="πŸ“ Upload Audio")
process_btn = gr.Button("πŸš€ Process Audio", variant="primary")
with gr.Column():
raw_output = gr.Textbox(label="🎯 Raw ASR Output", lines=5)
corrected_output = gr.Textbox(label="✨ Corrected Sanskrit", lines=5)
process_btn.click(
fn=pipeline,
inputs=[audio_input],
outputs=[raw_output, corrected_output]
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch(
show_error=True,
share=False,
server_name="0.0.0.0",
server_port=7860
)