File size: 3,977 Bytes
f6a5ddf 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 69721d2 b2dcf24 7741232 6214b1f 69721d2 6214b1f 7741232 6214b1f 69721d2 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 6214b1f 7741232 ccfd19c 92d8bd7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import librosa
import warnings
import numpy as np
# -------------------------------
# 0. SUPPRESS WARNINGS
# -------------------------------
warnings.filterwarnings("ignore", category=ResourceWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# -------------------------------
# 1. CONFIGURATION
# -------------------------------
MODEL_PATH = "MaryWambo/whisper-base-kikuyu4"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {device}...")
# -------------------------------
# 2. LOAD MODEL & PROCESSOR
# -------------------------------
processor = WhisperProcessor.from_pretrained(MODEL_PATH)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
# Force transcription mode
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="swahili",
task="transcribe"
)
# -------------------------------
# 3. CUSTOM CSS
# -------------------------------
theme_styles = """
body, .gradio-container { background-color: white !important; }
#title-text h1 {
color: #8b0000 !important;
font-weight: 900 !important;
text-align: center;
}
.upload-button svg, .mic-button svg {
transform: scale(1.5) !important;
color: #8b0000 !important;
}
#predict-box textarea {
font-size: 1.6rem !important;
font-weight: 800 !important;
color: #000000 !important;
border: 3px solid #8b0000 !important;
}
#run-btn {
background: #8b0000 !important;
color: white !important;
font-weight: bold !important;
font-size: 1.4rem !important;
}
"""
# -------------------------------
# 4. TRANSCRIPTION FUNCTION
# -------------------------------
def transcribe_kikuyu(audio):
if audio is None:
return "Please record or upload audio."
try:
# Load audio
speech_array, sr = librosa.load(audio, sr=16000)
# Convert to float32
if speech_array.dtype != np.float32:
speech_array = speech_array.astype(np.float32)
# Tokenize
inputs = processor(speech_array, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(device)
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(
input_features,
num_beams=5,
max_new_tokens=255
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# -------------------------------
# 5. GRADIO UI
# -------------------------------
with gr.Blocks() as demo:
gr.Markdown("# ποΈ Kikuyu ASR", elem_id="title-text")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="π€ Record or Upload Kikuyu Speech"
)
submit_btn = gr.Button(
"π RUN TRANSCRIPTION",
elem_id="run-btn"
)
with gr.Column():
text_out = gr.Textbox(
label="π€ AI Prediction",
elem_id="predict-box",
lines=8
)
submit_btn.click(
fn=transcribe_kikuyu,
inputs=audio_input,
outputs=text_out
)
# -------------------------------
# 6. LAUNCH APP
# -------------------------------
import asyncio
import sys
def _suppress_event_loop_closed(loop, context):
if "Invalid file descriptor" in str(context.get("exception", "")):
return
loop.default_exception_handler(context)
try:
loop = asyncio.get_event_loop()
loop.set_exception_handler(_suppress_event_loop_closed)
except RuntimeError:
pass
demo.launch(ssr_mode=False) |