rocky250 commited on
Commit
1f375b8
·
verified ·
1 Parent(s): da258d7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +34 -142
src/streamlit_app.py CHANGED
@@ -13,14 +13,11 @@ class BanglaPoliticalNet(nn.Module):
13
  def __init__(self, num_classes=5):
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
-
17
  self.hidden_size = self.banglabert.config.hidden_size
18
-
19
  self.cnn_layers = nn.ModuleList([
20
  nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
21
  for k in [3,5,7]
22
  ])
23
-
24
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
25
  self.classifier = nn.Sequential(
26
  nn.LayerNorm(self.hidden_size),
@@ -32,27 +29,20 @@ class BanglaPoliticalNet(nn.Module):
32
  nn.GELU(),
33
  nn.Linear(256, num_classes)
34
  )
35
-
36
  self.explainability_weights = nn.Parameter(torch.ones(num_classes) * 0.1)
37
 
38
  def forward(self, input_ids, attention_mask=None):
39
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
40
-
41
  cnn_outs = [F.relu(cnn(bert_out.transpose(1,2)).transpose(1,2)) for cnn in self.cnn_layers]
42
  cnn_concat = torch.cat(cnn_outs, dim=-1)
43
-
44
  if not hasattr(self, 'cnn_proj'):
45
-
46
  self.cnn_proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
47
-
48
  attn_input = self.cnn_proj(cnn_concat)
49
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
50
  pooled = attn_out[:, 0, :]
51
-
52
  logits = self.classifier(pooled)
53
  return logits, self.explainability_weights
54
 
55
-
56
  st.markdown("""
57
  <style>
58
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
@@ -219,55 +209,35 @@ h1, h2, h3 {
219
  """, unsafe_allow_html=True)
220
 
221
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
222
- label_colors = {
223
- 'Very Negative': '#ef4444',
224
- 'Negative': '#f97316',
225
- 'Neutral': '#64748b',
226
- 'Positive': '#22c55e',
227
- 'Very Positive': '#16a34a'
228
- }
229
 
230
  @st.cache_resource
231
  def load_models():
232
  models_loaded = {}
233
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
234
-
235
  standard_models = {
236
  "model_banglabert": "rocky250/Sentiment-banglabert",
237
  "model_mbert": "rocky250/Sentiment-mbert",
238
  "model_bbase": "rocky250/Sentiment-bbase",
239
  "model_xlmr": "rocky250/Sentiment-xlmr"
240
  }
241
-
242
  for name, repo in standard_models.items():
243
  try:
244
  tokenizer = AutoTokenizer.from_pretrained(repo)
245
  model = AutoModelForSequenceClassification.from_pretrained(repo)
246
  models_loaded[name] = (tokenizer, model.to(device))
247
- except Exception as e:
248
- print(f"Skipped {name}: {e}")
249
  continue
250
-
251
  try:
252
-
253
  model_path = hf_hub_download(repo_id="rocky250/bangla-political", filename="pytorch_model.bin")
254
-
255
-
256
  tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
257
-
258
-
259
  model = BanglaPoliticalNet(num_classes=5)
260
-
261
-
262
  if not hasattr(model, 'cnn_proj'):
263
  model.cnn_proj = nn.Linear(384, model.hidden_size)
264
-
265
  model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
266
-
267
  models_loaded["bangla_political"] = (tokenizer, model.to(device))
268
- except Exception as e:
269
- print(f"Skipped bangla_political: {e}")
270
-
271
  return models_loaded
272
 
273
  models_dict = load_models()
@@ -275,28 +245,23 @@ models_dict = load_models()
275
  def predict_single_model(text, model_name):
276
  clean_text = normalize(text)
277
  tokenizer, model = models_dict[model_name]
278
-
279
  device = next(model.parameters()).device
280
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
281
-
282
  with torch.no_grad():
283
  if isinstance(model, BanglaPoliticalNet):
284
  logits, _ = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
285
  else:
286
  outputs = model(**inputs)
287
  logits = outputs.logits
288
-
289
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
290
  pred_id = np.argmax(probs)
291
  prediction = id2label[pred_id]
292
-
293
  return prediction, probs
294
 
295
  def predict_ensemble(text):
296
  clean_text = normalize(text)
297
  all_probs = []
298
  all_predictions = []
299
-
300
  for name in models_dict.keys():
301
  try:
