rocky250 commited on
Commit
9459640
·
verified ·
1 Parent(s): ee406a9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +214 -246
src/streamlit_app.py CHANGED
@@ -14,12 +14,10 @@ class BanglaPoliticalNet(nn.Module):
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
  self.hidden_size = self.banglabert.config.hidden_size
17
-
18
  self.cnn_layers = nn.ModuleList([
19
- nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
20
  for k in [3,5,7]
21
  ])
22
-
23
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
24
  self.classifier = nn.Sequential(
25
  nn.Dropout(0.3),
@@ -28,256 +26,224 @@ class BanglaPoliticalNet(nn.Module):
28
  nn.Dropout(0.2),
29
  nn.Linear(512, num_classes)
30
  )
31
-
32
  def forward(self, input_ids, attention_mask=None):
33
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
34
-
35
  cnn_features = []
36
  for cnn in self.cnn_layers:
37
  cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2)
38
  cnn_features.append(F.relu(cnn_out))
39
-
40
  cnn_concat = torch.cat(cnn_features, dim=-1)
41
  proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
42
  attn_input = proj(cnn_concat)
43
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
44
  attn_pooled = attn_out[:, 0, :]
45
-
46
  logits = self.classifier(attn_pooled)
47
  return logits
48
 
49
  st.markdown("""
50
  <style>
51
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
52
-
53
- html, body, [class*="css"] {
54
- font-family: 'Inter', sans-serif !important;
55
- color: #1f2937 !important;
56
- }
57
-
58
- .stApp {
59
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
60
- }
61
-
62
- h1, h2, h3 {
63
- color: #ffffff !important;
64
- text-shadow: 0 2px 4px rgba(0,0,0,0.3);
65
- }
66
-
67
- .stTextArea textarea {
68
- background-color: #ffffff !important;
69
- color: #1f2937 !important;
70
- border: 2px solid #e5e7eb !important;
71
- border-radius: 12px !important;
72
- padding: 16px !important;
73
- font-size: 16px !important;
74
- }
75
-
76
- .stTextArea label {
77
- color: #ffffff !important;
78
- font-weight: 700 !important;
79
- }
80
-
81
- .main-card {
82
- background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%);
83
- padding: 35px;
84
- border-radius: 20px;
85
- box-shadow: 0 20px 40px rgba(0,0,0,0.15);
86
- margin-bottom: 25px;
87
- text-align: center;
88
- border: 1px solid rgba(255,255,255,0.3);
89
- backdrop-filter: blur(10px);
90
- }
91
-
92
- .result-title {
93
- color: #475569 !important;
94
- font-size: 16px;
95
- text-transform: uppercase;
96
- letter-spacing: 1.5px;
97
- margin-bottom: 12px;
98
- font-weight: 700;
99
- }
100
-
101
- .result-value {
102
- font-size: 52px;
103
- font-weight: 800;
104
- margin: 0;
105
- text-shadow: 0 2px 4px rgba(0,0,0,0.1);
106
- }
107
-
108
- .section-header {
109
- font-size: 22px;
110
- font-weight: 700;
111
- color: #1e293b !important;
112
- margin-bottom: 20px;
113
- border-left: 6px solid #3b82f6;
114
- padding-left: 15px;
115
- background: rgba(255,255,255,0.8);
116
- padding: 12px 20px;
117
- border-radius: 10px;
118
- box-shadow: 0 4px 12px rgba(0,0,0,0.1);
119
- }
120
-
121
- .model-card {
122
- background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%);
123
- padding: 25px;
124
- border-radius: 16px;
125
- box-shadow: 0 8px 25px rgba(0,0,0,0.12);
126
- margin-bottom: 20px;
127
- border: 1px solid rgba(255,255,255,0.5);
128
- transition: all 0.3s ease;
129
- }
130
-
131
- .model-card:hover {
132
- transform: translateY(-5px);
133
- box-shadow: 0 20px 40px rgba(0,0,0,0.2);
134
- }
135
-
136
- .model-name {
137
- color: #334155 !important;
138
- font-size: 15px;
139
- font-weight: 700;
140
- margin-bottom: 12px;
141
- border-bottom: 3px solid #e2e8f0;
142
- padding-bottom: 8px;
143
- }
144
-
145
- .prob-row {
146
- margin-bottom: 18px;
147
- background: rgba(255,255,255,0.9);
148
- padding: 15px;
149
- border-radius: 12px;
150
- box-shadow: 0 2px 8px rgba(0,0,0,0.05);
151
- }
152
-
153
- .prob-label {
154
- font-size: 15px;
155
- color: #1e293b !important;
156
- font-weight: 700;
157
- margin-bottom: 8px;
158
- display: flex;
159
- justify-content: space-between;
160
- align-items: center;
161
- }
162
-
163
- .prob-bar-bg {
164
- width: 100%;
165
- height: 14px;
166
- background: linear-gradient(90deg, #f1f5f9, #e2e8f0);
167
- border-radius: 7px;
168
- overflow: hidden;
169
- box-shadow: inset 0 2px 4px rgba(0,0,0,0.05);
170
- }
171
-
172
- .prob-bar-fill {
173
- height: 100%;
174
- border-radius: 7px;
175
- transition: width 0.8s ease;
176
- box-shadow: 0 0 20px rgba(0,0,0,0.2);
177
- }
178
-
179
- .stButton > button {
180
- background: linear-gradient(45deg, #3b82f6, #1d4ed8) !important;
181
- color: white !important;
182
- border: none !important;
183
- border-radius: 12px !important;
184
- padding: 14px 28px !important;
185
- font-weight: 700 !important;
186
- font-size: 16px !important;
187
- box-shadow: 0 8px 25px rgba(59,130,246,0.4) !important;
188
- transition: all 0.3s ease !important;
189
- }
190
-
191
- .stButton > button:hover {
192
- transform: translateY(-2px) !important;
193
- box-shadow: 0 12px 35px rgba(59,130,246,0.6) !important;
194
- }
195
-
196
- .stRadio > div > label {
197
- color: #ffffff !important;
198
- font-weight: 600 !important;
199
- }
200
-
201
- .stSelectbox > label {
202
- color: #ffffff !important;
203
- font-weight: 600 !important;
204
- }
205
-
206
- .stExpander {
207
- background: rgba(255,255,255,0.1) !important;
208
- border-radius: 12px !important;
209
- border: 1px solid rgba(255,255,255,0.2) !important;
210
- }
211
  </style>
212
  """, unsafe_allow_html=True)
