MaryWambo's picture
Update app.py
6dc4a53 verified
import streamlit as st
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForSeq2SeqLM,
AutoTokenizer,
GenerationConfig
)
import torch
import librosa
import numpy as np
import warnings
import os
warnings.filterwarnings("ignore")
# --- Page Configuration ---
st.set_page_config(page_title="Tunuh ASR and MT System", page_icon="🎙️", layout="wide")
# --- Initialize Session State for Resetting ---
if "reset_key" not in st.session_state:
st.session_state.reset_key = 0
def clear_callback():
st.session_state.reset_key += 1
# --- Customized Theme ---
st.markdown("""
<style>
.main { background-color: #ffffff; }
.hero-banner {
background: linear-gradient(135deg, #800000 0%, #4a0000 100%);
border-radius: 15px;
padding: 40px;
text-align: center;
margin-bottom: 30px;
box-shadow: 0 4px 15px rgba(0,0,0,0.3);
}
.hero-banner h1 { color: white !important; font-weight: 800; margin-bottom: 10px; }
.hero-banner p { color: #f5f5f5 !important; font-size: 1.2rem; }
.stAudioInput, .stFileUploader {
border: 2px solid #800000;
border-radius: 12px;
padding: 10px;
background-color: #fff5f5;
}
/* --- ENHANCED FIX FOR BUTTON COLORS --- */
.process-btn button, .clear-btn button {
background-color: #800000 !important;
color: white !important;
border: 2px solid #800000 !important;
border-radius: 8px !important;
height: 55px !important;
width: 100% !important;
font-weight: bold !important;
font-size: 16px !important;
text-transform: uppercase !important;
}
.process-btn button:hover, .clear-btn button:hover {
background-color: #a52a2a !important;
border-color: #a52a2a !important;
color: white !important;
}
.process-btn button p, .clear-btn button p {
color: white !important;
}
.process-btn button:active, .clear-btn button:active,
.process-btn button:focus:not(:active), .clear-btn button:focus:not(:active) {
background-color: #800000 !important;
color: white !important;
border-color: #800000 !important;
box-shadow: none !important;
}
.stTextArea textarea {
background-color: #f8f9fa;
border: 1px solid #800000;
}
</style>
""", unsafe_allow_html=True)
# --- Loading Models ---
@st.cache_resource
def load_asr_model():
MODEL_PATH = "MaryWambo/Microsoft2-178h"
device = "cpu"
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to(device)
return processor, model, device
@st.cache_resource
def load_translation_model():
MODEL_ID = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to("cpu")
return model, tokenizer
asr_processor, asr_model, asr_device = load_asr_model()
trans_model, trans_tokenizer = load_translation_model()
# --- Logic Functions ---
def transcribe(audio_file):
try:
speech_array, sr = librosa.load(audio_file, sr=16000)
# Extract features using the new processor
inputs = asr_processor(speech_array, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(asr_device)
# Setup generation structure for the updated model
gen_config = GenerationConfig.from_model_config(asr_model.config)
gen_config.update(max_new_tokens=255)
with torch.no_grad():
predicted_ids = asr_model.generate(input_features, generation_config=gen_config)
return asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
except Exception as e:
return f"ASR Error: {str(e)}"
def translate_text(text):
if not text or "Error" in text: return ""
trans_tokenizer.src_lang = "kik_Latn"
inputs = trans_tokenizer(text, return_tensors="pt").to("cpu")
with torch.no_grad():
forced_bos_id = trans_tokenizer.convert_tokens_to_ids("eng_Latn")
outputs = trans_model.generate(**inputs, forced_bos_token_id=forced_bos_id, max_new_tokens=128)
return trans_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# --- UI Layout ---
st.markdown('<div class="hero-banner"><h1>Tunuh AI System</h1><p>Powering the future of low resource languages through speech technologies</p></div>', unsafe_allow_html=True)
col_input, col_output = st.columns([1, 1], gap="large")
with col_input:
st.markdown("### 🎙️ Input Speech")
input_method = st.radio("Select input method", ["Record Voice", "Upload File"], horizontal=True)
audio_data = None
if input_method == "Record Voice":
audio_data = st.audio_input("Record", key=f"audio_in_{st.session_state.reset_key}")
else:
audio_data = st.file_uploader("Upload audio file", type=["wav", "mp3", "webm", "m4a"], key=f"file_up_{st.session_state.reset_key}")
btn_col1, btn_col2 = st.columns(2)
with btn_col1:
st.markdown('<div class="process-btn">', unsafe_allow_html=True)
run_btn = st.button("PROCESS AUDIO")
st.markdown('</div>', unsafe_allow_html=True)
with btn_col2:
st.markdown('<div class="clear-btn">', unsafe_allow_html=True)
st.button(" CLEAR ", on_click=clear_callback)
st.markdown('</div>', unsafe_allow_html=True)
with col_output:
st.markdown("### Output")
transcript_container = st.empty()
translation_container = st.empty()
if run_btn and audio_data:
with st.spinner("Transcribing..."):
kikuyu_text = transcribe(audio_data)
transcript_container.text_area("The transcript", value=kikuyu_text, height=150)
with st.spinner("Translating..."):
translated_text = translate_text(kikuyu_text)
translation_container.text_area("Translated text", value=translated_text, height=150)
else:
transcript_container.text_area("The transcript", value="", height=150, placeholder="transcription results")
translation_container.text_area("Translated text", value="", height=150, placeholder="Translation results")
st.divider()