302
  pred, probs = predict_single_model(clean_text, name)
@@ -304,144 +269,71 @@ def predict_ensemble(text):
304
  all_predictions.append(pred)
305
  except:
306
  continue
307
-
308
  if all_probs:
309
  avg_probs = np.mean(all_probs, axis=0)
310
  final_pred = id2label[np.argmax(avg_probs)]
311
  return final_pred, all_predictions, avg_probs
312
  return "Error", [], np.zeros(5)
313
 
314
-
315
  st.markdown("""
316
- <div style='
317
- text-align: center;
318
- background: rgba(255,255,255,0.1);
319
- padding: 30px;
320
- border-radius: 20px;
321
- margin-bottom: 30px;
322
- backdrop-filter: blur(20px);
323
- '>
324
  <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
325
  </div>
326
  """, unsafe_allow_html=True)
327
 
328
  col1, col2 = st.columns([3, 1])
329
  with col1:
330
- user_input = st.text_area("Enter Bengali political text:", height=140,
331
- placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...",
332
- help="Type or paste Bengali political text for sentiment analysis")
333
-
334
  with col2:
335
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
336
- mode = st.radio("Analysis Mode:",
337
- ["Single Model", "Ensemble"],
338
- horizontal=True)
339
-
340
  selected_model = None
341
  if mode == "Single Model":
342
  model_options = {name: name for name in models_dict.keys()}
343
  if model_options:
344
  selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
345
- else:
346
- st.warning("No models loaded.")
347
 
348
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
349
 
350
  if analyze_btn and user_input.strip():
351
- with st.spinner('Processing with models...'):
352
  if mode == "Single Model":
353
  if selected_model:
354
  final_res, probs = predict_single_model(user_input, selected_model)
355
-
356
- col1, col2 = st.columns([1, 2])
357
- with col1:
358
- st.markdown(f"""
359
- <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}">
360
- <div class="result-title">{selected_model}</div>
361
- <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
362
- <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div>
363
- </div>
364
- """, unsafe_allow_html=True)
365
-
366
- with col2:
367
  st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
368
  for i in range(5):
369
- label = id2label[i]
370
- prob = probs[i] * 100
371
- color = label_colors[label]
372
-
373
- st.markdown(f"""
374
- <div class="prob-row">
375
- <div class="prob-label">
376
- <span style="font-weight: 700;">{label}</span>
377
- <span style="font-weight: 700; color: {color};">{prob:.1f}%</span>
378
- </div>
379
- <div class="prob-bar-bg">
380
- <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
381
- </div>
382
- </div>
383
- """, unsafe_allow_html=True)
384
- else:
385
- st.error("Model not selected or failed to load.")
386
-
387
  else:
388
  final_res, all_votes, avg_probs = predict_ensemble(user_input)
389
-
390
- main_col, details_col = st.columns([1, 1.4])
391
-
392
- with main_col:
393
- st.markdown(f"""
394
- <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);">
395
- <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div>
396
- <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div>
397
- </div>
398
- """, unsafe_allow_html=True)
399
-
400
- st.markdown('<div class="section-header">Ensemble Probabilities</div>', unsafe_allow_html=True)
401
-
402
  for i in range(5):
403
- label = id2label[i]
404
- prob = avg_probs[i] * 100
405
- color = label_colors[label]
406
-
407
- st.markdown(f"""
408
- <div class="prob-row">
409
- <div class="prob-label">
410
- <span>{label}</span>
411
- <span style="color: {color};">{prob:.1f}%</span>
412
- </div>
413
- <div class="prob-bar-bg">
414
- <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
415
- </div>
416
- </div>
417
- """, unsafe_allow_html=True)
418
-
419
- with details_col:
420
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
421
- model_cols = st.columns(2)
422
  for idx, (name, vote) in enumerate(zip(list(models_dict.keys()), all_votes)):
423
- with model_cols[idx % 2]:
424
- color = label_colors[vote]
425
- st.markdown(f"""
426
- <div class="model-card">
427
- <div class="model-name">{name}</div>
428
- <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div>
429
- </div>
430
- """, unsafe_allow_html=True)
431
-
432
  elif analyze_btn and not user_input.strip():
433
  st.error("অনুগ্রহ করে কিছু টেক্সট লিখুন!")
434
 
