|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import ( |
|
|
AutoModelForSpeechSeq2Seq, |
|
|
AutoProcessor, |
|
|
AutoModelForCTC, |
|
|
AutoModel, |
|
|
WhisperProcessor, |
|
|
WhisperForConditionalGeneration, |
|
|
) |
|
|
import librosa |
|
|
import numpy as np |
|
|
from jiwer import wer, cer |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
LANGUAGE_CONFIGS = { |
|
|
"Hindi": { |
|
|
"code": "hi", |
|
|
"script": "Devanagari", |
|
|
"models": ["IndicConformer"] |
|
|
}, |
|
|
"Gujarati": { |
|
|
"code": "gu", |
|
|
"script": "Gujarati", |
|
|
"models": ["IndicConformer"] |
|
|
}, |
|
|
"Marathi": { |
|
|
"code": "mr", |
|
|
"script": "Devanagari", |
|
|
"models": ["IndicConformer"] |
|
|
}, |
|
|
"Tamil": { |
|
|
"code": "ta", |
|
|
"script": "Tamil", |
|
|
"models": ["IndicConformer"] |
|
|
}, |
|
|
"Telugu": { |
|
|
"code": "te", |
|
|
"script": "Telugu", |
|
|
"models": ["IndicConformer"] |
|
|
}, |
|
|
"Kannada": { |
|
|
"code": "kn", |
|
|
"script": "Kannada", |
|
|
"models": ["IndicConformer"] |
|
|
}, |
|
|
"Malayalam": { |
|
|
"code": "ml", |
|
|
"script": "Malayalam", |
|
|
"models": ["IndicConformer"] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"IndicConformer": { |
|
|
"repo": "ai4bharat/indic-conformer-600m-multilingual", |
|
|
"model_type": "ctc_rnnt", |
|
|
"description": "Supports 22 Indian languages", |
|
|
"trust_remote_code": True, |
|
|
"languages": ["hi", "gu", "mr", "ta", "te", "kn", "ml", "bn", "pa", "or", "as", "ur"] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def load_model_and_processor(model_name): |
|
|
config = MODEL_CONFIGS[model_name] |
|
|
repo = config["repo"] |
|
|
model_type = config["model_type"] |
|
|
try: |
|
|
if model_name == "IndicConformer": |
|
|
print(f"Loading {model_name}...") |
|
|
try: |
|
|
model = AutoModel.from_pretrained( |
|
|
repo, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
except Exception as e1: |
|
|
print(f"Primary loading failed, trying fallback: {e1}") |
|
|
model = AutoModel.from_pretrained(repo, trust_remote_code=True) |
|
|
processor = None |
|
|
return model, processor, model_type |
|
|
except Exception as e: |
|
|
return None, None, f"Error loading model: {str(e)}" |
|
|
|
|
|
|
|
|
def compute_metrics(reference, hypothesis, audio_duration, total_time): |
|
|
if not reference or not hypothesis: |
|
|
return None, None, None, None |
|
|
try: |
|
|
reference = reference.strip().lower() |
|
|
hypothesis = hypothesis.strip().lower() |
|
|
wer_score = wer(reference, hypothesis) |
|
|
cer_score = cer(reference, hypothesis) |
|
|
rtf = total_time / audio_duration if audio_duration > 0 else None |
|
|
return wer_score, cer_score, rtf, total_time |
|
|
except Exception: |
|
|
return None, None, None, None |
|
|
|
|
|
|
|
|
def transcribe_audio(audio_file, selected_language, selected_models, reference_text=""): |
|
|
if not audio_file: |
|
|
return "Please upload an audio file.", [], "" |
|
|
if not selected_models: |
|
|
return "Please select at least one model.", [], "" |
|
|
if not selected_language: |
|
|
return "Please select a language.", [], "" |
|
|
|
|
|
|
|
|
lang_info = LANGUAGE_CONFIGS[selected_language] |
|
|
lang_code = lang_info["code"] |
|
|
|
|
|
table_data = [] |
|
|
try: |
|
|
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
|
audio_duration = len(audio) / sr |
|
|
|
|
|
|
|
|
model_name = "IndicConformer" |
|
|
|
|
|
|
|
|
if model_name not in lang_info["models"]: |
|
|
table_data.append([ |
|
|
model_name, |
|
|
f"Language {selected_language} not supported by this model", |
|
|
"-", "-", "-", "-" |
|
|
]) |
|
|
|
|
|
|
|
|
model, processor, model_type = load_model_and_processor(model_name) |
|
|
if isinstance(model_type, str) and model_type.startswith("Error"): |
|
|
table_data.append([ |
|
|
model_name, |
|
|
f"Error: {model_type}", |
|
|
"-", "-", "-", "-" |
|
|
]) |
|
|
return "Error loading model.", [], "" |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
wav = torch.from_numpy(audio).unsqueeze(0) |
|
|
if torch.max(torch.abs(wav)) > 0: |
|
|
wav = wav / torch.max(torch.abs(wav)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
transcription = model(wav, lang_code, "rnnt") |
|
|
if isinstance(transcription, list): |
|
|
transcription = transcription[0] if transcription else "" |
|
|
transcription = str(transcription).strip() |
|
|
|
|
|
except Exception as e: |
|
|
transcription = f"Processing error: {str(e)}" |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
|
|
|
wer_score, cer_score, rtf = "-", "-", "-" |
|
|
if reference_text and transcription and not transcription.startswith("Processing error"): |
|
|
wer_val, cer_val, rtf_val, _ = compute_metrics( |
|
|
reference_text, transcription, audio_duration, total_time |
|
|
) |
|
|
wer_score = f"{wer_val:.3f}" if wer_val is not None else "-" |
|
|
cer_score = f"{cer_val:.3f}" if cer_val is not None else "-" |
|
|
rtf = f"{rtf_val:.3f}" if rtf_val is not None else "-" |
|
|
|
|
|
|
|
|
table_data.append([ |
|
|
model_name, |
|
|
transcription, |
|
|
wer_score, |
|
|
cer_score, |
|
|
rtf, |
|
|
f"{total_time:.2f}s" |
|
|
]) |
|
|
|
|
|
|
|
|
summary = f"**Language:** {selected_language} ({lang_code})\n" |
|
|
summary += f"**Audio Duration:** {audio_duration:.2f}s\n" |
|
|
summary += f"**Model Tested:** {model_name}\n" |
|
|
if reference_text: |
|
|
summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n" |
|
|
|
|
|
|
|
|
copyable_text = "MULTILINGUAL SPEECH-TO-TEXT BENCHMARK RESULTS\n" + "="*55 + "\n\n" |
|
|
copyable_text += f"Language: {selected_language} ({lang_code})\n" |
|
|
copyable_text += f"Script: {lang_info['script']}\n" |
|
|
copyable_text += f"Audio Duration: {audio_duration:.2f}s\n" |
|
|
copyable_text += f"Model Tested: {model_name}\n" |
|
|
if reference_text: |
|
|
copyable_text += f"Reference Text: {reference_text}\n" |
|
|
copyable_text += "\n" + "-"*55 + "\n\n" |
|
|
|
|
|
for i, row in enumerate(table_data): |
|
|
copyable_text += f"MODEL {i+1}: {row[0]}\n" |
|
|
copyable_text += f"Transcription: {row[1]}\n" |
|
|
copyable_text += f"WER: {row[2]}\n" |
|
|
copyable_text += f"CER: {row[3]}\n" |
|
|
copyable_text += f"RTF: {row[4]}\n" |
|
|
copyable_text += f"Time Taken: {row[5]}\n" |
|
|
copyable_text += "\n" + "-"*35 + "\n\n" |
|
|
|
|
|
return summary, table_data, copyable_text |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error during transcription: {str(e)}" |
|
|
return error_msg, [], error_msg |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
language_choices = list(LANGUAGE_CONFIGS.keys()) |
|
|
|
|
|
with gr.Blocks(title="Multilingual Speech-to-Text Benchmark", css=""" |
|
|
.language-info { background: #f0f8ff; padding: 10px; border-radius: 5px; margin: 10px 0; } |
|
|
.copy-area { font-family: monospace; font-size: 12px; } |
|
|
""") as iface: |
|
|
gr.Markdown(""" |
|
|
# π Multilingual Speech-to-Text Benchmark |
|
|
|
|
|
Using only the **IndicConformer** model for 22 Indian languages. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
language_selection = gr.Dropdown( |
|
|
choices=language_choices, |
|
|
label="π£οΈ Select Language", |
|
|
value=language_choices[0], |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
audio_input = gr.Audio( |
|
|
label="πΉ Upload Audio File (16kHz recommended)", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
|
|
|
model_selection = gr.CheckboxGroup( |
|
|
choices=["IndicConformer"], |
|
|
label="π€ Select Models", |
|
|
value=["IndicConformer"], |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
reference_input = gr.Textbox( |
|
|
label="π Reference Text (optional, paste supported)", |
|
|
placeholder="Paste reference transcription here...", |
|
|
lines=4, |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("π Run Multilingual Benchmark", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
summary_output = gr.Markdown( |
|
|
label="π Summary", |
|
|
value="Select language, upload audio file and choose models to begin..." |
|
|
) |
|
|
|
|
|
results_table = gr.Dataframe( |
|
|
headers=["Model", "Transcription", "WER", "CER", "RTF", "Time"], |
|
|
datatype=["str", "str", "str", "str", "str", "str"], |
|
|
label="π Results Comparison", |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
column_widths=[120, 350, 60, 60, 60, 80] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### π Export Results") |
|
|
copyable_output = gr.Textbox( |
|
|
label="Copy-Paste Friendly Results", |
|
|
lines=12, |
|
|
max_lines=25, |
|
|
show_copy_button=True, |
|
|
interactive=False, |
|
|
elem_classes="copy-area", |
|
|
placeholder="Benchmark results will appear here..." |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=transcribe_audio, |
|
|
inputs=[audio_input, language_selection, model_selection, reference_input], |
|
|
outputs=[summary_output, results_table, copyable_output] |
|
|
) |
|
|
|
|
|
reference_input.submit( |
|
|
fn=transcribe_audio, |
|
|
inputs=[audio_input, language_selection, model_selection, reference_input], |
|
|
outputs=[summary_output, results_table, copyable_output] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π€ Language & Model Support Matrix |
|
|
|
|
|
| Language | Script | IndicConformer | |
|
|
|----------|---------|---------------| |
|
|
| Hindi | Devanagari | β
| |
|
|
| Gujarati | Gujarati | β
| |
|
|
| Marathi | Devanagari | β
| |
|
|
| Tamil | Tamil | β
| |
|
|
| Telugu | Telugu | β
| |
|
|
| Kannada | Kannada | β
| |
|
|
| Malayalam | Malayalam | β
| |
|
|
|
|
|
### π‘ Tips: |
|
|
- **Model is fixed** to IndicConformer for this app. |
|
|
- **Reference Text**: Enable WER/CER calculation by providing ground truth. |
|
|
- **Copy Results**: Export formatted results using the copy button. |
|
|
""") |
|
|
return iface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface = create_interface() |
|
|
iface.launch( |
|
|
share=False, |
|
|
debug=True, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |