AvtnshM commited on
Commit
8ae6da7
Β·
verified Β·
1 Parent(s): 2524349
Files changed (1) hide show
  1. app.py +42 -71
app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from jiwer import wer, cer
7
  from transformers import (
8
  WhisperProcessor, WhisperForConditionalGeneration,
9
- Wav2Vec2Processor, Wav2Vec2ForCTC
10
  )
11
 
12
  # Global variables for models (loaded once)
@@ -16,73 +16,46 @@ conformer_processor = None
16
  conformer_model = None
17
 
18
  def load_models():
19
- """Load models once at startup with error handling"""
20
  global whisper_processor, whisper_model, conformer_processor, conformer_model
21
 
22
  if whisper_processor is None:
23
- try:
24
- print("Loading IndicWhisper...")
25
- # Try the original model first
26
- whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
27
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
28
- print("βœ… Using OpenAI Whisper-medium as fallback")
29
- except Exception as e:
30
- print(f"❌ Error loading IndicWhisper: {e}")
31
- # Fallback to standard Whisper
32
- whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
33
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
34
- print("βœ… Using OpenAI Whisper-base as fallback")
35
 
36
- try:
37
- print("Loading IndicConformer...")
38
- conformer_processor = Wav2Vec2Processor.from_pretrained("ai4bharat/indic-conformer-600m-multilingual")
39
- conformer_model = Wav2Vec2ForCTC.from_pretrained("ai4bharat/indic-conformer-600m-multilingual")
40
- print("βœ… IndicConformer loaded successfully")
41
- except Exception as e:
42
- print(f"❌ Error loading IndicConformer: {e}")
43
- # Fallback to a working multilingual model
44
- conformer_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
45
- conformer_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
46
- print("βœ… Using Facebook XLSR-53 as fallback")
47
 
48
  print("Models loaded successfully!")
49
 
50
  def transcribe_whisper(audio_path):
51
- """Transcribe using Whisper model"""
52
- try:
53
- audio, sr = librosa.load(audio_path, sr=16000)
54
- input_features = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features
55
-
56
- start_time = time.time()
57
- with torch.no_grad():
58
- # Force Hindi language for better results
59
- predicted_ids = whisper_model.generate(
60
- input_features,
61
- forced_decoder_ids=whisper_processor.get_decoder_prompt_ids(language="hindi", task="transcribe")
62
- )
63
- end_time = time.time()
64
-
65
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
66
- return transcription, end_time - start_time
67
- except Exception as e:
68
- return f"Error in Whisper transcription: {str(e)}", 0
69
 
70
  def transcribe_conformer(audio_path):
71
- """Transcribe using Conformer model"""
72
- try:
73
- audio, sr = librosa.load(audio_path, sr=16000)
74
- input_values = conformer_processor(audio, sampling_rate=sr, return_tensors="pt").input_values
75
-
76
- start_time = time.time()
77
- with torch.no_grad():
78
- logits = conformer_model(input_values).logits
79
- predicted_ids = torch.argmax(logits, dim=-1)
80
- end_time = time.time()
81
-
82
- transcription = conformer_processor.batch_decode(predicted_ids)[0]
83
- return transcription, end_time - start_time
84
- except Exception as e:
85
- return f"Error in Conformer transcription: {str(e)}", 0
86
 
87
  def compare_models(audio_file, ground_truth_text):
88
  """Main comparison function for Gradio interface"""
@@ -113,7 +86,7 @@ def compare_models(audio_file, ground_truth_text):
113
 
114
  # Format results with metrics
115
  whisper_result = f"""
116
- ## πŸ“Š Whisper Results:
117
  **Prediction:** {whisper_pred}
118
 
119
  **WER:** {whisper_wer:.3f}
@@ -133,9 +106,9 @@ def compare_models(audio_file, ground_truth_text):
133
  """
134
 
135
  # Winner analysis
136
- wer_winner = "Whisper" if whisper_wer < conformer_wer else "IndicConformer"
137
- cer_winner = "Whisper" if whisper_cer < conformer_cer else "IndicConformer"
138
- rtf_winner = "Whisper" if whisper_rtf < conformer_rtf else "IndicConformer"
139
 
140
  winner_analysis = f"""
141
  ## πŸ† Winner Analysis:
@@ -146,7 +119,7 @@ def compare_models(audio_file, ground_truth_text):
146
  else:
147
  # Results without metrics (no ground truth)
148
  whisper_result = f"""
149
- ## πŸ“Š Whisper Results:
150
  **Prediction:** {whisper_pred}
151
 
152
  **RTF:** {whisper_rtf:.3f}
@@ -163,7 +136,7 @@ def compare_models(audio_file, ground_truth_text):
163
 
164
  winner_analysis = f"""
165
  ## πŸ† Speed Comparison:
166
- **Faster Model:** {'Whisper' if whisper_rtf < conformer_rtf else 'IndicConformer'}
167
  **RTF Difference:** {abs(whisper_rtf - conformer_rtf):.3f}
168
  """
169
 
@@ -177,17 +150,15 @@ def compare_models(audio_file, ground_truth_text):
177
  with gr.Blocks(title="ASR Model Comparison") as demo:
178
 
179
  gr.Markdown("""
180
- # 🎀 ASR Model Comparison: Whisper vs IndicConformer
181
 
182
- Compare **OpenAI Whisper** vs **AI4Bharat IndicConformer** on your audio files!
183
 
184
  **Models:**
185
- - **Whisper:** `openai/whisper-medium` (with Hindi language setting)
186
- - **IndicConformer:** `ai4bharat/indic-conformer-600m-multilingual`
187
 
188
  **Metrics:** WER (Word Error Rate), CER (Character Error Rate), RTF (Real-Time Factor)
189
-
190
- ⚠️ **Note:** Using standard Whisper model with Hindi language setting for comparison.
191
  """)