213
 
214
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
215
  label_colors = {
216
- 'Very Negative': '#ef4444',
217
- 'Negative': '#f97316',
218
- 'Neutral': '#64748b',
219
- 'Positive': '#22c55e',
220
  'Very Positive': '#16a34a'
221
  }
222
 
 
 
 
 
 
 
 
 
223
  @st.cache_resource
224
  def load_models():
225
- models = {}
226
-
227
- standard_models = {
228
- "BanglaBERT": "rocky250/Sentiment-banglabert",
229
- "mBERT": "rocky250/Sentiment-mbert",
230
- "B-Base": "rocky250/Sentiment-bbase",
231
- "XLM-R": "rocky250/Sentiment-xlmr",
232
- "Bangla-P": "rocky250/bangla-political"
233
- }
234
-
235
- for name, repo in standard_models.items():
236
  try:
237
  tokenizer = AutoTokenizer.from_pretrained(repo)
238
  model = AutoModelForSequenceClassification.from_pretrained(repo)
239
- models[name] = (tokenizer, model.to('cuda' if torch.cuda.is_available() else 'cpu'))
240
  except:
241
  continue
242
-
243
- try:
244
- SA_tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
245
- model_SA = BanglaPoliticalNet(num_classes=5)
246
- model_SA.load_state_dict(torch.load("rocky250/bangla-political/pytorch_model.bin", map_location='cpu'))
247
- model_SA = model_SA.to('cuda' if torch.cuda.is_available() else 'cpu')
248
- models["Creative Model"] = (SA_tokenizer, model_SA)
249
- except:
250
- pass
251
-
252
- return models
253
 
254
  models_dict = load_models()
255
 
256
  def predict_single_model(text, model_name):
257
  clean_text = normalize(text)
258
  tokenizer, model = models_dict[model_name]
259
-
260
  device = next(model.parameters()).device
261
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
262
-
263
  with torch.no_grad():
264
  if "Creative" in model_name:
265
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
266
  else:
267
  outputs = model(**inputs)
268
  logits = outputs.logits
269
-
270
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
271
  pred_id = np.argmax(probs)
