fischerman's picture
Update app.py
0fa2ea9 verified
import gradio as gr
import torch
import torchaudio
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForCTC,
AutoModel,
WhisperProcessor,
WhisperForConditionalGeneration,
)
import librosa
import numpy as np
from jiwer import wer, cer
import time
# Language configurations
# Simplified to only include IndicConformer
LANGUAGE_CONFIGS = {
"Hindi": {
"code": "hi",
"script": "Devanagari",
"models": ["IndicConformer"]
},
"Gujarati": {
"code": "gu",
"script": "Gujarati",
"models": ["IndicConformer"]
},
"Marathi": {
"code": "mr",
"script": "Devanagari",
"models": ["IndicConformer"]
},
"Tamil": {
"code": "ta",
"script": "Tamil",
"models": ["IndicConformer"]
},
"Telugu": {
"code": "te",
"script": "Telugu",
"models": ["IndicConformer"]
},
"Kannada": {
"code": "kn",
"script": "Kannada",
"models": ["IndicConformer"]
},
"Malayalam": {
"code": "ml",
"script": "Malayalam",
"models": ["IndicConformer"]
}
}
# Model configurations
# Simplified to only include IndicConformer
MODEL_CONFIGS = {
"IndicConformer": {
"repo": "ai4bharat/indic-conformer-600m-multilingual",
"model_type": "ctc_rnnt",
"description": "Supports 22 Indian languages",
"trust_remote_code": True,
"languages": ["hi", "gu", "mr", "ta", "te", "kn", "ml", "bn", "pa", "or", "as", "ur"]
}
}
# Load model and processor
def load_model_and_processor(model_name):
config = MODEL_CONFIGS[model_name]
repo = config["repo"]
model_type = config["model_type"]
try:
if model_name == "IndicConformer":
print(f"Loading {model_name}...")
try:
model = AutoModel.from_pretrained(
repo,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
except Exception as e1:
print(f"Primary loading failed, trying fallback: {e1}")
model = AutoModel.from_pretrained(repo, trust_remote_code=True)
processor = None
return model, processor, model_type
except Exception as e:
return None, None, f"Error loading model: {str(e)}"
# Compute metrics (WER, CER, RTF)
def compute_metrics(reference, hypothesis, audio_duration, total_time):
if not reference or not hypothesis:
return None, None, None, None
try:
reference = reference.strip().lower()
hypothesis = hypothesis.strip().lower()
wer_score = wer(reference, hypothesis)
cer_score = cer(reference, hypothesis)
rtf = total_time / audio_duration if audio_duration > 0 else None
return wer_score, cer_score, rtf, total_time
except Exception:
return None, None, None, None
# Main transcription function
def transcribe_audio(audio_file, selected_language, selected_models, reference_text=""):
if not audio_file:
return "Please upload an audio file.", [], ""
if not selected_models:
return "Please select at least one model.", [], ""
if not selected_language:
return "Please select a language.", [], ""
# Get language info
lang_info = LANGUAGE_CONFIGS[selected_language]
lang_code = lang_info["code"]
table_data = []
try:
# Load and preprocess audio once
audio, sr = librosa.load(audio_file, sr=16000)
audio_duration = len(audio) / sr
# We only use one model now: IndicConformer
model_name = "IndicConformer"
# Check if model supports the selected language
if model_name not in lang_info["models"]:
table_data.append([
model_name,
f"Language {selected_language} not supported by this model",
"-", "-", "-", "-"
])
# This part will not be reached due to simplified UI, but kept for robustness
model, processor, model_type = load_model_and_processor(model_name)
if isinstance(model_type, str) and model_type.startswith("Error"):
table_data.append([
model_name,
f"Error: {model_type}",
"-", "-", "-", "-"
])
return "Error loading model.", [], "" # Exit on model error
start_time = time.time()
try:
# AI4Bharat specific processing for IndicConformer
wav = torch.from_numpy(audio).unsqueeze(0)
if torch.max(torch.abs(wav)) > 0:
wav = wav / torch.max(torch.abs(wav))
with torch.no_grad():
transcription = model(wav, lang_code, "rnnt")
if isinstance(transcription, list):
transcription = transcription[0] if transcription else ""
transcription = str(transcription).strip()
except Exception as e:
transcription = f"Processing error: {str(e)}"
total_time = time.time() - start_time
# Compute metrics
wer_score, cer_score, rtf = "-", "-", "-"
if reference_text and transcription and not transcription.startswith("Processing error"):
wer_val, cer_val, rtf_val, _ = compute_metrics(
reference_text, transcription, audio_duration, total_time
)
wer_score = f"{wer_val:.3f}" if wer_val is not None else "-"
cer_score = f"{cer_val:.3f}" if cer_val is not None else "-"
rtf = f"{rtf_val:.3f}" if rtf_val is not None else "-"
# Add row to table
table_data.append([
model_name,
transcription,
wer_score,
cer_score,
rtf,
f"{total_time:.2f}s"
])
# Create summary text
summary = f"**Language:** {selected_language} ({lang_code})\n"
summary += f"**Audio Duration:** {audio_duration:.2f}s\n"
summary += f"**Model Tested:** {model_name}\n"
if reference_text:
summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n"
# Create copyable text output
copyable_text = "MULTILINGUAL SPEECH-TO-TEXT BENCHMARK RESULTS\n" + "="*55 + "\n\n"
copyable_text += f"Language: {selected_language} ({lang_code})\n"
copyable_text += f"Script: {lang_info['script']}\n"
copyable_text += f"Audio Duration: {audio_duration:.2f}s\n"
copyable_text += f"Model Tested: {model_name}\n"
if reference_text:
copyable_text += f"Reference Text: {reference_text}\n"
copyable_text += "\n" + "-"*55 + "\n\n"
for i, row in enumerate(table_data):
copyable_text += f"MODEL {i+1}: {row[0]}\n"
copyable_text += f"Transcription: {row[1]}\n"
copyable_text += f"WER: {row[2]}\n"
copyable_text += f"CER: {row[3]}\n"
copyable_text += f"RTF: {row[4]}\n"
copyable_text += f"Time Taken: {row[5]}\n"
copyable_text += "\n" + "-"*35 + "\n\n"
return summary, table_data, copyable_text
except Exception as e:
error_msg = f"Error during transcription: {str(e)}"
return error_msg, [], error_msg
# Create Gradio interface
def create_interface():
language_choices = list(LANGUAGE_CONFIGS.keys())
with gr.Blocks(title="Multilingual Speech-to-Text Benchmark", css="""
.language-info { background: #f0f8ff; padding: 10px; border-radius: 5px; margin: 10px 0; }
.copy-area { font-family: monospace; font-size: 12px; }
""") as iface:
gr.Markdown("""
# 🌐 Multilingual Speech-to-Text Benchmark
Using only the **IndicConformer** model for 22 Indian languages.
""")
with gr.Row():
with gr.Column(scale=1):
# Language selection
language_selection = gr.Dropdown(
choices=language_choices,
label="πŸ—£οΈ Select Language",
value=language_choices[0],
interactive=True
)
audio_input = gr.Audio(
label="πŸ“Ή Upload Audio File (16kHz recommended)",
type="filepath"
)
# Model selection is now a fixed checkbox
model_selection = gr.CheckboxGroup(
choices=["IndicConformer"],
label="πŸ€– Select Models",
value=["IndicConformer"],
interactive=False # Disabled as only one model is used
)
reference_input = gr.Textbox(
label="πŸ“„ Reference Text (optional, paste supported)",
placeholder="Paste reference transcription here...",
lines=4,
interactive=True
)
submit_btn = gr.Button("πŸš€ Run Multilingual Benchmark", variant="primary", size="lg")
with gr.Column(scale=2):
summary_output = gr.Markdown(
label="πŸ“Š Summary",
value="Select language, upload audio file and choose models to begin..."
)
results_table = gr.Dataframe(
headers=["Model", "Transcription", "WER", "CER", "RTF", "Time"],
datatype=["str", "str", "str", "str", "str", "str"],
label="πŸ† Results Comparison",
interactive=False,
wrap=True,
column_widths=[120, 350, 60, 60, 60, 80]
)
# Copyable results section
with gr.Group():
gr.Markdown("### πŸ“‹ Export Results")
copyable_output = gr.Textbox(
label="Copy-Paste Friendly Results",
lines=12,
max_lines=25,
show_copy_button=True,
interactive=False,
elem_classes="copy-area",
placeholder="Benchmark results will appear here..."
)
# Connect the main function
submit_btn.click(
fn=transcribe_audio,
inputs=[audio_input, language_selection, model_selection, reference_input],
outputs=[summary_output, results_table, copyable_output]
)
reference_input.submit(
fn=transcribe_audio,
inputs=[audio_input, language_selection, model_selection, reference_input],
outputs=[summary_output, results_table, copyable_output]
)
# Language information display
gr.Markdown("""
---
### πŸ“€ Language & Model Support Matrix
| Language | Script | IndicConformer |
|----------|---------|---------------|
| Hindi | Devanagari | βœ… |
| Gujarati | Gujarati | βœ… |
| Marathi | Devanagari | βœ… |
| Tamil | Tamil | βœ… |
| Telugu | Telugu | βœ… |
| Kannada | Kannada | βœ… |
| Malayalam | Malayalam | βœ… |
### πŸ’‘ Tips:
- **Model is fixed** to IndicConformer for this app.
- **Reference Text**: Enable WER/CER calculation by providing ground truth.
- **Copy Results**: Export formatted results using the copy button.
""")
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch(
share=False,
debug=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)