Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| WhisperProcessor, WhisperForConditionalGeneration, | |
| Wav2Vec2Processor, Wav2Vec2ForCTC | |
| ) | |
| import librosa | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class NigerianWhisperTranscriber: | |
| def __init__(self): | |
| self.models = {} | |
| self.processors = {} | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Model configurations with their architectures | |
| self.model_configs = { | |
| "Yoruba": { | |
| "model_name": "DereAbdulhameed/Whisper-Yoruba", | |
| "architecture": "whisper" | |
| }, | |
| "Hausa": { | |
| "model_name": "Baghdad99/saad-speech-recognition-hausa-audio-to-text", | |
| "architecture": "whisper" | |
| }, | |
| "Igbo": { | |
| "model_name": "AstralZander/igbo_ASR", | |
| "architecture": "wav2vec2" | |
| } | |
| } | |
| print(f"Using device: {self.device}") | |
| def load_model(self, language): | |
| """Load model and processor for specific language""" | |
| if language not in self.models: | |
| try: | |
| print(f"Loading {language} model...") | |
| config = self.model_configs[language] | |
| model_name = config["model_name"] | |
| architecture = config["architecture"] | |
| if architecture == "whisper": | |
| # Load Whisper model | |
| processor = WhisperProcessor.from_pretrained(model_name) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_name) | |
| model = model.to(self.device) | |
| elif architecture == "wav2vec2": | |
| # Load Wav2Vec2 model | |
| processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| model = model.to(self.device) | |
| self.processors[language] = processor | |
| self.models[language] = model | |
| print(f"{language} model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading {language} model: {str(e)}") | |
| return False | |
| return True | |
| def preprocess_audio(self, audio_path): | |
| """Preprocess audio file for Whisper""" | |
| try: | |
| # Load audio file | |
| audio, sr = librosa.load(audio_path, sr=16000) | |
| # Ensure audio is not empty | |
| if len(audio) == 0: | |
| raise ValueError("Audio file is empty") | |
| # Normalize audio | |
| audio = audio.astype(np.float32) | |
| return audio | |
| except Exception as e: | |
| raise ValueError(f"Error processing audio: {str(e)}") | |
| def chunk_audio(self, audio, chunk_length=25): | |
| """Split audio into chunks for processing longer recordings""" | |
| sample_rate = 16000 | |
| chunk_samples = chunk_length * sample_rate | |
| chunks = [] | |
| for i in range(0, len(audio), chunk_samples): | |
| chunk = audio[i:i + chunk_samples] | |
| if len(chunk) > sample_rate: # Only process chunks longer than 1 second | |
| chunks.append(chunk) | |
| return chunks | |
| def transcribe_chunk(self, audio_chunk, language): | |
| """Transcribe a single audio chunk""" | |
| processor = self.processors[language] | |
| model = self.models[language] | |
| config = self.model_configs[language] | |
| if config["architecture"] == "whisper": | |
| # Whisper processing | |
| inputs = processor( | |
| audio_chunk, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| input_features = inputs.input_features.to(self.device) | |
| # Create attention mask if available | |
| attention_mask = None | |
| if hasattr(inputs, 'attention_mask') and inputs.attention_mask is not None: | |
| attention_mask = inputs.attention_mask.to(self.device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| if attention_mask is not None: | |
| predicted_ids = model.generate( | |
| input_features, | |
| attention_mask=attention_mask, | |
| max_new_tokens=400, | |
| num_beams=5, | |
| temperature=0.0, | |
| do_sample=False, | |
| use_cache=True, | |
| pad_token_id=processor.tokenizer.eos_token_id | |
| ) | |
| else: | |
| predicted_ids = model.generate( | |
| input_features, | |
| max_new_tokens=400, | |
| num_beams=5, | |
| temperature=0.0, | |
| do_sample=False, | |
| use_cache=True, | |
| pad_token_id=processor.tokenizer.eos_token_id | |
| ) | |
| # Decode transcription | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription.strip() | |
| elif config["architecture"] == "wav2vec2": | |
| # Wav2Vec2 processing | |
| inputs = processor( | |
| audio_chunk, | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| input_values = inputs.input_values.to(self.device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| # Decode transcription for Wav2Vec2 | |
| # The key is to use `skip_special_tokens=True` here as well, | |
| # and potentially handle any remaining [PAD] explicitly if the tokenizer | |
| # doesn't completely remove them with that flag. | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True # Ensure special tokens are skipped | |
| )[0] | |
| # Additional clean-up for Wav2Vec2 specific models if skip_special_tokens isn't enough | |
| # Some Wav2Vec2 tokenizers might represent padding characters differently or | |
| # not fully remove them with skip_special_tokens=True depending on how they were trained. | |
| # We can perform an explicit string replacement as a fallback. | |
| transcription = transcription.replace("[PAD]", "").strip() | |
| transcription = " ".join(transcription.split()) # To remove extra spaces | |
| return transcription.strip() | |
| def transcribe(self, audio_path, language): | |
| """Transcribe audio file in specified language""" | |
| try: | |
| # Load model if not already loaded | |
| if not self.load_model(language): | |
| return f"Error: Could not load {language} model" | |
| # Preprocess audio | |
| audio = self.preprocess_audio(audio_path) | |
| # Check audio length (25 seconds = 400,000 samples at 16kHz) | |
| if len(audio) > 400000: # If longer than 25 seconds | |
| # Process in chunks | |
| chunks = self.chunk_audio(audio, chunk_length=25) | |
| transcriptions = [] | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i+1}/{len(chunks)}") | |
| # Transcribe chunk | |
| chunk_transcription = self.transcribe_chunk(chunk, language) | |
| transcriptions.append(chunk_transcription) | |
| # Combine all transcriptions | |
| full_transcription = " ".join(transcriptions) | |
| return full_transcription | |
| else: | |
| # Process short audio normally | |
| return self.transcribe_chunk(audio, language) | |
| except Exception as e: | |
| return f"Error during transcription: {str(e)}" | |
| # Initialize transcriber | |
| transcriber = NigerianWhisperTranscriber() | |
| def transcribe_audio_unified(audio_file, audio_mic, language): | |
| """Gradio function for transcription from either file or microphone""" | |
| # Determine which audio source to use | |
| audio_source = audio_file if audio_file is not None else audio_mic | |
| if audio_source is None: | |
| return "Please upload an audio file or record from microphone" | |
| try: | |
| result = transcriber.transcribe(audio_source, language) | |
| return result | |
| except Exception as e: | |
| return f"Transcription failed: {str(e)}" | |
| def get_model_info(language): | |
| """Get information about the selected model""" | |
| model_info = { | |
| "Yoruba": "DereAbdulhameed/Whisper-Yoruba - Whisper model specialized for Yoruba language", | |
| "Hausa": "Baghdad99/saad-speech-recognition-hausa-audio-to-text - Fine-tuned Whisper model for Hausa (WER: 44.4%)", | |
| "Igbo": "AstralZander/igbo_ASR - Wav2Vec2-XLS-R model fine-tuned for Igbo language (WER: 51%)" | |
| } | |
| return model_info.get(language, "Model information not available") | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="Nigerian Languages Speech Transcription", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-header { | |
| text-align: center; | |
| color: #2E7D32; | |
| margin-bottom: 20px; | |
| } | |
| .language-info { | |
| background-color: #f5f5f5; | |
| padding: 10px; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML(""" | |
| <h1 class="main-header">π€ Nigerian Languages Speech Transcription</h1> | |
| <p style="text-align: center; color: #666;"> | |
| Transcribe audio in Yoruba, Hausa, and Igbo using specialized Whisper models | |
| </p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Language selection | |
| language_dropdown = gr.Dropdown( | |
| choices=["Yoruba", "Hausa", "Igbo"], | |
| value="Yoruba", | |
| label="Select Language", | |
| info="Choose the language of your audio file" | |
| ) | |
| # Audio input options | |
| gr.HTML("<h3>π΅ Audio Input Options</h3>") | |
| with gr.Tabs(): | |
| with gr.TabItem("π Upload File"): | |
| audio_file = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| format="wav" | |
| ) | |
| with gr.TabItem("π€ Record Speech"): | |
| audio_mic = gr.Audio( | |
| label="Record from Microphone", | |
| type="filepath" | |
| ) | |
| # Transcribe button | |
| transcribe_btn = gr.Button( | |
| "π― Transcribe Audio", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Model information | |
| model_info_text = gr.Textbox( | |
| label="Model Information", | |
| value=get_model_info("Yoruba"), | |
| interactive=False, | |
| elem_classes="language-info" | |
| ) | |
| with gr.Column(scale=2): | |
| # Transcription output | |
| transcription_output = gr.Textbox( | |
| label="Transcription Result", | |
| placeholder="Your transcription will appear here...", | |
| lines=10, | |
| max_lines=20, | |
| show_copy_button=True | |
| ) | |
| # Usage instructions | |
| gr.HTML(""" | |
| <div style="margin-top: 20px; padding: 15px; background-color: #e8f5e8; border-radius: 5px;"> | |
| <h3>π How to Use:</h3> | |
| <ol> | |
| <li>Select your target language (Yoruba, Hausa, or Igbo)</li> | |
| <li><strong>Option 1:</strong> Upload an audio file (WAV, MP3, etc.)</li> | |
| <li><strong>Option 2:</strong> Click the microphone tab and record speech directly</li> | |
| <li>Click "Transcribe Audio" to get the text transcription</li> | |
| <li>Copy the result using the copy button</li> | |
| </ol> | |
| <p><strong>Note:</strong> First-time model loading may take a few minutes.</p> | |
| <p><strong>Recording Tip:</strong> Speak clearly and ensure good audio quality for better transcription accuracy.</p> | |
| <p><strong>Long Audio:</strong> Audio longer than 25 seconds will be automatically processed in chunks.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| transcribe_btn.click( | |
| fn=transcribe_audio_unified, | |
| inputs=[audio_file, audio_mic, language_dropdown], | |
| outputs=transcription_output, | |
| show_progress=True | |
| ) | |
| language_dropdown.change( | |
| fn=get_model_info, | |
| inputs=language_dropdown, | |
| outputs=model_info_text | |
| ) | |
| # Examples section | |
| gr.HTML(""" | |
| <div style="margin-top: 30px;"> | |
| <h3>π Supported Languages:</h3> | |
| <ul> | |
| <li><strong>Yoruba:</strong> Widely spoken in Nigeria, Benin, and Togo</li> | |
| <li><strong>Hausa:</strong> Major language in Northern Nigeria and Niger</li> | |
| <li><strong>Igbo:</strong> Predominantly spoken in Southeastern Nigeria</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| debug=True, | |
| show_error=True | |
| ) |