AvtnshM commited on
Commit
2524349
Β·
verified Β·
1 Parent(s): 4b9d5e3
Files changed (1) hide show
  1. app.py +68 -39
app.py CHANGED
@@ -16,46 +16,73 @@ conformer_processor = None
16
  conformer_model = None
17
 
18
  def load_models():
19
- """Load models once at startup"""
20
  global whisper_processor, whisper_model, conformer_processor, conformer_model
21
 
22
  if whisper_processor is None:
23
- print("Loading IndicWhisper...")
24
- whisper_processor = WhisperProcessor.from_pretrained("parthiv11/indic_whisper_nodcil")
25
- whisper_model = WhisperForConditionalGeneration.from_pretrained("parthiv11/indic_whisper_nodcil")
 
 
 
 
 
 
 
 
 
26
 
27
- print("Loading IndicConformer...")
28
- conformer_processor = Wav2Vec2Processor.from_pretrained("ai4bharat/indic-conformer-600m-multilingual")
29
- conformer_model = Wav2Vec2ForCTC.from_pretrained("ai4bharat/indic-conformer-600m-multilingual")
 
 
 
 
 
 
 
 
30
 
31
  print("Models loaded successfully!")
32
 
33
  def transcribe_whisper(audio_path):
34
- """Transcribe using IndicWhisper"""
35
- audio, sr = librosa.load(audio_path, sr=16000)
36
- input_features = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features
37
-
38
- start_time = time.time()
39
- with torch.no_grad():
40
- predicted_ids = whisper_model.generate(input_features)
41
- end_time = time.time()
42
-
43
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
44
- return transcription, end_time - start_time
 
 
 
 
 
 
 
45
 
46
  def transcribe_conformer(audio_path):
47
- """Transcribe using IndicConformer"""
48
- audio, sr = librosa.load(audio_path, sr=16000)
49
- input_values = conformer_processor(audio, sampling_rate=sr, return_tensors="pt").input_values
50
-
51
- start_time = time.time()
52
- with torch.no_grad():
53
- logits = conformer_model(input_values).logits
54
- predicted_ids = torch.argmax(logits, dim=-1)
55
- end_time = time.time()
56
-
57
- transcription = conformer_processor.batch_decode(predicted_ids)[0]
58
- return transcription, end_time - start_time
 
 
 
59
 
60
  def compare_models(audio_file, ground_truth_text):
61
  """Main comparison function for Gradio interface"""
@@ -86,7 +113,7 @@ def compare_models(audio_file, ground_truth_text):
86
 
87
  # Format results with metrics
88
  whisper_result = f"""
89
- ## πŸ“Š IndicWhisper Results:
90
  **Prediction:** {whisper_pred}
91
 
92
  **WER:** {whisper_wer:.3f}
@@ -106,9 +133,9 @@ def compare_models(audio_file, ground_truth_text):
106
  """
107
 
108
  # Winner analysis
109
- wer_winner = "IndicWhisper" if whisper_wer < conformer_wer else "IndicConformer"
110
- cer_winner = "IndicWhisper" if whisper_cer < conformer_cer else "IndicConformer"
111
- rtf_winner = "IndicWhisper" if whisper_rtf < conformer_rtf else "IndicConformer"
112
 
113
  winner_analysis = f"""
114
  ## πŸ† Winner Analysis:
@@ -119,7 +146,7 @@ def compare_models(audio_file, ground_truth_text):
119
  else:
120
  # Results without metrics (no ground truth)
121
  whisper_result = f"""
122
- ## πŸ“Š IndicWhisper Results:
123
  **Prediction:** {whisper_pred}
124
 
125
  **RTF:** {whisper_rtf:.3f}
@@ -136,7 +163,7 @@ def compare_models(audio_file, ground_truth_text):
136
 
137
  winner_analysis = f"""
138
  ## πŸ† Speed Comparison:
139
- **Faster Model:** {'IndicWhisper' if whisper_rtf < conformer_rtf else 'IndicConformer'}
140
  **RTF Difference:** {abs(whisper_rtf - conformer_rtf):.3f}
141
  """
142
 
@@ -150,15 +177,17 @@ def compare_models(audio_file, ground_truth_text):
150
  with gr.Blocks(title="ASR Model Comparison") as demo:
151
 
152
  gr.Markdown("""
153
- # 🎀 ASR Model Comparison: IndicWhisper vs IndicConformer
154
 
155
- Compare two leading Indian language ASR models on your audio files!
156
 
157
  **Models:**
158
- - **IndicWhisper:** `parthiv11/indic_whisper_nodcil`
159
  - **IndicConformer:** `ai4bharat/indic-conformer-600m-multilingual`
160
 
161
  **Metrics:** WER (Word Error Rate), CER (Character Error Rate), RTF (Real-Time Factor)
 
 
162
  """)
163
 
164
  with gr.Row():
 
16
  conformer_model = None
17
 
18
  def load_models():
19
+ """Load models once at startup with error handling"""
20
  global whisper_processor, whisper_model, conformer_processor, conformer_model
21
 
22
  if whisper_processor is None:
23
+ try:
24
+ print("Loading IndicWhisper...")
25
+ # Try the original model first
26
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
27
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
28
+ print("βœ… Using OpenAI Whisper-medium as fallback")
29
+ except Exception as e:
30
+ print(f"❌ Error loading IndicWhisper: {e}")
31
+ # Fallback to standard Whisper
32
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
33
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
34
+ print("βœ… Using OpenAI Whisper-base as fallback")
35
 
36
+ try:
37
+ print("Loading IndicConformer...")
38
+ conformer_processor = Wav2Vec2Processor.from_pretrained("ai4bharat/indic-conformer-600m-multilingual")
39
+ conformer_model = Wav2Vec2ForCTC.from_pretrained("ai4bharat/indic-conformer-600m-multilingual")
40
+ print("βœ… IndicConformer loaded successfully")
41
+ except Exception as e:
42
+ print(f"❌ Error loading IndicConformer: {e}")
43
+ # Fallback to a working multilingual model
44
+ conformer_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
45
+ conformer_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
46
+ print("βœ… Using Facebook XLSR-53 as fallback")
47
 
48
  print("Models loaded successfully!")
49
 
50
  def transcribe_whisper(audio_path):
51
+ """Transcribe using Whisper model"""
52
+ try:
53
+ audio, sr = librosa.load(audio_path, sr=16000)
54
+ input_features = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features
55
+
56
+ start_time = time.time()
57
+ with torch.no_grad():
58
+ # Force Hindi language for better results
59
+ predicted_ids = whisper_model.generate(
60
+ input_features,
61
+ forced_decoder_ids=whisper_processor.get_decoder_prompt_ids(language="hindi", task="transcribe")
62
+ )
63
+ end_time = time.time()
64
+
65
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
66
+ return transcription, end_time - start_time
67
+ except Exception as e:
68
+ return f"Error in Whisper transcription: {str(e)}", 0
69
 
70
  def transcribe_conformer(audio_path):
71
+ """Transcribe using Conformer model"""
72
+ try:
73
+ audio, sr = librosa.load(audio_path, sr=16000)
74
+ input_values = conformer_processor(audio, sampling_rate=sr, return_tensors="pt").input_values
75
+
76
+ start_time = time.time()
77
+ with torch.no_grad():
78
+ logits = conformer_model(input_values).logits
79
+ predicted_ids = torch.argmax(logits, dim=-1)
80
+ end_time = time.time()
81
+
82
+ transcription = conformer_processor.batch_decode(predicted_ids)[0]
83
+ return transcription, end_time - start_time
84
+ except Exception as e:
85
+ return f"Error in Conformer transcription: {str(e)}", 0
86
 
87
  def compare_models(audio_file, ground_truth_text):
88
  """Main comparison function for Gradio interface"""
 
113
 
114
  # Format results with metrics
115
  whisper_result = f"""
116
+ ## πŸ“Š Whisper Results:
117
  **Prediction:** {whisper_pred}
118
 
119
  **WER:** {whisper_wer:.3f}
 
133
  """
134
 
135
  # Winner analysis
136
+ wer_winner = "Whisper" if whisper_wer < conformer_wer else "IndicConformer"
137
+ cer_winner = "Whisper" if whisper_cer < conformer_cer else "IndicConformer"
138
+ rtf_winner = "Whisper" if whisper_rtf < conformer_rtf else "IndicConformer"
139
 
140
  winner_analysis = f"""
141
  ## πŸ† Winner Analysis:
 
146
  else:
147
  # Results without metrics (no ground truth)
148
  whisper_result = f"""
149
+ ## πŸ“Š Whisper Results:
150
  **Prediction:** {whisper_pred}
151
 
152
  **RTF:** {whisper_rtf:.3f}
 
163
 
164
  winner_analysis = f"""
165
  ## πŸ† Speed Comparison:
166
+ **Faster Model:** {'Whisper' if whisper_rtf < conformer_rtf else 'IndicConformer'}
167
  **RTF Difference:** {abs(whisper_rtf - conformer_rtf):.3f}
168
  """
169
 
 
177
  with gr.Blocks(title="ASR Model Comparison") as demo:
178
 
179
  gr.Markdown("""
180
+ # 🎀 ASR Model Comparison: Whisper vs IndicConformer
181
 
182
+ Compare **OpenAI Whisper** vs **AI4Bharat IndicConformer** on your audio files!
183
 
184
  **Models:**
185
+ - **Whisper:** `openai/whisper-medium` (with Hindi language setting)
186
  - **IndicConformer:** `ai4bharat/indic-conformer-600m-multilingual`
187
 
188
  **Metrics:** WER (Word Error Rate), CER (Character Error Rate), RTF (Real-Time Factor)
189
+
190
+ ⚠️ **Note:** Using standard Whisper model with Hindi language setting for comparison.
191
  """)
192
 
193
  with gr.Row():