Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from normalizer import normalize | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| st.set_page_config(page_title="Political Sentiment", layout="wide") | |
| class BanglaPoliticalNet(nn.Module): | |
| def __init__(self, num_classes=5): | |
| super().__init__() | |
| self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert") | |
| self.hidden_size = self.banglabert.config.hidden_size | |
| self.cnn_layers = nn.ModuleList([ | |
| nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2) | |
| for k in [3,5,7] | |
| ]) | |
| self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.3), | |
| nn.Linear(self.hidden_size, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, input_ids, attention_mask=None): | |
| bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state | |
| cnn_features = [] | |
| for cnn in self.cnn_layers: | |
| cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2) | |
| cnn_features.append(F.relu(cnn_out)) | |
| cnn_concat = torch.cat(cnn_features, dim=-1) | |
| proj = nn.Linear(384, self.hidden_size).to(input_ids.device) | |
| attn_input = proj(cnn_concat) | |
| attn_out, _ = self.attention(attn_input, attn_input, attn_input) | |
| attn_pooled = attn_out[:, 0, :] | |
| logits = self.classifier(attn_pooled) | |
| return logits | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'Inter', sans-serif !important; | |
| color: #1f2937 !important; | |
| } | |
| .stApp { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| } | |
| h1, h2, h3 { | |
| color: #ffffff !important; | |
| text-shadow: 0 2px 4px rgba(0,0,0,0.3); | |
| } | |
| .stTextArea textarea { | |
| background-color: #ffffff !important; | |
| color: #1f2937 !important; | |
| border: 2px solid #e5e7eb !important; | |
| border-radius: 12px !important; | |
| padding: 16px !important; | |
| font-size: 16px !important; | |
| } | |
| .stTextArea label { | |
| color: #ffffff !important; | |
| font-weight: 700 !important; | |
| } | |
| .main-card { | |
| background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%); | |
| padding: 35px; | |
| border-radius: 20px; | |
| box-shadow: 0 20px 40px rgba(0,0,0,0.15); | |
| margin-bottom: 25px; | |
| text-align: center; | |
| border: 1px solid rgba(255,255,255,0.3); | |
| backdrop-filter: blur(10px); | |
| } | |
| .result-title { | |
| color: #475569 !important; | |
| font-size: 16px; | |
| text-transform: uppercase; | |
| letter-spacing: 1.5px; | |
| margin-bottom: 12px; | |
| font-weight: 700; | |
| } | |
| .result-value { | |
| font-size: 52px; | |
| font-weight: 800; | |
| margin: 0; | |
| text-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .section-header { | |
| font-size: 22px; | |
| font-weight: 700; | |
| color: #1e293b !important; | |
| margin-bottom: 20px; | |
| border-left: 6px solid #3b82f6; | |
| padding-left: 15px; | |
| background: rgba(255,255,255,0.8); | |
| padding: 12px 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
| } | |
| .model-card { | |
| background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%); | |
| padding: 25px; | |
| border-radius: 16px; | |
| box-shadow: 0 8px 25px rgba(0,0,0,0.12); | |
| margin-bottom: 20px; | |
| border: 1px solid rgba(255,255,255,0.5); | |
| transition: all 0.3s ease; | |
| } | |
| .model-card:hover { | |
| transform: translateY(-5px); | |
| box-shadow: 0 20px 40px rgba(0,0,0,0.2); | |
| } | |
| .model-name { | |
| color: #334155 !important; | |
| font-size: 15px; | |
| font-weight: 700; | |
| margin-bottom: 12px; | |
| border-bottom: 3px solid #e2e8f0; | |
| padding-bottom: 8px; | |
| } | |
| .prob-row { | |
| margin-bottom: 18px; | |
| background: rgba(255,255,255,0.9); | |
| padding: 15px; | |
| border-radius: 12px; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.05); | |
| } | |
| .prob-label { | |
| font-size: 15px; | |
| color: #1e293b !important; | |
| font-weight: 700; | |
| margin-bottom: 8px; | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| } | |
| .prob-bar-bg { | |
| width: 100%; | |
| height: 14px; | |
| background: linear-gradient(90deg, #f1f5f9, #e2e8f0); | |
| border-radius: 7px; | |
| overflow: hidden; | |
| box-shadow: inset 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| .prob-bar-fill { | |
| height: 100%; | |
| border-radius: 7px; | |
| transition: width 0.8s ease; | |
| box-shadow: 0 0 20px rgba(0,0,0,0.2); | |
| } | |
| .stButton > button { | |
| background: linear-gradient(45deg, #3b82f6, #1d4ed8) !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: 12px !important; | |
| padding: 14px 28px !important; | |
| font-weight: 700 !important; | |
| font-size: 16px !important; | |
| box-shadow: 0 8px 25px rgba(59,130,246,0.4) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 12px 35px rgba(59,130,246,0.6) !important; | |
| } | |
| .stRadio > div > label { | |
| color: #ffffff !important; | |
| font-weight: 600 !important; | |
| } | |
| .stSelectbox > label { | |
| color: #ffffff !important; | |
| font-weight: 600 !important; | |
| } | |
| .stExpander { | |
| background: rgba(255,255,255,0.1) !important; | |
| border-radius: 12px !important; | |
| border: 1px solid rgba(255,255,255,0.2) !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'} | |
| label_colors = { | |
| 'Very Negative': '#ef4444', | |
| 'Negative': '#f97316', | |
| 'Neutral': '#64748b', | |
| 'Positive': '#22c55e', | |
| 'Very Positive': '#16a34a' | |
| } | |
| def load_models(): | |
| models_loaded = {} | |
| target_models = { | |
| "model_banglabert": "rocky250/Sentiment-banglabert", | |
| "model_mbert": "rocky250/Sentiment-mbert", | |
| "model_bbase": "rocky250/Sentiment-bbase", | |
| "model_xlmr": "rocky250/Sentiment-xlmr", | |
| "bangla_political": "rocky250/bangla-political" | |
| } | |
| for name, repo in target_models.items(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(repo) | |
| model = AutoModelForSequenceClassification.from_pretrained(repo) | |
| models_loaded[name] = (tokenizer, model.to('cuda' if torch.cuda.is_available() else 'cpu')) | |
| except: | |
| continue | |
| return models_loaded | |
| models_dict = load_models() | |
| def predict_single_model(text, model_name): | |
| clean_text = normalize(text) | |
| tokenizer, model = models_dict[model_name] | |
| device = next(model.parameters()).device | |
| inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| pred_id = np.argmax(probs) | |
| prediction = id2label[pred_id] | |
| return prediction, probs | |
| def predict_ensemble(text): | |
| clean_text = normalize(text) | |
| all_probs = [] | |
| all_predictions = [] | |
| for name in models_dict.keys(): | |
| try: | |
| pred, probs = predict_single_model(clean_text, name) | |
| all_probs.append(probs) | |
| all_predictions.append(pred) | |
| except: | |
| continue | |
| if all_probs: | |
| avg_probs = np.mean(all_probs, axis=0) | |
| final_pred = id2label[np.argmax(avg_probs)] | |
| return final_pred, all_predictions, avg_probs | |
| return "Error", [], np.zeros(5) | |
| st.markdown(""" | |
| <div style=' | |
| text-align: center; | |
| background: rgba(255,255,255,0.1); | |
| padding: 30px; | |
| border-radius: 20px; | |
| margin-bottom: 30px; | |
| backdrop-filter: blur(20px); | |
| '> | |
| <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| user_input = st.text_area("Enter Bengali political text:", height=140, | |
| placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...", | |
| help="Type or paste Bengali political text for sentiment analysis") | |
| with col2: | |
| st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True) | |
| mode = st.radio("Analysis Mode:", | |
| ["Single Model", "Ensemble"], | |
| horizontal=True) | |
| selected_model = None | |
| if mode == "Single Model": | |
| model_options = {name: name for name in models_dict.keys()} | |
| selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0) | |
| analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True) | |
| if analyze_btn and user_input.strip(): | |
| with st.spinner('Processing with models...'): | |
| if mode == "Single Model": | |
| model_name = selected_model | |
| final_res, probs = predict_single_model(user_input, model_name) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.markdown(f""" | |
| <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}"> | |
| <div class="result-title">{model_name}</div> | |
| <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div> | |
| <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True) | |
| for i in range(5): | |
| label = id2label[i] | |
| prob = probs[i] * 100 | |
| color = label_colors[label] | |
| st.markdown(f""" | |
| <div class="prob-row"> | |
| <div class="prob-label"> | |
| <span style="font-weight: 700;">{label}</span> | |
| <span style="font-weight: 700; color: {color};">{prob:.1f}%</span> | |
| </div> | |
| <div class="prob-bar-bg"> | |
| <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| final_res, all_votes, avg_probs = predict_ensemble(user_input) | |
| main_col, details_col = st.columns([1, 1.4]) | |
| with main_col: | |
| st.markdown(f""" | |
| <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);"> | |
| <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div> | |
| <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown('<div class="section-header">Ensemble Probabilities</div>', unsafe_allow_html=True) | |
| for i in range(5): | |
| label = id2label[i] | |
| prob = avg_probs[i] * 100 | |
| color = label_colors[label] | |
| st.markdown(f""" | |
| <div class="prob-row"> | |
| <div class="prob-label"> | |
| <span>{label}</span> | |
| <span style="color: {color};">{prob:.1f}%</span> | |
| </div> | |
| <div class="prob-bar-bg"> | |
| <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with details_col: | |
| st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True) | |
| model_cols = st.columns(2) | |
| for idx, (name, vote) in enumerate(zip(list(models_dict.keys()), all_votes)): | |
| with model_cols[idx % 2]: | |
| color = label_colors[vote] | |
| st.markdown(f""" | |
| <div class="model-card"> | |
| <div class="model-name">{name}</div> | |
| <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif analyze_btn and not user_input.strip(): | |
| st.error("অনুগ্রহ করে কিছু টেক্সট লিখুন!") | |
| with st.expander("Example Political Texts", expanded=False): | |
| examples = [ | |
| "সরকারের এই নীতি দেশকে ধ্বংসের দিকে নিয়ে যাবে!", | |
| "চমৎকার সিদ্ধান্ত! দেশের জন্য গর্বিত। ভালো চলবে!", | |
| "রাজনীতির কোনো পরিবর্তন হবে না, সব একই রকম" | |
| ] | |
| example_cols = st.columns(3) | |
| for idx, example in enumerate(examples): | |
| with example_cols[idx]: | |
| if st.button(example[:40] + "..." if len(example) > 40 else example, | |
| use_container_width=True): | |
| st.session_state.user_input = example | |
| st.rerun() |