435
- with st.expander("Example Political Texts", expanded=False):
436
- examples = [
437
- "সরকারের এই নীতি দেশকে ধ্বংসের দিকে নিয়ে যাবে!",
438
- "চমৎকার সিদ্ধান্ত! দেশের জন্য গর্বিত। ভালো চলবে!",
439
- "রাজনীতির কোনো পরিবর্তন হবে না, সব একই রকম"
440
- ]
441
- example_cols = st.columns(3)
442
- for idx, example in enumerate(examples):
443
- with example_cols[idx]:
444
- if st.button(example[:40] + "..." if len(example) > 40 else example,
445
- use_container_width=True):
446
- st.session_state.user_input = example
447
  st.rerun()
 
13
  def __init__(self, num_classes=5):
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
 
16
  self.hidden_size = self.banglabert.config.hidden_size
 
17
  self.cnn_layers = nn.ModuleList([
18
  nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
19
  for k in [3,5,7]
20
  ])
 
21
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
22
  self.classifier = nn.Sequential(
23
  nn.LayerNorm(self.hidden_size),
 
29
  nn.GELU(),
30
  nn.Linear(256, num_classes)
31
  )
 
32
  self.explainability_weights = nn.Parameter(torch.ones(num_classes) * 0.1)
33
 
34
  def forward(self, input_ids, attention_mask=None):
35
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
 
36
  cnn_outs = [F.relu(cnn(bert_out.transpose(1,2)).transpose(1,2)) for cnn in self.cnn_layers]
37
  cnn_concat = torch.cat(cnn_outs, dim=-1)
 
38
  if not hasattr(self, 'cnn_proj'):
 
39
  self.cnn_proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
 
40
  attn_input = self.cnn_proj(cnn_concat)
41
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
42
  pooled = attn_out[:, 0, :]
 
43
  logits = self.classifier(pooled)
44
  return logits, self.explainability_weights
45
 
 
46
  st.markdown("""
47
  <style>
48
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
 
209
  """, unsafe_allow_html=True)
210
 
211
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
212
+ label_colors = { 'Very Negative': '#ef4444', 'Negative': '#f97316', 'Neutral': '#64748b', 'Positive': '#22c55e', 'Very Positive': '#16a34a' }
 
 
 
 
 
 
213
 
214
  @st.cache_resource
215
  def load_models():
216
  models_loaded = {}
217
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
218
  standard_models = {
219
  "model_banglabert": "rocky250/Sentiment-banglabert",
220
  "model_mbert": "rocky250/Sentiment-mbert",
221
  "model_bbase": "rocky250/Sentiment-bbase",
222
  "model_xlmr": "rocky250/Sentiment-xlmr"
223
  }
 
224
  for name, repo in standard_models.items():
225
  try:
226
  tokenizer = AutoTokenizer.from_pretrained(repo)
227
  model = AutoModelForSequenceClassification.from_pretrained(repo)
228
  models_loaded[name] = (tokenizer, model.to(device))
229
+ except:
 
230
  continue
 
231
  try:
 
232
  model_path = hf_hub_download(repo_id="rocky250/bangla-political", filename="pytorch_model.bin")
 
 
233
  tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
 
 
234
  model = BanglaPoliticalNet(num_classes=5)
 
 
235
  if not hasattr(model, 'cnn_proj'):
236
  model.cnn_proj = nn.Linear(384, model.hidden_size)
 
237
  model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
 
238
  models_loaded["bangla_political"] = (tokenizer, model.to(device))
239
+ except:
240
+ pass
 
241
  return models_loaded
242
 
243
  models_dict = load_models()
 
245
  def predict_single_model(text, model_name):
246
  clean_text = normalize(text)
247
  tokenizer, model = models_dict[model_name]
 
248
  device = next(model.parameters()).device
249
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
 
250
  with torch.no_grad():
251
  if isinstance(model, BanglaPoliticalNet):
252
  logits, _ = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
253
  else:
254
  outputs = model(**inputs)
255
  logits = outputs.logits
 
256
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
257
  pred_id = np.argmax(probs)
258
  prediction = id2label[pred_id]
 
259
  return prediction, probs
260
 
261
  def predict_ensemble(text):
262
  clean_text = normalize(text)
263
  all_probs = []
264
  all_predictions = []
 
265
  for name in models_dict.keys():
266
  try:
267
  pred, probs = predict_single_model(clean_text, name)
 
