rocky250 commited on
Commit
75479e0
·
verified ·
1 Parent(s): d21ca0a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +212 -193
src/streamlit_app.py CHANGED
@@ -4,236 +4,255 @@ import torch.nn.functional as F
4
  import numpy as np
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from normalizer import normalize
 
 
7
 
8
  st.set_page_config(page_title="Political Sentiment", page_icon="🇧🇩", layout="wide")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  st.markdown("""
11
  <style>
12
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
13
 
14
- html, body, [class*="css"] {
15
- font-family: 'Inter', sans-serif;
16
- color: #000000;
17
- }
18
-
19
- .stApp {
20
- background-color: #f4f6f9;
21
- }
22
-
23
- h1, h2, h3 {
24
- color: #111827 !important;
25
- }
26
-
27
- .main-card {
28
- background-color: white;
29
- padding: 30px;
30
- border-radius: 15px;
31
- box-shadow: 0 4px 10px rgba(0,0,0,0.08);
32
- margin-bottom: 25px;
33
- text-align: center;
34
- border: 1px solid #e5e7eb;
35
- }
36
-
37
- .result-title {
38
- color: #374151;
39
- font-size: 16px;
40
- text-transform: uppercase;
41
- letter-spacing: 1.5px;
42
- margin-bottom: 10px;
43
- font-weight: 700;
44
- }
45
-
46
- .result-value {
47
- font-size: 48px;
48
- font-weight: 800;
49
- margin: 0;
50
- }
51
-
52
- .section-header {
53
- font-size: 20px;
54
- font-weight: 700;
55
- color: #1f2937;
56
- margin-bottom: 20px;
57
- border-left: 5px solid #2563eb;
58
- padding-left: 10px;
59
- }
60
-
61
- .model-card {
62
- background-color: white;
63
- padding: 20px;
64
- border-radius: 12px;
65
- box-shadow: 0 2px 5px rgba(0,0,0,0.05);
66
- margin-bottom: 15px;
67
- border: 1px solid #e5e7eb;
68
- transition: all 0.2s ease;
69
- }
70
-
71
- .model-card:hover {
72
- transform: translateY(-3px);
73
- box-shadow: 0 8px 15px rgba(0,0,0,0.1);
74
- }
75
-
76
- .model-name {
77
- color: #4b5563;
78
- font-size: 14px;
79
- font-weight: 600;
80
- margin-bottom: 8px;
81
- border-bottom: 2px solid #f3f4f6;
82
- padding-bottom: 5px;
83
- }
84
-
85
- .prob-row {
86
- margin-bottom: 15px;
87
- }
88
-
89
- .prob-label {
90
- font-size: 14px;
91
- color: #111827;
92
- font-weight: 600;
93
- margin-bottom: 5px;
94
- display: flex;
95
- justify-content: space-between;
96
- }
97
-
98
- .prob-bar-bg {
99
- width: 100%;
100
- height: 10px;
101
- background-color: #e5e7eb;
102
- border-radius: 5px;
103
- overflow: hidden;
104
- }
105
-
106
- .prob-bar-fill {
107
- height: 100%;
108
- border-radius: 5px;
109
- }
110
-
111
- .stTextArea textarea {
112
- border-radius: 12px !important;
113
- border: 2px solid #d1d5db !important;
114
- padding: 15px !important;
115
- font-size: 16px;
116
- }
117
  </style>
118
  """, unsafe_allow_html=True)
119
 
120
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
121
- label_colors = {
122
- 'Very Negative': '#DC2626',
123
- 'Negative': '#EA580C',
124
- 'Neutral': '#6B7280',
125
- 'Positive': '#16A34A',
126
- 'Very Positive': '#15803D'
127
- }
128
 
129
  @st.cache_resource
130
- def load_all_transformers():
131
- model_repos = {
 
 
132
  "BanglaBERT": "rocky250/political-banglabert",
133
  "mBERT": "rocky250/political-mbert",
134
  "B-Base": "rocky250/political-bbase",
135
  "XLM-R": "rocky250/political-xlmr"
136
  }
137
- loaded_models = {}
138
- for name, repo_path in model_repos.items():
139
- tokenizer = AutoTokenizer.from_pretrained(repo_path)
140
- model = AutoModelForSequenceClassification.from_pretrained(repo_path)
141
- loaded_models[name] = (tokenizer, model)
142
- return loaded_models
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- models_dict = load_all_transformers()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- def get_detailed_prediction(text):
147
  clean_text = normalize(text)
148
  all_probs = []
149
- votes = []
150
 
151
- for name, (tokenizer, model) in models_dict.items():
152
- inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
153
- with torch.no_grad():
154
- logits = model(**inputs).logits
155
- probs = F.softmax(logits, dim=1).numpy()[0]
156
  all_probs.append(probs)
157
- prediction_id = np.argmax(probs)
158
- votes.append(id2label[prediction_id])
159
-
160
- avg_probs = np.mean(all_probs, axis=0)
161
- final_vote = max(set(votes), key=votes.count)
162
- return final_vote, votes, avg_probs
 
 
 
163
 
164
- st.markdown("<h1 style='text-align: center; margin-bottom: 10px;'>Political Sentiment Analysis</h1>", unsafe_allow_html=True)
 
165
 
166
- with st.container():
167
- col_input, col_action = st.columns([4, 1])
168
- with col_input:
169
- user_input = st.text_area("", height=100, placeholder="Enter Bengali political text here...", label_visibility="collapsed")
170
- with col_action:
171
- st.markdown("<div style='height: 10px'></div>", unsafe_allow_html=True)
172
- analyze_btn = st.button("Analyze", type="primary", use_container_width=True)
 
 
 
 
 
 
 
 
 
173
 
174
  if analyze_btn and user_input.strip():
175
- with st.spinner('Running Ensemble Analysis...'):
176
- final_res, all_votes, avg_probs = get_detailed_prediction(user_input)
177
-
178
- st.markdown("<div style='height: 30px'></div>", unsafe_allow_html=True)
179
-
180
- main_col, details_col = st.columns([1, 1.5], gap="large")
181
-
182
- with main_col:
183
- st.markdown(f"""
184
- <div class="main-card" style="border-top: 6px solid {label_colors[final_res]}">
185
- <div class="result-title">Ensemble Consensus</div>
186
- <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
187
- </div>
188
- """, unsafe_allow_html=True)
189
 
190
- st.markdown('<div class="section-header">Probability Distribution</div>', unsafe_allow_html=True)
191
-
192
- for i in range(5):
193
- label = id2label[i]
194
- raw_prob = avg_probs[i] * 100
195
-
196
- # Visual cleanup: Force 0% if extremely low to reduce noise
197
- if raw_prob < 1.0:
198
- prob = 0.0
199
- else:
200
- prob = raw_prob
201
-
202
- color = label_colors[label]
203
- # High opacity for significant numbers, low for 0-1%
204
- opacity = "1.0" if prob > 1 else "0.3"
205
-
206
  st.markdown(f"""
207
- <div class="prob-row" style="opacity: {opacity}">
208
- <div class="prob-label">
209
- <span>{label}</span>
210
- <span>{prob:.1f}%</span>
211
- </div>
212
- <div class="prob-bar-bg">
213
- <div class="prob-bar-fill" style="width: {prob}%; background-color: {color};"></div>
214
- </div>
215
  </div>
216
  """, unsafe_allow_html=True)
217
-
218
- with details_col:
219
- st.markdown('<div class="section-header">Individual Model Predictions</div>', unsafe_allow_html=True)
220
- m_names = list(models_dict.keys())
221
 
222
- row1 = st.columns(2)
223
- row2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- for idx, col in enumerate(row1 + row2):
226
- model_name = m_names[idx]
227
- vote = all_votes[idx]
228
- color = label_colors[vote]
 
 
 
 
 
229
 
230
- with col:
 
 
 
 
 
231
  st.markdown(f"""
232
- <div class="model-card">
233
- <div class="model-name">{model_name}</div>
234
- <div style="color: {color}; font-weight: 700; font-size: 20px;">{vote}</div>
 
 
 
 
 
235
  </div>
236
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  elif analyze_btn and not user_input.strip():
239
- st.warning("Please enter some text to analyze.")
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from normalizer import normalize
7
+ import torch.nn as nn
8
+ from transformers import AutoModel
9
 
10
  st.set_page_config(page_title="Political Sentiment", page_icon="🇧🇩", layout="wide")
11
 
12
+ # Your Creative Model Architecture (from previous cells)
13
+ class CreativeBanglaPoliticalNet(nn.Module):
14
+ def __init__(self, num_classes=5):
15
+ super().__init__()
16
+ self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
17
+ self.hidden_size = self.banglabert.config.hidden_size
18
+
19
+ self.cnn_layers = nn.ModuleList([
20
+ nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
21
+ for k in [3,5,7]
22
+ ])
23
+
24
+ self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
25
+ self.classifier = nn.Sequential(
26
+ nn.Dropout(0.3),
27
+ nn.Linear(self.hidden_size, 512),
28
+ nn.ReLU(),
29
+ nn.Dropout(0.2),
30
+ nn.Linear(512, num_classes)
31
+ )
32
+
33
+ def forward(self, input_ids, attention_mask=None):
34
+ bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
35
+
36
+ cnn_features = []
37
+ for cnn in self.cnn_layers:
38
+ cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2)
39
+ cnn_features.append(F.relu(cnn_out))
40
+
41
+ cnn_concat = torch.cat(cnn_features, dim=-1)
42
+ attn_out, _ = self.attention(cnn_concat, cnn_concat, cnn_concat)
43
+ attn_pooled = attn_out[:, 0, :]
44
+
45
+ logits = self.classifier(attn_pooled)
46
+ return logits
47
+
48
  st.markdown("""
49
  <style>
50
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
51
 
52
+ html, body, [class*="css"] { font-family: 'Inter', sans-serif; color: #000000; }
53
+ .stApp { background-color: #f4f6f9; }
54
+ h1, h2, h3 { color: #111827 !important; }
55
+ .main-card { background-color: white; padding: 30px; border-radius: 15px; box-shadow: 0 4px 10px rgba(0,0,0,0.08); margin-bottom: 25px; text-align: center; border: 1px solid #e5e7eb; }
56
+ .result-title { color: #374151; font-size: 16px; text-transform: uppercase; letter-spacing: 1.5px; margin-bottom: 10px; font-weight: 700; }
57
+ .result-value { font-size: 48px; font-weight: 800; margin: 0; }
58
+ .section-header { font-size: 20px; font-weight: 700; color: #1f2937; margin-bottom: 20px; border-left: 5px solid #2563eb; padding-left: 10px; }
59
+ .model-card { background-color: white; padding: 20px; border-radius: 12px; box-shadow: 0 2px 5px rgba(0,0,0,0.05); margin-bottom: 15px; border: 1px solid #e5e7eb; }
60
+ .model-card:hover { transform: translateY(-3px); box-shadow: 0 8px 15px rgba(0,0,0,0.1); }
61
+ .model-name { color: #4b5563; font-size: 14px; font-weight: 600; margin-bottom: 8px; border-bottom: 2px solid #f3f4f6; padding-bottom: 5px; }
62
+ .prob-row { margin-bottom: 15px; }
63
+ .prob-label { font-size: 14px; color: #111827; font-weight: 600; margin-bottom: 5px; display: flex; justify-content: space-between; }
64
+ .prob-bar-bg { width: 100%; height: 10px; background-color: #e5e7eb; border-radius: 5px; overflow: hidden; }
65
+ .prob-bar-fill { height: 100%; border-radius: 5px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  </style>
67
  """, unsafe_allow_html=True)
68
 
69
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
70
+ label_colors = {'Very Negative': '#DC2626', 'Negative': '#EA580C', 'Neutral': '#6B7280', 'Positive': '#16A34A', 'Very Positive': '#15803D'}
 
 
 
 
 
 
71
 
72
  @st.cache_resource
73
+ def load_models():
74
+ models = {}
75
+
76
+ standard_models = {
77
  "BanglaBERT": "rocky250/political-banglabert",
78
  "mBERT": "rocky250/political-mbert",
79
  "B-Base": "rocky250/political-bbase",
80
  "XLM-R": "rocky250/political-xlmr"
81
  }
82
+
83
+ for name, repo in standard_models.items():
84
+ try:
85
+ tokenizer = AutoTokenizer.from_pretrained(repo)
86
+ model = AutoModelForSequenceClassification.from_pretrained(repo)
87
+ models[name] = (tokenizer, model)
88
+ except:
89
+ st.warning(f"Could not load {name}")
90
+
91
+
92
+ try:
93
+ creative_tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
94
+ creative_model = CreativeBanglaPoliticalNet(num_classes=5)
95
+ creative_model.load_state_dict(torch.load("rocky250/bangla-political/pytorch_model.bin"))
96
+ models["Creative Hybrid"] = (creative_tokenizer, creative_model)
97
+ except:
98
+ st.info("Bangla-political model not available - using standard models")
99
+
100
+ return models
101
 
102
+ models_dict = load_models()
103
+
104
+ def predict_single_model(text, model_name):
105
+ clean_text = normalize(text)
106
+ tokenizer, model = models_dict[model_name]
107
+
108
+ device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
109
+
110
+ inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
111
+
112
+ with torch.no_grad():
113
+ if "Creative" in model_name:
114
+ logits = model(**inputs)
115
+ else:
116
+ outputs = model(**inputs)
117
+ logits = outputs.logits
118
+
119
+ probs = F.softmax(logits, dim=1).cpu().numpy()[0]
120
+ pred_id = np.argmax(probs)
121
+ prediction = id2label[pred_id]
122
+
123
+ return prediction, probs
124
 
125
+ def predict_ensemble(text):
126
  clean_text = normalize(text)
127
  all_probs = []
128
+ all_predictions = []
129
 
130
+ for name in models_dict.keys():
131
+ try:
132
+ pred, probs = predict_single_model(clean_text, name)
 
 
133
  all_probs.append(probs)
134
+ all_predictions.append(pred)
135
+ except:
136
+ continue
137
+
138
+ if all_probs:
139
+ avg_probs = np.mean(all_probs, axis=0)
140
+ final_pred = id2label[np.argmax(avg_probs)]
141
+ return final_pred, all_predictions, avg_probs
142
+ return "Error", [], np.zeros(5)
143
 
144
+ st.markdown("<h1 style='text-align: center;'>Political Sentiment Analysis</h1>", unsafe_allow_html=True)
145
+ st.markdown("<p style='text-align: center; color: #6B7280; font-size: 18px;'>Advanced Bengali Political Text Analysis with Custom Hybrid Architecture</p>", unsafe_allow_html=True)
146
 
147
+ # Model selection
148
+ col1, col2 = st.columns([3, 1])
149
+ with col1:
150
+ user_input = st.text_area("", height=120, placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...", label_visibility="collapsed")
151
+
152
+ with col2:
153
+ st.markdown("<div style='height: 40px'></div>", unsafe_allow_html=True)
154
+ mode = st.radio("Analysis Mode:",
155
+ ["Single Model", "Ensemble (Recommended)"],
156
+ horizontal=True,
157
+ label_visibility="collapsed")
158
+
159
+ model_options = {f"({i+1}) {name}": name for i, name in enumerate(models_dict.keys())}
160
+ selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
161
+
162
+ analyze_btn = st.button("Analyze Sentiment", type="primary", use_container_width=True)
163
 
164
  if analyze_btn and user_input.strip():
165
+ with st.spinner('Processing with advanced models...'):
166
+ if mode == "Single Model":
167
+ model_name = model_options[selected_model]
168
+ final_res, probs = predict_single_model(user_input, model_name)
 
 
 
 
 
 
 
 
 
 
169
 
170
+ col1, col2 = st.columns([1, 2])
171
+ with col1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  st.markdown(f"""
173
+ <div class="main-card" style="border-top: 6px solid {label_colors[final_res]}">
174
+ <div class="result-title">{model_name}</div>
175
+ <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
 
 
 
 
 
176
  </div>
177
  """, unsafe_allow_html=True)
 
 
 
 
178
 
179
+ with col2:
180
+ st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
181
+ for i in range(5):
182
+ label = id2label[i]
183
+ prob = probs[i] * 100
184
+ color = label_colors[label]
185
+ opacity = "1.0" if prob > 5 else "0.4"
186
+
187
+ st.markdown(f"""
188
+ <div class="prob-row" style="opacity: {opacity}">
189
+ <div class="prob-label">
190
+ <span>{label}</span>
191
+ <span>{prob:.1f}%</span>
192
+ </div>
193
+ <div class="prob-bar-bg">
194
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background-color: {color};"></div>
195
+ </div>
196
+ </div>
197
+ """, unsafe_allow_html=True)
198
+
199
+ else: # Ensemble mode
200
+ final_res, all_votes, avg_probs = predict_ensemble(user_input)
201
+
202
+ main_col, details_col = st.columns([1, 1.5], gap="large")
203
 
204
+ with main_col:
205
+ st.markdown(f"""
206
+ <div class="main-card" style="border-top: 6px solid {label_colors[final_res]}">
207
+ <div class="result-title">Ensemble Consensus</div>
208
+ <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
209
+ </div>
210
+ """, unsafe_allow_html=True)
211
+
212
+ st.markdown('<div class="section-header">Ensemble Probabilities</div>', unsafe_allow_html=True)
213
 
214
+ for i in range(5):
215
+ label = id2label[i]
216
+ prob = avg_probs[i] * 100
217
+ color = label_colors[label]
218
+ opacity = "1.0" if prob > 1 else "0.3"
219
+
220
  st.markdown(f"""
221
+ <div class="prob-row" style="opacity: {opacity}">
222
+ <div class="prob-label">
223
+ <span>{label}</span>
224
+ <span>{prob:.1f}%</span>
225
+ </div>
226
+ <div class="prob-bar-bg">
227
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background-color: {color};"></div>
228
+ </div>
229
  </div>
230
  """, unsafe_allow_html=True)
231
+
232
+ with details_col:
233
+ st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
234
+ cols = st.columns(2)
235
+ for idx, (name, vote) in enumerate(zip(models_dict.keys(), all_votes[:4])):
236
+ with cols[idx % 2]:
237
+ color = label_colors[vote]
238
+ st.markdown(f"""
239
+ <div class="model-card">
240
+ <div class="model-name">{name}</div>
241
+ <div style="color: {color}; font-weight: 700; font-size: 20px;">{vote}</div>
242
+ </div>
243
+ """, unsafe_allow_html=True)
244
 
245
  elif analyze_btn and not user_input.strip():
246
+ st.warning("অনুগ্রহ করে কিছু টেক্সট লিখুন!")
247
+
248
+ # Examples
249
+ with st.expander("Example Political Texts"):
250
+ examples = [
251
+ "সরকারের এই নীতি দেশকে ধ্বংসের দিকে নিয়ে যাবে!",
252
+ "চমৎকার সিদ্ধান্ত! দেশের জন্য গর্বিত। ভালো চলবে!",
253
+ "রাজনীতির কোনো পরিবর্তন হবে না, সব একই রকম"
254
+ ]
255
+ for example in examples:
256
+ if st.button(example, key=example[:20]):
257
+ st.session_state.user_input = example
258
+ st.rerun()