ViToSA-Demo / src /streamlit_app.py
ViToSAResearch's picture
Update steamlit_app.py (in UI replace h6 into h5)
8843abc verified
import os
# Thiết lập env trước khi import bất kỳ module nào dùng Streamlit hoặc Transformers
os.environ['TRANSFORMERS_CACHE'] = '/cache/hf_cache'
os.environ['HF_HOME'] = '/cache/hf_cache'
os.environ['XDG_CACHE_HOME'] = '/cache/.cache'
os.environ['STREAMLIT_CONFIG_DIR'] = '/cache/.streamlit'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import asyncio
import nest_asyncio
nest_asyncio.apply()
import tempfile
import streamlit as st
import librosa
import torch
import pandas as pd
from transformers import (
Wav2Vec2Processor, Wav2Vec2ForCTC,
WhisperProcessor, WhisperForConditionalGeneration,
AutoTokenizer, AutoModelForTokenClassification,
AutoProcessor, AutoModelForSpeechSeq2Seq,
pipeline,
)
# disable torch dynamo for stability
import torch._dynamo
torch._dynamo.disable()
# --- Load ASR processors & models (cached) ---
@st.cache_resource
def load_asr(path):
proc = WhisperProcessor.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
mod = WhisperForConditionalGeneration.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
return proc, mod
asr_path = "Huydb/phowhisper-toxic"
asr_processor, asr_model = load_asr(asr_path)
# --- Load TSD tokenizers & models ---
@st.cache_resource
def load_tsd(path):
tok = AutoTokenizer.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
mod = AutoModelForTokenClassification.from_pretrained(path, num_labels=2, cache_dir=os.environ['HF_HOME'])
return tok, mod
tsd_path = "Huydb/PhoBERT-toxic"
tsd_tokenizer, tsd_model = load_tsd(tsd_path)
# --- Streamlit UI ---
st.markdown("""
<style> /* CSS animation & button */
@keyframes bgfade {0%{background-color:white;}50%{background-color:#889ECE;}100%{background-color:white;}}
html, body, .reportview-container, .main {height:100%!important; margin:0; padding:0; animation:bgfade 10s ease infinite;}
div.stButton>button:first-child{background-color:red!important;color:white!important;border:none;}
</style>
""", unsafe_allow_html=True)
st.title("🔊🤬 Toxic Spans Detection from Audio")
uploaded_audio = st.file_uploader("1. Upload a WAV audio file", type=["wav"])
if not uploaded_audio:
st.info("Please upload a WAV audio file to begin.")
st.stop()
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tfile:
tfile.write(uploaded_audio.read())
audio_path = tfile.name
st.success("Audio uploaded.")
st.audio(audio_path, format='audio/wav')
# Process button
def highlight_toxic_span(words, labels):
sen_hide = ""
for word, label in zip(words, labels):
if label == 1:
sen_hide += "*"*len(word) + " "
else:
sen_hide += word + " "
return sen_hide.strip()
if st.button("Transcript and Detect Toxic Spans Now"):
waveform, _ = librosa.load(audio_path, sr=16000)
input_features = asr_processor(waveform, return_tensors="pt", sampling_rate=16000).input_features.to("cpu")
predicted_ids = asr_model.generate(input_features)
transcript_text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
st.subheader("Result")
enc = tsd_tokenizer(list([transcript_text]), is_split_into_words=True,
padding='max_length', truncation=True,
max_length=len(list(transcript_text)), return_tensors="pt")
with torch.no_grad():
logits = tsd_model(input_ids=enc.input_ids, attention_mask=enc.attention_mask).logits
labels = logits.argmax(-1)[0].cpu().tolist()
sen_hide = highlight_toxic_span(transcript_text.split(), labels)
st.markdown(f"<h5 style='text-align: center;'>{sen_hide}</h5>", unsafe_allow_html=True)