Spaces:
Runtime error
Runtime error
| import os | |
| # Thiết lập env trước khi import bất kỳ module nào dùng Streamlit hoặc Transformers | |
| os.environ['TRANSFORMERS_CACHE'] = '/cache/hf_cache' | |
| os.environ['HF_HOME'] = '/cache/hf_cache' | |
| os.environ['XDG_CACHE_HOME'] = '/cache/.cache' | |
| os.environ['STREAMLIT_CONFIG_DIR'] = '/cache/.streamlit' | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| import asyncio | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| import tempfile | |
| import streamlit as st | |
| import librosa | |
| import torch | |
| import pandas as pd | |
| from transformers import ( | |
| Wav2Vec2Processor, Wav2Vec2ForCTC, | |
| WhisperProcessor, WhisperForConditionalGeneration, | |
| AutoTokenizer, AutoModelForTokenClassification, | |
| AutoProcessor, AutoModelForSpeechSeq2Seq, | |
| pipeline, | |
| ) | |
| # disable torch dynamo for stability | |
| import torch._dynamo | |
| torch._dynamo.disable() | |
| # --- Load ASR processors & models (cached) --- | |
| def load_asr(path): | |
| proc = WhisperProcessor.from_pretrained(path, cache_dir=os.environ['HF_HOME']) | |
| mod = WhisperForConditionalGeneration.from_pretrained(path, cache_dir=os.environ['HF_HOME']) | |
| return proc, mod | |
| asr_path = "Huydb/phowhisper-toxic" | |
| asr_processor, asr_model = load_asr(asr_path) | |
| # --- Load TSD tokenizers & models --- | |
| def load_tsd(path): | |
| tok = AutoTokenizer.from_pretrained(path, cache_dir=os.environ['HF_HOME']) | |
| mod = AutoModelForTokenClassification.from_pretrained(path, num_labels=2, cache_dir=os.environ['HF_HOME']) | |
| return tok, mod | |
| tsd_path = "Huydb/PhoBERT-toxic" | |
| tsd_tokenizer, tsd_model = load_tsd(tsd_path) | |
| # --- Streamlit UI --- | |
| st.markdown(""" | |
| <style> /* CSS animation & button */ | |
| @keyframes bgfade {0%{background-color:white;}50%{background-color:#889ECE;}100%{background-color:white;}} | |
| html, body, .reportview-container, .main {height:100%!important; margin:0; padding:0; animation:bgfade 10s ease infinite;} | |
| div.stButton>button:first-child{background-color:red!important;color:white!important;border:none;} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("🔊🤬 Toxic Spans Detection from Audio") | |
| uploaded_audio = st.file_uploader("1. Upload a WAV audio file", type=["wav"]) | |
| if not uploaded_audio: | |
| st.info("Please upload a WAV audio file to begin.") | |
| st.stop() | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tfile: | |
| tfile.write(uploaded_audio.read()) | |
| audio_path = tfile.name | |
| st.success("Audio uploaded.") | |
| st.audio(audio_path, format='audio/wav') | |
| # Process button | |
| def highlight_toxic_span(words, labels): | |
| sen_hide = "" | |
| for word, label in zip(words, labels): | |
| if label == 1: | |
| sen_hide += "*"*len(word) + " " | |
| else: | |
| sen_hide += word + " " | |
| return sen_hide.strip() | |
| if st.button("Transcript and Detect Toxic Spans Now"): | |
| waveform, _ = librosa.load(audio_path, sr=16000) | |
| input_features = asr_processor(waveform, return_tensors="pt", sampling_rate=16000).input_features.to("cpu") | |
| predicted_ids = asr_model.generate(input_features) | |
| transcript_text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| st.subheader("Result") | |
| enc = tsd_tokenizer(list([transcript_text]), is_split_into_words=True, | |
| padding='max_length', truncation=True, | |
| max_length=len(list(transcript_text)), return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = tsd_model(input_ids=enc.input_ids, attention_mask=enc.attention_mask).logits | |
| labels = logits.argmax(-1)[0].cpu().tolist() | |
| sen_hide = highlight_toxic_span(transcript_text.split(), labels) | |
| st.markdown(f"<h5 style='text-align: center;'>{sen_hide}</h5>", unsafe_allow_html=True) |