rocky250 commited on
Commit
b03b414
·
verified ·
1 Parent(s): 1f375b8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +148 -77
src/streamlit_app.py CHANGED
@@ -2,10 +2,10 @@ import streamlit as st
2
  import torch
3
  import torch.nn.functional as F
4
  import numpy as np
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
6
  from normalizer import normalize
7
  import torch.nn as nn
8
- from huggingface_hub import hf_hub_download
9
 
10
  st.set_page_config(page_title="Political Sentiment", layout="wide")
11
 
@@ -14,34 +14,37 @@ class BanglaPoliticalNet(nn.Module):
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
  self.hidden_size = self.banglabert.config.hidden_size
 
17
  self.cnn_layers = nn.ModuleList([
18
- nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
19
  for k in [3,5,7]
20
  ])
 
21
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
22
  self.classifier = nn.Sequential(
23
- nn.LayerNorm(self.hidden_size),
24
- nn.Dropout(0.4),
25
- nn.Linear(self.hidden_size, 512),
26
- nn.GELU(),
27
  nn.Dropout(0.3),
28
- nn.Linear(512, 256),
29
- nn.GELU(),
30
- nn.Linear(256, num_classes)
 
31
  )
32
- self.explainability_weights = nn.Parameter(torch.ones(num_classes) * 0.1)
33
-
34
  def forward(self, input_ids, attention_mask=None):
35
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
36
- cnn_outs = [F.relu(cnn(bert_out.transpose(1,2)).transpose(1,2)) for cnn in self.cnn_layers]
37
- cnn_concat = torch.cat(cnn_outs, dim=-1)
38
- if not hasattr(self, 'cnn_proj'):
39
- self.cnn_proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
40
- attn_input = self.cnn_proj(cnn_concat)
 
 
 
 
41
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
42
- pooled = attn_out[:, 0, :]
43
- logits = self.classifier(pooled)
44
- return logits, self.explainability_weights
 
45
 
46
  st.markdown("""
47
  <style>
@@ -209,35 +212,34 @@ h1, h2, h3 {
209
  """, unsafe_allow_html=True)
210
 
211
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
212
- label_colors = { 'Very Negative': '#ef4444', 'Negative': '#f97316', 'Neutral': '#64748b', 'Positive': '#22c55e', 'Very Positive': '#16a34a' }
 
 
 
 
 
 
213
 
214
  @st.cache_resource
215
  def load_models():
216
  models_loaded = {}
217
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
218
- standard_models = {
219
  "model_banglabert": "rocky250/Sentiment-banglabert",
220
  "model_mbert": "rocky250/Sentiment-mbert",
221
  "model_bbase": "rocky250/Sentiment-bbase",
222
- "model_xlmr": "rocky250/Sentiment-xlmr"
 
223
  }
224
- for name, repo in standard_models.items():
 
225
  try:
226
  tokenizer = AutoTokenizer.from_pretrained(repo)
227
  model = AutoModelForSequenceClassification.from_pretrained(repo)
228
- models_loaded[name] = (tokenizer, model.to(device))
229
  except:
230
  continue
231
- try:
232
- model_path = hf_hub_download(repo_id="rocky250/bangla-political", filename="pytorch_model.bin")
233
- tokenizer = AutoTokenizer.from_pretrained("rocky250/bangla-political")
234
- model = BanglaPoliticalNet(num_classes=5)
235
- if not hasattr(model, 'cnn_proj'):
236
- model.cnn_proj = nn.Linear(384, model.hidden_size)
237
- model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
238
- models_loaded["bangla_political"] = (tokenizer, model.to(device))
239
- except:
240
- pass
241
  return models_loaded
242
 
243
  models_dict = load_models()
@@ -245,23 +247,25 @@ models_dict = load_models()
245
  def predict_single_model(text, model_name):
246
  clean_text = normalize(text)
247
  tokenizer, model = models_dict[model_name]
 