269
  all_predictions.append(pred)
270
  except:
271
  continue
 
272
  if all_probs:
273
  avg_probs = np.mean(all_probs, axis=0)
274
  final_pred = id2label[np.argmax(avg_probs)]
275
  return final_pred, all_predictions, avg_probs
276
  return "Error", [], np.zeros(5)
277
 
 
278
  st.markdown("""
279
+ <div style='text-align: center; background: rgba(255,255,255,0.1); padding: 30px; border-radius: 20px; margin-bottom: 30px; backdrop-filter: blur(20px);'>
 
 
 
 
 
 
 
280
  <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
281
  </div>
282
  """, unsafe_allow_html=True)
283
 
284
  col1, col2 = st.columns([3, 1])
285
  with col1:
286
+ user_input = st.text_area("Enter Bengali political text:", height=140, placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...")
 
 
 
287
  with col2:
288
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
289
+ mode = st.radio("Analysis Mode:", ["Single Model", "Ensemble"], horizontal=True)
 
 
 
290
  selected_model = None
291
  if mode == "Single Model":
292
  model_options = {name: name for name in models_dict.keys()}
293
  if model_options:
294
  selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
 
 
295
 
296
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
297
 
298
  if analyze_btn and user_input.strip():
299
+ with st.spinner('Processing...'):
300
  if mode == "Single Model":
301
  if selected_model:
302
  final_res, probs = predict_single_model(user_input, selected_model)
303
+ c1, c2 = st.columns([1, 2])
304
+ with c1:
305
+ st.markdown(f'<div class="main-card" style="border-top: 8px solid {label_colors[final_res]}"><div class="result-title">{selected_model}</div><div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div><div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div></div>', unsafe_allow_html=True)
306
+ with c2:
 
 
 
 
 
 
 
 
307
  st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
308
  for i in range(5):
309
+ lbl = id2label[i]
310
+ p = probs[i] * 100
311
+ clr = label_colors[lbl]
312
+ st.markdown(f'<div class="prob-row"><div class="prob-label"><span>{lbl}</span><span style="color: {clr};">{p:.1f}%</span></div><div class="prob-bar-bg"><div class="prob-bar-fill" style="width: {min(p, 100)}%; background: {clr};"></div></div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  else:
314
  final_res, all_votes, avg_probs = predict_ensemble(user_input)
315
+ mc, dc = st.columns([1, 1.4])
316
+ with mc:
317
+ st.markdown(f'<div class="main-card" style="border-top: 8px solid {label_colors[final_res]}"><div class="result-title">ENSEMBLE CONSENSUS</div><div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
318
  for i in range(5):
319
+ lbl = id2label[i]
320
+ p = avg_probs[i] * 100
321
+ clr = label_colors[lbl]
322
+ st.markdown(f'<div class="prob-row"><div class="prob-label"><span>{lbl}</span><span style="color: {clr};">{p:.1f}%</span></div><div class="prob-bar-bg"><div class="prob-bar-fill" style="width: {min(p, 100)}%; background: {clr};"></div></div></div>', unsafe_allow_html=True)
323
+ with dc:
 
 
 
 
 
 
 
 
 
 
 
 
324
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
325
+ m_cols = st.columns(2)
326
  for idx, (name, vote) in enumerate(zip(list(models_dict.keys()), all_votes)):
327
+ with m_cols[idx % 2]:
328
+ st.markdown(f'<div class="model-card"><div class="model-name">{name}</div><div style="color: {label_colors[vote]}; font-weight: 800; font-size: 24px;">{vote}</div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
329
  elif analyze_btn and not user_input.strip():
330
  st.error("অনুগ্রহ করে কিছু টেক্সট লিখুন!")
331
 
332
+ with st.expander("Example Political Texts"):
333
+ examples = ["সরকারের এই নীতি দেশকে ধ্বংসের দিকে নিয়ে যাবে!", "চমৎকার সিদ্ধান্ত! দেশের জন্য গর্বিত।", "রাজনীতির কোনো পরিবর্তন হবে না"]
334
+ ex_cols = st.columns(3)
335
+ for idx, ex in enumerate(examples):
336
+ with ex_cols[idx]:
337
+ if st.button(ex[:30] + "...", use_container_width=True):
338
+ st.session_state.user_input = ex
 
 
 
 
 
339
  st.rerun()