File size: 3,638 Bytes
e404cbe
 
 
 
 
 
 
 
 
 
 
 
 
85aa5a8
e404cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e0207
 
 
 
 
 
 
e404cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2665383
e404cbe
9c9a541
 
 
e404cbe
 
 
 
 
 
 
 
 
85aa5a8
8843abc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
# Thiết lập env trước khi import bất kỳ module nào dùng Streamlit hoặc Transformers
os.environ['TRANSFORMERS_CACHE'] = '/cache/hf_cache'
os.environ['HF_HOME'] = '/cache/hf_cache'
os.environ['XDG_CACHE_HOME'] = '/cache/.cache'
os.environ['STREAMLIT_CONFIG_DIR'] = '/cache/.streamlit'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import asyncio
import nest_asyncio
nest_asyncio.apply()

import tempfile
import streamlit as st
import librosa
import torch
import pandas as pd
from transformers import (
    Wav2Vec2Processor, Wav2Vec2ForCTC,
    WhisperProcessor, WhisperForConditionalGeneration,
    AutoTokenizer, AutoModelForTokenClassification,
    AutoProcessor, AutoModelForSpeechSeq2Seq,
    pipeline,
)

# disable torch dynamo for stability
import torch._dynamo
torch._dynamo.disable()


# --- Load ASR processors & models (cached) ---
@st.cache_resource
def load_asr(path):
    proc = WhisperProcessor.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
    mod = WhisperForConditionalGeneration.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
    return proc, mod


asr_path = "Huydb/phowhisper-toxic"
asr_processor, asr_model = load_asr(asr_path)

# --- Load TSD tokenizers & models ---
@st.cache_resource
def load_tsd(path):
    tok = AutoTokenizer.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
    mod = AutoModelForTokenClassification.from_pretrained(path, num_labels=2, cache_dir=os.environ['HF_HOME'])
    return tok, mod

tsd_path = "Huydb/PhoBERT-toxic"
tsd_tokenizer, tsd_model = load_tsd(tsd_path)

# --- Streamlit UI ---
st.markdown("""
<style> /* CSS animation & button */
@keyframes bgfade {0%{background-color:white;}50%{background-color:#889ECE;}100%{background-color:white;}}
html, body, .reportview-container, .main {height:100%!important; margin:0; padding:0; animation:bgfade 10s ease infinite;}
 div.stButton>button:first-child{background-color:red!important;color:white!important;border:none;}
</style>
""", unsafe_allow_html=True)

st.title("🔊🤬 Toxic Spans Detection from Audio")
uploaded_audio = st.file_uploader("1. Upload a WAV audio file", type=["wav"])
if not uploaded_audio:
    st.info("Please upload a WAV audio file to begin.")
    st.stop()

with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tfile:
    tfile.write(uploaded_audio.read())
    audio_path = tfile.name
st.success("Audio uploaded.")
st.audio(audio_path, format='audio/wav')


# Process button
def highlight_toxic_span(words, labels):
    sen_hide = ""
    for word, label in zip(words, labels):
        if label == 1:
            sen_hide += "*"*len(word) + " "
        else:
            sen_hide += word + " "
    return sen_hide.strip()

if st.button("Transcript and Detect Toxic Spans Now"):
    waveform, _ = librosa.load(audio_path, sr=16000)
    input_features = asr_processor(waveform, return_tensors="pt", sampling_rate=16000).input_features.to("cpu")
    predicted_ids = asr_model.generate(input_features)
    transcript_text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    st.subheader("Result")
    enc = tsd_tokenizer(list([transcript_text]), is_split_into_words=True,
              padding='max_length', truncation=True,
              max_length=len(list(transcript_text)), return_tensors="pt")
    with torch.no_grad(): 
        logits = tsd_model(input_ids=enc.input_ids, attention_mask=enc.attention_mask).logits
    labels = logits.argmax(-1)[0].cpu().tolist()
    sen_hide = highlight_toxic_span(transcript_text.split(), labels)

    st.markdown(f"<h5 style='text-align: center;'>{sen_hide}</h5>", unsafe_allow_html=True)