248
  device = next(model.parameters()).device
249
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
 
250
  with torch.no_grad():
251
- if isinstance(model, BanglaPoliticalNet):
252
- logits, _ = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
253
- else:
254
- outputs = model(**inputs)
255
- logits = outputs.logits
256
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
257
  pred_id = np.argmax(probs)
258
  prediction = id2label[pred_id]
 
259
  return prediction, probs
260
 
261
  def predict_ensemble(text):
262
  clean_text = normalize(text)
263
  all_probs = []
264
  all_predictions = []
 
265
  for name in models_dict.keys():
266
  try:
267
  pred, probs = predict_single_model(clean_text, name)
@@ -269,6 +273,7 @@ def predict_ensemble(text):
269
  all_predictions.append(pred)
270
  except:
271
  continue
 
272
  if all_probs:
273
  avg_probs = np.mean(all_probs, axis=0)
274
  final_pred = id2label[np.argmax(avg_probs)]
@@ -276,64 +281,130 @@ def predict_ensemble(text):
276
  return "Error", [], np.zeros(5)
277
 
278
  st.markdown("""
279
- <div style='text-align: center; background: rgba(255,255,255,0.1); padding: 30px; border-radius: 20px; margin-bottom: 30px; backdrop-filter: blur(20px);'>
 
 
 
 
 
 
 
280
  <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
281
  </div>
282
  """, unsafe_allow_html=True)
283
 
284
  col1, col2 = st.columns([3, 1])
285
  with col1:
