Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| import torchaudio | |
| from transformers import AutoModel | |
| from dotenv import load_dotenv | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from langchain_google_genai import GoogleGenerativeAI | |
| from langchain.schema import BaseOutputParser | |
| import google.generativeai as genai | |
| load_dotenv() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"β Using device: {device}") | |
| print("π§ Loading IndicConformer model...") | |
| TOKEN = os.getenv("Token") | |
| try: | |
| ASR_MODEL = AutoModel.from_pretrained( | |
| "ai4bharat/indic-conformer-600m-multilingual", | |
| trust_remote_code=True, | |
| token=TOKEN | |
| ) | |
| if device == "cuda": | |
| ASR_MODEL = ASR_MODEL.to(device) | |
| print("β IndicConformer loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load IndicConformer: {e}") | |
| ASR_MODEL = None | |
| gemini_key = os.getenv("GEMINI_API_KEY") | |
| if gemini_key: | |
| try: | |
| genai.configure(api_key=gemini_key) | |
| llm = GoogleGenerativeAI( | |
| model="gemini-2.0-flash-exp", | |
| google_api_key=gemini_key, | |
| temperature=0.1 | |
| ) | |
| print("β Gemini initialized") | |
| except Exception as e: | |
| print(f"β Gemini init failed: {e}") | |
| llm = None | |
| else: | |
| llm = None | |
| class SanskritTextParser(BaseOutputParser): | |
| def parse(self, text: str) -> str: | |
| cleaned = text.strip() | |
| if cleaned.startswith('```') and cleaned.endswith('```'): | |
| cleaned = "\n".join(cleaned.split('\n')[1:-1]) | |
| return cleaned.strip() | |
| sanskrit_correction_prompt = PromptTemplate( | |
| input_variables=["raw_text"], | |
| template="""You are an expert Sanskrit linguist. | |
| User speech may contain Hindi, English, or other Indian languages, | |
| but the goal is to produce correct, grammatical Sanskrit in Devanagari. | |
| Rules: | |
| - Convert meaning into proper Sanskrit, not transliteration. | |
| - Output ONLY corrected Sanskrit text. | |
| - No explanation, no markdown, no translation. | |
| Raw ASR text: {raw_text} | |
| Corrected Sanskrit:""" | |
| ) | |
| def create_langchain_chain(): | |
| if llm: | |
| return LLMChain( | |
| llm=llm, | |
| prompt=sanskrit_correction_prompt, | |
| output_parser=SanskritTextParser() | |
| ) | |
| return None | |
| CHAIN = create_langchain_chain() | |
| def transcribe_with_indic_conformer(audio_path): | |
| if not ASR_MODEL: | |
| return "β ASR model not loaded" | |
| try: | |
| wav, sr = torchaudio.load(audio_path) | |
| if wav.shape[0] > 1: | |
| wav = torch.mean(wav, dim=0, keepdim=True) | |
| wav = torchaudio.transforms.Resample(sr, 16000)(wav) | |
| if device == "cuda": | |
| wav = wav.to(device) | |
| print("π Transcribing (CTC)...") | |
| return ASR_MODEL(wav, "sa", "ctc") | |
| except Exception as e: | |
| return f"β ASR Error: {e}" | |
| def pipeline(audio_file): | |
| raw = transcribe_with_indic_conformer(audio_file) | |
| if not CHAIN: | |
| return raw, "β Gemini not available" | |
| try: | |
| corrected = CHAIN.run(raw_text=raw) | |
| except Exception as e: | |
| corrected = f"β Correction failed: {e}" | |
| return raw, corrected | |
| def create_gradio_interface(): | |
| with gr.Blocks(title="Sanskrit Speech Recognition") as interface: | |
| gr.Markdown("# ποΈ Sanskrit Speech Recognition (CTC + Gemini)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(type="filepath", label="π Upload Audio") | |
| process_btn = gr.Button("π Process Audio", variant="primary") | |
| with gr.Column(): | |
| raw_output = gr.Textbox(label="π― Raw ASR Output", lines=5) | |
| corrected_output = gr.Textbox(label="β¨ Corrected Sanskrit", lines=5) | |
| process_btn.click( | |
| fn=pipeline, | |
| inputs=[audio_input], | |
| outputs=[raw_output, corrected_output] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch( | |
| show_error=True, | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |