trial2 / app.py
MaryWambo's picture
Update app.py
92d8bd7 verified
import gradio as gr
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import warnings
import numpy as np
# -------------------------------
# 0. SUPPRESS WARNINGS
# -------------------------------
warnings.filterwarnings("ignore", category=ResourceWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# -------------------------------
# 1. CONFIGURATION
# -------------------------------
MODEL_PATH = "MaryWambo/whisper-base-kikuyu4"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {device}...")
# -------------------------------
# 2. LOAD MODEL & PROCESSOR
# -------------------------------
processor = WhisperProcessor.from_pretrained(MODEL_PATH)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
# Force transcription mode
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="swahili",
task="transcribe"
)
# -------------------------------
# 3. CUSTOM CSS
# -------------------------------
theme_styles = """
body, .gradio-container { background-color: white !important; }
#title-text h1 {
color: #8b0000 !important;
font-weight: 900 !important;
text-align: center;
}
.upload-button svg, .mic-button svg {
transform: scale(1.5) !important;
color: #8b0000 !important;
}
#predict-box textarea {
font-size: 1.6rem !important;
font-weight: 800 !important;
color: #000000 !important;
border: 3px solid #8b0000 !important;
}
#run-btn {
background: #8b0000 !important;
color: white !important;
font-weight: bold !important;
font-size: 1.4rem !important;
}
"""
# -------------------------------
# 4. TRANSCRIPTION FUNCTION
# -------------------------------
def transcribe_kikuyu(audio):
if audio is None:
return "Please record or upload audio."
try:
# Load audio
speech_array, sr = librosa.load(audio, sr=16000)
# Convert to float32
if speech_array.dtype != np.float32:
speech_array = speech_array.astype(np.float32)
# Tokenize
inputs = processor(speech_array, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(device)
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(
input_features,
num_beams=5,
max_new_tokens=255
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# -------------------------------
# 5. GRADIO UI
# -------------------------------
with gr.Blocks() as demo:
gr.Markdown("# πŸŽ™οΈ Kikuyu ASR", elem_id="title-text")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎀 Record or Upload Kikuyu Speech"
)
submit_btn = gr.Button(
"πŸš€ RUN TRANSCRIPTION",
elem_id="run-btn"
)
with gr.Column():
text_out = gr.Textbox(
label="πŸ€– AI Prediction",
elem_id="predict-box",
lines=8
)
submit_btn.click(
fn=transcribe_kikuyu,
inputs=audio_input,
outputs=text_out
)
# -------------------------------
# 6. LAUNCH APP
# -------------------------------
import asyncio
import sys
def _suppress_event_loop_closed(loop, context):
if "Invalid file descriptor" in str(context.get("exception", "")):
return
loop.default_exception_handler(context)
try:
loop = asyncio.get_event_loop()
loop.set_exception_handler(_suppress_event_loop_closed)
except RuntimeError:
pass
demo.launch(ssr_mode=False)