272
  prediction = id2label[pred_id]
273
-
274
  return prediction, probs
275
 
276
  def predict_ensemble(text):
277
  clean_text = normalize(text)
278
  all_probs = []
279
  all_predictions = []
280
-
281
  for name in list(models_dict.keys())[:4]:
282
  try:
283
  pred, probs = predict_single_model(clean_text, name)
@@ -285,7 +251,6 @@ def predict_ensemble(text):
285
  all_predictions.append(pred)
286
  except:
287
  continue
288
-
289
  if all_probs:
290
  avg_probs = np.mean(all_probs, axis=0)
291
  final_pred = id2label[np.argmax(avg_probs)]
@@ -294,31 +259,34 @@ def predict_ensemble(text):
294
 
295
  st.markdown("""
296
  <div style='
297
- text-align: center;
298
- background: rgba(255,255,255,0.1);
299
- padding: 30px;
300
- border-radius: 20px;
301
- margin-bottom: 30px;
302
- backdrop-filter: blur(20px);
303
  '>
304
- <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
305
  </div>
306
  """, unsafe_allow_html=True)
307
 
308
  col1, col2 = st.columns([3, 1])
309
  with col1:
310
- user_input = st.text_area("Enter Bengali political text:", height=140,
311
- placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...",
312
- help="Type or paste Bengali political text for sentiment analysis")
313
 
314
  with col2:
315
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
316
- mode = st.radio("Analysis Mode:",
317
- ["Single Model", "Ensemble"],
318
- horizontal=True)
319
-
 
320
  model_options = {f"({i+1}) {name}": name for i, name in enumerate(models_dict.keys())}
321
  selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
 
 
322
 
323
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
324
 
@@ -327,68 +295,68 @@ if analyze_btn and user_input.strip():
327
  if mode == "Single Model":
328
  model_name = model_options[selected_model]
329
  final_res, probs = predict_single_model(user_input, model_name)
330
-
331
  col1, col2 = st.columns([1, 2])
332
  with col1:
333
  st.markdown(f"""
334
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}">
335
- <div class="result-title">{model_name}</div>
336
- <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
337
- <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div>
338
  </div>
339
  """, unsafe_allow_html=True)
340
-
341
  with col2:
342
  st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
343
  for i in range(5):
344
  label = id2label[i]
345
  prob = probs[i] * 100
346
  color = label_colors[label]
347
-
348
  st.markdown(f"""
349
  <div class="prob-row">
350
- <div class="prob-label">
351
- <span style="font-weight: 700;">{label}</span>
352
- <span style="font-weight: 700; color: {color};">{prob:.1f}%</span>
353
- </div>
354
- <div class="prob-bar-bg">
355
- <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
356
- </div>
357
  </div>
358
  """, unsafe_allow_html=True)
359
-
360
  else:
361
  final_res, all_votes, avg_probs = predict_ensemble(user_input)
362
-
363
  main_col, details_col = st.columns([1, 1.4])
364
-
365
  with main_col:
366
  st.markdown(f"""
367
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);">
368
- <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div>
369
- <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div>
370
  </div>
371
  """, unsafe_allow_html=True)
372
-
373
  st.markdown('<div class="section-header">Ensemble Probabilities</div>', unsafe_allow_html=True)
374
-
375
  for i in range(5):
376
  label = id2label[i]
377
  prob = avg_probs[i] * 100
378
  color = label_colors[label]
379
-
380
  st.markdown(f"""
381
  <div class="prob-row">
382
- <div class="prob-label">
383
- <span>{label}</span>
384
- <span style="color: {color};">{prob:.1f}%</span>
385
- </div>
386
- <div class="prob-bar-bg">
387
- <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
388
- </div>
389
  </div>
390
  """, unsafe_allow_html=True)
391
-
392
  with details_col:
393
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
394
  model_cols = st.columns(2)
@@ -397,8 +365,8 @@ if analyze_btn and user_input.strip():
397
  color = label_colors[vote]
398
  st.markdown(f"""
399
  <div class="model-card">
400
- <div class="model-name">{name}</div>
401
- <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div>
402
  </div>
403
  """, unsafe_allow_html=True)
404
 
@@ -414,7 +382,7 @@ with st.expander("Example Political Texts", expanded=False):
414
  example_cols = st.columns(3)
415
  for idx, example in enumerate(examples):
416
  with example_cols[idx]:
417
- if st.button(example[:40] + "..." if len(example) > 40 else example,
418
  use_container_width=True):
419
  st.session_state.user_input = example
420
  st.rerun()
 
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
  self.hidden_size = self.banglabert.config.hidden_size
 
17
  self.cnn_layers = nn.ModuleList([
18
+ nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
19
  for k in [3,5,7]
20
  ])
 
21
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
22
  self.classifier = nn.Sequential(
23
  nn.Dropout(0.3),
 
26
  nn.Dropout(0.2),
27
  nn.Linear(512, num_classes)
28
  )
29
+
30
  def forward(self, input_ids, attention_mask=None):
31
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
 
32
  cnn_features = []
33
  for cnn in self.cnn_layers:
34
  cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2)
35
  cnn_features.append(F.relu(cnn_out))
 
36
  cnn_concat = torch.cat(cnn_features, dim=-1)
37
  proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
38
  attn_input = proj(cnn_concat)
39
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
40
  attn_pooled = attn_out[:, 0, :]
 
41
  logits = self.classifier(attn_pooled)
42
  return logits
43
 
44
  st.markdown("""
45
  <style>
46
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
47
+ html, body, [class*="css"] {
48
+ font-family: 'Inter', sans-serif !important;
49
+ color: #1f2937 !important;
50
+ }
51
+ .stApp {
52
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
53
+ }
54
+ h1, h2, h3 {
55
+ color: #ffffff !important;
56
+ text-shadow: 0 2px 4px rgba(0,0,0,0.3);
57
+ }
58
+ .stTextArea textarea {
59
+ background-color: #ffffff !important;
60
+ color: #1f2937 !important;
61
+ border: 2px solid #e5e7eb !important;
62
+ border-radius: 12px !important;
63
+ padding: 16px !important;
64
+ font-size: 16px !important;
65
+ }
66
+ .stTextArea label {
67
+ color: #ffffff !important;
68
+ font-weight: 700 !important;
69
+ }
70
+ .main-card {
71
+ background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%);
72
+ padding: 35px;
73
+ border-radius: 20px;
74
+ box-shadow: 0 20px 40px rgba(0,0,0,0.15);
75
+ margin-bottom: 25px;
76
+ text-align: center;
77
+ border: 1px solid rgba(255,255,255,0.3);
78
+ backdrop-filter: blur(10px);
79
+ }
80
+ .result-title {
81
+ color: #475569 !important;
82
+ font-size: 16px;
83
+ text-transform: uppercase;
84
+ letter-spacing: 1.5px;
85
+ margin-bottom: 12px;
86
+ font-weight: 700;
87
+ }
88
+ .result-value {
89
+ font-size: 52px;
90
+ font-weight: 800;
91
+ margin: 0;
92
+ text-shadow: 0 2px 4px rgba(0,0,0,0.1);
93
+ }
94
+ .section-header {
95
+ font-size: 22px;
96
+ font-weight: 700;
97
+ color: #1e293b !important;
98
+ margin-bottom: 20px;
99
+ border-left: 6px solid #3b82f6;
100
+ padding-left: 15px;
101
+ background: rgba(255,255,255,0.8);
102
+ padding: 12px 20px;
103
+ border-radius: 10px;
104
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
105
+ }
106
+ .model-card {
107
+ background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%);
108
+ padding: 25px;
109
+ border-radius: 16px;
110
+ box-shadow: 0 8px 25px rgba(0,0,0,0.12);
111
+ margin-bottom: 20px;
112
+ border: 1px solid rgba(255,255,255,0.5);
113
+ transition: all 0.3s ease;
114
+ }
115
+ .model-card:hover {
116
+ transform: translateY(-5px);
117
+ box-shadow: 0 20px 40px rgba(0,0,0,0.2);
118
+ }
119
+ .model-name {
120
+ color: #334155 !important;
121
+ font-size: 15px;
122
+ font-weight: 700;
123
+ margin-bottom: 12px;
124
+ border-bottom: 3px solid #e2e8f0;
125
+ padding-bottom: 8px;
126
+ }
127
+ .prob-row {
128
+ margin-bottom: 18px;
129
+ background: rgba(255,255,255,0.9);
130
+ padding: 15px;
131
+ border-radius: 12px;
132
+ box-shadow: 0 2px 8px rgba(0,0,0,0.05);
133
+ }
134
+ .prob-label {
135
+ font-size: 15px;
136
+ color: #1e293b !important;
137
+ font-weight: 700;
138
+ margin-bottom: 8px;
139
+ display: flex;
140
+ justify-content: space-between;
141
+ align-items: center;
142
+ }
143
+ .prob-bar-bg {
144
+ width: 100%;
145
+ height: 14px;
146
+ background: linear-gradient(90deg, #f1f5f9, #e2e8f0);
147
+ border-radius: 7px;
148
+ overflow: hidden;
149
+ box-shadow: inset 0 2px 4px rgba(0,0,0,0.05);
150
+ }
151
+ .prob-bar-fill {
152
+ height: 100%;
153
+ border-radius: 7px;
154
+ transition: width 0.8s ease;
155
+ box-shadow: 0 0 20px rgba(0,0,0,0.2);
156
+ }
157
+ .stButton > button {
158
+ background: linear-gradient(45deg, #3b82f6, #1d4ed8) !important;
159
+ color: white !important;
160
+ border: none !important;
161
+ border-radius: 12px !important;
162
+ padding: 14px 28px !important;
163
+ font-weight: 700 !important;
164
+ font-size: 16px !important;
165
+ box-shadow: 0 8px 25px rgba(59,130,246,0.4) !important;
166
+ transition: all 0.3s ease !important;
167
+ }
168
+ .stButton > button:hover {
169
+ transform: translateY(-2px) !important;
170
+ box-shadow: 0 12px 35px rgba(59,130,246,0.6) !important;
171
+ }
172
+ .stRadio > div > label {
173
+ color: #ffffff !important;
174
+ font-weight: 600 !important;
175
+ }
176
+ .stSelectbox > label {
177
+ color: #ffffff !important;
178
+ font-weight: 600 !important;
179
+ }
180
+ .stExpander {
181
+ background: rgba(255,255,255,0.1) !important;
182
+ border-radius: 12px !important;
183
+ border: 1px solid rgba(255,255,255,0.2) !important;
184
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  </style>
186
  """, unsafe_allow_html=True)
