import streamlit as st from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig ) import torch import librosa import numpy as np import warnings import os warnings.filterwarnings("ignore") # --- Page Configuration --- st.set_page_config(page_title="Tunuh ASR and MT System", page_icon="🎙️", layout="wide") # --- Initialize Session State for Resetting --- if "reset_key" not in st.session_state: st.session_state.reset_key = 0 def clear_callback(): st.session_state.reset_key += 1 # --- Customized Theme --- st.markdown(""" """, unsafe_allow_html=True) # --- Loading Models --- @st.cache_resource def load_asr_model(): MODEL_PATH = "MaryWambo/Microsoft2-178h" device = "cpu" processor = AutoProcessor.from_pretrained(MODEL_PATH) model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_PATH, torch_dtype=torch.float32, low_cpu_mem_usage=True ).to(device) return processor, model, device @st.cache_resource def load_translation_model(): MODEL_ID = "facebook/nllb-200-distilled-600M" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True ).to("cpu") return model, tokenizer asr_processor, asr_model, asr_device = load_asr_model() trans_model, trans_tokenizer = load_translation_model() # --- Logic Functions --- def transcribe(audio_file): try: speech_array, sr = librosa.load(audio_file, sr=16000) # Extract features using the new processor inputs = asr_processor(speech_array, sampling_rate=sr, return_tensors="pt") input_features = inputs.input_features.to(asr_device) # Setup generation structure for the updated model gen_config = GenerationConfig.from_model_config(asr_model.config) gen_config.update(max_new_tokens=255) with torch.no_grad(): predicted_ids = asr_model.generate(input_features, generation_config=gen_config) return asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] except Exception as e: return f"ASR Error: {str(e)}" def translate_text(text): if not text or "Error" in text: return "" trans_tokenizer.src_lang = "kik_Latn" inputs = trans_tokenizer(text, return_tensors="pt").to("cpu") with torch.no_grad(): forced_bos_id = trans_tokenizer.convert_tokens_to_ids("eng_Latn") outputs = trans_model.generate(**inputs, forced_bos_token_id=forced_bos_id, max_new_tokens=128) return trans_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] # --- UI Layout --- st.markdown('

Tunuh AI System

Powering the future of low resource languages through speech technologies

', unsafe_allow_html=True) col_input, col_output = st.columns([1, 1], gap="large") with col_input: st.markdown("### 🎙️ Input Speech") input_method = st.radio("Select input method", ["Record Voice", "Upload File"], horizontal=True) audio_data = None if input_method == "Record Voice": audio_data = st.audio_input("Record", key=f"audio_in_{st.session_state.reset_key}") else: audio_data = st.file_uploader("Upload audio file", type=["wav", "mp3", "webm", "m4a"], key=f"file_up_{st.session_state.reset_key}") btn_col1, btn_col2 = st.columns(2) with btn_col1: st.markdown('
', unsafe_allow_html=True) run_btn = st.button("PROCESS AUDIO") st.markdown('
', unsafe_allow_html=True) with btn_col2: st.markdown('
', unsafe_allow_html=True) st.button(" CLEAR ", on_click=clear_callback) st.markdown('
', unsafe_allow_html=True) with col_output: st.markdown("### Output") transcript_container = st.empty() translation_container = st.empty() if run_btn and audio_data: with st.spinner("Transcribing..."): kikuyu_text = transcribe(audio_data) transcript_container.text_area("The transcript", value=kikuyu_text, height=150) with st.spinner("Translating..."): translated_text = translate_text(kikuyu_text) translation_container.text_area("Translated text", value=translated_text, height=150) else: transcript_container.text_area("The transcript", value="", height=150, placeholder="transcription results") translation_container.text_area("Translated text", value="", height=150, placeholder="Translation results") st.divider()