Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import ( | |
| AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, | |
| AutoModelForCTC, | |
| AutoModel, | |
| ) | |
| import librosa | |
| import numpy as np | |
| from jiwer import wer, cer | |
| import time | |
| # Model configurations | |
| MODEL_CONFIGS = { | |
| "AudioX-North (Jivi AI)": { | |
| "repo": "jiviai/audioX-north-v1", | |
| "model_type": "seq2seq", | |
| "description": "Supports Hindi, Gujarati, Marathi", | |
| }, | |
| "IndicConformer (AI4Bharat)": { | |
| "repo": "ai4bharat/indic-conformer-600m-multilingual", | |
| "model_type": "ctc_rnnt", | |
| "description": "Supports 22 Indian languages", | |
| "trust_remote_code": True, | |
| }, | |
| "MMS (Facebook)": { | |
| "repo": "facebook/mms-1b-all", | |
| "model_type": "ctc", | |
| "description": "Supports over 1,400 languages (fine-tuning recommended)", | |
| }, | |
| } | |
| # Load model and processor | |
| def load_model_and_processor(model_name): | |
| config = MODEL_CONFIGS[model_name] | |
| repo = config["repo"] | |
| model_type = config["model_type"] | |
| trust_remote_code = config.get("trust_remote_code", False) | |
| try: | |
| if model_name == "IndicConformer (AI4Bharat)": | |
| # Use the working method for AI4Bharat model | |
| print(f"Loading {model_name}...") | |
| try: | |
| model = AutoModel.from_pretrained( | |
| repo, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| except Exception as e1: | |
| print(f"Primary loading failed, trying fallback: {e1}") | |
| model = AutoModel.from_pretrained(repo, trust_remote_code=True) | |
| # AI4Bharat doesn't use a traditional processor | |
| processor = None | |
| return model, processor, model_type | |
| elif model_name == "MMS (Facebook)": | |
| model = AutoModelForCTC.from_pretrained(repo) | |
| processor = AutoProcessor.from_pretrained(repo) | |
| else: # AudioX-North | |
| processor = AutoProcessor.from_pretrained(repo, trust_remote_code=trust_remote_code) | |
| if model_type == "seq2seq": | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained(repo, trust_remote_code=trust_remote_code) | |
| else: | |
| model = AutoModelForCTC.from_pretrained(repo, trust_remote_code=trust_remote_code) | |
| return model, processor, model_type | |
| except Exception as e: | |
| return None, None, f"Error loading model: {str(e)}" | |
| # Compute metrics (WER, CER, RTF) | |
| def compute_metrics(reference, hypothesis, audio_duration, total_time): | |
| if not reference or not hypothesis: | |
| return None, None, None, None | |
| try: | |
| reference = reference.strip().lower() | |
| hypothesis = hypothesis.strip().lower() | |
| wer_score = wer(reference, hypothesis) | |
| cer_score = cer(reference, hypothesis) | |
| rtf = total_time / audio_duration if audio_duration > 0 else None | |
| return wer_score, cer_score, rtf, total_time | |
| except Exception: | |
| return None, None, None, None | |
| # Main transcription function | |
| def transcribe_audio(audio_file, selected_models, reference_text=""): | |
| if not audio_file: | |
| return "Please upload an audio file.", [], "" | |
| if not selected_models: | |
| return "Please select at least one model.", [], "" | |
| table_data = [] | |
| try: | |
| # Load and preprocess audio once | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| audio_duration = len(audio) / sr | |
| for model_name in selected_models: | |
| model, processor, model_type = load_model_and_processor(model_name) | |
| if isinstance(model_type, str) and model_type.startswith("Error"): | |
| table_data.append([ | |
| model_name, | |
| f"Error: {model_type}", | |
| "-", | |
| "-", | |
| "-", | |
| "-" | |
| ]) | |
| continue | |
| start_time = time.time() | |
| # Handle different model types | |
| try: | |
| if model_name == "IndicConformer (AI4Bharat)": | |
| # Use AI4Bharat specific processing | |
| wav = torch.from_numpy(audio).unsqueeze(0) # Add batch dimension | |
| if torch.max(torch.abs(wav)) > 0: | |
| wav = wav / torch.max(torch.abs(wav)) # Normalize | |
| with torch.no_grad(): | |
| # Default to Hindi and RNNT for AI4Bharat | |
| transcription = model(wav, "hi", "rnnt") | |
| if isinstance(transcription, list): | |
| transcription = transcription[0] if transcription else "" | |
| transcription = str(transcription).strip() | |
| else: | |
| # Standard processing for other models | |
| inputs = processor(audio, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| if model_type == "seq2seq": | |
| input_features = inputs["input_features"] | |
| outputs = model.generate(input_features) | |
| transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| else: # CTC or RNNT | |
| input_values = inputs["input_values"] | |
| logits = model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| transcription = f"Processing error: {str(e)}" | |
| total_time = time.time() - start_time | |
| # Compute metrics | |
| wer_score, cer_score, rtf = "-", "-", "-" | |
| if reference_text and transcription and not transcription.startswith("Processing error"): | |
| wer_val, cer_val, rtf_val, _ = compute_metrics( | |
| reference_text, transcription, audio_duration, total_time | |
| ) | |
| wer_score = f"{wer_val:.3f}" if wer_val is not None else "-" | |
| cer_score = f"{cer_val:.3f}" if cer_val is not None else "-" | |
| rtf = f"{rtf_val:.3f}" if rtf_val is not None else "-" | |
| # Add row to table | |
| table_data.append([ | |
| model_name, | |
| transcription, | |
| wer_score, | |
| cer_score, | |
| rtf, | |
| f"{total_time:.2f}s" | |
| ]) | |
| # Create summary text | |
| summary = f"**Audio Duration:** {audio_duration:.2f}s\n" | |
| summary += f"**Models Tested:** {len(selected_models)}\n" | |
| if reference_text: | |
| summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n" | |
| # Create copyable text output | |
| copyable_text = "SPEECH-TO-TEXT BENCHMARK RESULTS\n" + "="*50 + "\n\n" | |
| copyable_text += f"Audio Duration: {audio_duration:.2f}s\n" | |
| copyable_text += f"Models Tested: {len(selected_models)}\n" | |
| if reference_text: | |
| copyable_text += f"Reference Text: {reference_text}\n" | |
| copyable_text += "\n" + "-"*50 + "\n\n" | |
| for i, row in enumerate(table_data): | |
| copyable_text += f"MODEL {i+1}: {row[0]}\n" | |
| copyable_text += f"Transcription: {row[1]}\n" | |
| copyable_text += f"WER: {row[2]}\n" | |
| copyable_text += f"CER: {row[3]}\n" | |
| copyable_text += f"RTF: {row[4]}\n" | |
| copyable_text += f"Time Taken: {row[5]}\n" | |
| copyable_text += "\n" + "-"*30 + "\n\n" | |
| return summary, table_data, copyable_text | |
| except Exception as e: | |
| error_msg = f"Error during transcription: {str(e)}" | |
| return error_msg, [], error_msg | |
| # Create Gradio interface with blocks for better control | |
| def create_interface(): | |
| model_choices = list(MODEL_CONFIGS.keys()) | |
| with gr.Blocks(title="Multilingual Speech-to-Text Benchmark", css=""" | |
| .paste-button { margin: 5px 0; } | |
| .copy-area { font-family: monospace; font-size: 12px; } | |
| """) as iface: | |
| gr.Markdown(""" | |
| # Multilingual Speech-to-Text Benchmark | |
| Upload an audio file, select one or more models, and optionally provide reference text. | |
| The app benchmarks WER, CER, RTF, and Time Taken for each model. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File (16kHz recommended)", | |
| type="filepath" | |
| ) | |
| model_selection = gr.CheckboxGroup( | |
| choices=model_choices, | |
| label="Select Models", | |
| value=[model_choices[0]], # Default to first model | |
| interactive=True | |
| ) | |
| # Enhanced reference text input with paste functionality | |
| with gr.Group(): | |
| gr.Markdown("### Reference Text (Optional for WER/CER)") | |
| reference_input = gr.Textbox( | |
| label="Reference Text (optional, paste supported)", | |
| placeholder="Paste reference transcription here...", | |
| lines=4, | |
| interactive=True | |
| ) | |
| submit_btn = gr.Button("π Transcribe", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| summary_output = gr.Markdown(label="Summary", value="Upload an audio file and select models to begin...") | |
| results_table = gr.Dataframe( | |
| headers=["Model", "Transcription", "WER", "CER", "RTF", "Time Taken"], | |
| datatype=["str", "str", "str", "str", "str", "str"], | |
| label="Results Comparison", | |
| interactive=False, | |
| wrap=True, | |
| column_widths=[150, 400, 80, 80, 80, 100] | |
| ) | |
| # Copyable results section | |
| with gr.Group(): | |
| gr.Markdown("### π Copy Results") | |
| copyable_output = gr.Textbox( | |
| label="Copy-Paste Friendly Results", | |
| lines=15, | |
| max_lines=30, | |
| show_copy_button=True, | |
| interactive=False, | |
| elem_classes="copy-area", | |
| placeholder="Results will appear here in copy-paste friendly format..." | |
| ) | |
| # Connect the function | |
| submit_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, model_selection, reference_input], | |
| outputs=[summary_output, results_table, copyable_output] | |
| ) | |
| # Also allow triggering on Enter in reference text | |
| reference_input.submit( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, model_selection, reference_input], | |
| outputs=[summary_output, results_table, copyable_output] | |
| ) | |
| # Add example and instructions | |
| gr.Markdown(""" | |
| --- | |
| ### π‘ Tips: | |
| - **Reference Text**: Paste your ground truth text to calculate WER/CER metrics | |
| - **Copy Results**: Use the copy button in the results section to copy formatted results | |
| - **AI4Bharat Model**: Automatically uses Hindi language with RNNT decoding | |
| - **Supported Formats**: WAV, MP3, FLAC, M4A (16kHz recommended for best results) | |
| """) | |
| return iface | |
| if __name__ == "__main__": | |
| iface = create_interface() | |
| iface.launch( | |
| share=False, | |
| debug=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |