rocky250 commited on
Commit
70b2f7a
·
verified ·
1 Parent(s): 9459640

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +90 -68
src/streamlit_app.py CHANGED
@@ -14,10 +14,12 @@ class BanglaPoliticalNet(nn.Module):
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
  self.hidden_size = self.banglabert.config.hidden_size
 
17
  self.cnn_layers = nn.ModuleList([
18
  nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
19
  for k in [3,5,7]
20
  ])
 
21
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
22
  self.classifier = nn.Sequential(
23
  nn.Dropout(0.3),
@@ -29,32 +31,39 @@ class BanglaPoliticalNet(nn.Module):
29
 
30
  def forward(self, input_ids, attention_mask=None):
31
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
 
32
  cnn_features = []
33
  for cnn in self.cnn_layers:
34
  cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2)
35
  cnn_features.append(F.relu(cnn_out))
 
36
  cnn_concat = torch.cat(cnn_features, dim=-1)
37
  proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
38
  attn_input = proj(cnn_concat)
39
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
40
  attn_pooled = attn_out[:, 0, :]
 
41
  logits = self.classifier(attn_pooled)
42
  return logits
43
 
44
  st.markdown("""
45
  <style>
46
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
 
47
  html, body, [class*="css"] {
48
  font-family: 'Inter', sans-serif !important;
49
  color: #1f2937 !important;
50
  }
 
51
  .stApp {
52
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
53
  }
 
54
  h1, h2, h3 {
55
  color: #ffffff !important;
56
  text-shadow: 0 2px 4px rgba(0,0,0,0.3);
57
  }
 
58
  .stTextArea textarea {
59
  background-color: #ffffff !important;
60
  color: #1f2937 !important;
@@ -63,10 +72,12 @@ h1, h2, h3 {
63
  padding: 16px !important;
64
  font-size: 16px !important;
65
  }
 
66
  .stTextArea label {
67
  color: #ffffff !important;
68
  font-weight: 700 !important;
69
  }
 
70
  .main-card {
71
  background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%);
72
  padding: 35px;
@@ -77,6 +88,7 @@ h1, h2, h3 {
77
  border: 1px solid rgba(255,255,255,0.3);
78
  backdrop-filter: blur(10px);
79
  }
 
80
  .result-title {
81
  color: #475569 !important;
82
  font-size: 16px;
@@ -85,12 +97,14 @@ h1, h2, h3 {
85
  margin-bottom: 12px;
86
  font-weight: 700;
87
  }
 
88
  .result-value {
89
  font-size: 52px;
90
  font-weight: 800;
91
  margin: 0;
92
  text-shadow: 0 2px 4px rgba(0,0,0,0.1);
93
  }
 
94
  .section-header {
95
  font-size: 22px;
96
  font-weight: 700;
@@ -103,6 +117,7 @@ h1, h2, h3 {
103
  border-radius: 10px;
104
  box-shadow: 0 4px 12px rgba(0,0,0,0.1);
105
  }
 
106
  .model-card {
107
  background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%);
108
  padding: 25px;
@@ -112,10 +127,12 @@ h1, h2, h3 {
112
  border: 1px solid rgba(255,255,255,0.5);
113
  transition: all 0.3s ease;
114
  }
 
115
  .model-card:hover {
116
  transform: translateY(-5px);
117
  box-shadow: 0 20px 40px rgba(0,0,0,0.2);
118
  }
 
119
  .model-name {
120
  color: #334155 !important;
121
  font-size: 15px;
@@ -124,6 +141,7 @@ h1, h2, h3 {
124
  border-bottom: 3px solid #e2e8f0;
125
  padding-bottom: 8px;
126
  }
 
127
  .prob-row {
128
  margin-bottom: 18px;
129
  background: rgba(255,255,255,0.9);
@@ -131,6 +149,7 @@ h1, h2, h3 {
131
  border-radius: 12px;
132
  box-shadow: 0 2px 8px rgba(0,0,0,0.05);
133
  }
 
134
  .prob-label {
135
  font-size: 15px;
136
  color: #1e293b !important;
@@ -140,6 +159,7 @@ h1, h2, h3 {
140
  justify-content: space-between;
141
  align-items: center;
142
  }
 
143
  .prob-bar-bg {
144
  width: 100%;
145
  height: 14px;
@@ -148,12 +168,14 @@ h1, h2, h3 {
148
  overflow: hidden;
149
  box-shadow: inset 0 2px 4px rgba(0,0,0,0.05);
150
  }
 
151
  .prob-bar-fill {
152
  height: 100%;
153
  border-radius: 7px;
154
  transition: width 0.8s ease;
155
  box-shadow: 0 0 20px rgba(0,0,0,0.2);
156
  }
 
157
  .stButton > button {
158
  background: linear-gradient(45deg, #3b82f6, #1d4ed8) !important;
159
  color: white !important;
@@ -165,18 +187,22 @@ h1, h2, h3 {
165
  box-shadow: 0 8px 25px rgba(59,130,246,0.4) !important;
166
  transition: all 0.3s ease !important;
167
  }
 
168
  .stButton > button:hover {
169
  transform: translateY(-2px) !important;
170
  box-shadow: 0 12px 35px rgba(59,130,246,0.6) !important;
171
  }
 
172
  .stRadio > div > label {
173
  color: #ffffff !important;
174
  font-weight: 600 !important;
175
  }
 
176
  .stSelectbox > label {
177
  color: #ffffff !important;
178
  font-weight: 600 !important;
179
  }
 
180
  .stExpander {
181
  background: rgba(255,255,255,0.1) !important;
182
  border-radius: 12px !important;
@@ -194,63 +220,60 @@ label_colors = {
194
  'Very Positive': '#16a34a'
195
  }
196
 
197
- models = {
198
- "model_banglabert": "rocky250/Sentiment-banglabert",
199
- "model_mbert": "rocky250/Sentiment-mbert",
200
- "model_bbase": "rocky250/Sentiment-bbase",
201
- "model_xlmr": "rocky250/Sentiment-xlmr",
202
- "bangla_political": "rocky250/bangla-political"
203
- }
204
-
205
  @st.cache_resource
206
  def load_models():
207
- models_dict = {}
208
- for key, repo in models.items():
 
 
 
 
 
 
 
 
 
209
  try:
210
  tokenizer = AutoTokenizer.from_pretrained(repo)
211
  model = AutoModelForSequenceClassification.from_pretrained(repo)
212
- models_dict[key] = (tokenizer, model.to('cuda' if torch.cuda.is_available() else 'cpu'))
213
  except:
214
  continue
215
- try:
216
- SA_tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
217
- model_SA = BanglaPoliticalNet(num_classes=5)
218
- model_SA.load_state_dict(torch.load("rocky250/bangla-political/pytorch_model.bin", map_location='cpu'))
219
- model_SA = model_SA.to('cuda' if torch.cuda.is_available() else 'cpu')
220
- models_dict["Creative Model"] = (SA_tokenizer, model_SA)
221
- except:
222
- pass
223
- return models_dict
224
 
225
  models_dict = load_models()
226
 
227
  def predict_single_model(text, model_name):
228
  clean_text = normalize(text)
229
  tokenizer, model = models_dict[model_name]
 
230
  device = next(model.parameters()).device
231
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
 
232
  with torch.no_grad():
233
- if "Creative" in model_name:
234
- logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
235
- else:
236
- outputs = model(**inputs)
237
- logits = outputs.logits
238
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
239
  pred_id = np.argmax(probs)
240
  prediction = id2label[pred_id]
 
241
  return prediction, probs
242
 
243
  def predict_ensemble(text):
244
  clean_text = normalize(text)
245
  all_probs = []
246
  all_predictions = []
247
- for name in list(models_dict.keys())[:4]:
 
248
  try:
249
  pred, probs = predict_single_model(clean_text, name)
250
  all_probs.append(probs)
251
  all_predictions.append(pred)
252
  except:
253
  continue
 
254
  if all_probs:
255
  avg_probs = np.mean(all_probs, axis=0)
256
  final_pred = id2label[np.argmax(avg_probs)]
@@ -259,50 +282,49 @@ def predict_ensemble(text):
259
 
260
  st.markdown("""
261
  <div style='
262
- text-align: center;
263
- background: rgba(255,255,255,0.1);
264
- padding: 30px;
265
- border-radius: 20px;
266
- margin-bottom: 30px;
267
- backdrop-filter: blur(20px);
268
  '>
269
- <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
270
  </div>
271
  """, unsafe_allow_html=True)
272
 
273
  col1, col2 = st.columns([3, 1])
274
  with col1:
275
  user_input = st.text_area("Enter Bengali political text:", height=140,
276
- placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...",
277
- help="Type or paste Bengali political text for sentiment analysis")
278
 
279
  with col2:
280
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
281
  mode = st.radio("Analysis Mode:",
282
- ["Single Model", "Ensemble"],
283
- horizontal=True)
284
 
285
- if mode == "Single Model":
286
- model_options = {f"({i+1}) {name}": name for i, name in enumerate(models_dict.keys())}
287
- selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
288
- else:
289
- st.markdown("<div style='height: 50px'></div>", unsafe_allow_html=True)
290
 
291
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
292
 
293
  if analyze_btn and user_input.strip():
294
  with st.spinner('Processing with models...'):
295
  if mode == "Single Model":
296
- model_name = model_options[selected_model]
297
  final_res, probs = predict_single_model(user_input, model_name)
298
 
299
  col1, col2 = st.columns([1, 2])
300
  with col1:
301
  st.markdown(f"""
302
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}">
303
- <div class="result-title">{model_name}</div>
304
- <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
305
- <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div>
306
  </div>
