AvtnshM commited on
Commit
c511528
·
verified ·
1 Parent(s): 8ae6da7
Files changed (1) hide show
  1. app.py +97 -211
app.py CHANGED
@@ -1,219 +1,105 @@
1
- import gradio as gr
2
  import time
3
  import librosa
4
- import torch
5
- import numpy as np
6
  from jiwer import wer, cer
7
- from transformers import (
8
- WhisperProcessor, WhisperForConditionalGeneration,
9
- AutoProcessor, AutoModelForCTC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
11
 
12
- # Global variables for models (loaded once)
13
- whisper_processor = None
14
- whisper_model = None
15
- conformer_processor = None
16
- conformer_model = None
17
-
18
- def load_models():
19
- """Load models once at startup"""
20
- global whisper_processor, whisper_model, conformer_processor, conformer_model
21
-
22
- if whisper_processor is None:
23
- print("Loading IndicWhisper...")
24
- whisper_processor = WhisperProcessor.from_pretrained("parthiv11/indic_whisper_nodcil")
25
- whisper_model = WhisperForConditionalGeneration.from_pretrained("parthiv11/indic_whisper_nodcil")
26
-
27
- print("Loading IndicConformer...")
28
- conformer_processor = AutoProcessor.from_pretrained("ai4bharat/indicconformer_asr_conformer_multilingual")
29
- conformer_model = AutoModelForCTC.from_pretrained("ai4bharat/indicconformer_asr_conformer_multilingual")
30
-
31
- print("Models loaded successfully!")
32
-
33
- def transcribe_whisper(audio_path):
34
- """Transcribe using IndicWhisper"""
35
- audio, sr = librosa.load(audio_path, sr=16000)
36
- input_features = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features
37
-
38
- start_time = time.time()
39
- with torch.no_grad():
40
- predicted_ids = whisper_model.generate(input_features)
41
- end_time = time.time()
42
-
43
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
44
- return transcription, end_time - start_time
45
-
46
- def transcribe_conformer(audio_path):
47
- """Transcribe using IndicConformer"""
48
- audio, sr = librosa.load(audio_path, sr=16000)
49
- input_values = conformer_processor(audio, sampling_rate=sr, return_tensors="pt").input_values
50
-
51
- start_time = time.time()
52
- with torch.no_grad():
53
- logits = conformer_model(input_values).logits
54
- predicted_ids = torch.argmax(logits, dim=-1)
55
- end_time = time.time()
56
-
57
- transcription = conformer_processor.batch_decode(predicted_ids)[0]
58
- return transcription, end_time - start_time
59
-
60
- def compare_models(audio_file, ground_truth_text):
61
- """Main comparison function for Gradio interface"""
62
-
63
- if audio_file is None:
64
- return "Please upload an audio file", "", "", "", "", ""
65
-
66
- load_models() # Ensure models are loaded
67
-
68
- try:
69
- # Get audio duration
70
- audio_duration = librosa.get_duration(filename=audio_file)
71
-
72
- # Test IndicWhisper
73
- whisper_pred, whisper_time = transcribe_whisper(audio_file)
74
- whisper_rtf = whisper_time / audio_duration if audio_duration > 0 else 0
75
-
76
- # Test IndicConformer
77
- conformer_pred, conformer_time = transcribe_conformer(audio_file)
78
- conformer_rtf = conformer_time / audio_duration if audio_duration > 0 else 0
79
-
80
- # Calculate metrics if ground truth provided
81
- if ground_truth_text and ground_truth_text.strip():
82
- whisper_wer = wer(ground_truth_text, whisper_pred)
83
- whisper_cer = cer(ground_truth_text, whisper_pred)
84
- conformer_wer = wer(ground_truth_text, conformer_pred)
85
- conformer_cer = cer(ground_truth_text, conformer_pred)
86
-
87
- # Format results with metrics
88
- whisper_result = f"""
89
- ## 📊 IndicWhisper Results:
90
- **Prediction:** {whisper_pred}
91
-
92
- **WER:** {whisper_wer:.3f}
93
- **CER:** {whisper_cer:.3f}
94
- **RTF:** {whisper_rtf:.3f} {'✅ Real-time' if whisper_rtf < 1.0 else '⚠️ Slower'}
95
- **Time:** {whisper_time:.2f}s
96
- """
97
-
98
- conformer_result = f"""
99
- ## 📊 IndicConformer Results:
100
- **Prediction:** {conformer_pred}
101
-
102
- **WER:** {conformer_wer:.3f}
103
- **CER:** {conformer_cer:.3f}
104
- **RTF:** {conformer_rtf:.3f} {'✅ Real-time' if conformer_rtf < 1.0 else '⚠️ Slower'}
105
- **Time:** {conformer_time:.2f}s
106
- """
107
-
108
- # Winner analysis
109
- wer_winner = "IndicWhisper" if whisper_wer < conformer_wer else "IndicConformer"
110
- cer_winner = "IndicWhisper" if whisper_cer < conformer_cer else "IndicConformer"
111
- rtf_winner = "IndicWhisper" if whisper_rtf < conformer_rtf else "IndicConformer"
112
-
113
- winner_analysis = f"""
114
- ## 🏆 Winner Analysis:
115
- **Best WER:** {wer_winner} ({min(whisper_wer, conformer_wer):.3f})
116
- **Best CER:** {cer_winner} ({min(whisper_cer, conformer_cer):.3f})
117
- **Fastest:** {rtf_winner} ({min(whisper_rtf, conformer_rtf):.3f})
118
- """
119
- else:
120
- # Results without metrics (no ground truth)
121
- whisper_result = f"""
122
- ## 📊 IndicWhisper Results:
123
- **Prediction:** {whisper_pred}
124
-
125
- **RTF:** {whisper_rtf:.3f}
126
- **Time:** {whisper_time:.2f}s
127
- """
128
-
129
- conformer_result = f"""
130
- ## 📊 IndicConformer Results:
131
- **Prediction:** {conformer_pred}
132
-
133
- **RTF:** {conformer_rtf:.3f}
134
- **Time:** {conformer_time:.2f}s
135
- """
136
-
137
- winner_analysis = f"""
138
- ## 🏆 Speed Comparison:
139
- **Faster Model:** {'IndicWhisper' if whisper_rtf < conformer_rtf else 'IndicConformer'}
140
- **RTF Difference:** {abs(whisper_rtf - conformer_rtf):.3f}
141
- """
142
-
143
- return whisper_result, conformer_result, winner_analysis, whisper_pred, conformer_pred, f"Audio duration: {audio_duration:.2f}s"
144
-
145
- except Exception as e:
146
- error_msg = f"❌ Error processing audio: {str(e)}"
147
- return error_msg, "", "", "", "", ""
148
-
149
- # Create Gradio Interface
150
- with gr.Blocks(title="ASR Model Comparison") as demo:
151
-
152
- gr.Markdown("""
153
- # 🎤 ASR Model Comparison: IndicWhisper vs IndicConformer
154
-
155
- Compare two leading Indian language ASR models on your audio files!
156
-
157
- **Models:**
158
- - **IndicWhisper:** `parthiv11/indic_whisper_nodcil`
159
- - **IndicConformer:** `ai4bharat/indicconformer_asr_conformer_multilingual`
160
-
161
- **Metrics:** WER (Word Error Rate), CER (Character Error Rate), RTF (Real-Time Factor)
162
- """)
163
-
164
- with gr.Row():
165
- with gr.Column():
166
- audio_input = gr.Audio(
167
- label="🎵 Upload Audio File",
168
- type="filepath"
169
- )
170
- ground_truth_input = gr.Textbox(
171
- label="📝 Ground Truth Text (Optional)",
172
- placeholder="Enter expected transcription for WER/CER calculation...",
173
- lines=3
174
- )
175
- compare_btn = gr.Button("🚀 Compare Models", variant="primary")
176
-
177
- with gr.Column():
178
- audio_info = gr.Textbox(label="ℹ️ Audio Info", interactive=False)
179
-
180
- with gr.Row():
181
- with gr.Column():
182
- whisper_output = gr.Markdown(label="IndicWhisper Results")
183
- with gr.Column():
184
- conformer_output = gr.Markdown(label="IndicConformer Results")
185
-
186
- winner_output = gr.Markdown(label="🏆 Comparison Summary")
187
-
188
- # Hidden outputs for API access
189
- with gr.Row(visible=False):
190
- whisper_text = gr.Textbox(label="Whisper Transcription")
191
- conformer_text = gr.Textbox(label="Conformer Transcription")
192
-
193
- compare_btn.click(
194
- fn=compare_models,
195
- inputs=[audio_input, ground_truth_input],
196
- outputs=[whisper_output, conformer_output, winner_output, whisper_text, conformer_text, audio_info]
197
- )
198
-
199
- gr.Markdown("""
200
- ## 📋 How to Use:
201
- 1. **Upload audio** in any supported format (WAV, MP3, M4A, etc.)
202
- 2. **Add ground truth** (optional) - if provided, you'll get WER/CER metrics
203
- 3. **Click Compare** to see results from both models
204
- 4. **Analyze** which model performs better for your use case
205
-
206
- ## 📖 Understanding Metrics:
207
- - **WER (Word Error Rate):** Percentage of words transcribed incorrectly (Lower = Better, 0 = Perfect)
208
- - **CER (Character Error Rate):** Percentage of characters transcribed incorrectly (Lower = Better, 0 = Perfect)
209
- - **RTF (Real-Time Factor):** Ratio of processing time to audio duration (Lower = Faster, <1.0 = Real-time capable)
210
-
211
- ## 🌐 Supported Languages:
212
- Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Odia, Punjabi, Sanskrit, Tamil, Telugu, Urdu
213
- """)
214
-
215
- # Load models on startup
216
- load_models()
217
 
218
  if __name__ == "__main__":
219
  demo.launch()
 
 
1
  import time
2
  import librosa
3
+ import gradio as gr
4
+ from transformers import AutoModelForCTC, AutoProcessor, pipeline
5
  from jiwer import wer, cer
6
+
7
+ # ---------------------------
8
+ # Load Models (CPU only)
9
+ # ---------------------------
10
+
11
+ # 1. IndicConformer
12
+ indic_model_id = "ai4bharat/indic-conformer-600m-multilingual"
13
+ indic_processor = AutoProcessor.from_pretrained(indic_model_id)
14
+ indic_model = AutoModelForCTC.from_pretrained(indic_model_id)
15
+ indic_pipeline = pipeline(
16
+ "automatic-speech-recognition",
17
+ model=indic_model,
18
+ tokenizer=indic_processor.tokenizer,
19
+ feature_extractor=indic_processor.feature_extractor,
20
+ device=-1 # CPU
21
+ )
22
+
23
+ # 2. Facebook MMS (generic multilingual ASR)
24
+ mms_model_id = "facebook/mms-1b-all"
25
+ mms_processor = AutoProcessor.from_pretrained(mms_model_id)
26
+ mms_model = AutoModelForCTC.from_pretrained(mms_model_id)
27
+ mms_pipeline = pipeline(
28
+ "automatic-speech-recognition",
29
+ model=mms_model,
30
+ tokenizer=mms_processor.tokenizer,
31
+ feature_extractor=mms_processor.feature_extractor,
32
+ device=-1
33
  )
34
 
35
+ # 3. Jivi AudioX (North example)
36
+ jivi_model_id = "jiviai/audioX-north-v1"
37
+ jivi_pipeline = pipeline(
38
+ "automatic-speech-recognition",
39
+ model=jivi_model_id,
40
+ device=-1
41
+ )
42
+
43
+ # ---------------------------
44
+ # Utility Functions
45
+ # ---------------------------
46
+
47
+ def evaluate_model(pipeline_fn, audio_path, reference_text):
48
+ # Load audio (resample to 16kHz for consistency)
49
+ speech, sr = librosa.load(audio_path, sr=16000)
50
+
51
+ # Measure runtime
52
+ start = time.time()
53
+ result = pipeline_fn(speech)
54
+ end = time.time()
55
+
56
+ # Extract transcription
57
+ hypothesis = result["text"]
58
+
59
+ # Compute metrics
60
+ word_error = wer(reference_text.lower(), hypothesis.lower())
61
+ char_error = cer(reference_text.lower(), hypothesis.lower())
62
+ rtf = (end - start) / (len(speech) / sr) # real-time factor
63
+
64
+ return hypothesis, word_error, char_error, rtf
65
+
66
+ def compare_models(audio, reference_text, lang="hi"):
67
+ results = {}
68
+
69
+ # IndicConformer
70
+ hyp, w, c, r = evaluate_model(indic_pipeline, audio, reference_text)
71
+ results["IndicConformer"] = (hyp, w, c, r)
72
+
73
+ # MMS
74
+ hyp, w, c, r = evaluate_model(mms_pipeline, audio, reference_text)
75
+ results["MMS"] = (hyp, w, c, r)
76
+
77
+ # Jivi
78
+ hyp, w, c, r = evaluate_model(jivi_pipeline, audio, reference_text)
79
+ results["Jivi"] = (hyp, w, c, r)
80
+
81
+ # Build results table
82
+ table = "| Model | Transcription | WER | CER | RTF |\n"
83
+ table += "|-------|---------------|-----|-----|-----|\n"
84
+ for model, (hyp, w, c, r) in results.items():
85
+ table += f"| {model} | {hyp} | {w:.3f} | {c:.3f} | {r:.3f} |\n"
86
+
87
+ return table
88
+
89
+ # ---------------------------
90
+ # Gradio UI
91
+ # ---------------------------
92
+ demo = gr.Interface(
93
+ fn=compare_models,
94
+ inputs=[
95
+ gr.Audio(type="filepath", label="Upload Audio (≤20s recommended)"),
96
+ gr.Textbox(label="Reference Text"),
97
+ gr.Dropdown(choices=["hi", "gu", "ta"], value="hi", label="Language")
98
+ ],
99
+ outputs=gr.Markdown(label="Results"),
100
+ title="ASR Benchmark (CPU mode): IndicConformer vs MMS vs Jivi",
101
+ description="Runs on free CPU Spaces. Upload short audio and reference text. Compares models on WER, CER, and RTF."
102
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
  demo.launch()