Kikuyu_ASR / app.py
MaryWambo's picture
Update app.py
db5f9eb verified
import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig
import torch
import librosa
import numpy as np
import warnings
import os
warnings.filterwarnings("ignore")
# --- Page Configuration ---
st.set_page_config(page_title="Kikuyu ASR System", page_icon="🎙️", layout="wide")
# --- Custom CSS for Maroon Theme ---
st.markdown("""
<style>
.main { background-color: #ffffff; }
.hero-banner {
background: linear-gradient(135deg, #800000 0%, #4a0000 100%);
border-radius: 15px;
padding: 40px;
text-align: center;
margin-bottom: 30px;
box-shadow: 0 4px 15px rgba(0,0,0,0.3);
}
.hero-banner h1 { color: white !important; font-weight: 800; margin-bottom: 10px; }
.hero-banner p { color: #f5f5f5 !important; font-size: 1.2rem; }
.stAudioInput, .stFileUploader {
border: 2px solid #800000;
border-radius: 12px;
padding: 10px;
background-color: #fff5f5;
}
div.stButton > button:first-child {
background-color: #800000;
color: white;
border-radius: 8px;
width: 100%;
height: 50px;
font-weight: bold;
border: none;
}
div.stButton > button:first-child:hover { background-color: #a52a2a; }
</style>
""", unsafe_allow_html=True)
# --- Model Loading (Cached) ---
@st.cache_resource
def load_model():
MODEL_PATH = "MaryWambo/whisper-base-kikuyu4"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float32 for CPU to avoid errors, float16 for CUDA to save VRAM
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = WhisperProcessor.from_pretrained(MODEL_PATH)
model = WhisperForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
).to(device)
return processor, model, device
processor, model, device = load_model()
# --- Transcription Logic ---
def transcribe(audio_file):
try:
# 1. Load audio with librosa (Whisper expects 16kHz)
speech_array, sr = librosa.load(audio_file, sr=16000)
inputs = processor(speech_array, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(device)
if device == "cuda":
input_features = input_features.half()
# 2. Get the specific tokens for Swahili/Transcription (Proxy for Kikuyu)
forced_decoder_ids = processor.get_decoder_prompt_ids(
language="swahili",
task="transcribe"
)
# 3. Create a GenerationConfig object
gen_config = GenerationConfig.from_model_config(model.config)
gen_config.update(
forced_decoder_ids=forced_decoder_ids,
max_new_tokens=255,
num_beams=1,
use_cache=True
)
# 4. Run inference
with torch.no_grad():
predicted_ids = model.generate(
input_features,
generation_config=gen_config
)
return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
except Exception as e:
return f"Error during transcription: {str(e)}"
# --- UI Layout ---
st.markdown("""
<div class="hero-banner">
<h1>Kikuyu Automatic Speech Recognition System</h1>
<p>Powering the future of low resource languages through speech technologies</p>
</div>
""", unsafe_allow_html=True)
col1, col2 = st.columns(2, gap="large")
with col1:
# FIXED: Added the mandatory label argument here
input_method = st.radio("Choose Input Method:", ["Record Voice", "Upload File"])
audio_data = None
if input_method == "Record Voice":
# Ensure you are using a recent version of Streamlit (1.37.0+)
audio_data = st.audio_input("Record your Kikuyu speech")
else:
audio_data = st.file_uploader("Upload audio file", type=["wav", "mp3", "webm", "m4a"])
run_btn = st.button("🚀 RUN ")
with col2:
if run_btn:
if audio_data is not None:
with st.spinner("Transcribing..."):
result = transcribe(audio_data)
st.subheader("Transcription Result")
st.text_area("Transcript", value=result, height=300, label_visibility="collapsed")
else:
st.warning("Please record or upload an audio file first.")
else:
st.subheader("Transcription Result")
st.text_area("Transcript", placeholder="The transcript will appear here...", height=300, label_visibility="collapsed")
st.divider()