samuelolubukun commited on
Commit
5b42ac6
Β·
verified Β·
1 Parent(s): f2877d5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +383 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ !pip install gradio transformers torch librosa numpy accelerate
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import (
7
+ WhisperProcessor, WhisperForConditionalGeneration,
8
+ Wav2Vec2Processor, Wav2Vec2ForCTC
9
+ )
10
+ import librosa
11
+ import numpy as np
12
+ import warnings
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+
17
+ class NigerianWhisperTranscriber:
18
+ def __init__(self):
19
+ self.models = {}
20
+ self.processors = {}
21
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Model configurations with their architectures
24
+ self.model_configs = {
25
+ "Yoruba": {
26
+ "model_name": "DereAbdulhameed/Whisper-Yoruba",
27
+ "architecture": "whisper"
28
+ },
29
+ "Hausa": {
30
+ "model_name": "Baghdad99/saad-speech-recognition-hausa-audio-to-text",
31
+ "architecture": "whisper"
32
+ },
33
+ "Igbo": {
34
+ "model_name": "AstralZander/igbo_ASR",
35
+ "architecture": "wav2vec2"
36
+ }
37
+ }
38
+
39
+ print(f"Using device: {self.device}")
40
+
41
+ def load_model(self, language):
42
+ """Load model and processor for specific language"""
43
+ if language not in self.models:
44
+ try:
45
+ print(f"Loading {language} model...")
46
+ config = self.model_configs[language]
47
+ model_name = config["model_name"]
48
+ architecture = config["architecture"]
49
+
50
+ if architecture == "whisper":
51
+ # Load Whisper model
52
+ processor = WhisperProcessor.from_pretrained(model_name)
53
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
54
+ model = model.to(self.device)
55
+
56
+ elif architecture == "wav2vec2":
57
+ # Load Wav2Vec2 model
58
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
59
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
60
+ model = model.to(self.device)
61
+
62
+ self.processors[language] = processor
63
+ self.models[language] = model
64
+
65
+ print(f"{language} model loaded successfully!")
66
+ return True
67
+ except Exception as e:
68
+ print(f"Error loading {language} model: {str(e)}")
69
+ return False
70
+ return True
71
+
72
+ def preprocess_audio(self, audio_path):
73
+ """Preprocess audio file for Whisper"""
74
+ try:
75
+ # Load audio file
76
+ audio, sr = librosa.load(audio_path, sr=16000)
77
+
78
+ # Ensure audio is not empty
79
+ if len(audio) == 0:
80
+ raise ValueError("Audio file is empty")
81
+
82
+ # Normalize audio
83
+ audio = audio.astype(np.float32)
84
+
85
+ return audio
86
+ except Exception as e:
87
+ raise ValueError(f"Error processing audio: {str(e)}")
88
+
89
+ def chunk_audio(self, audio, chunk_length=25):
90
+ """Split audio into chunks for processing longer recordings"""
91
+ sample_rate = 16000
92
+ chunk_samples = chunk_length * sample_rate
93
+
94
+ chunks = []
95
+ for i in range(0, len(audio), chunk_samples):
96
+ chunk = audio[i:i + chunk_samples]
97
+ if len(chunk) > sample_rate: # Only process chunks longer than 1 second
98
+ chunks.append(chunk)
99
+
100
+ return chunks
101
+
102
+ def transcribe_chunk(self, audio_chunk, language):
103
+ """Transcribe a single audio chunk"""
104
+ processor = self.processors[language]
105
+ model = self.models[language]
106
+ config = self.model_configs[language]
107
+
108
+ if config["architecture"] == "whisper":
109
+ # Whisper processing
110
+ inputs = processor(
111
+ audio_chunk,
112
+ sampling_rate=16000,
113
+ return_tensors="pt"
114
+ )
115
+
116
+ input_features = inputs.input_features.to(self.device)
117
+
118
+ # Create attention mask if available
119
+ attention_mask = None
120
+ if hasattr(inputs, 'attention_mask') and inputs.attention_mask is not None:
121
+ attention_mask = inputs.attention_mask.to(self.device)
122
+
123
+ # Generate transcription
124
+ with torch.no_grad():
125
+ if attention_mask is not None:
126
+ predicted_ids = model.generate(
127
+ input_features,
128
+ attention_mask=attention_mask,
129
+ max_new_tokens=400,
130
+ num_beams=5,
131
+ temperature=0.0,
132
+ do_sample=False,
133
+ use_cache=True,
134
+ pad_token_id=processor.tokenizer.eos_token_id
135
+ )
136
+ else:
137
+ predicted_ids = model.generate(
138
+ input_features,
139
+ max_new_tokens=400,
140
+ num_beams=5,
141
+ temperature=0.0,
142
+ do_sample=False,
143
+ use_cache=True,
144
+ pad_token_id=processor.tokenizer.eos_token_id
145
+ )
146
+
147
+ # Decode transcription
148
+ transcription = processor.batch_decode(
149
+ predicted_ids,
150
+ skip_special_tokens=True
151
+ )[0]
152
+
153
+ return transcription.strip()
154
+
155
+ elif config["architecture"] == "wav2vec2":
156
+ # Wav2Vec2 processing
157
+ inputs = processor(
158
+ audio_chunk,
159
+ sampling_rate=16000,
160
+ return_tensors="pt",
161
+ padding=True
162
+ )
163
+
164
+ input_values = inputs.input_values.to(self.device)
165
+
166
+ # Generate transcription
167
+ with torch.no_grad():
168
+ logits = model(input_values).logits
169
+ predicted_ids = torch.argmax(logits, dim=-1)
170
+
171
+ # Decode transcription for Wav2Vec2
172
+ # The key is to use `skip_special_tokens=True` here as well,
173
+ # and potentially handle any remaining [PAD] explicitly if the tokenizer
174
+ # doesn't completely remove them with that flag.
175
+ transcription = processor.batch_decode(
176
+ predicted_ids,
177
+ skip_special_tokens=True # Ensure special tokens are skipped
178
+ )[0]
179
+
180
+ # Additional clean-up for Wav2Vec2 specific models if skip_special_tokens isn't enough
181
+ # Some Wav2Vec2 tokenizers might represent padding characters differently or
182
+ # not fully remove them with skip_special_tokens=True depending on how they were trained.
183
+ # We can perform an explicit string replacement as a fallback.
184
+ transcription = transcription.replace("[PAD]", "").strip()
185
+ transcription = " ".join(transcription.split()) # To remove extra spaces
186
+
187
+ return transcription.strip()
188
+
189
+ def transcribe(self, audio_path, language):
190
+ """Transcribe audio file in specified language"""
191
+ try:
192
+ # Load model if not already loaded
193
+ if not self.load_model(language):
194
+ return f"Error: Could not load {language} model"
195
+
196
+ # Preprocess audio
197
+ audio = self.preprocess_audio(audio_path)
198
+
199
+ # Check audio length (25 seconds = 400,000 samples at 16kHz)
200
+ if len(audio) > 400000: # If longer than 25 seconds
201
+ # Process in chunks
202
+ chunks = self.chunk_audio(audio, chunk_length=25)
203
+ transcriptions = []
204
+
205
+ for i, chunk in enumerate(chunks):
206
+ print(f"Processing chunk {i+1}/{len(chunks)}")
207
+
208
+ # Transcribe chunk
209
+ chunk_transcription = self.transcribe_chunk(chunk, language)
210
+ transcriptions.append(chunk_transcription)
211
+
212
+ # Combine all transcriptions
213
+ full_transcription = " ".join(transcriptions)
214
+ return full_transcription
215
+
216
+ else:
217
+ # Process short audio normally
218
+ return self.transcribe_chunk(audio, language)
219
+
220
+ except Exception as e:
221
+ return f"Error during transcription: {str(e)}"
222
+
223
+
224
+ # Initialize transcriber
225
+ transcriber = NigerianWhisperTranscriber()
226
+
227
+
228
+ def transcribe_audio_unified(audio_file, audio_mic, language):
229
+ """Gradio function for transcription from either file or microphone"""
230
+ # Determine which audio source to use
231
+ audio_source = audio_file if audio_file is not None else audio_mic
232
+
233
+ if audio_source is None:
234
+ return "Please upload an audio file or record from microphone"
235
+
236
+ try:
237
+ result = transcriber.transcribe(audio_source, language)
238
+ return result
239
+ except Exception as e:
240
+ return f"Transcription failed: {str(e)}"
241
+
242
+
243
+ def get_model_info(language):
244
+ """Get information about the selected model"""
245
+ model_info = {
246
+ "Yoruba": "DereAbdulhameed/Whisper-Yoruba - Whisper model specialized for Yoruba language",
247
+ "Hausa": "Baghdad99/saad-speech-recognition-hausa-audio-to-text - Fine-tuned Whisper model for Hausa (WER: 44.4%)",
248
+ "Igbo": "AstralZander/igbo_ASR - Wav2Vec2-XLS-R model fine-tuned for Igbo language (WER: 51%)"
249
+ }
250
+ return model_info.get(language, "Model information not available")
251
+
252
+
253
+ # Create Gradio interface
254
+ with gr.Blocks(
255
+ title="Nigerian Languages Speech Transcription",
256
+ theme=gr.themes.Soft(),
257
+ css="""
258
+ .main-header {
259
+ text-align: center;
260
+ color: #2E7D32;
261
+ margin-bottom: 20px;
262
+ }
263
+ .language-info {
264
+ background-color: #f5f5f5;
265
+ padding: 10px;
266
+ border-radius: 5px;
267
+ margin: 10px 0;
268
+ }
269
+ """
270
+ ) as demo:
271
+
272
+ gr.HTML("""
273
+ <h1 class="main-header">🎀 Nigerian Languages Speech Transcription</h1>
274
+ <p style="text-align: center; color: #666;">
275
+ Transcribe audio in Yoruba, Hausa, and Igbo using specialized Whisper models
276
+ </p>
277
+ """)
278
+
279
+ with gr.Row():
280
+ with gr.Column(scale=1):
281
+ # Language selection
282
+ language_dropdown = gr.Dropdown(
283
+ choices=["Yoruba", "Hausa", "Igbo"],
284
+ value="Yoruba",
285
+ label="Select Language",
286
+ info="Choose the language of your audio file"
287
+ )
288
+
289
+ # Audio input options
290
+ gr.HTML("<h3>🎡 Audio Input Options</h3>")
291
+
292
+ with gr.Tabs():
293
+ with gr.TabItem("πŸ“ Upload File"):
294
+ audio_file = gr.Audio(
295
+ label="Upload Audio File",
296
+ type="filepath",
297
+ format="wav"
298
+ )
299
+
300
+ with gr.TabItem("🎀 Record Speech"):
301
+ audio_mic = gr.Audio(
302
+ label="Record from Microphone",
303
+ type="filepath"
304
+ )
305
+
306
+ # Transcribe button
307
+ transcribe_btn = gr.Button(
308
+ "🎯 Transcribe Audio",
309
+ variant="primary",
310
+ size="lg"
311
+ )
312
+
313
+ # Model information
314
+ model_info_text = gr.Textbox(
315
+ label="Model Information",
316
+ value=get_model_info("Yoruba"),
317
+ interactive=False,
318
+ elem_classes="language-info"
319
+ )
320
+
321
+ with gr.Column(scale=2):
322
+ # Transcription output
323
+ transcription_output = gr.Textbox(
324
+ label="Transcription Result",
325
+ placeholder="Your transcription will appear here...",
326
+ lines=10,
327
+ max_lines=20,
328
+ show_copy_button=True
329
+ )
330
+
331
+ # Usage instructions
332
+ gr.HTML("""
333
+ <div style="margin-top: 20px; padding: 15px; background-color: #e8f5e8; border-radius: 5px;">
334
+ <h3>πŸ“‹ How to Use:</h3>
335
+ <ol>
336
+ <li>Select your target language (Yoruba, Hausa, or Igbo)</li>
337
+ <li><strong>Option 1:</strong> Upload an audio file (WAV, MP3, etc.)</li>
338
+ <li><strong>Option 2:</strong> Click the microphone tab and record speech directly</li>
339
+ <li>Click "Transcribe Audio" to get the text transcription</li>
340
+ <li>Copy the result using the copy button</li>
341
+ </ol>
342
+ <p><strong>Note:</strong> First-time model loading may take a few minutes.</p>
343
+ <p><strong>Recording Tip:</strong> Speak clearly and ensure good audio quality for better transcription accuracy.</p>
344
+ <p><strong>Long Audio:</strong> Audio longer than 25 seconds will be automatically processed in chunks.</p>
345
+ </div>
346
+ """)
347
+
348
+ # Event handlers
349
+ transcribe_btn.click(
350
+ fn=transcribe_audio_unified,
351
+ inputs=[audio_file, audio_mic, language_dropdown],
352
+ outputs=transcription_output,
353
+ show_progress=True
354
+ )
355
+
356
+ language_dropdown.change(
357
+ fn=get_model_info,
358
+ inputs=language_dropdown,
359
+ outputs=model_info_text
360
+ )
361
+
362
+ # Examples section
363
+ gr.HTML("""
364
+ <div style="margin-top: 30px;">
365
+ <h3>🌍 Supported Languages:</h3>
366
+ <ul>
367
+ <li><strong>Yoruba:</strong> Widely spoken in Nigeria, Benin, and Togo</li>
368
+ <li><strong>Hausa:</strong> Major language in Northern Nigeria and Niger</li>
369
+ <li><strong>Igbo:</strong> Predominantly spoken in Southeastern Nigeria</li>
370
+ </ul>
371
+ </div>
372
+ """)
373
+
374
+
375
+ # Launch the application
376
+ if __name__ == "__main__":
377
+ demo.launch(
378
+ server_name="0.0.0.0",
379
+ server_port=7860,
380
+ share=True,
381
+ debug=True,
382
+ show_error=True
383
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ librosa
5
+ numpy
6
+ accelerate