187
 
188
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
189
  label_colors = {
190
+ 'Very Negative': '#ef4444',
191
+ 'Negative': '#f97316',
192
+ 'Neutral': '#64748b',
193
+ 'Positive': '#22c55e',
194
  'Very Positive': '#16a34a'
195
  }
196
 
197
+ models = {
198
+ "model_banglabert": "rocky250/Sentiment-banglabert",
199
+ "model_mbert": "rocky250/Sentiment-mbert",
200
+ "model_bbase": "rocky250/Sentiment-bbase",
201
+ "model_xlmr": "rocky250/Sentiment-xlmr",
202
+ "bangla_political": "rocky250/bangla-political"
203
+ }
204
+
205
  @st.cache_resource
206
  def load_models():
207
+ models_dict = {}
208
+ for key, repo in models.items():
 
 
 
 
 
 
 
 
 
209
  try:
210
  tokenizer = AutoTokenizer.from_pretrained(repo)
211
  model = AutoModelForSequenceClassification.from_pretrained(repo)
212
+ models_dict[key] = (tokenizer, model.to('cuda' if torch.cuda.is_available() else 'cpu'))
213
  except:
214
  continue
215
+ try:
216
+ SA_tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
217
+ model_SA = BanglaPoliticalNet(num_classes=5)
218
+ model_SA.load_state_dict(torch.load("rocky250/bangla-political/pytorch_model.bin", map_location='cpu'))
219
+ model_SA = model_SA.to('cuda' if torch.cuda.is_available() else 'cpu')
220
+ models_dict["Creative Model"] = (SA_tokenizer, model_SA)
221
+ except:
222
+ pass
223
+ return models_dict
 
 
224
 
225
  models_dict = load_models()
226
 
227
  def predict_single_model(text, model_name):
228
  clean_text = normalize(text)
229
  tokenizer, model = models_dict[model_name]
 
230
  device = next(model.parameters()).device
231
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
 
232
  with torch.no_grad():
233
  if "Creative" in model_name:
234
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
235
  else:
236
  outputs = model(**inputs)
237
  logits = outputs.logits
 
238
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
239
  pred_id = np.argmax(probs)
240
  prediction = id2label[pred_id]
 
