QualiHive / src /streamlit_app.py
simzacademy's picture
Update src/streamlit_app.py
9812441 verified
import streamlit as st
import torch
import tempfile
import os
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# Model from Hugging Face
MODEL_NAME = "chiyo123/whisper-small-tonga"
@st.cache_resource
def load_model_and_processor():
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
model.eval()
return processor, model
processor, model = load_model_and_processor()
# Streamlit UI
st.title("πŸ—£οΈ Custom Whisper Transcriber")
st.write("Upload an audio file and transcribe it using your fine-tuned Whisper model.")
uploaded_file = st.file_uploader("Upload audio", type=["mp3", "wav", "flac", "m4a"])
language = st.text_input("Target language code (e.g., loz, bemba, en)", value="loz")
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(uploaded_file.read())
tmp_path = tmp.name
# Load and preprocess audio
speech_array, sampling_rate = torchaudio.load(tmp_path)
speech_array = torchaudio.functional.resample(speech_array, orig_freq=sampling_rate, new_freq=16000)
input_values = processor(speech_array.squeeze(), return_tensors="pt", sampling_rate=16000).input_features
# Generate
with st.spinner("Transcribing..."):
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
predicted_ids = model.generate(input_values, forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
st.subheader("πŸ“„ Transcription")
st.success(transcription)
# Cleanup
os.remove(tmp_path)