307
  """, unsafe_allow_html=True)
308
 
@@ -315,13 +337,13 @@ if analyze_btn and user_input.strip():
315
 
316
  st.markdown(f"""
317
  <div class="prob-row">
318
- <div class="prob-label">
319
- <span style="font-weight: 700;">{label}</span>
320
- <span style="font-weight: 700; color: {color};">{prob:.1f}%</span>
321
- </div>
322
- <div class="prob-bar-bg">
323
- <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
324
- </div>
325
  </div>
326
  """, unsafe_allow_html=True)
327
 
@@ -333,8 +355,8 @@ if analyze_btn and user_input.strip():
333
  with main_col:
334
  st.markdown(f"""
335
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);">
336
- <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div>
337
- <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div>
338
  </div>
339
  """, unsafe_allow_html=True)
340
 
@@ -347,26 +369,26 @@ if analyze_btn and user_input.strip():
347
 
348
  st.markdown(f"""
349
  <div class="prob-row">
350
- <div class="prob-label">
351
- <span>{label}</span>
352
- <span style="color: {color};">{prob:.1f}%</span>
353
- </div>
354
- <div class="prob-bar-bg">
355
- <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
356
- </div>
357
  </div>
358
  """, unsafe_allow_html=True)
359
 
360
  with details_col:
361
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
362
  model_cols = st.columns(2)
363
- for idx, (name, vote) in enumerate(zip(list(models_dict.keys())[:4], all_votes[:4])):
364
  with model_cols[idx % 2]:
365
  color = label_colors[vote]
366
  st.markdown(f"""
367
  <div class="model-card">
368
- <div class="model-name">{name}</div>
369
- <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div>
370
  </div>
371
  """, unsafe_allow_html=True)
372
 
@@ -382,7 +404,7 @@ with st.expander("Example Political Texts", expanded=False):
382
  example_cols = st.columns(3)
383
  for idx, example in enumerate(examples):
384
  with example_cols[idx]:
385
- if st.button(example[:40] + "..." if len(example) > 40 else example,
386
- use_container_width=True):
387
  st.session_state.user_input = example
388
- st.rerun()
 
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
  self.hidden_size = self.banglabert.config.hidden_size
17
+
18
  self.cnn_layers = nn.ModuleList([
19
  nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
20
  for k in [3,5,7]
21
  ])
22
+
23
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
24
  self.classifier = nn.Sequential(
25
  nn.Dropout(0.3),
 
31
 
32
  def forward(self, input_ids, attention_mask=None):
33
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
34
+
35
  cnn_features = []
36
  for cnn in self.cnn_layers:
37
  cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2)
38
  cnn_features.append(F.relu(cnn_out))
39
+
40
  cnn_concat = torch.cat(cnn_features, dim=-1)
41
  proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
42
  attn_input = proj(cnn_concat)
43
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
44
  attn_pooled = attn_out[:, 0, :]
45
+
46
  logits = self.classifier(attn_pooled)
47
  return logits
48
 
49
  st.markdown("""
50
  <style>
51
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
52
+
53
  html, body, [class*="css"] {
54
  font-family: 'Inter', sans-serif !important;
55
  color: #1f2937 !important;
56
  }
57
+
58
  .stApp {
59
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
60
  }
61
+
62
  h1, h2, h3 {
63
  color: #ffffff !important;
64
  text-shadow: 0 2px 4px rgba(0,0,0,0.3);
65
  }
66
+
67
  .stTextArea textarea {
68
  background-color: #ffffff !important;
69
  color: #1f2937 !important;
 
72
  padding: 16px !important;
73
  font-size: 16px !important;
74
  }
75
+
76
  .stTextArea label {
77
  color: #ffffff !important;
78
  font-weight: 700 !important;
79
  }
80
+
81
  .main-card {
82
  background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%);
83
  padding: 35px;
 
88
  border: 1px solid rgba(255,255,255,0.3);
89
  backdrop-filter: blur(10px);
90
  }
91
+
92
  .result-title {
93
  color: #475569 !important;
94
  font-size: 16px;
 
97
  margin-bottom: 12px;
98
  font-weight: 700;
99
  }
100
+
101
  .result-value {
102
  font-size: 52px;
103
  font-weight: 800;
104
  margin: 0;
105
  text-shadow: 0 2px 4px rgba(0,0,0,0.1);
106
  }
107
+
108
  .section-header {
109
  font-size: 22px;
110
  font-weight: 700;
 
117
  border-radius: 10px;
118
  box-shadow: 0 4px 12px rgba(0,0,0,0.1);
119
  }
120
+
121
  .model-card {
122
  background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%);
123
  padding: 25px;
 
127
  border: 1px solid rgba(255,255,255,0.5);
128
  transition: all 0.3s ease;
129
  }
130
+
131
  .model-card:hover {
132
  transform: translateY(-5px);
133
  box-shadow: 0 20px 40px rgba(0,0,0,0.2);
134
  }
135
+
136
  .model-name {
137
  color: #334155 !important;
138
  font-size: 15px;
 
141
  border-bottom: 3px solid #e2e8f0;
142
  padding-bottom: 8px;
143
  }
144
+
145
  .prob-row {
146
  margin-bottom: 18px;
147
  background: rgba(255,255,255,0.9);
 
149
  border-radius: 12px;
150
  box-shadow: 0 2px 8px rgba(0,0,0,0.05);
151
  }
152
+
153
  .prob-label {
154
  font-size: 15px;
155
  color: #1e293b !important;
 
159
  justify-content: space-between;
160
  align-items: center;
161
  }
162
+
163
  .prob-bar-bg {
164
  width: 100%;
165
  height: 14px;
 
168
  overflow: hidden;
169
  box-shadow: inset 0 2px 4px rgba(0,0,0,0.05);
170
  }
171
+
172
  .prob-bar-fill {
173
  height: 100%;
174
  border-radius: 7px;
175
  transition: width 0.8s ease;
176
  box-shadow: 0 0 20px rgba(0,0,0,0.2);
177
  }
