abdou21367 commited on
Commit
8351502
Β·
verified Β·
1 Parent(s): f59316d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +830 -586
app.py CHANGED
@@ -1,587 +1,831 @@
1
- """
2
- Sentiment Analysis App - Custom Transformer vs BERTweet
3
- Lightweight version with automatic model download from Google Drive
4
- """
5
-
6
- import streamlit as st
7
- import torch
8
- import torch.nn.functional as F
9
- import pandas as pd
10
- import numpy as np
11
- import pickle
12
- from pathlib import Path
13
- import plotly.graph_objects as go
14
- from typing import Dict, List, Tuple
15
- import sys
16
- import os
17
- import requests
18
- import gdown
19
-
20
- # Add src to path
21
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
22
-
23
- try:
24
- from src.models.baseline.custom_transformer import CustomTransformer
25
- from src.models.pretrained.bertweet import BERTweetClassifier, get_bertweet_tokenizer
26
- from src.data.preprocessing import EnhancedTextPreprocessor
27
- except ModuleNotFoundError:
28
- from models.baseline.custom_transformer import CustomTransformer
29
- from models.pretrained.bertweet import BERTweetClassifier, get_bertweet_tokenizer
30
- from data.preprocessing import EnhancedTextPreprocessor
31
-
32
-
33
- # ============================================================================
34
- # CONFIGURATION
35
- # ============================================================================
36
-
37
- class Config:
38
- """App configuration"""
39
- # Local paths
40
- TRANSFORMER_MODEL_PATH = "https://drive.google.com/file/d/124EHm4lHWWWzJfVdlJF-9l8QRqTjI0BK/view?usp=drive_link"
41
- BERTWEET_MODEL_PATH = "https://drive.google.com/file/d/1DlGRe4qHypaWby6MU1ab0ZcpJp2FhQtL/view?usp=drive_link"
42
- VOCABULARY_PATH = "https://drive.google.com/file/d/1DkbnnYe1_dVFGuOwsDaZ_9UuH1zCXCwE/view?usp=drive_link"
43
-
44
- # Google Drive File IDs (REPLACE THESE WITH YOUR FILE IDs!)
45
- # Instructions: Upload to Google Drive, share publicly, extract FILE_ID from share link
46
- GDRIVE_IDS = {
47
- 'transformer': 'YOUR_TRANSFORMER_FILE_ID', # Replace with actual ID
48
- 'bertweet': 'YOUR_BERTWEET_FILE_ID', # Replace with actual ID
49
- 'vocab': 'YOUR_VOCAB_FILE_ID' # Replace with actual ID
50
- }
51
-
52
- # Model parameters
53
- TRANSFORMER_CONFIG = {
54
- 'vocab_size': 10000,
55
- 'd_model': 256,
56
- 'num_heads': 4,
57
- 'num_layers': 4,
58
- 'd_ff': 1024,
59
- 'num_classes': 3,
60
- 'max_len': 100,
61
- 'dropout': 0.1,
62
- 'padding_idx': 0
63
- }
64
-
65
- BERTWEET_CONFIG = {
66
- 'model_name': 'vinai/bertweet-base',
67
- 'num_classes': 3,
68
- 'dropout': 0.5
69
- }
70
-
71
- # Labels
72
- LABEL_MAP = {0: 'Negative', 1: 'Neutral', 2: 'Positive'}
73
- LABEL_COLORS = {
74
- 'Negative': '#FF4B4B',
75
- 'Neutral': '#FFA500',
76
- 'Positive': '#00D66C'
77
- }
78
-
79
- # Device
80
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
-
82
-
83
- # ============================================================================
84
- # MODEL DOWNLOAD FUNCTIONS
85
- # ============================================================================
86
-
87
- def download_from_gdrive(file_id: str, output_path: str):
88
- """Download file from Google Drive"""
89
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
90
- url = f"https://drive.google.com/uc?id={file_id}"
91
- gdown.download(url, output_path, quiet=False)
92
-
93
-
94
- def ensure_models_downloaded():
95
- """Download models if they don't exist"""
96
- models_dir = Path("models")
97
- models_dir.mkdir(exist_ok=True)
98
-
99
- files_to_check = [
100
- (Config.TRANSFORMER_MODEL_PATH, 'transformer'),
101
- (Config.BERTWEET_MODEL_PATH, 'bertweet'),
102
- (Config.VOCABULARY_PATH, 'vocab')
103
- ]
104
-
105
- for file_path, key in files_to_check:
106
- if not os.path.exists(file_path):
107
- file_id = Config.GDRIVE_IDS[key]
108
- if file_id == f'YOUR_{key.upper()}_FILE_ID':
109
- st.error(f"⚠️ Google Drive ID not configured for {key}!")
110
- st.info("Please update Config.GDRIVE_IDS in the code with your file IDs")
111
- st.stop()
112
-
113
- st.info(f"πŸ“₯ Downloading {key} model... (first time only, ~1-2 min)")
114
- try:
115
- download_from_gdrive(file_id, file_path)
116
- st.success(f"βœ… {key} downloaded!")
117
- except Exception as e:
118
- st.error(f"❌ Failed to download {key}: {e}")
119
- st.stop()
120
-
121
-
122
- # ============================================================================
123
- # MODEL LOADING & CACHING
124
- # ============================================================================
125
-
126
- @st.cache_resource
127
- def load_custom_transformer() -> CustomTransformer:
128
- """Load Custom Transformer model"""
129
- try:
130
- config = Config.TRANSFORMER_CONFIG.copy()
131
- model = CustomTransformer(**config)
132
-
133
- checkpoint = torch.load(Config.TRANSFORMER_MODEL_PATH, map_location=Config.DEVICE)
134
-
135
- if isinstance(checkpoint, dict):
136
- if 'model_state_dict' in checkpoint:
137
- model.load_state_dict(checkpoint['model_state_dict'])
138
- elif 'state_dict' in checkpoint:
139
- model.load_state_dict(checkpoint['state_dict'])
140
- else:
141
- model.load_state_dict(checkpoint)
142
- else:
143
- model.load_state_dict(checkpoint)
144
-
145
- model.to(Config.DEVICE)
146
- model.eval()
147
-
148
- return model
149
- except Exception as e:
150
- st.error(f"❌ Failed to load Custom Transformer: {e}")
151
- return None
152
-
153
-
154
- @st.cache_resource
155
- def load_bertweet_model() -> BERTweetClassifier:
156
- """Load BERTweet model"""
157
- try:
158
- model = BERTweetClassifier(**Config.BERTWEET_CONFIG)
159
-
160
- checkpoint = torch.load(Config.BERTWEET_MODEL_PATH, map_location=Config.DEVICE)
161
-
162
- if isinstance(checkpoint, dict):
163
- if 'model_state_dict' in checkpoint:
164
- model.load_state_dict(checkpoint['model_state_dict'])
165
- elif 'state_dict' in checkpoint:
166
- model.load_state_dict(checkpoint['state_dict'])
167
- else:
168
- model.load_state_dict(checkpoint)
169
- else:
170
- model.load_state_dict(checkpoint)
171
-
172
- model.to(Config.DEVICE)
173
- model.eval()
174
-
175
- return model
176
- except Exception as e:
177
- st.error(f"❌ Failed to load BERTweet: {e}")
178
- return None
179
-
180
-
181
- @st.cache_resource
182
- def load_bertweet_tokenizer():
183
- """Load BERTweet tokenizer"""
184
- try:
185
- return get_bertweet_tokenizer()
186
- except Exception as e:
187
- st.error(f"❌ Failed to load BERTweet tokenizer: {e}")
188
- return None
189
-
190
-
191
- @st.cache_resource
192
- def load_preprocessor():
193
- """Load text preprocessor"""
194
- try:
195
- preprocessor = EnhancedTextPreprocessor(
196
- vocab_size=10000,
197
- max_length=100,
198
- min_freq=2,
199
- use_spell_check=False,
200
- use_lemmatization=False
201
- )
202
-
203
- preprocessor.load_vocabulary(Config.VOCABULARY_PATH)
204
-
205
- return preprocessor
206
- except Exception as e:
207
- st.error(f"❌ Failed to load preprocessor: {e}")
208
- return None
209
-
210
-
211
- # ============================================================================
212
- # PREDICTION FUNCTIONS
213
- # ============================================================================
214
-
215
- def text_to_indices(text: str, preprocessor: EnhancedTextPreprocessor, max_len: int = 100) -> torch.Tensor:
216
- """Convert text to token indices for Custom Transformer"""
217
- tokens = text.lower().split()
218
- indices = [preprocessor.word2idx.get(token, preprocessor.word2idx.get('<UNK>', 1)) for token in tokens]
219
-
220
- if len(indices) < max_len:
221
- indices = indices + [0] * (max_len - len(indices))
222
- else:
223
- indices = indices[:max_len]
224
-
225
- return torch.tensor([indices], dtype=torch.long)
226
-
227
-
228
- def predict_custom_transformer(
229
- text: str,
230
- model: CustomTransformer,
231
- preprocessor: EnhancedTextPreprocessor
232
- ) -> Tuple[str, Dict[str, float]]:
233
- """Predict sentiment using Custom Transformer"""
234
- try:
235
- processed_text = preprocessor.clean_text(text)
236
- input_ids = text_to_indices(processed_text, preprocessor, max_len=100).to(Config.DEVICE)
237
- mask = (input_ids != 0).float()
238
-
239
- with torch.no_grad():
240
- logits = model(input_ids, mask=mask)
241
- probs = F.softmax(logits, dim=1)[0]
242
-
243
- pred_idx = torch.argmax(probs).item()
244
- pred_label = Config.LABEL_MAP[pred_idx]
245
-
246
- confidences = {
247
- Config.LABEL_MAP[i]: float(probs[i])
248
- for i in range(len(Config.LABEL_MAP))
249
- }
250
-
251
- return pred_label, confidences
252
-
253
- except Exception as e:
254
- st.error(f"❌ Custom Transformer prediction failed: {e}")
255
- return "Error", {}
256
-
257
-
258
- def predict_bertweet(
259
- text: str,
260
- model: BERTweetClassifier,
261
- tokenizer,
262
- preprocessor: EnhancedTextPreprocessor
263
- ) -> Tuple[str, Dict[str, float]]:
264
- """Predict sentiment using BERTweet"""
265
- try:
266
- processed_text = preprocessor.clean_text(text)
267
-
268
- encoded = tokenizer(
269
- processed_text,
270
- padding='max_length',
271
- truncation=True,
272
- max_length=128,
273
- return_tensors='pt'
274
- )
275
-
276
- input_ids = encoded['input_ids'].to(Config.DEVICE)
277
- attention_mask = encoded['attention_mask'].to(Config.DEVICE)
278
-
279
- with torch.no_grad():
280
- logits = model(input_ids=input_ids, attention_mask=attention_mask)
281
- probs = F.softmax(logits, dim=1)[0]
282
-
283
- pred_idx = torch.argmax(probs).item()
284
- pred_label = Config.LABEL_MAP[pred_idx]
285
-
286
- confidences = {
287
- Config.LABEL_MAP[i]: float(probs[i])
288
- for i in range(len(Config.LABEL_MAP))
289
- }
290
-
291
- return pred_label, confidences
292
-
293
- except Exception as e:
294
- st.error(f"❌ BERTweet prediction failed: {e}")
295
- return "Error", {}
296
-
297
-
298
- # ============================================================================
299
- # VISUALIZATION FUNCTIONS
300
- # ============================================================================
301
-
302
- def create_confidence_chart(confidences: Dict[str, float], model_name: str) -> go.Figure:
303
- """Create a beautiful confidence bar chart"""
304
- labels = list(confidences.keys())
305
- values = [confidences[label] * 100 for label in labels]
306
- colors = [Config.LABEL_COLORS[label] for label in labels]
307
-
308
- fig = go.Figure(data=[
309
- go.Bar(
310
- x=values,
311
- y=labels,
312
- orientation='h',
313
- marker=dict(
314
- color=colors,
315
- line=dict(color='rgba(0,0,0,0.3)', width=1)
316
- ),
317
- text=[f'{v:.1f}%' for v in values],
318
- textposition='auto',
319
- hovertemplate='<b>%{y}</b><br>Confidence: %{x:.2f}%<extra></extra>'
320
- )
321
- ])
322
-
323
- fig.update_layout(
324
- title=dict(
325
- text=f'{model_name} Confidence Scores',
326
- font=dict(size=16, family='Arial, sans-serif')
327
- ),
328
- xaxis=dict(
329
- title='Confidence (%)',
330
- range=[0, 100],
331
- showgrid=True,
332
- gridcolor='rgba(0,0,0,0.1)'
333
- ),
334
- yaxis=dict(
335
- title='',
336
- categoryorder='array',
337
- categoryarray=['Positive', 'Neutral', 'Negative']
338
- ),
339
- height=250,
340
- margin=dict(l=10, r=10, t=40, b=10),
341
- plot_bgcolor='rgba(0,0,0,0)',
342
- paper_bgcolor='rgba(0,0,0,0)',
343
- font=dict(family='Arial, sans-serif')
344
- )
345
-
346
- return fig
347
-
348
-
349
- def create_comparison_chart(transformer_conf: Dict, bertweet_conf: Dict) -> go.Figure:
350
- """Create side-by-side comparison chart"""
351
- labels = list(Config.LABEL_MAP.values())
352
-
353
- transformer_values = [transformer_conf[label] * 100 for label in labels]
354
- bertweet_values = [bertweet_conf[label] * 100 for label in labels]
355
-
356
- fig = go.Figure(data=[
357
- go.Bar(
358
- name='Custom Transformer',
359
- x=labels,
360
- y=transformer_values,
361
- marker_color='#636EFA',
362
- text=[f'{v:.1f}%' for v in transformer_values],
363
- textposition='auto',
364
- ),
365
- go.Bar(
366
- name='BERTweet',
367
- x=labels,
368
- y=bertweet_values,
369
- marker_color='#EF553B',
370
- text=[f'{v:.1f}%' for v in bertweet_values],
371
- textposition='auto',
372
- )
373
- ])
374
-
375
- fig.update_layout(
376
- title='Model Comparison',
377
- xaxis_title='Sentiment',
378
- yaxis_title='Confidence (%)',
379
- barmode='group',
380
- height=350,
381
- yaxis=dict(range=[0, 100]),
382
- legend=dict(
383
- orientation="h",
384
- yanchor="bottom",
385
- y=1.02,
386
- xanchor="right",
387
- x=1
388
- ),
389
- plot_bgcolor='rgba(0,0,0,0)',
390
- paper_bgcolor='rgba(0,0,0,0)',
391
- font=dict(family='Arial, sans-serif')
392
- )
393
-
394
- return fig
395
-
396
-
397
- # ============================================================================
398
- # STREAMLIT UI
399
- # ============================================================================
400
-
401
- def main():
402
- # Page config
403
- st.set_page_config(
404
- page_title="Sentiment Analysis App",
405
- page_icon="πŸ’­",
406
- layout="wide",
407
- initial_sidebar_state="expanded"
408
- )
409
-
410
- # Custom CSS
411
- st.markdown("""
412
- <style>
413
- .main {
414
- padding: 2rem;
415
- }
416
- .stButton>button {
417
- width: 100%;
418
- background-color: #4CAF50;
419
- color: white;
420
- border-radius: 8px;
421
- height: 3em;
422
- font-weight: 600;
423
- }
424
- .stButton>button:hover {
425
- background-color: #45a049;
426
- }
427
- .prediction-box {
428
- padding: 1.5rem;
429
- border-radius: 10px;
430
- margin: 1rem 0;
431
- text-align: center;
432
- font-size: 1.2rem;
433
- font-weight: 600;
434
- }
435
- .negative-box {
436
- background-color: #ffe6e6;
437
- border: 2px solid #FF4B4B;
438
- color: #c41e3a;
439
- }
440
- .neutral-box {
441
- background-color: #fff5e6;
442
- border: 2px solid #FFA500;
443
- color: #d97706;
444
- }
445
- .positive-box {
446
- background-color: #e6f7ed;
447
- border: 2px solid #00D66C;
448
- color: #059669;
449
- }
450
- </style>
451
- """, unsafe_allow_html=True)
452
-
453
- # Header
454
- st.title("πŸ’­ Sentiment Analysis App")
455
- st.markdown("### Compare Custom Transformer vs BERTweet Models")
456
- st.markdown("---")
457
-
458
- # Sidebar
459
- with st.sidebar:
460
- st.header("πŸ“Š Model Information")
461
-
462
- st.markdown("**Custom Transformer**")
463
- st.info("""
464
- - Architecture: From-scratch Transformer
465
- - Layers: 4 encoder layers
466
- - Attention Heads: 4
467
- - Parameters: ~2M
468
- """)
469
-
470
- st.markdown("**BERTweet**")
471
- st.info("""
472
- - Architecture: Twitter-specific BERT
473
- - Pretrained: vinai/bertweet-base
474
- - Parameters: ~135M
475
- - Fine-tuned for sentiment
476
- """)
477
-
478
- st.markdown("---")
479
- st.markdown("**Labels**")
480
- st.markdown("πŸ”΄ **Negative** - Negative sentiment")
481
- st.markdown("🟠 **Neutral** - Neutral sentiment")
482
- st.markdown("🟒 **Positive** - Positive sentiment")
483
-
484
- # Download models if needed
485
- ensure_models_downloaded()
486
-
487
- # Load models
488
- with st.spinner("πŸ”„ Loading models..."):
489
- preprocessor = load_preprocessor()
490
- transformer_model = load_custom_transformer()
491
- bertweet_model = load_bertweet_model()
492
- bertweet_tokenizer = load_bertweet_tokenizer()
493
-
494
- # Check if models loaded
495
- models_loaded = all([
496
- preprocessor is not None,
497
- transformer_model is not None,
498
- bertweet_model is not None,
499
- bertweet_tokenizer is not None
500
- ])
501
-
502
- if not models_loaded:
503
- st.error("❌ Failed to load models. Please check configuration.")
504
- st.stop()
505
-
506
- st.success(f"βœ… Models loaded! Running on: **{Config.DEVICE}**")
507
-
508
- # Main UI
509
- st.markdown("### Enter text to analyze sentiment")
510
-
511
- user_text = st.text_area(
512
- "Text Input",
513
- placeholder="Type or paste your text here...",
514
- height=120,
515
- label_visibility="collapsed"
516
- )
517
-
518
- # Examples
519
- col1, col2, col3 = st.columns(3)
520
- with col1:
521
- if st.button("😊 Example: Positive"):
522
- user_text = "This is absolutely amazing! I love it so much! πŸŽ‰"
523
- st.rerun()
524
- with col2:
525
- if st.button("😐 Example: Neutral"):
526
- user_text = "It was okay. Nothing special."
527
- st.rerun()
528
- with col3:
529
- if st.button("😞 Example: Negative"):
530
- user_text = "This is terrible. Very disappointed."
531
- st.rerun()
532
-
533
- # Predict
534
- if st.button("πŸ” Analyze Sentiment", type="primary"):
535
- if not user_text.strip():
536
- st.warning("⚠️ Please enter some text!")
537
- else:
538
- with st.spinner("πŸ€– Analyzing..."):
539
- transformer_pred, transformer_conf = predict_custom_transformer(
540
- user_text, transformer_model, preprocessor
541
- )
542
- bertweet_pred, bertweet_conf = predict_bertweet(
543
- user_text, bertweet_model, bertweet_tokenizer, preprocessor
544
- )
545
-
546
- st.markdown("---")
547
- st.markdown("### 🎯 Results")
548
-
549
- col1, col2 = st.columns(2)
550
-
551
- with col1:
552
- st.markdown("#### Custom Transformer")
553
- sentiment_class = transformer_pred.lower()
554
- st.markdown(
555
- f'<div class="prediction-box {sentiment_class}-box">'
556
- f'{transformer_pred}'
557
- f'</div>',
558
- unsafe_allow_html=True
559
- )
560
- fig1 = create_confidence_chart(transformer_conf, "Custom Transformer")
561
- st.plotly_chart(fig1, use_container_width=True)
562
-
563
- with col2:
564
- st.markdown("#### BERTweet")
565
- sentiment_class = bertweet_pred.lower()
566
- st.markdown(
567
- f'<div class="prediction-box {sentiment_class}-box">'
568
- f'{bertweet_pred}'
569
- f'</div>',
570
- unsafe_allow_html=True
571
- )
572
- fig2 = create_confidence_chart(bertweet_conf, "BERTweet")
573
- st.plotly_chart(fig2, use_container_width=True)
574
-
575
- st.markdown("---")
576
- st.markdown("### πŸ“Š Comparison")
577
- fig_comparison = create_comparison_chart(transformer_conf, bertweet_conf)
578
- st.plotly_chart(fig_comparison, use_container_width=True)
579
-
580
- if transformer_pred == bertweet_pred:
581
- st.success(f"βœ… **Both models agree:** {transformer_pred}")
582
- else:
583
- st.warning(f"⚠️ **Disagreement:** Transformer={transformer_pred}, BERTweet={bertweet_pred}")
584
-
585
-
586
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  main()
 
