ViToSAResearch commited on
Commit
e404cbe
·
verified ·
1 Parent(s): c28ec78

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +104 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,106 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ # Thiết lập env trước khi import bất kỳ module nào dùng Streamlit hoặc Transformers
3
+ os.environ['TRANSFORMERS_CACHE'] = '/cache/hf_cache'
4
+ os.environ['HF_HOME'] = '/cache/hf_cache'
5
+ os.environ['XDG_CACHE_HOME'] = '/cache/.cache'
6
+ os.environ['STREAMLIT_CONFIG_DIR'] = '/cache/.streamlit'
7
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
8
+
9
+ import asyncio
10
+ import nest_asyncio
11
+ nest_asyncio.apply()
12
+
13
+ import tempfile
14
  import streamlit as st
15
+ import librosa
16
+ import torch
17
+ import pandas as pd
18
+ from transformers import (
19
+ Wav2Vec2Processor, Wav2Vec2ForCTC,
20
+ WhisperProcessor, WhisperForConditionalGeneration,
21
+ AutoTokenizer, AutoModelForTokenClassification,
22
+ AutoProcessor, AutoModelForSpeechSeq2Seq,
23
+ pipeline,
24
+ )
25
+
26
+ # disable torch dynamo for stability
27
+ import torch._dynamo
28
+ torch._dynamo.disable()
29
+
30
+ # --- Configuration: model paths ---
31
+ ASR_MODELS = {
32
+ "PhoWhisper": "Huydb/phowhisper-toxic", # ensure this repo has processor_config.json
33
+ }
34
+ TSD_MODELS = {
35
+ "PhoBERT": "Huydb/PhoBERT-toxic",
36
+ }
37
+
38
+ # --- Load ASR processors & models (cached) ---
39
+ @st.cache_resource
40
+ def load_asr(path):
41
+ proc = WhisperProcessor.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
42
+ mod = WhisperForConditionalGeneration.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
43
+ return proc, mod
44
+
45
+
46
+ asr_path = "Huydb/phowhisper-toxic"
47
+ asr_processor, asr_model = load_asr(asr_path)
48
+
49
+ # --- Load TSD tokenizers & models ---
50
+ @st.cache_resource
51
+ def load_tsd(path):
52
+ tok = AutoTokenizer.from_pretrained(path, cache_dir=os.environ['HF_HOME'])
53
+ mod = AutoModelForTokenClassification.from_pretrained(path, num_labels=2, cache_dir=os.environ['HF_HOME'])
54
+ return tok, mod
55
+
56
+ tsd_path = "Huydb/PhoBERT-toxic"
57
+ tsd_tokenizer, tsd_model = load_tsd(tsd_path)
58
+
59
+ # --- Streamlit UI ---
60
+ st.markdown("""
61
+ <style> /* CSS animation & button */
62
+ @keyframes bgfade {0%{background-color:white;}50%{background-color:#889ECE;}100%{background-color:white;}}
63
+ html, body, .reportview-container, .main {height:100%!important; margin:0; padding:0; animation:bgfade 10s ease infinite;}
64
+ div.stButton>button:first-child{background-color:red!important;color:white!important;border:none;}
65
+ </style>
66
+ """, unsafe_allow_html=True)
67
+
68
+ st.title("🔊🤬 Toxic Spans Detection from Audio")
69
+ uploaded_audio = st.file_uploader("1. Upload a WAV audio file", type=["wav"])
70
+ if not uploaded_audio:
71
+ st.info("Please upload a WAV audio file to begin.")
72
+ st.stop()
73
+
74
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tfile:
75
+ tfile.write(uploaded_audio.read())
76
+ audio_path = tfile.name
77
+ st.success("Audio uploaded.")
78
+ st.audio(audio_path, format='audio/wav')
79
+
80
+
81
+ # Process button
82
+ def highlight_toxic_span(words, labels):
83
+ sen_hide = ""
84
+ for word, label in zip(words, labels):
85
+ if label == 1:
86
+ sen_hide += "*"*len(word) + " "
87
+ else:
88
+ sen_hide += word + " "
89
+ return sen_hide.strip()
90
+
91
+ if st.button("Transcript and Detect Toxic Spans Now"):
92
+ waveform, _ = librosa.load(audio_path, sr=16000)
93
+ input_features = proc(waveform, return_tensors="pt", sampling_rate=16000).input_features.to("cpu")
94
+ predicted_ids = mod.generate(input_features)
95
+ transcript_text = proc.batch_decode(predicted_ids, skip_special_tokens=True)[0]
96
+
97
+ st.subheader("Result")
98
+ enc = tsd_tokenizer(list([transcript_text]), is_split_into_words=True,
99
+ padding='max_length', truncation=True,
100
+ max_length=len(list(transcript_text)), return_tensors="pt")
101
+ with torch.no_grad():
102
+ logits = tsd_model(input_ids=enc.input_ids, attention_mask=enc.attention_mask).logits
103
+ labels = logits.argmax(-1)[0].cpu().tolist()
104
+ sen_hide = highlight_toxic_span(transcript_text.split(), labels)
105
 
106
+ st.markdown(f"<h1 style='text-align: center; color: red;'>{sen_hide}</h1>", unsafe_allow_html=True)