Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +95 -38
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,97 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
st.
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 6 |
+
from normalizer import normalize
|
| 7 |
+
|
| 8 |
+
st.set_page_config(page_title="Political Sentiment AI", page_icon="🇧🇩", layout="wide")
|
| 9 |
+
|
| 10 |
+
st.markdown("""
|
| 11 |
+
<style>
|
| 12 |
+
.main { background-color: #f8f9fa; }
|
| 13 |
+
.stTextArea textarea { border-radius: 15px; border: 2px solid #e0e0e0; }
|
| 14 |
+
.sentiment-card { padding: 25px; border-radius: 15px; background-color: white; box-shadow: 0 4px 15px rgba(0,0,0,0.05); margin-bottom: 20px; border-left: 10px solid; }
|
| 15 |
+
.model-box { background-color: #ffffff; padding: 15px; border-radius: 10px; border: 1px solid #eee; text-align: center; }
|
| 16 |
+
.bar-container { width: 100%; background-color: #f1f1f1; border-radius: 10px; margin: 5px 0 15px 0; }
|
| 17 |
+
.bar-fill { height: 20px; border-radius: 10px; text-align: center; color: white; font-size: 12px; line-height: 20px; font-weight: bold; }
|
| 18 |
+
</style>
|
| 19 |
+
""", unsafe_allow_html=True)
|
| 20 |
+
|
| 21 |
+
id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
|
| 22 |
+
label_colors = {
|
| 23 |
+
'Very Negative': '#D32F2F', 'Negative': '#F44336',
|
| 24 |
+
'Neutral': '#757575', 'Positive': '#4CAF50', 'Very Positive': '#1B5E20'
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
@st.cache_resource
|
| 28 |
+
def load_all_transformers():
|
| 29 |
+
# Replace 'your-username' with your actual Hugging Face username
|
| 30 |
+
model_repos = {
|
| 31 |
+
"BanglaBERT": "rocky250/political-banglabert",
|
| 32 |
+
"mBERT": "rocky250/political-mbert",
|
| 33 |
+
"B-Base": "rocky250/political-bbase",
|
| 34 |
+
"XLM-R": "rocky250/political-xlmr"
|
| 35 |
+
}
|
| 36 |
+
loaded_models = {}
|
| 37 |
+
for name, repo_path in model_repos.items():
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_path)
|
| 39 |
+
model = AutoModelForSequenceClassification.from_pretrained(repo_path)
|
| 40 |
+
loaded_models[name] = (tokenizer, model)
|
| 41 |
+
return loaded_models
|
| 42 |
+
|
| 43 |
+
models_dict = load_all_transformers()
|
| 44 |
+
|
| 45 |
+
def get_detailed_prediction(text):
|
| 46 |
+
clean_text = normalize(text)
|
| 47 |
+
all_probs = []
|
| 48 |
+
votes = []
|
| 49 |
+
for name, (tokenizer, model) in models_dict.items():
|
| 50 |
+
inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
logits = model(**inputs).logits
|
| 53 |
+
probs = F.softmax(logits, dim=1).numpy()[0]
|
| 54 |
+
all_probs.append(probs)
|
| 55 |
+
prediction_id = np.argmax(probs)
|
| 56 |
+
votes.append(id2label[prediction_id])
|
| 57 |
+
avg_probs = np.mean(all_probs, axis=0)
|
| 58 |
+
final_vote = max(set(votes), key=votes.count)
|
| 59 |
+
return final_vote, votes, avg_probs
|
| 60 |
+
|
| 61 |
+
st.title("🇧🇩 Political Sentiment Analysis")
|
| 62 |
+
st.markdown("Advanced Multi-Model Ensemble Dashboard")
|
| 63 |
+
|
| 64 |
+
with st.container():
|
| 65 |
+
st.markdown('<div class="sentiment-card" style="border-left-color: #007BFF;">', unsafe_allow_html=True)
|
| 66 |
+
user_input = st.text_area("Input Political Comment:", height=120)
|
| 67 |
+
analyze_btn = st.button("🚀 Analyze Sentiment")
|
| 68 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 69 |
|
| 70 |
+
if analyze_btn:
|
| 71 |
+
if user_input.strip() == "":
|
| 72 |
+
st.warning("Please provide input text.")
|
| 73 |
+
else:
|
| 74 |
+
with st.spinner('Fetching models from Cloud & Analyzing...'):
|
| 75 |
+
final_res, all_votes, avg_probs = get_detailed_prediction(user_input)
|
| 76 |
+
col1, col2 = st.columns([1, 1])
|
| 77 |
+
with col1:
|
| 78 |
+
st.markdown(f"""
|
| 79 |
+
<div class="sentiment-card" style="border-left-color: {label_colors[final_res]};">
|
| 80 |
+
<h3 style="margin:0;">Ensemble Decision</h3>
|
| 81 |
+
<h1 style="color: {label_colors[final_res]}; margin:0;">{final_res}</h1>
|
| 82 |
+
</div>
|
| 83 |
+
""", unsafe_allow_html=True)
|
| 84 |
+
for i in range(5):
|
| 85 |
+
label = id2label[i]
|
| 86 |
+
prob = avg_probs[i] * 100
|
| 87 |
+
st.markdown(f"""
|
| 88 |
+
<div style="display: flex; justify-content: space-between;"><span>{label}</span><span>{prob:.1f}%</span></div>
|
| 89 |
+
<div class="bar-container"><div class="bar-fill" style="width: {prob}%; background-color: {label_colors[label]};"></div></div>
|
| 90 |
+
""", unsafe_allow_html=True)
|
| 91 |
+
with col2:
|
| 92 |
+
m_names = list(models_dict.keys())
|
| 93 |
+
m_cols = st.columns(2)
|
| 94 |
+
for i in range(4):
|
| 95 |
+
with m_cols[i % 2]:
|
| 96 |
+
vote = all_votes[i]
|
| 97 |
+
st.markdown(f'<div class="model-box"><small>{m_names[i]}</small><div style="color: {label_colors[vote]}; font-weight: bold;">{vote}</div></div><br>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|