178
+
179
  .stButton > button {
180
  background: linear-gradient(45deg, #3b82f6, #1d4ed8) !important;
181
  color: white !important;
 
187
  box-shadow: 0 8px 25px rgba(59,130,246,0.4) !important;
188
  transition: all 0.3s ease !important;
189
  }
190
+
191
  .stButton > button:hover {
192
  transform: translateY(-2px) !important;
193
  box-shadow: 0 12px 35px rgba(59,130,246,0.6) !important;
194
  }
195
+
196
  .stRadio > div > label {
197
  color: #ffffff !important;
198
  font-weight: 600 !important;
199
  }
200
+
201
  .stSelectbox > label {
202
  color: #ffffff !important;
203
  font-weight: 600 !important;
204
  }
205
+
206
  .stExpander {
207
  background: rgba(255,255,255,0.1) !important;
208
  border-radius: 12px !important;
 
220
  'Very Positive': '#16a34a'
221
  }
222
 
 
 
 
 
 
 
 
 
223
  @st.cache_resource
224
  def load_models():
225
+ models_loaded = {}
226
+
227
+ target_models = {
228
+ "model_banglabert": "rocky250/Sentiment-banglabert",
229
+ "model_mbert": "rocky250/Sentiment-mbert",
230
+ "model_bbase": "rocky250/Sentiment-bbase",
231
+ "model_xlmr": "rocky250/Sentiment-xlmr",
232
+ "bangla_political": "rocky250/bangla-political"
233
+ }
234
+
235
+ for name, repo in target_models.items():
236
  try:
237
  tokenizer = AutoTokenizer.from_pretrained(repo)
238
  model = AutoModelForSequenceClassification.from_pretrained(repo)
239
+ models_loaded[name] = (tokenizer, model.to('cuda' if torch.cuda.is_available() else 'cpu'))
240
  except:
241
  continue
242
+
243
+ return models_loaded
 
 
 
 
 
 
 
244
 
245
  models_dict = load_models()
246
 
247
  def predict_single_model(text, model_name):
248
  clean_text = normalize(text)
249
  tokenizer, model = models_dict[model_name]
250
+
251
  device = next(model.parameters()).device
252
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
253
+
254
  with torch.no_grad():
255
+ outputs = model(**inputs)
256
+ logits = outputs.logits
257
+
 
 
258
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
259
  pred_id = np.argmax(probs)
260
  prediction = id2label[pred_id]
261
+
262
  return prediction, probs
263
 
264
  def predict_ensemble(text):
265
  clean_text = normalize(text)
266
  all_probs = []
267
  all_predictions = []
268
+
269
+ for name in models_dict.keys():
270
  try:
271
  pred, probs = predict_single_model(clean_text, name)
272
  all_probs.append(probs)
273
  all_predictions.append(pred)
274
  except:
275
  continue
276
+
277
  if all_probs:
278
  avg_probs = np.mean(all_probs, axis=0)
279
  final_pred = id2label[np.argmax(avg_probs)]
 
282
 
283
  st.markdown("""
284
  <div style='
285
+ text-align: center;
286
+ background: rgba(255,255,255,0.1);
287
+ padding: 30px;
288
+ border-radius: 20px;
289
+ margin-bottom: 30px;
290
+ backdrop-filter: blur(20px);
291
  '>
292
+ <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
293
  </div>
294
  """, unsafe_allow_html=True)
295
 
296
  col1, col2 = st.columns([3, 1])
297
  with col1:
298
  user_input = st.text_area("Enter Bengali political text:", height=140,
299
+ placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...",
300
+ help="Type or paste Bengali political text for sentiment analysis")
301
 
302
  with col2:
303
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
304
  mode = st.radio("Analysis Mode:",
305
+ ["Single Model", "Ensemble"],
306
+ horizontal=True)
307
 
308
+ selected_model = None
309
+ if mode == "Single Model":
310
+ model_options = {name: name for name in models_dict.keys()}
311
+ selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
 
312
 
313
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
314
 
315
  if analyze_btn and user_input.strip():
316
  with st.spinner('Processing with models...'):
317
  if mode == "Single Model":
318
+ model_name = selected_model
319
  final_res, probs = predict_single_model(user_input, model_name)
320
 
321
  col1, col2 = st.columns([1, 2])
322
  with col1:
323
  st.markdown(f"""
324
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}">
325
+ <div class="result-title">{model_name}</div>
326
+ <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
327
+ <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div>
328
  </div>
329
  """, unsafe_allow_html=True)
330
 
 
337
 
338
  st.markdown(f"""
339
  <div class="prob-row">
340
+ <div class="prob-label">
341
+ <span style="font-weight: 700;">{label}</span>
342
+ <span style="font-weight: 700; color: {color};">{prob:.1f}%</span>
343
+ </div>
344
+ <div class="prob-bar-bg">
345
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
346
+ </div>
347
  </div>
348
  """, unsafe_allow_html=True)
349
 
 
355
  with main_col:
356
  st.markdown(f"""
357
  <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);">
358
+ <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div>
359
+ <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div>
360
  </div>
361
  """, unsafe_allow_html=True)
362
 
 
369
 
370
  st.markdown(f"""
371
  <div class="prob-row">
372
+ <div class="prob-label">
373
+ <span>{label}</span>
374
+ <span style="color: {color};">{prob:.1f}%</span>
375
+ </div>
376
+ <div class="prob-bar-bg">
377
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
378
+ </div>
379
  </div>
380
  """, unsafe_allow_html=True)
381
 
382
  with details_col:
383
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
384
  model_cols = st.columns(2)
385
+ for idx, (name, vote) in enumerate(zip(list(models_dict.keys()), all_votes)):
386
  with model_cols[idx % 2]:
387
  color = label_colors[vote]
388
  st.markdown(f"""
389
  <div class="model-card">
390
+ <div class="model-name">{name}</div>
391
+ <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div>
392
  </div>
393
  """, unsafe_allow_html=True)
394
 
 
404
  example_cols = st.columns(3)
405
  for idx, example in enumerate(examples):
406
  with example_cols[idx]:
407
+ if st.button(example[:40] + "..." if len(example) > 40 else example,
408
+ use_container_width=True):
409
  st.session_state.user_input = example
410
+ st.rerun()