192
 
193
  with gr.Row():
@@ -201,7 +172,7 @@ with gr.Blocks(title="ASR Model Comparison") as demo:
201
  placeholder="Enter expected transcription for WER/CER calculation...",
202
  lines=3
203
  )
204
- compare_btn = gr.Button("πŸš€ Compare Models", variant="primary", size="lg")
205
 
206
  with gr.Column():
207
  audio_info = gr.Textbox(label="ℹ️ Audio Info", interactive=False)
 
6
  from jiwer import wer, cer
7
  from transformers import (
8
  WhisperProcessor, WhisperForConditionalGeneration,
9
+ AutoProcessor, AutoModelForCTC
10
  )
11
 
12
  # Global variables for models (loaded once)
 
16
  conformer_model = None
17
 
18
  def load_models():
19
+ """Load models once at startup"""
20
  global whisper_processor, whisper_model, conformer_processor, conformer_model
21
 
22
  if whisper_processor is None:
23
+ print("Loading IndicWhisper...")
24
+ whisper_processor = WhisperProcessor.from_pretrained("parthiv11/indic_whisper_nodcil")
25
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("parthiv11/indic_whisper_nodcil")
 
 
 
 
 
 
 
 
 
26
 
27
+ print("Loading IndicConformer...")
28
+ conformer_processor = AutoProcessor.from_pretrained("ai4bharat/indicconformer_asr_conformer_multilingual")
29
+ conformer_model = AutoModelForCTC.from_pretrained("ai4bharat/indicconformer_asr_conformer_multilingual")
 
 
 
 
 
 
 
 
30
 
31
  print("Models loaded successfully!")
32
 
33
  def transcribe_whisper(audio_path):
34
+ """Transcribe using IndicWhisper"""
35
+ audio, sr = librosa.load(audio_path, sr=16000)
36
+ input_features = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features
37
+
38
+ start_time = time.time()
39
+ with torch.no_grad():
40
+ predicted_ids = whisper_model.generate(input_features)
41
+ end_time = time.time()
42
+
43
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
44
+ return transcription, end_time - start_time
 
 
 
 
 
 
 
45
 
46
  def transcribe_conformer(audio_path):
47
+ """Transcribe using IndicConformer"""
48
+ audio, sr = librosa.load(audio_path, sr=16000)
49
+ input_values = conformer_processor(audio, sampling_rate=sr, return_tensors="pt").input_values
50
+
51
+ start_time = time.time()
52
+ with torch.no_grad():
53
+ logits = conformer_model(input_values).logits
54
+ predicted_ids = torch.argmax(logits, dim=-1)
55
+ end_time = time.time()
56
+
57
+ transcription = conformer_processor.batch_decode(predicted_ids)[0]
58
+ return transcription, end_time - start_time
 
 
 
59
 
60
  def compare_models(audio_file, ground_truth_text):
61
  """Main comparison function for Gradio interface"""
 
86
 
87
  # Format results with metrics
88
  whisper_result = f"""
89
+ ## πŸ“Š IndicWhisper Results:
90
  **Prediction:** {whisper_pred}
91
 
92
  **WER:** {whisper_wer:.3f}
 
106
  """
107
 
108
  # Winner analysis
109
+ wer_winner = "IndicWhisper" if whisper_wer < conformer_wer else "IndicConformer"
110
+ cer_winner = "IndicWhisper" if whisper_cer < conformer_cer else "IndicConformer"
111
+ rtf_winner = "IndicWhisper" if whisper_rtf < conformer_rtf else "IndicConformer"
112
 
113
  winner_analysis = f"""
114
  ## πŸ† Winner Analysis:
 
119
  else:
120
  # Results without metrics (no ground truth)
121
  whisper_result = f"""
122
+ ## πŸ“Š IndicWhisper Results:
123
  **Prediction:** {whisper_pred}
124
 
125
  **RTF:** {whisper_rtf:.3f}
 
136
 
137
  winner_analysis = f"""
138
  ## πŸ† Speed Comparison:
139
+ **Faster Model:** {'IndicWhisper' if whisper_rtf < conformer_rtf else 'IndicConformer'}
140
  **RTF Difference:** {abs(whisper_rtf - conformer_rtf):.3f}
141
  """
142
 
 
150
  with gr.Blocks(title="ASR Model Comparison") as demo:
151
 
152
  gr.Markdown("""
153
+ # 🎀 ASR Model Comparison: IndicWhisper vs IndicConformer
154
 
155
+ Compare two leading Indian language ASR models on your audio files!
156
 
157
  **Models:**
158
+ - **IndicWhisper:** `parthiv11/indic_whisper_nodcil`
159
+ - **IndicConformer:** `ai4bharat/indicconformer_asr_conformer_multilingual`
160
 
161
  **Metrics:** WER (Word Error Rate), CER (Character Error Rate), RTF (Real-Time Factor)
 
 
162
  """)
163
 
164
  with gr.Row():
 
172
  placeholder="Enter expected transcription for WER/CER calculation...",
173
  lines=3
174
  )
175
+ compare_btn = gr.Button("πŸš€ Compare Models", variant="primary")
176
 
177
  with gr.Column():
178
  audio_info = gr.Textbox(label="ℹ️ Audio Info", interactive=False)