1
+ """
2
+ Sentiment Analysis App - Custom Transformer vs BERTweet
3
+ Clean UI with single text prediction, model comparison, and CSV batch processing
4
+ """
5
+
6
+ import streamlit as st
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import pandas as pd
10
+ import numpy as np
11
+ import pickle
12
+ from pathlib import Path
13
+ import plotly.graph_objects as go
14
+ from typing import Dict, List, Tuple
15
+ import sys
16
+ import os
17
+ import gdown
18
+
19
+ # Add project root and src to path
20
+ current_dir = os.path.dirname(os.path.abspath(__file__))
21
+ project_root = os.path.dirname(current_dir) if 'app' in current_dir else current_dir
22
+ src_path = os.path.join(project_root, 'src')
23
+
24
+ # Add both paths to ensure imports work
25
+ sys.path.insert(0, project_root)
26
+ sys.path.insert(0, src_path)
27
+
28
+ # Try different import methods to be robust
29
+ try:
30
+ from src.models.baseline.custom_transformer import CustomTransformer
31
+ from src.models.pretrained.bertweet import BERTweetClassifier, get_bertweet_tokenizer
32
+ from src.data.preprocessing import EnhancedTextPreprocessor
33
+ except ModuleNotFoundError:
34
+ from models.baseline.custom_transformer import CustomTransformer
35
+ from models.pretrained.bertweet import BERTweetClassifier, get_bertweet_tokenizer
36
+ from data.preprocessing import EnhancedTextPreprocessor
37
+
38
+
39
+ # ============================================================================
40
+ # CONFIGURATION
41
+ # ============================================================================
42
+
43
+ class Config:
44
+ """App configuration"""
45
+ # Local paths where models will be saved after download
46
+ TRANSFORMER_MODEL_PATH = "models/transformer_best_model.pt"
47
+ BERTWEET_MODEL_PATH = "models/bertweet_best_model.pt"
48
+ VOCABULARY_PATH = "models/vocabulary.pkl"
49
+
50
+ # Google Drive File IDs (extracted from your share links)
51
+ GDRIVE_IDS = {
52
+ 'transformer': '124EHm4lHWWWzJfVdlJF-9l8QRqTjI0BK',
53
+ 'bertweet': '1DlGRe4qHypaWby6MU1ab0ZcpJp2FhQtL',
54
+ 'vocab': '1DkbnnYe1_dVFGuOwsDaZ_9UuH1zCXCwE'
55
+ }
56
+
57
+ # Model parameters
58
+ TRANSFORMER_CONFIG = {
59
+ 'vocab_size': 10000,
60
+ 'd_model': 256,
61
+ 'num_heads': 4,
62
+ 'num_layers': 4,
63
+ 'd_ff': 1024,
64
+ 'num_classes': 3,
65
+ 'max_len': 100,
66
+ 'dropout': 0.1,
67
+ 'padding_idx': 0
68
+ }
69
+
70
+ BERTWEET_CONFIG = {
71
+ 'model_name': 'vinai/bertweet-base',
72
+ 'num_classes': 3,
73
+ 'dropout': 0.5
74
+ }
75
+
76
+ # Labels
77
+ LABEL_MAP = {0: 'Negative', 1: 'Neutral', 2: 'Positive'}
78
+ LABEL_COLORS = {
79
+ 'Negative': '#FF4B4B',
80
+ 'Neutral': '#FFA500',
81
+ 'Positive': '#00D66C'
82
+ }
83
+
84
+ # Device
85
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
+
87
+
88
+ # ============================================================================
89
+ # MODEL DOWNLOAD FUNCTIONS
90
+ # ============================================================================
91
+
92
+ def download_from_gdrive(file_id: str, output_path: str) -> bool:
93
+ """Download file from Google Drive using gdown"""
94
+ try:
95
+ # Create directory if needed
96
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
97
+
98
+ # Download URL
99
+ url = f"https://drive.google.com/uc?id={file_id}"
100
+
101
+ # Download with gdown
102
+ gdown.download(url, output_path, quiet=False, fuzzy=True)
103
+
104
+ # Verify download
105
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
106
+ return True
107
+ else:
108
+ return False
109
+ except Exception as e:
110
+ st.error(f"Download error: {e}")
111
+ return False
112
+
113
+
114
+ @st.cache_resource
115
+ def ensure_models_downloaded():
116
+ """Download models from Google Drive if not present"""
117
+
118
+ files_to_download = [
119
+ ('transformer', Config.TRANSFORMER_MODEL_PATH),
120
+ ('bertweet', Config.BERTWEET_MODEL_PATH),
121
+ ('vocab', Config.VOCABULARY_PATH)
122
+ ]
123
+
124
+ for model_name, file_path in files_to_download:
125
+ # Check if file already exists
126
+ if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
127
+ st.info(f"βœ“ {model_name} already downloaded")
128
+ continue
129
+
130
+ # Download from Google Drive
131
+ st.info(f"πŸ“₯ Downloading {model_name}... (first time only, may take 1-2 minutes)")
132
+
133
+ file_id = Config.GDRIVE_IDS[model_name]
134
+ success = download_from_gdrive(file_id, file_path)
135
+
136
+ if success:
137
+ st.success(f"βœ… {model_name} downloaded successfully!")
138
+ else:
139
+ st.error(f"❌ Failed to download {model_name}")
140
+ return False
141
+
142
+ return True
143
+
144
+
145
+ # ============================================================================
146
+ # MODEL LOADING & CACHING
147
+ # ============================================================================
148
+
149
+ @st.cache_resource
150
+ def load_custom_transformer() -> CustomTransformer:
151
+ """Load Custom Transformer model"""
152
+ try:
153
+ config = Config.TRANSFORMER_CONFIG.copy()
154
+ model = CustomTransformer(**config)
155
+
156
+ checkpoint = torch.load(Config.TRANSFORMER_MODEL_PATH, map_location=Config.DEVICE)
157
+
158
+ if isinstance(checkpoint, dict):
159
+ if 'model_state_dict' in checkpoint:
160
+ model.load_state_dict(checkpoint['model_state_dict'])
161
+ elif 'state_dict' in checkpoint:
162
+ model.load_state_dict(checkpoint['state_dict'])
163
+ else:
164
+ model.load_state_dict(checkpoint)
165
+ else:
166
+ model.load_state_dict(checkpoint)
167
+
168
+ model.to(Config.DEVICE)
169
+ model.eval()
170
+
171
+ return model
172
+ except Exception as e:
173
+ st.error(f"❌ Failed to load Custom Transformer: {e}")
174
+ return None
175
+
176
+
177
+ @st.cache_resource
178
+ def load_bertweet_model() -> BERTweetClassifier:
179
+ """Load BERTweet model"""
180
+ try:
181
+ model = BERTweetClassifier(**Config.BERTWEET_CONFIG)
182
+
183
+ checkpoint = torch.load(Config.BERTWEET_MODEL_PATH, map_location=Config.DEVICE)
184
+
185
+ if isinstance(checkpoint, dict):
186
+ if 'model_state_dict' in checkpoint:
187
+ model.load_state_dict(checkpoint['model_state_dict'])
188
+ elif 'state_dict' in checkpoint:
189
+ model.load_state_dict(checkpoint['state_dict'])
190
+ else:
191
+ model.load_state_dict(checkpoint)
192
+ else:
193
+ model.load_state_dict(checkpoint)
194
+
195
+ model.to(Config.DEVICE)
196
+ model.eval()
197
+
198
+ return model
199
+ except Exception as e:
200
+ st.error(f"❌ Failed to load BERTweet: {e}")
201
+ return None
202
+
203
+
204
+ @st.cache_resource
205
+ def load_bertweet_tokenizer():
206
+ """Load BERTweet tokenizer"""
207
+ try:
208
+ return get_bertweet_tokenizer()
209
+ except Exception as e:
210
+ st.error(f"❌ Failed to load BERTweet tokenizer: {e}")
211
+ return None
212
+
213
+
214
+ @st.cache_resource
215
+ def load_preprocessor():
216
+ """Load text preprocessor"""
217
+ try:
218
+ preprocessor = EnhancedTextPreprocessor(
219
+ vocab_size=10000,
220
+ max_length=100,
221
+ min_freq=2,
222
+ use_spell_check=False,
223
+ use_lemmatization=False
224
+ )
225
+
226
+ preprocessor.load_vocabulary(Config.VOCABULARY_PATH)
227
+
228
+ return preprocessor
229
+ except Exception as e:
230
+ st.error(f"❌ Failed to load preprocessor: {e}")
231
+ return None
232
+
233
+
234
+ # ============================================================================
235
+ # PREDICTION FUNCTIONS
236
+ # ============================================================================
237
+
238
+ def text_to_indices(text: str, preprocessor: EnhancedTextPreprocessor, max_len: int = 100) -> torch.Tensor:
239
+ """Convert text to token indices for Custom Transformer"""
240
+ tokens = text.lower().split()
241
+ indices = [preprocessor.word2idx.get(token, preprocessor.word2idx.get('<UNK>', 1)) for token in tokens]
242
+
243
+ if len(indices) < max_len:
244
+ indices = indices + [0] * (max_len - len(indices))
245
+ else:
246
+ indices = indices[:max_len]
247
+
248
+ return torch.tensor([indices], dtype=torch.long)
249
+
250
+
251
+ def predict_custom_transformer(
252
+ text: str,
253
+ model: CustomTransformer,
254
+ preprocessor: EnhancedTextPreprocessor
255
+ ) -> Tuple[str, Dict[str, float]]:
256
+ """Predict sentiment using Custom Transformer"""
257
+ try:
258
+ processed_text = preprocessor.clean_text(text)
259
+ input_ids = text_to_indices(processed_text, preprocessor, max_len=100).to(Config.DEVICE)
260
+ mask = (input_ids != 0).float()
261
+
262
+ with torch.no_grad():
263
+ logits = model(input_ids, mask=mask)
264
+ probs = F.softmax(logits, dim=1)[0]
265
+
266
+ pred_idx = torch.argmax(probs).item()
267
+ pred_label = Config.LABEL_MAP[pred_idx]
268
+
269
+ confidences = {
270
+ Config.LABEL_MAP[i]: float(probs[i])
271
+ for i in range(len(Config.LABEL_MAP))
272
+ }
273
+
274
+ return pred_label, confidences
275
+
276
+ except Exception as e:
277
+ st.error(f"❌ Custom Transformer prediction failed: {e}")
278
+ return "Error", {}
279
+
280
+
281
+ def predict_bertweet(
282
+ text: str,
283
+ model: BERTweetClassifier,
284
+ tokenizer,
285
+ preprocessor: EnhancedTextPreprocessor
286
+ ) -> Tuple[str, Dict[str, float]]:
287
+ """Predict sentiment using BERTweet"""
288
+ try:
289
+ processed_text = preprocessor.clean_text(text)
290
+
291
+ encoded = tokenizer(
292
+ processed_text,
293
+ padding='max_length',
294
+ truncation=True,
295
+ max_length=128,
296
+ return_tensors='pt'
297
+ )
298
+
299
+ input_ids = encoded['input_ids'].to(Config.DEVICE)
300
+ attention_mask = encoded['attention_mask'].to(Config.DEVICE)
301
+
302
+ with torch.no_grad():
303
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
304
+ probs = F.softmax(logits, dim=1)[0]
305
+
306
+ pred_idx = torch.argmax(probs).item()
307
+ pred_label = Config.LABEL_MAP[pred_idx]
308
+
309
+ confidences = {
310
+ Config.LABEL_MAP[i]: float(probs[i])
311
+ for i in range(len(Config.LABEL_MAP))
312
+ }
313
+
314
+ return pred_label, confidences
315
+
316
+ except Exception as e:
317
+ st.error(f"❌ BERTweet prediction failed: {e}")
318
+ return "Error", {}
319
+
320
+
321
+ # ============================================================================
322
+ # VISUALIZATION FUNCTIONS
323
+ # ============================================================================
324
+
325
+ def create_confidence_chart(confidences: Dict[str, float], model_name: str) -> go.Figure:
326
+ """Create a beautiful confidence bar chart"""
327
+ labels = list(confidences.keys())
328
+ values = [confidences[label] * 100 for label in labels]
329
+ colors = [Config.LABEL_COLORS[label] for label in labels]
330
+
331
+ fig = go.Figure(data=[
332
+ go.Bar(
333
+ x=values,
334
+ y=labels,
335
+ orientation='h',
336
+ marker=dict(
337
+ color=colors,
338
+ line=dict(color='rgba(0,0,0,0.3)', width=1)
339
+ ),
340
+ text=[f'{v:.1f}%' for v in values],
341
+ textposition='auto',
342
+ hovertemplate='<b>%{y}</b><br>Confidence: %{x:.2f}%<extra></extra>'
343
+ )
344
+ ])
345
+
346
+ fig.update_layout(
347
+ title=dict(
348
+ text=f'{model_name} Confidence Scores',
349
+ font=dict(size=16, family='Arial, sans-serif')
350
+ ),
351
+ xaxis=dict(
352
+ title='Confidence (%)',
353
+ range=[0, 100],
354
+ showgrid=True,
355
+ gridcolor='rgba(0,0,0,0.1)'
356
+ ),
357
+ yaxis=dict(
358
+ title='',
359
+ categoryorder='array',
360
+ categoryarray=['Positive', 'Neutral', 'Negative']
361
+ ),
362
+ height=250,
363
+ margin=dict(l=10, r=10, t=40, b=10),
364
+ plot_bgcolor='rgba(0,0,0,0)',
365
+ paper_bgcolor='rgba(0,0,0,0)',
366
+ font=dict(family='Arial, sans-serif')
367
+ )
368
+
369
+ return fig
370
+
371
+
372
+ def create_comparison_chart(transformer_conf: Dict, bertweet_conf: Dict) -> go.Figure:
373
+ """Create side-by-side comparison chart"""
374
+ labels = list(Config.LABEL_MAP.values())
375
+
376
+ transformer_values = [transformer_conf[label] * 100 for label in labels]
377
+ bertweet_values = [bertweet_conf[label] * 100 for label in labels]
378
+
379
+ fig = go.Figure(data=[
380
+ go.Bar(
381
+ name='Custom Transformer',
382
+ x=labels,
383
+ y=transformer_values,
384
+ marker_color='#636EFA',
385
+ text=[f'{v:.1f}%' for v in transformer_values],
386
+ textposition='auto',
387
+ ),
388
+ go.Bar(
389
+ name='BERTweet',
390
+ x=labels,
391
+ y=bertweet_values,
392
+ marker_color='#EF553B',
393
+ text=[f'{v:.1f}%' for v in bertweet_values],
394
+ textposition='auto',
395
+ )
396
+ ])
397
+
398
+ fig.update_layout(
399
+ title='Model Comparison',
400
+ xaxis_title='Sentiment',
401
+ yaxis_title='Confidence (%)',
402
+ barmode='group',
403
+ height=350,
404
+ yaxis=dict(range=[0, 100]),
405
+ legend=dict(
406
+ orientation="h",
407
+ yanchor="bottom",
408
+ y=1.02,
409
+ xanchor="right",
410
+ x=1
411
+ ),
412
+ plot_bgcolor='rgba(0,0,0,0)',
413
+ paper_bgcolor='rgba(0,0,0,0)',
414
+ font=dict(family='Arial, sans-serif')
415
+ )
416
+
417
+ return fig
418
+
419
+
420
+ # ============================================================================
421
+ # STREAMLIT UI
422
+ # ============================================================================
423
+
424
+ def main():
425
+ # Page config
426
+ st.set_page_config(
427
+ page_title="Sentiment Analysis App",
428
+ page_icon="πŸ’­",
429
+ layout="wide",
430
+ initial_sidebar_state="expanded"
431
+ )
432
+
433
+ # Custom CSS for better styling
434
+ st.markdown("""
435
+ <style>
436
+ .main {
437
+ padding: 2rem;
438
+ }
439
+ .stButton>button {
440
+ width: 100%;
441
+ background-color: #4CAF50;
442
+ color: white;
443
+ border-radius: 8px;
444
+ height: 3em;
445
+ font-weight: 600;
446
+ }
447
+ .stButton>button:hover {
448
+ background-color: #45a049;
449
+ }
450
+ .prediction-box {
451
+ padding: 1.5rem;
452
+ border-radius: 10px;
453
+ margin: 1rem 0;
454
+ text-align: center;
455
+ font-size: 1.2rem;
456
+ font-weight: 600;
457
+ }
458
+ .negative-box {
459
+ background-color: #ffe6e6;
460
+ border: 2px solid #FF4B4B;
461
+ color: #c41e3a;
462
+ }
463
+ .neutral-box {
464
+ background-color: #fff5e6;
465
+ border: 2px solid #FFA500;
466
+ color: #d97706;
467
+ }
468
+ .positive-box {
469
+ background-color: #e6f7ed;
470
+ border: 2px solid #00D66C;
471
+ color: #059669;
472
+ }
473
+ .stTextArea textarea {
474
+ border-radius: 8px;
475
+ }
476
+ h1 {
477
+ color: #1e3a8a;
478
+ font-weight: 700;
479
+ }
480
+ h2, h3 {
481
+ color: #1e40af;
482
+ }
483
+ </style>
484
+ """, unsafe_allow_html=True)
485
+
486
+ # Header
487
+ st.title("πŸ’­ Sentiment Analysis App")
488
+ st.markdown("### Compare Custom Transformer vs BERTweet Models")
489
+ st.markdown("---")
490
+
491
+ # Sidebar
492
+ with st.sidebar:
493
+ st.header("πŸ“Š Model Information")
494
+
495
+ st.markdown("**Custom Transformer**")
496
+ st.info("""
497
+ - Architecture: From-scratch Transformer
498
+ - Layers: 4 encoder layers
499
+ - Attention Heads: 4
500
+ - Parameters: ~2M
501
+ """)
502
+
503
+ st.markdown("**BERTweet**")
504
+ st.info("""
505
+ - Architecture: Twitter-specific BERT
506
+ - Pretrained: vinai/bertweet-base
507
+ - Parameters: ~135M
508
+ - Fine-tuned for sentiment
509
+ """)
510
+
511
+ st.markdown("---")
512
+ st.markdown("**Labels**")
513
+ st.markdown("πŸ”΄ **Negative** - Negative sentiment")
514
+ st.markdown("🟠 **Neutral** - Neutral sentiment")
515
+ st.markdown("🟒 **Positive** - Positive sentiment")
516
+
517
+ # Download models if needed
518
+ with st.spinner("πŸ”„ Preparing models..."):
519
+ download_success = ensure_models_downloaded()
520
+
521
+ if not download_success:
522
+ st.error("❌ Failed to download models. Please refresh the page.")
523
+ st.stop()
524
+
525
+ # Load models
526
+ with st.spinner("πŸ”„ Loading models and data..."):
527
+ preprocessor = load_preprocessor()
528
+ transformer_model = load_custom_transformer()
529
+ bertweet_model = load_bertweet_model()
530
+ bertweet_tokenizer = load_bertweet_tokenizer()
531
+
532
+ # Check if models loaded successfully
533
+ models_loaded = all([
534
+ preprocessor is not None,
535
+ transformer_model is not None,
536
+ bertweet_model is not None,
537
+ bertweet_tokenizer is not None
538
+ ])
539
+
540
+ if not models_loaded:
541
+ st.error("❌ Failed to load one or more models. Please check the paths and try again.")
542
+ st.stop()
543
+
544
+ st.success(f"βœ… Models loaded successfully! Running on: **{Config.DEVICE}**")
545
+
546
+ # Tabs
547
+ tab1, tab2 = st.tabs(["πŸ“ Single Text Prediction", "πŸ“ Batch CSV Processing"])
548
+
549
+ # ========================================================================
550
+ # TAB 1: SINGLE TEXT PREDICTION
551
+ # ========================================================================
552
+ with tab1:
553
+ st.markdown("### Enter text to analyze sentiment")
554
+
555
+ # Text input
556
+ user_text = st.text_area(
557
+ "Text Input",
558
+ placeholder="Type or paste your text here... (e.g., 'This movie was absolutely amazing!')",
559
+ height=120,
560
+ label_visibility="collapsed"
561
+ )
562
+
563
+ # Example texts
564
+ col1, col2, col3 = st.columns(3)
565
+ with col1:
566
+ if st.button("😊 Example: Positive"):
567
+ user_text = "This is absolutely amazing! I love it so much! πŸŽ‰"
568
+ st.rerun()
569
+ with col2:
570
+ if st.button("😐 Example: Neutral"):
571
+ user_text = "It was okay. Nothing special, but not bad either."
572
+ st.rerun()
573
+ with col3:
574
+ if st.button("😞 Example: Negative"):
575
+ user_text = "This is terrible. Worst experience ever. Very disappointed."
576
+ st.rerun()
577
+
578
+ # Predict button
579
+ if st.button("πŸ” Analyze Sentiment", type="primary"):
580
+ if not user_text.strip():
581
+ st.warning("⚠️ Please enter some text to analyze!")
582
+ else:
583
+ with st.spinner("πŸ€– Analyzing sentiment..."):
584
+ # Get predictions from both models
585
+ transformer_pred, transformer_conf = predict_custom_transformer(
586
+ user_text, transformer_model, preprocessor
587
+ )
588
+ bertweet_pred, bertweet_conf = predict_bertweet(
589
+ user_text, bertweet_model, bertweet_tokenizer, preprocessor
590
+ )
591
+
592
+ # Display results
593
+ st.markdown("---")
594
+ st.markdown("### 🎯 Prediction Results")
595
+
596
+ # Side-by-side predictions
597
+ col1, col2 = st.columns(2)
598
+
599
+ with col1:
600
+ st.markdown("#### Custom Transformer")
601
+ sentiment_class = transformer_pred.lower()
602
+ st.markdown(
603
+ f'<div class="prediction-box {sentiment_class}-box">'
604
+ f'Sentiment: {transformer_pred}'
605
+ f'</div>',
606
+ unsafe_allow_html=True
607
+ )
608
+ fig1 = create_confidence_chart(transformer_conf, "Custom Transformer")
609
+ st.plotly_chart(fig1, use_container_width=True)
610
+
611
+ with col2:
612
+ st.markdown("#### BERTweet")
613
+ sentiment_class = bertweet_pred.lower()
614
+ st.markdown(
615
+ f'<div class="prediction-box {sentiment_class}-box">'
616
+ f'Sentiment: {bertweet_pred}'
617
+ f'</div>',
618
+ unsafe_allow_html=True
619
+ )
620
+ fig2 = create_confidence_chart(bertweet_conf, "BERTweet")
621
+ st.plotly_chart(fig2, use_container_width=True)
622
+
623
+ # Comparison chart
624
+ st.markdown("---")
625
+ st.markdown("### πŸ“Š Model Comparison")
626
+ fig_comparison = create_comparison_chart(transformer_conf, bertweet_conf)
627
+ st.plotly_chart(fig_comparison, use_container_width=True)
628
+
629
+ # Agreement/Disagreement indicator
630
+ if transformer_pred == bertweet_pred:
631
+ st.success(f"βœ… **Both models agree:** {transformer_pred}")
632
+ else:
633
+ st.warning(f"⚠️ **Models disagree:** Transformer={transformer_pred}, BERTweet={bertweet_pred}")
634
+
635
+ # ========================================================================
636
+ # TAB 2: BATCH CSV PROCESSING
637
+ # ========================================================================
638
+ with tab2:
639
+ st.markdown("### Upload CSV file for batch prediction")
640
+ st.markdown("**Required:** Your CSV must have a column named `text`")
641
+
642
+ # File uploader
643
+ uploaded_file = st.file_uploader(
644
+ "Choose a CSV file",
645
+ type=['csv'],
646
+ help="Upload a CSV file with a 'text' column"
647
+ )
648
+
649
+ if uploaded_file is not None:
650
+ try:
651
+ # Read CSV
652
+ df = pd.read_csv(uploaded_file)
653
+
654
+ # Check for 'text' column
655
+ if 'text' not in df.columns:
656
+ st.error("❌ CSV must contain a 'text' column!")
657
+ st.stop()
658
+
659
+ st.success(f"βœ… Loaded {len(df)} texts from CSV")
660
+
661
+ # Show preview
662
+ with st.expander("πŸ“‹ Preview Data", expanded=True):
663
+ st.dataframe(df.head(10), use_container_width=True)
664
+
665
+ # Process button
666
+ if st.button("πŸš€ Process All Texts", type="primary"):
667
+ with st.spinner(f"πŸ€– Processing {len(df)} texts..."):
668
+ # Initialize result lists
669
+ transformer_predictions = []
670
+ transformer_confidences = []
671
+ bertweet_predictions = []
672
+ bertweet_confidences = []
673
+
674
+ # Progress bar
675
+ progress_bar = st.progress(0)
676
+
677
+ # Process each text
678
+ for idx, text in enumerate(df['text']):
679
+ # Skip empty texts
680
+ if pd.isna(text) or str(text).strip() == '':
681
+ transformer_predictions.append('N/A')
682
+ transformer_confidences.append(0.0)
683
+ bertweet_predictions.append('N/A')
684
+ bertweet_confidences.append(0.0)
685
+ else:
686
+ # Transformer prediction
687
+ t_pred, t_conf = predict_custom_transformer(
688
+ str(text), transformer_model, preprocessor
689
+ )
690
+ transformer_predictions.append(t_pred)
691
+ transformer_confidences.append(max(t_conf.values()))
692
+
693
+ # BERTweet prediction
694
+ b_pred, b_conf = predict_bertweet(
695
+ str(text), bertweet_model, bertweet_tokenizer, preprocessor
696
+ )
697
+ bertweet_predictions.append(b_pred)
698
+ bertweet_confidences.append(max(b_conf.values()))
699
+
700
+ # Update progress
701
+ progress_bar.progress((idx + 1) / len(df))
702
+
703
+ progress_bar.empty()
704
+
705
+ # Add predictions to dataframe
706
+ results_df = df.copy()
707
+ results_df['Transformer_Prediction'] = transformer_predictions
708
+ results_df['Transformer_Confidence'] = [f"{c:.2%}" for c in transformer_confidences]
709
+ results_df['BERTweet_Prediction'] = bertweet_predictions
710
+ results_df['BERTweet_Confidence'] = [f"{c:.2%}" for c in bertweet_confidences]
711
+ results_df['Agreement'] = [
712
+ 'βœ…' if t == b else '❌'
713
+ for t, b in zip(transformer_predictions, bertweet_predictions)
714
+ ]
715
+
716
+ # Display results
717
+ st.markdown("---")
718
+ st.markdown("### 🎯 Batch Prediction Results")
719
+
720
+ # Summary metrics
721
+ col1, col2, col3, col4 = st.columns(4)
722
+
723
+ with col1:
724
+ st.metric("Total Texts", len(results_df))
725
+
726
+ with col2:
727
+ agreement_rate = (results_df['Agreement'] == 'βœ…').sum() / len(results_df) * 100
728
+ st.metric("Agreement Rate", f"{agreement_rate:.1f}%")
729
+
730
+ with col3:
731
+ avg_trans_conf = np.mean(transformer_confidences) * 100
732
+ st.metric("Avg Transformer Conf.", f"{avg_trans_conf:.1f}%")
733
+
734
+ with col4:
735
+ avg_bert_conf = np.mean(bertweet_confidences) * 100
736
+ st.metric("Avg BERTweet Conf.", f"{avg_bert_conf:.1f}%")
737
+
738
+ # Results table
739
+ st.markdown("#### πŸ“Š Detailed Results")
740
+
741
+ # Color-code predictions
742
+ def highlight_sentiment(row):
743
+ colors = []
744
+ for col in row.index:
745
+ if 'Prediction' in col:
746
+ if 'Positive' in str(row[col]):
747
+ colors.append('background-color: #d1fae5')
748
+ elif 'Negative' in str(row[col]):
749
+ colors.append('background-color: #fee2e2')
750
+ elif 'Neutral' in str(row[col]):
751
+ colors.append('background-color: #fef3c7')
752
+ else:
753
+ colors.append('')
754
+ else:
755
+ colors.append('')
756
+ return colors
757
+
758
+ styled_df = results_df.style.apply(highlight_sentiment, axis=1)
759
+ st.dataframe(styled_df, use_container_width=True, height=400)
760
+
761
+ # Download button
762
+ csv = results_df.to_csv(index=False).encode('utf-8')
763
+ st.download_button(
764
+ label="πŸ“₯ Download Results as CSV",
765
+ data=csv,
766
+ file_name="sentiment_analysis_results.csv",
767
+ mime="text/csv",
768
+ use_container_width=True
769
+ )
770
+
771
+ # Distribution charts
772
+ st.markdown("---")
773
+ st.markdown("### πŸ“ˆ Distribution Analysis")
774
+
775
+ col1, col2 = st.columns(2)
776
+
777
+ with col1:
778
+ # Transformer distribution
779
+ trans_counts = results_df['Transformer_Prediction'].value_counts()
780
+ fig_trans_dist = go.Figure(data=[
781
+ go.Pie(
782
+ labels=trans_counts.index,
783
+ values=trans_counts.values,
784
+ marker=dict(colors=[
785
+ Config.LABEL_COLORS.get(label, '#gray')
786
+ for label in trans_counts.index
787
+ ]),
788
+ hole=0.4
789
+ )
790
+ ])
791
+ fig_trans_dist.update_layout(
792
+ title="Custom Transformer Distribution",
793
+ height=300
794
+ )
795
+ st.plotly_chart(fig_trans_dist, use_container_width=True)
796
+
797
+ with col2:
798
+ # BERTweet distribution
799
+ bert_counts = results_df['BERTweet_Prediction'].value_counts()
800
+ fig_bert_dist = go.Figure(data=[
801
+ go.Pie(
802
+ labels=bert_counts.index,
803
+ values=bert_counts.values,
804
+ marker=dict(colors=[
805
+ Config.LABEL_COLORS.get(label, '#gray')
806
+ for label in bert_counts.index
807
+ ]),
808
+ hole=0.4
809
+ )
810
+ ])
811
+ fig_bert_dist.update_layout(
812
+ title="BERTweet Distribution",
813
+ height=300
814
+ )
815
+ st.plotly_chart(fig_bert_dist, use_container_width=True)
816
+
817
+ except Exception as e:
818
+ st.error(f"❌ Error processing CSV: {e}")
819
+
820
+ # Footer
821
+ st.markdown("---")
822
+ st.markdown(
823
+ "<div style='text-align: center; color: gray;'>"
824
+ "Built with Streamlit | Custom Transformer vs BERTweet Comparison"
825
+ "</div>",
826
+ unsafe_allow_html=True
827
+ )
828
+
829
+
830
+ if __name__ == "__main__":
831
  main()