286
- user_input = st.text_area("Enter Bengali political text:", height=140, placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...")
 
 
 
287
  with col2:
288
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
289
- mode = st.radio("Analysis Mode:", ["Single Model", "Ensemble"], horizontal=True)
 
 
 
290
  selected_model = None
291
  if mode == "Single Model":
292
  model_options = {name: name for name in models_dict.keys()}
293
- if model_options:
294
- selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
295
 
296
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
297
 
298
  if analyze_btn and user_input.strip():
299
- with st.spinner('Processing...'):
300
  if mode == "Single Model":
301
- if selected_model:
302
- final_res, probs = predict_single_model(user_input, selected_model)
303
- c1, c2 = st.columns([1, 2])
304
- with c1:
305
- st.markdown(f'<div class="main-card" style="border-top: 8px solid {label_colors[final_res]}"><div class="result-title">{selected_model}</div><div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div><div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div></div>', unsafe_allow_html=True)
306
- with c2:
307
- st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
308
- for i in range(5):
309
- lbl = id2label[i]
310
- p = probs[i] * 100
311
- clr = label_colors[lbl]
312
- st.markdown(f'<div class="prob-row"><div class="prob-label"><span>{lbl}</span><span style="color: {clr};">{p:.1f}%</span></div><div class="prob-bar-bg"><div class="prob-bar-fill" style="width: {min(p, 100)}%; background: {clr};"></div></div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  else:
314
  final_res, all_votes, avg_probs = predict_ensemble(user_input)
315
- mc, dc = st.columns([1, 1.4])
316
- with mc:
317
- st.markdown(f'<div class="main-card" style="border-top: 8px solid {label_colors[final_res]}"><div class="result-title">ENSEMBLE CONSENSUS</div><div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
318
  for i in range(5):
319
- lbl = id2label[i]
320
- p = avg_probs[i] * 100
321
- clr = label_colors[lbl]
322
- st.markdown(f'<div class="prob-row"><div class="prob-label"><span>{lbl}</span><span style="color: {clr};">{p:.1f}%</span></div><div class="prob-bar-bg"><div class="prob-bar-fill" style="width: {min(p, 100)}%; background: {clr};"></div></div></div>', unsafe_allow_html=True)
323
- with dc:
 
 
 
 
 
 
 
 
 
 
 
 
324
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
325
- m_cols = st.columns(2)
326
  for idx, (name, vote) in enumerate(zip(list(models_dict.keys()), all_votes)):
327
- with m_cols[idx % 2]:
328
- st.markdown(f'<div class="model-card"><div class="model-name">{name}</div><div style="color: {label_colors[vote]}; font-weight: 800; font-size: 24px;">{vote}</div></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
329
  elif analyze_btn and not user_input.strip():
330
  st.error("অনুগ্রহ করে কিছু টেক্সট লিখুন!")
331
 
332
- with st.expander("Example Political Texts"):
333
- examples = ["সরকারের এই নীতি দেশকে ধ্বংসের দিকে নিয়ে যাবে!", "চমৎকার সিদ্ধান্ত! দেশের জন্য গর্বিত।", "রাজনীতির কোনো পরিবর্তন হবে না"]
334
- ex_cols = st.columns(3)
335
- for idx, ex in enumerate(examples):
336
- with ex_cols[idx]:
337
- if st.button(ex[:30] + "...", use_container_width=True):
338
- st.session_state.user_input = ex
 
 
 
 
 
339
  st.rerun()
 
2
  import torch
3
  import torch.nn.functional as F
4
  import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from normalizer import normalize
7
  import torch.nn as nn
8
+ from transformers import AutoModel
9
 
10
  st.set_page_config(page_title="Political Sentiment", layout="wide")
11
 
 
14
  super().__init__()
15
  self.banglabert = AutoModel.from_pretrained("csebuetnlp/banglabert")
16
  self.hidden_size = self.banglabert.config.hidden_size
17
+
18
  self.cnn_layers = nn.ModuleList([
19
+ nn.Conv1d(self.hidden_size, 128, kernel_size=k, padding=k//2)
20
  for k in [3,5,7]
21
  ])
22
+
23
  self.attention = nn.MultiheadAttention(self.hidden_size, 8, batch_first=True)
24
  self.classifier = nn.Sequential(
 
 
 
 
25
  nn.Dropout(0.3),
26
+ nn.Linear(self.hidden_size, 512),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.2),
29
+ nn.Linear(512, num_classes)
30
  )
31
+
 
32
  def forward(self, input_ids, attention_mask=None):
33
  bert_out = self.banglabert(input_ids, attention_mask=attention_mask).last_hidden_state
34
+
35
+ cnn_features = []
36
+ for cnn in self.cnn_layers:
37
+ cnn_out = cnn(bert_out.transpose(1,2)).transpose(1,2)
38
+ cnn_features.append(F.relu(cnn_out))
39
+
40
+ cnn_concat = torch.cat(cnn_features, dim=-1)
41
+ proj = nn.Linear(384, self.hidden_size).to(input_ids.device)
42
+ attn_input = proj(cnn_concat)
43
  attn_out, _ = self.attention(attn_input, attn_input, attn_input)
44
+ attn_pooled = attn_out[:, 0, :]
45
+
46
+ logits = self.classifier(attn_pooled)
47
+ return logits
48
 
49
  st.markdown("""
50
  <style>
 
212
  """, unsafe_allow_html=True)
213
 
214
  id2label = {0: 'Very Negative', 1: 'Negative', 2: 'Neutral', 3: 'Positive', 4: 'Very Positive'}
215
+ label_colors = {
216
+ 'Very Negative': '#ef4444',
217
+ 'Negative': '#f97316',
218
+ 'Neutral': '#64748b',
219
+ 'Positive': '#22c55e',
220
+ 'Very Positive': '#16a34a'
221
+ }
222
 
223
  @st.cache_resource
224
  def load_models():
225
  models_loaded = {}
226
+
227
+ target_models = {
228
  "model_banglabert": "rocky250/Sentiment-banglabert",
229
  "model_mbert": "rocky250/Sentiment-mbert",
230
  "model_bbase": "rocky250/Sentiment-bbase",
231
+ "model_xlmr": "rocky250/Sentiment-xlmr",
232
+ "bangla_political": "rocky250/bangla-political"
233
  }
234
+
235
+ for name, repo in target_models.items():
236
  try:
237
  tokenizer = AutoTokenizer.from_pretrained(repo)
238
  model = AutoModelForSequenceClassification.from_pretrained(repo)
239
+ models_loaded[name] = (tokenizer, model.to('cuda' if torch.cuda.is_available() else 'cpu'))
240
  except:
241
  continue
242
+
 
 
 
 
 
 
 
 
 
243
  return models_loaded
244
 
245
  models_dict = load_models()
 
247
  def predict_single_model(text, model_name):
248
  clean_text = normalize(text)
249
  tokenizer, model = models_dict[model_name]
250
+
251
  device = next(model.parameters()).device
252
  inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
253
+
254
  with torch.no_grad():
255
+ outputs = model(**inputs)
256
+ logits = outputs.logits
257
+
 
 
258
  probs = F.softmax(logits, dim=1).cpu().numpy()[0]
259
  pred_id = np.argmax(probs)
260
  prediction = id2label[pred_id]
261
+
262
  return prediction, probs
263
 
264
  def predict_ensemble(text):
265
  clean_text = normalize(text)
266
  all_probs = []
267
  all_predictions = []
268
+
269
  for name in models_dict.keys():
270
  try:
271
  pred, probs = predict_single_model(clean_text, name)
 
273
  all_predictions.append(pred)
274
  except:
275
  continue
276
+
277
  if all_probs:
278
  avg_probs = np.mean(all_probs, axis=0)
279
  final_pred = id2label[np.argmax(avg_probs)]
 
281
  return "Error", [], np.zeros(5)
282
 
283
  st.markdown("""
284
+ <div style='
285
+ text-align: center;
286
+ background: rgba(255,255,255,0.1);
287
+ padding: 30px;
288
+ border-radius: 20px;
289
+ margin-bottom: 30px;
290
+ backdrop-filter: blur(20px);
291
+ '>
292
  <h1 style='font-size: 3.5rem; margin: 0; background: linear-gradient(45deg, #ffffff, #e2e8f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 800;'>Political Sentiment Analysis</h1>
293
  </div>
294
  """, unsafe_allow_html=True)
295
 
296
  col1, col2 = st.columns([3, 1])
297
  with col1:
298
+ user_input = st.text_area("Enter Bengali political text:", height=140,
299
+ placeholder="এই বক্সে বাংলা রাজনৈতিক মন্তব্য লিখুন...",
300
+ help="Type or paste Bengali political text for sentiment analysis")
301
+
302
  with col2:
303
  st.markdown("<div style='height: 20px'></div>", unsafe_allow_html=True)
304
+ mode = st.radio("Analysis Mode:",
305
+ ["Single Model", "Ensemble"],
306
+ horizontal=True)
307
+
308
  selected_model = None
309
  if mode == "Single Model":
310
  model_options = {name: name for name in models_dict.keys()}
311
+ selected_model = st.selectbox("Select Model:", list(model_options.keys()), index=0)
 
312
 
313
  analyze_btn = st.button("ANALYZE SENTIMENT", type="primary", use_container_width=True)
314
 
315
  if analyze_btn and user_input.strip():
316
+ with st.spinner('Processing with models...'):
317
  if mode == "Single Model":
318
+ model_name = selected_model
319
+ final_res, probs = predict_single_model(user_input, model_name)
320
+
321
+ col1, col2 = st.columns([1, 2])
322
+ with col1:
323
+ st.markdown(f"""
324
+ <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}">
325
+ <div class="result-title">{model_name}</div>
326
+ <div class="result-value" style="color: {label_colors[final_res]}">{final_res}</div>
327
+ <div style="font-size: 18px; color: #64748b; margin-top: 15px;">Confidence: {max(probs)*100:.1f}%</div>
328
+ </div>
329
+ """, unsafe_allow_html=True)
330
+
331
+ with col2:
332
+ st.markdown('<div class="section-header">Confidence Scores</div>', unsafe_allow_html=True)
333
+ for i in range(5):
334
+ label = id2label[i]
335
+ prob = probs[i] * 100
336
+ color = label_colors[label]
337
+
338
+ st.markdown(f"""
339
+ <div class="prob-row">
340
+ <div class="prob-label">
341
+ <span style="font-weight: 700;">{label}</span>
342
+ <span style="font-weight: 700; color: {color};">{prob:.1f}%</span>
343
+ </div>
344
+ <div class="prob-bar-bg">
345
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
346
+ </div>
347
+ </div>
348
+ """, unsafe_allow_html=True)
349
+
350
  else:
351
  final_res, all_votes, avg_probs = predict_ensemble(user_input)
352
+
353
+ main_col, details_col = st.columns([1, 1.4])
354
+
355
+ with main_col:
356
+ st.markdown(f"""
357
+ <div class="main-card" style="border-top: 8px solid {label_colors[final_res]}; box-shadow: 0 25px 50px rgba(0,0,0,0.2);">
358
+ <div class="result-title" style="font-size: 18px;">ENSEMBLE CONSENSUS</div>
359
+ <div class="result-value" style="color: {label_colors[final_res]}; font-size: 60px;">{final_res}</div>
360
+ </div>
361
+ """, unsafe_allow_html=True)
362
+
363
+ st.markdown('<div class="section-header">Ensemble Probabilities</div>', unsafe_allow_html=True)
364
+
365
  for i in range(5):
366
+ label = id2label[i]
367
+ prob = avg_probs[i] * 100
368
+ color = label_colors[label]
369
+
370
+ st.markdown(f"""
371
+ <div class="prob-row">
372
+ <div class="prob-label">
373
+ <span>{label}</span>
374
+ <span style="color: {color};">{prob:.1f}%</span>
375
+ </div>
376
+ <div class="prob-bar-bg">
377
+ <div class="prob-bar-fill" style="width: {min(prob, 100)}%; background: linear-gradient(90deg, {color}, {color}cc);"></div>
378
+ </div>
379
+ </div>
380
+ """, unsafe_allow_html=True)
381
+
382
+ with details_col:
383
  st.markdown('<div class="section-header">Individual Model Votes</div>', unsafe_allow_html=True)
384
+ model_cols = st.columns(2)
385
  for idx, (name, vote) in enumerate(zip(list(models_dict.keys()), all_votes)):
386
+ with model_cols[idx % 2]:
387
+ color = label_colors[vote]
388
+ st.markdown(f"""
389
+ <div class="model-card">
390
+ <div class="model-name">{name}</div>
391
+ <div style="color: {color}; font-weight: 800; font-size: 24px; margin-top: 8px;">{vote}</div>
392
+ </div>
393
+ """, unsafe_allow_html=True)
394
+
395
  elif analyze_btn and not user_input.strip():
396
  st.error("অনুগ্রহ করে কিছু টেক্সট লিখুন!")
397
 
398
+ with st.expander("Example Political Texts", expanded=False):
399
+ examples = [
400
+ "সরকারের এই নীতি দ���শকে ধ্বংসের দিকে নিয়ে যাবে!",
401
+ "চমৎকার সিদ্ধান্ত! দেশের জন্য গর্বিত। ভালো চলবে!",
402
+ "রাজনীতির কোনো পরিবর্তন হবে না, সব একই রকম"
403
+ ]
404
+ example_cols = st.columns(3)
405
+ for idx, example in enumerate(examples):
406
+ with example_cols[idx]:
407
+ if st.button(example[:40] + "..." if len(example) > 40 else example,
408
+ use_container_width=True):
409
+ st.session_state.user_input = example
410
  st.rerun()