Spaces:
Running
Running
| import streamlit as st | |
| from transformers import ( | |
| AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| GenerationConfig | |
| ) | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import warnings | |
| import os | |
| warnings.filterwarnings("ignore") | |
| # --- Page Configuration --- | |
| st.set_page_config(page_title="Tunuh ASR and MT System", page_icon="🎙️", layout="wide") | |
| # --- Initialize Session State for Resetting --- | |
| if "reset_key" not in st.session_state: | |
| st.session_state.reset_key = 0 | |
| def clear_callback(): | |
| st.session_state.reset_key += 1 | |
| # --- Customized Theme --- | |
| st.markdown(""" | |
| <style> | |
| .main { background-color: #ffffff; } | |
| .hero-banner { | |
| background: linear-gradient(135deg, #800000 0%, #4a0000 100%); | |
| border-radius: 15px; | |
| padding: 40px; | |
| text-align: center; | |
| margin-bottom: 30px; | |
| box-shadow: 0 4px 15px rgba(0,0,0,0.3); | |
| } | |
| .hero-banner h1 { color: white !important; font-weight: 800; margin-bottom: 10px; } | |
| .hero-banner p { color: #f5f5f5 !important; font-size: 1.2rem; } | |
| .stAudioInput, .stFileUploader { | |
| border: 2px solid #800000; | |
| border-radius: 12px; | |
| padding: 10px; | |
| background-color: #fff5f5; | |
| } | |
| /* --- ENHANCED FIX FOR BUTTON COLORS --- */ | |
| .process-btn button, .clear-btn button { | |
| background-color: #800000 !important; | |
| color: white !important; | |
| border: 2px solid #800000 !important; | |
| border-radius: 8px !important; | |
| height: 55px !important; | |
| width: 100% !important; | |
| font-weight: bold !important; | |
| font-size: 16px !important; | |
| text-transform: uppercase !important; | |
| } | |
| .process-btn button:hover, .clear-btn button:hover { | |
| background-color: #a52a2a !important; | |
| border-color: #a52a2a !important; | |
| color: white !important; | |
| } | |
| .process-btn button p, .clear-btn button p { | |
| color: white !important; | |
| } | |
| .process-btn button:active, .clear-btn button:active, | |
| .process-btn button:focus:not(:active), .clear-btn button:focus:not(:active) { | |
| background-color: #800000 !important; | |
| color: white !important; | |
| border-color: #800000 !important; | |
| box-shadow: none !important; | |
| } | |
| .stTextArea textarea { | |
| background-color: #f8f9fa; | |
| border: 1px solid #800000; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Loading Models --- | |
| def load_asr_model(): | |
| MODEL_PATH = "MaryWambo/Microsoft2-178h" | |
| device = "cpu" | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| return processor, model, device | |
| def load_translation_model(): | |
| MODEL_ID = "facebook/nllb-200-distilled-600M" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to("cpu") | |
| return model, tokenizer | |
| asr_processor, asr_model, asr_device = load_asr_model() | |
| trans_model, trans_tokenizer = load_translation_model() | |
| # --- Logic Functions --- | |
| def transcribe(audio_file): | |
| try: | |
| speech_array, sr = librosa.load(audio_file, sr=16000) | |
| # Extract features using the new processor | |
| inputs = asr_processor(speech_array, sampling_rate=sr, return_tensors="pt") | |
| input_features = inputs.input_features.to(asr_device) | |
| # Setup generation structure for the updated model | |
| gen_config = GenerationConfig.from_model_config(asr_model.config) | |
| gen_config.update(max_new_tokens=255) | |
| with torch.no_grad(): | |
| predicted_ids = asr_model.generate(input_features, generation_config=gen_config) | |
| return asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| return f"ASR Error: {str(e)}" | |
| def translate_text(text): | |
| if not text or "Error" in text: return "" | |
| trans_tokenizer.src_lang = "kik_Latn" | |
| inputs = trans_tokenizer(text, return_tensors="pt").to("cpu") | |
| with torch.no_grad(): | |
| forced_bos_id = trans_tokenizer.convert_tokens_to_ids("eng_Latn") | |
| outputs = trans_model.generate(**inputs, forced_bos_token_id=forced_bos_id, max_new_tokens=128) | |
| return trans_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| # --- UI Layout --- | |
| st.markdown('<div class="hero-banner"><h1>Tunuh AI System</h1><p>Powering the future of low resource languages through speech technologies</p></div>', unsafe_allow_html=True) | |
| col_input, col_output = st.columns([1, 1], gap="large") | |
| with col_input: | |
| st.markdown("### 🎙️ Input Speech") | |
| input_method = st.radio("Select input method", ["Record Voice", "Upload File"], horizontal=True) | |
| audio_data = None | |
| if input_method == "Record Voice": | |
| audio_data = st.audio_input("Record", key=f"audio_in_{st.session_state.reset_key}") | |
| else: | |
| audio_data = st.file_uploader("Upload audio file", type=["wav", "mp3", "webm", "m4a"], key=f"file_up_{st.session_state.reset_key}") | |
| btn_col1, btn_col2 = st.columns(2) | |
| with btn_col1: | |
| st.markdown('<div class="process-btn">', unsafe_allow_html=True) | |
| run_btn = st.button("PROCESS AUDIO") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with btn_col2: | |
| st.markdown('<div class="clear-btn">', unsafe_allow_html=True) | |
| st.button(" CLEAR ", on_click=clear_callback) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with col_output: | |
| st.markdown("### Output") | |
| transcript_container = st.empty() | |
| translation_container = st.empty() | |
| if run_btn and audio_data: | |
| with st.spinner("Transcribing..."): | |
| kikuyu_text = transcribe(audio_data) | |
| transcript_container.text_area("The transcript", value=kikuyu_text, height=150) | |
| with st.spinner("Translating..."): | |
| translated_text = translate_text(kikuyu_text) | |
| translation_container.text_area("Translated text", value=translated_text, height=150) | |
| else: | |
| transcript_container.text_area("The transcript", value="", height=150, placeholder="transcription results") | |
| translation_container.text_area("Translated text", value="", height=150, placeholder="Translation results") | |
| st.divider() |