241
  return prediction, probs
242
 
243
  def predict_ensemble(text):
244
  clean_text = normalize(text)
245
  all_probs = []
246
  all_predictions = []
 
247
  for name in list(models_dict.keys())[:4]:
248
  try:
249
  pred, probs = predict_single_model(clean_text, name)
 
251
  all_predictions.append(pred)
252
  except:
253
  continue
 
254
  if all_probs:
255
  avg_probs = np.mean(all_probs, axis=0)
256
  final_pred = id2label[np.argmax(avg_probs)]
 
259
 
260
  st.markdown("""
261
  <div style='
262
+ text-align: center;
263
+ background: rgba(255,255,255,0.1);
264
+ padding: 30px;
265
+ border-radius: 20px;
266
+ margin-bottom: 30px;
267
+ backdrop-filter: blur(20px);
268
  '>
269
+ <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
270
  </div>
271
  """, unsafe_allow_html=True)
272
 
273
  col1, col2 = st.columns([3, 1])
274
  with col1:
275
+ user_input = st.text_area("Enter Bengali political text:", height=140,
276
+ placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...",
277
+ help="Type or paste Bengali political text for sentiment analysis")
278
 
279
  with col2:
280
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
281
+ mode = st.radio("Analysis Mode:",
282
+ ["Single Model", "Ensemble"],
283
+ horizontal=True)
284
+
285
+ if mode == "Single Model":
286
  model_options = {f"({i+1}) {name}": name for i, name in enumerate(models_dict.keys())}
287
  selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
288
+ else:
289
+ st.markdown("<div style='height: 50px'></div>", unsafe_allow_html=True)
290
 
291
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
292
 
 
295
  if mode == "Single Model":
296
  model_name = model_options[selected_model]
297
  final_res, probs = predict_single_model(user_input, model_name)
298
+
299
  col1, col2 = st.columns([1, 2])
300
  with col1:
301
  st.markdown(f"""
302
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}">
303
+ <div class="result-title">{model_name}</div>
304
+ <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
305
+ <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div>
306
  </div>
307
  """, unsafe_allow_html=True)
308
+
309
  with col2:
310
  st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
311
  for i in range(5):
312
  label = id2label[i]
313
  prob = probs[i] * 100
314
  color = label_colors[label]
315
+
316
  st.markdown(f"""
317
  <div class="prob-row">
318
+ <div class="prob-label">
319
+ <span style="font-weight: 700;">{label}</span>
320
+ <span style="font-weight: 700; color: {color};">{prob:.1f}%</span>
321
+ </div>
322
+ <div class="prob-bar-bg">
323
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
324
+ </div>
325
  </div>
326
  """, unsafe_allow_html=True)
327
+
328
  else:
329
  final_res, all_votes, avg_probs = predict_ensemble(user_input)
330
+
331
  main_col, details_col = st.columns([1, 1.4])
332
+
333
  with main_col:
334
  st.markdown(f"""
335
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);">
336
+ <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div>
337
+ <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div>
338
  </div>
339
  """, unsafe_allow_html=True)
340
+
341
  st.markdown('<div class="section-header">Ensemble Probabilities</div>', unsafe_allow_html=True)
342
+
343
  for i in range(5):
344
  label = id2label[i]
345
  prob = avg_probs[i] * 100
346
  color = label_colors[label]
347
+
348
  st.markdown(f"""
349
  <div class="prob-row">
350
+ <div class="prob-label">
351
+ <span>{label}</span>
352
+ <span style="color: {color};">{prob:.1f}%</span>
353
+ </div>
354
+ <div class="prob-bar-bg">
355
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
356
+ </div>
357
  </div>
358
  """, unsafe_allow_html=True)
359
+
360
  with details_col:
361
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
362
  model_cols = st.columns(2)
 
365
  color = label_colors[vote]
366
  st.markdown(f"""
367
  <div class="model-card">
368
+ <div class="model-name">{name}</div>
369
+ <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div>
370
  </div>
371
  """, unsafe_allow_html=True)
372
 
 
382
  example_cols = st.columns(3)
383
  for idx, example in enumerate(examples):
384
  with example_cols[idx]:
385
+ if st.button(example[:40] + "..." if len(example) > 40 else example,
386
  use_container_width=True):
387
  st.session_state.user_input = example
388
  st.rerun()