import gradio as gr from transformers import WhisperForConditionalGeneration, WhisperProcessor import torch import librosa import warnings import numpy as np # ------------------------------- # 0. SUPPRESS WARNINGS # ------------------------------- warnings.filterwarnings("ignore", category=ResourceWarning) warnings.filterwarnings("ignore", category=FutureWarning) # ------------------------------- # 1. CONFIGURATION # ------------------------------- MODEL_PATH = "MaryWambo/whisper-base-kikuyu4" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on {device}...") # ------------------------------- # 2. LOAD MODEL & PROCESSOR # ------------------------------- processor = WhisperProcessor.from_pretrained(MODEL_PATH) model = WhisperForConditionalGeneration.from_pretrained(MODEL_PATH).to(device) # Force transcription mode model.config.forced_decoder_ids = processor.get_decoder_prompt_ids( language="swahili", task="transcribe" ) # ------------------------------- # 3. CUSTOM CSS # ------------------------------- theme_styles = """ body, .gradio-container { background-color: white !important; } #title-text h1 { color: #8b0000 !important; font-weight: 900 !important; text-align: center; } .upload-button svg, .mic-button svg { transform: scale(1.5) !important; color: #8b0000 !important; } #predict-box textarea { font-size: 1.6rem !important; font-weight: 800 !important; color: #000000 !important; border: 3px solid #8b0000 !important; } #run-btn { background: #8b0000 !important; color: white !important; font-weight: bold !important; font-size: 1.4rem !important; } """ # ------------------------------- # 4. TRANSCRIPTION FUNCTION # ------------------------------- def transcribe_kikuyu(audio): if audio is None: return "Please record or upload audio." try: # Load audio speech_array, sr = librosa.load(audio, sr=16000) # Convert to float32 if speech_array.dtype != np.float32: speech_array = speech_array.astype(np.float32) # Tokenize inputs = processor(speech_array, sampling_rate=sr, return_tensors="pt") input_features = inputs.input_features.to(device) # Generate transcription with torch.no_grad(): predicted_ids = model.generate( input_features, num_beams=5, max_new_tokens=255 ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription except Exception as e: return f"Error during transcription: {str(e)}" # ------------------------------- # 5. GRADIO UI # ------------------------------- with gr.Blocks() as demo: gr.Markdown("# 🎙️ Kikuyu ASR", elem_id="title-text") with gr.Row(): with gr.Column(): audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="🎤 Record or Upload Kikuyu Speech" ) submit_btn = gr.Button( "🚀 RUN TRANSCRIPTION", elem_id="run-btn" ) with gr.Column(): text_out = gr.Textbox( label="🤖 AI Prediction", elem_id="predict-box", lines=8 ) submit_btn.click( fn=transcribe_kikuyu, inputs=audio_input, outputs=text_out ) # ------------------------------- # 6. LAUNCH APP # ------------------------------- import asyncio import sys def _suppress_event_loop_closed(loop, context): if "Invalid file descriptor" in str(context.get("exception", "")): return loop.default_exception_handler(context) try: loop = asyncio.get_event_loop() loop.set_exception_handler(_suppress_event_loop_closed) except RuntimeError: pass demo.launch(ssr_mode=False)