import os # Thiết lập env trước khi import bất kỳ module nào dùng Streamlit hoặc Transformers os.environ['TRANSFORMERS_CACHE'] = '/cache/hf_cache' os.environ['HF_HOME'] = '/cache/hf_cache' os.environ['XDG_CACHE_HOME'] = '/cache/.cache' os.environ['STREAMLIT_CONFIG_DIR'] = '/cache/.streamlit' os.environ['TOKENIZERS_PARALLELISM'] = 'false' import asyncio import nest_asyncio nest_asyncio.apply() import tempfile import streamlit as st import librosa import torch import pandas as pd from transformers import ( Wav2Vec2Processor, Wav2Vec2ForCTC, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForTokenClassification, AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline, ) # disable torch dynamo for stability import torch._dynamo torch._dynamo.disable() # --- Load ASR processors & models (cached) --- @st.cache_resource def load_asr(path): proc = WhisperProcessor.from_pretrained(path, cache_dir=os.environ['HF_HOME']) mod = WhisperForConditionalGeneration.from_pretrained(path, cache_dir=os.environ['HF_HOME']) return proc, mod asr_path = "Huydb/phowhisper-toxic" asr_processor, asr_model = load_asr(asr_path) # --- Load TSD tokenizers & models --- @st.cache_resource def load_tsd(path): tok = AutoTokenizer.from_pretrained(path, cache_dir=os.environ['HF_HOME']) mod = AutoModelForTokenClassification.from_pretrained(path, num_labels=2, cache_dir=os.environ['HF_HOME']) return tok, mod tsd_path = "Huydb/PhoBERT-toxic" tsd_tokenizer, tsd_model = load_tsd(tsd_path) # --- Streamlit UI --- st.markdown(""" """, unsafe_allow_html=True) st.title("🔊🤬 Toxic Spans Detection from Audio") uploaded_audio = st.file_uploader("1. Upload a WAV audio file", type=["wav"]) if not uploaded_audio: st.info("Please upload a WAV audio file to begin.") st.stop() with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tfile: tfile.write(uploaded_audio.read()) audio_path = tfile.name st.success("Audio uploaded.") st.audio(audio_path, format='audio/wav') # Process button def highlight_toxic_span(words, labels): sen_hide = "" for word, label in zip(words, labels): if label == 1: sen_hide += "*"*len(word) + " " else: sen_hide += word + " " return sen_hide.strip() if st.button("Transcript and Detect Toxic Spans Now"): waveform, _ = librosa.load(audio_path, sr=16000) input_features = asr_processor(waveform, return_tensors="pt", sampling_rate=16000).input_features.to("cpu") predicted_ids = asr_model.generate(input_features) transcript_text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] st.subheader("Result") enc = tsd_tokenizer(list([transcript_text]), is_split_into_words=True, padding='max_length', truncation=True, max_length=len(list(transcript_text)), return_tensors="pt") with torch.no_grad(): logits = tsd_model(input_ids=enc.input_ids, attention_mask=enc.attention_mask).logits labels = logits.argmax(-1)[0].cpu().tolist() sen_hide = highlight_toxic_span(transcript_text.split(), labels) st.markdown(f"