alzami commited on
Commit
d9b6560
Β·
verified Β·
1 Parent(s): 5f889c1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +889 -35
src/streamlit_app.py CHANGED
@@ -1,40 +1,894 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streamlit App for Government Complaint Classification
4
+ Author: Based on XLM-RoBERTa implementation by Farrikh Alzami
5
+ """
6
+
7
  import streamlit as st
8
+ import pandas as pd
9
+ import numpy as np
10
+ import time
11
+ import io
12
+ from typing import List, Dict, Tuple
13
+ import os
14
+ from pathlib import Path
15
 
16
+ # Custom imports
17
+ from utils.model_loader import ModelLoader
18
+ from utils.text_preprocessor import TextPreprocessor
19
+ from utils.visualization import Visualizer
20
 
21
+ # Page configuration
22
+ st.set_page_config(
23
+ page_title="Government Complaint Classifier",
24
+ page_icon="πŸ›οΈ",
25
+ layout="wide",
26
+ initial_sidebar_state="expanded"
27
+ )
28
 
29
+ # Custom CSS for warm color scheme
30
+ st.markdown("""
31
+ <style>
32
+ .main-header {
33
+ background: linear-gradient(90deg, #FF6B35 0%, #F7931E 100%);
34
+ padding: 1rem;
35
+ border-radius: 10px;
36
+ margin-bottom: 2rem;
37
+ text-align: center;
38
+ color: white;
39
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
40
+ }
41
+
42
+ .metric-container {
43
+ background: linear-gradient(135deg, #FFF5E6 0%, #FFE5CC 100%);
44
+ padding: 1rem;
45
+ border-radius: 10px;
46
+ border-left: 4px solid #FF6B35;
47
+ margin: 0.5rem 0;
48
+ }
49
+
50
+ .prediction-container {
51
+ background: linear-gradient(135deg, #FFF9F5 0%, #FFEDE6 100%);
52
+ padding: 1.5rem;
53
+ border-radius: 15px;
54
+ border: 2px solid #FFB366;
55
+ margin: 1rem 0;
56
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
57
+ }
58
+
59
+ .stProgress > div > div > div > div {
60
+ background-color: #FF6B35;
61
+ }
62
+
63
+ div[data-testid="metric-container"] {
64
+ background-color: #FFF5E6;
65
+ border: 1px solid #FFD4A3;
66
+ padding: 1rem;
67
+ border-radius: 10px;
68
+ box-shadow: 0 2px 4px rgba(255, 107, 53, 0.1);
69
+ }
70
+ </style>
71
+ """, unsafe_allow_html=True)
72
+
73
+ class StreamlitApp:
74
+ def __init__(self):
75
+ self.model_loader = ModelLoader()
76
+ self.text_preprocessor = TextPreprocessor()
77
+ self.visualizer = Visualizer()
78
+
79
+ # Initialize session state
80
+ if 'model_type' not in st.session_state:
81
+ st.session_state.model_type = 'cross_entropy'
82
+ if 'model_loaded' not in st.session_state:
83
+ st.session_state.model_loaded = False
84
+ if 'predictions_history' not in st.session_state:
85
+ st.session_state.predictions_history = []
86
+ if 'last_analyzed_text' not in st.session_state:
87
+ st.session_state.last_analyzed_text = ""
88
+ if 'current_results' not in st.session_state:
89
+ st.session_state.current_results = None
90
+ if 'batch_results' not in st.session_state:
91
+ st.session_state.batch_results = None
92
+
93
+ def render_header(self):
94
+ """Render application header"""
95
+ st.markdown("""
96
+ <div class="main-header">
97
+ <h1>πŸ›οΈ Government Complaint Classifier</h1>
98
+ <p>Klasifikasi Otomatis Keluhan Masyarakat menggunakan XLM-RoBERTa</p>
99
+ </div>
100
+ """, unsafe_allow_html=True)
101
+
102
+ def render_sidebar(self):
103
+ """Render sidebar with model selection"""
104
+ with st.sidebar:
105
+ st.header("βš™οΈ Model Configuration")
106
+
107
+ # Model selection toggle
108
+ model_options = {
109
+ 'cross_entropy': '🎯 Cross Entropy Loss',
110
+ 'focal_loss': 'πŸ”₯ Focal Loss'
111
+ }
112
+
113
+ selected_model = st.radio(
114
+ "Pilih Model:",
115
+ options=list(model_options.keys()),
116
+ format_func=lambda x: model_options[x],
117
+ index=0 if st.session_state.model_type == 'cross_entropy' else 1
118
+ )
119
+
120
+ # Update session state if model changed
121
+ if selected_model != st.session_state.model_type:
122
+ st.session_state.model_type = selected_model
123
+ st.session_state.model_loaded = False
124
+ st.rerun()
125
+
126
+ st.markdown("---")
127
+
128
+ # Model availability check
129
+ st.subheader("πŸ“ Model Files Status")
130
+ available_models = self.model_loader.get_available_models()
131
+
132
+ for model_type in ['cross_entropy', 'focal_loss']:
133
+ if model_type in available_models:
134
+ # Check if this model is currently loaded
135
+ is_current_loaded = (
136
+ hasattr(self.model_loader, 'current_model_type') and
137
+ self.model_loader.current_model_type == model_type and
138
+ hasattr(self.model_loader, 'classifier_pipeline') and
139
+ self.model_loader.classifier_pipeline is not None
140
+ )
141
+
142
+ if is_current_loaded and model_type == st.session_state.model_type:
143
+ st.success(f"βœ… {model_type.replace('_', ' ').title()} (Currently Loaded)")
144
+ else:
145
+ st.success(f"βœ… {model_type.replace('_', ' ').title()}")
146
+ else:
147
+ st.error(f"❌ {model_type.replace('_', ' ').title()}")
148
+
149
+ if not available_models:
150
+ st.warning("⚠️ No models found! Please check model directory.")
151
+ st.info("""
152
+ Expected structure:
153
+ ```
154
+ models/
155
+ β”œβ”€β”€ cross_entropy/
156
+ β”‚ β”œβ”€β”€ model.safetensors
157
+ β”‚ β”œβ”€β”€ config.json
158
+ β”‚ └── ...
159
+ └── focal_loss/
160
+ β”œβ”€β”€ model.safetensors
161
+ β”œβ”€β”€ config.json
162
+ └── ...
163
+ ```
164
+ """)
165
+
166
+ st.markdown("---")
167
+
168
+ # Model info
169
+ st.subheader("πŸ“Š Model Information")
170
+
171
+ # Real-time check model status
172
+ is_model_actually_loaded = (
173
+ hasattr(self.model_loader, 'classifier_pipeline') and
174
+ self.model_loader.classifier_pipeline is not None and
175
+ self.model_loader.current_model_type == st.session_state.model_type
176
+ )
177
+
178
+ if is_model_actually_loaded:
179
+ model_info = self.model_loader.get_model_info()
180
+ st.success(f"**Status:** βœ… {model_info['status']}")
181
+ st.info(f"**Current Model:** {model_info['model_type'].replace('_', ' ').title()}")
182
+ st.info(f"**Device:** {model_info['device']}")
183
+ st.info(f"**Categories:** {model_info['num_labels']}")
184
+
185
+ # Show some model details
186
+ with st.expander("πŸ” Model Details"):
187
+ st.write(f"**Model Size:** {model_info['model_size']}")
188
+ st.write(f"**Available Categories:**")
189
+ categories = model_info.get('categories', [])
190
+ if categories:
191
+ # Show first 10 categories
192
+ display_categories = categories[:10]
193
+ st.write(", ".join(display_categories))
194
+ if len(categories) > 10:
195
+ st.write(f"... and {len(categories) - 10} more categories")
196
+ else:
197
+ st.write("Categories not available")
198
+ else:
199
+ st.info(f"""
200
+ **Current Model:** {model_options[st.session_state.model_type]}
201
+
202
+ **Architecture:** XLM-RoBERTa Base
203
+
204
+ **Max Length:** 256 tokens
205
+
206
+ **Languages:** Multilingual (ID, EN, etc.)
207
+
208
+ **Status:** ⏳ Not loaded (will load on first use)
209
+ """)
210
+
211
+ # Show loading hint
212
+ if not st.session_state.model_loaded:
213
+ st.info("πŸ’‘ Model will be loaded automatically when you analyze text.")
214
+
215
+ st.markdown("---")
216
+
217
+ # Global reset button
218
+ st.subheader("πŸ”„ Reset Application")
219
+ if st.button("🧹 Clear All & Reset Models", use_container_width=True, type="secondary"):
220
+ # Clear all session states
221
+ for key in list(st.session_state.keys()):
222
+ if key.startswith(('model_', 'predictions_', 'last_', 'current_', 'batch_')):
223
+ del st.session_state[key]
224
+
225
+ # Reinitialize essential states
226
+ st.session_state.model_type = 'cross_entropy'
227
+ st.session_state.model_loaded = False
228
+ st.session_state.predictions_history = []
229
+ st.session_state.last_analyzed_text = ""
230
+ st.session_state.current_results = None
231
+ st.session_state.batch_results = None
232
+
233
+ # Clear model loader state
234
+ self.model_loader.model = None
235
+ self.model_loader.tokenizer = None
236
+ self.model_loader.label_mappings = None
237
+ self.model_loader.classifier_pipeline = None
238
+ self.model_loader.current_model_type = None
239
+
240
+ # Clear cache
241
+ st.cache_resource.clear()
242
+ st.success("βœ… Application reset complete!")
243
+ st.rerun()
244
+
245
+ st.markdown("---")
246
+
247
+ # Prediction history
248
+ if st.session_state.predictions_history:
249
+ st.subheader("πŸ“ˆ Recent Predictions")
250
+ for i, pred in enumerate(st.session_state.predictions_history[-3:]):
251
+ with st.expander(f"Prediction {len(st.session_state.predictions_history) - i}"):
252
+ st.write(f"**Text:** {pred['text'][:100]}...")
253
+ st.write(f"**Category:** {pred['category']}")
254
+ st.write(f"**Confidence:** {pred['confidence']:.2%}")
255
+
256
+ def predict_single_text(self, text: str) -> Dict:
257
+ """Predict single text with timing"""
258
+ start_time = time.time()
259
+
260
+ # Preprocess text
261
+ cleaned_text = self.text_preprocessor.clean_text(text)
262
+
263
+ # Force reload if model type changed or model not available
264
+ force_reload = (
265
+ not st.session_state.model_loaded or
266
+ self.model_loader.current_model_type != st.session_state.model_type or
267
+ self.model_loader.classifier_pipeline is None
268
+ )
269
+
270
+ # Load model if needed
271
+ try:
272
+ if force_reload:
273
+ with st.spinner("Loading model..."):
274
+ # Clear existing model first
275
+ self.model_loader.model = None
276
+ self.model_loader.tokenizer = None
277
+ self.model_loader.label_mappings = None
278
+ self.model_loader.classifier_pipeline = None
279
+ self.model_loader.current_model_type = None
280
+
281
+ # Load fresh model
282
+ self.model_loader.load_model(st.session_state.model_type)
283
+
284
+ # Update session state explicitly
285
+ st.session_state.model_loaded = True
286
+
287
+ except Exception as e:
288
+ st.error(f"Failed to load model: {str(e)}")
289
+ return {
290
+ 'predicted_category': 'Error: Model Loading Failed',
291
+ 'confidence': 0.0,
292
+ 'predicted_id': -1,
293
+ 'all_predictions': {'Error': 1.0},
294
+ 'processing_time': 0.0,
295
+ 'original_text': text,
296
+ 'cleaned_text': cleaned_text
297
+ }
298
+
299
+ # Make prediction
300
+ try:
301
+ result = self.model_loader.predict(cleaned_text)
302
+ except Exception as e:
303
+ st.error(f"Failed to make prediction: {str(e)}")
304
+ return {
305
+ 'predicted_category': 'Error: Prediction Failed',
306
+ 'confidence': 0.0,
307
+ 'predicted_id': -1,
308
+ 'all_predictions': {'Error': 1.0},
309
+ 'processing_time': 0.0,
310
+ 'original_text': text,
311
+ 'cleaned_text': cleaned_text
312
+ }
313
+
314
+ processing_time = time.time() - start_time
315
+ result['processing_time'] = processing_time
316
+ result['original_text'] = text
317
+ result['cleaned_text'] = cleaned_text
318
+
319
+ return result
320
+
321
+ def predict_batch_texts(self, texts: List[str]) -> List[Dict]:
322
+ """Predict batch of texts"""
323
+ # Force reload if model type changed or model not available
324
+ force_reload = (
325
+ not st.session_state.model_loaded or
326
+ self.model_loader.current_model_type != st.session_state.model_type or
327
+ self.model_loader.classifier_pipeline is None
328
+ )
329
+
330
+ # Load model once for batch
331
+ try:
332
+ if force_reload:
333
+ with st.spinner("Loading model for batch processing..."):
334
+ # Clear existing model first
335
+ self.model_loader.model = None
336
+ self.model_loader.tokenizer = None
337
+ self.model_loader.label_mappings = None
338
+ self.model_loader.classifier_pipeline = None
339
+ self.model_loader.current_model_type = None
340
+
341
+ # Load fresh model
342
+ self.model_loader.load_model(st.session_state.model_type)
343
+
344
+ # Update session state explicitly
345
+ st.session_state.model_loaded = True
346
+
347
+ except Exception as e:
348
+ st.error(f"Failed to load model for batch processing: {str(e)}")
349
+ # Return error results for all texts
350
+ error_result = {
351
+ 'predicted_category': 'Error: Model Loading Failed',
352
+ 'confidence': 0.0,
353
+ 'predicted_id': -1,
354
+ 'all_predictions': {'Error': 1.0}
355
+ }
356
+ return [error_result] * len(texts)
357
+
358
+ results = []
359
+ progress_bar = st.progress(0)
360
+
361
+ for i, text in enumerate(texts):
362
+ try:
363
+ # Preprocess
364
+ cleaned_text = self.text_preprocessor.clean_text(text)
365
+
366
+ # Predict
367
+ result = self.model_loader.predict(cleaned_text)
368
+ result['original_text'] = text
369
+ result['cleaned_text'] = cleaned_text
370
+
371
+ results.append(result)
372
+
373
+ except Exception as e:
374
+ st.warning(f"Failed to process text {i+1}: {str(e)}")
375
+ # Add error result for this specific text
376
+ error_result = {
377
+ 'predicted_category': 'Error: Prediction Failed',
378
+ 'confidence': 0.0,
379
+ 'predicted_id': -1,
380
+ 'all_predictions': {'Error': 1.0},
381
+ 'original_text': text,
382
+ 'cleaned_text': self.text_preprocessor.clean_text(text)
383
+ }
384
+ results.append(error_result)
385
+
386
+ # Update progress
387
+ progress_bar.progress((i + 1) / len(texts))
388
+
389
+ return results
390
+
391
+ def render_single_text_tab(self):
392
+ """Render single text analysis tab"""
393
+ st.header("πŸ“ Single Text Analysis")
394
+
395
+ # Show current model status at top
396
+ is_model_loaded = (
397
+ hasattr(self.model_loader, 'classifier_pipeline') and
398
+ self.model_loader.classifier_pipeline is not None and
399
+ self.model_loader.current_model_type == st.session_state.model_type
400
+ )
401
+
402
+ if is_model_loaded:
403
+ st.success(f"🎯 Current Model: **{st.session_state.model_type.replace('_', ' ').title()} - READY**")
404
+ else:
405
+ st.info(f"⏳ Current Model: **{st.session_state.model_type.replace('_', ' ').title()} - Will load on first use**")
406
+
407
+ # Text input
408
+ user_text = st.text_area(
409
+ "Masukkan teks keluhan masyarakat:",
410
+ height=150,
411
+ placeholder="Contoh: Saya ingin melaporkan jalan rusak di daerah saya yang sudah lama tidak diperbaiki...",
412
+ key="main_text_input"
413
+ )
414
+
415
+ # Analysis button
416
+ col1, col2, col3, col4 = st.columns([2, 1, 1, 2])
417
+ with col2:
418
+ analyze_button = st.button(
419
+ "πŸ” Analyze Text",
420
+ type="primary",
421
+ use_container_width=True
422
+ )
423
+
424
+ with col3:
425
+ clear_button = st.button(
426
+ "🧹 Clear",
427
+ type="secondary",
428
+ use_container_width=True,
429
+ help="Clear results and reset model state"
430
+ )
431
+
432
+ if clear_button:
433
+ # Clear all states
434
+ st.session_state.model_loaded = False
435
+ st.session_state.predictions_history = []
436
+ # Clear model loader state
437
+ self.model_loader.model = None
438
+ self.model_loader.tokenizer = None
439
+ self.model_loader.label_mappings = None
440
+ self.model_loader.classifier_pipeline = None
441
+ self.model_loader.current_model_type = None
442
+ # Clear cache
443
+ st.cache_resource.clear()
444
+ st.success("βœ… Cleared all states and model cache!")
445
+ st.rerun()
446
+
447
+ if 'last_analyzed_text' not in st.session_state:
448
+ st.session_state.last_analyzed_text = ""
449
+ if 'current_results' not in st.session_state:
450
+ st.session_state.current_results = None
451
+
452
+ # Check if text has changed since last analysis
453
+ text_changed = user_text.strip() != st.session_state.last_analyzed_text
454
+
455
+ if clear_button:
456
+ # Clear all states
457
+ st.session_state.model_loaded = False
458
+ st.session_state.predictions_history = []
459
+ st.session_state.last_analyzed_text = ""
460
+ st.session_state.current_results = None
461
+ # Clear model loader state
462
+ self.model_loader.model = None
463
+ self.model_loader.tokenizer = None
464
+ self.model_loader.label_mappings = None
465
+ self.model_loader.classifier_pipeline = None
466
+ self.model_loader.current_model_type = None
467
+ # Clear cache
468
+ st.cache_resource.clear()
469
+ st.success("βœ… Cleared all states and model cache!")
470
+ st.rerun()
471
+
472
+ if analyze_button and user_text.strip():
473
+ try:
474
+ with st.spinner("Analyzing text..."):
475
+ result = self.predict_single_text(user_text)
476
+
477
+ # Store in history and session state
478
+ st.session_state.predictions_history.append({
479
+ 'text': user_text,
480
+ 'category': result['predicted_category'],
481
+ 'confidence': result['confidence']
482
+ })
483
+ st.session_state.last_analyzed_text = user_text.strip()
484
+ st.session_state.current_results = result
485
+
486
+ # Display results
487
+ self.display_single_prediction_results(result)
488
+
489
+ except Exception as e:
490
+ st.error(f"❌ Error during analysis: {str(e)}")
491
+ st.info("πŸ’‘ Try clicking the 'Clear' button to reset the model state.")
492
+
493
+ elif analyze_button and not user_text.strip():
494
+ st.warning("⚠️ Please enter some text to analyze!")
495
+
496
+ # Display previous results if available and text hasn't changed
497
+ elif st.session_state.current_results and not text_changed and not analyze_button:
498
+ st.info("πŸ“‹ Showing previous analysis results. Click 'Analyze Text' to update or 'Clear' to reset.")
499
+ self.display_single_prediction_results(st.session_state.current_results)
500
+
501
+ # Show hint if text has changed
502
+ elif text_changed and st.session_state.current_results:
503
+ st.info("✏️ Text has been modified. Click 'Analyze Text' to get new predictions or 'Clear' to reset.")
504
+
505
+ def display_single_prediction_results(self, result: Dict):
506
+ """Display single prediction results"""
507
+ st.markdown("## πŸ“Š Analysis Results")
508
+
509
+ # Main prediction container
510
+ st.markdown(f"""
511
+ <div class="prediction-container">
512
+ <h3>🎯 Predicted Category</h3>
513
+ <h2 style="color: #FF6B35; margin: 0;">{result['predicted_category']}</h2>
514
+ </div>
515
+ """, unsafe_allow_html=True)
516
+
517
+ # Metrics
518
+ col1, col2, col3 = st.columns(3)
519
+
520
+ with col1:
521
+ st.metric(
522
+ label="🎯 Confidence Score",
523
+ value=f"{result['confidence']:.2%}",
524
+ delta=f"Top prediction"
525
+ )
526
+
527
+ with col2:
528
+ st.metric(
529
+ label="⏱️ Processing Time",
530
+ value=f"{result['processing_time']:.3f}s",
531
+ delta="Real-time"
532
+ )
533
+
534
+ with col3:
535
+ st.metric(
536
+ label="πŸ“ Text Length",
537
+ value=f"{len(result['cleaned_text'])} chars",
538
+ delta="After cleaning"
539
+ )
540
+
541
+ # Confidence visualization
542
+ st.markdown("### πŸ“ˆ Confidence Scores by Category")
543
+ fig = self.visualizer.plot_confidence_scores(result['all_predictions'])
544
+ st.plotly_chart(fig, use_container_width=True)
545
+
546
+ # Top predictions table
547
+ st.markdown("### πŸ† Top 5 Predictions")
548
+ top_predictions = sorted(
549
+ result['all_predictions'].items(),
550
+ key=lambda x: x[1],
551
+ reverse=True
552
+ )[:5]
553
+
554
+ df_top = pd.DataFrame([
555
+ {
556
+ 'Rank': i+1,
557
+ 'Category': category,
558
+ 'Confidence': f"{confidence:.2%}",
559
+ 'Confidence_Score': confidence
560
+ }
561
+ for i, (category, confidence) in enumerate(top_predictions)
562
+ ])
563
+
564
+ # Style the dataframe
565
+ styled_df = df_top.style.format({
566
+ 'Confidence_Score': '{:.4f}'
567
+ }).hide(['Confidence_Score'], axis=1).background_gradient(
568
+ subset=['Confidence_Score'],
569
+ cmap='Oranges'
570
+ )
571
+
572
+ st.dataframe(styled_df, use_container_width=True)
573
+
574
+ # Show preprocessing details
575
+ with st.expander("πŸ”§ Preprocessing Details"):
576
+ col1, col2 = st.columns(2)
577
+
578
+ with col1:
579
+ st.markdown("**Original Text:**")
580
+ st.text_area(
581
+ "Original Text",
582
+ value=result['original_text'],
583
+ height=100,
584
+ disabled=True,
585
+ key="original_text_display",
586
+ label_visibility="collapsed"
587
+ )
588
+
589
+ with col2:
590
+ st.markdown("**Cleaned Text:**")
591
+ st.text_area(
592
+ "Cleaned Text",
593
+ value=result['cleaned_text'],
594
+ height=100,
595
+ disabled=True,
596
+ key="cleaned_text_display",
597
+ label_visibility="collapsed"
598
+ )
599
+
600
+ def render_batch_processing_tab(self):
601
+ """Render batch processing tab"""
602
+ st.header("πŸ“Š Batch Processing")
603
+
604
+ # Show current model status at top
605
+ is_model_loaded = (
606
+ hasattr(self.model_loader, 'classifier_pipeline') and
607
+ self.model_loader.classifier_pipeline is not None and
608
+ self.model_loader.current_model_type == st.session_state.model_type
609
+ )
610
+
611
+ if is_model_loaded:
612
+ st.success(f"🎯 Current Model: **{st.session_state.model_type.replace('_', ' ').title()} - READY**")
613
+ else:
614
+ st.info(f"⏳ Current Model: **{st.session_state.model_type.replace('_', ' ').title()} - Will load on first use**")
615
+
616
+ # File upload
617
+ st.markdown("### πŸ“ Upload CSV File")
618
+ uploaded_file = st.file_uploader(
619
+ "Choose a CSV file containing texts to classify",
620
+ type=['csv'],
621
+ help="CSV should have a column named 'text' containing the texts to classify"
622
+ )
623
+
624
+ if uploaded_file is not None:
625
+ try:
626
+ # Read uploaded file
627
+ df = pd.read_csv(uploaded_file)
628
+
629
+ # Show preview
630
+ st.markdown("### πŸ‘€ Data Preview")
631
+ st.dataframe(df.head(10))
632
+
633
+ # Column selection
634
+ text_columns = df.columns.tolist()
635
+ selected_column = st.selectbox(
636
+ "Select the text column to classify:",
637
+ options=text_columns,
638
+ index=0 if 'text' not in text_columns else text_columns.index('text')
639
+ )
640
+
641
+ # Batch processing button
642
+ col1, col2, col3, col4 = st.columns([2, 1, 1, 2])
643
+ with col2:
644
+ process_button = st.button(
645
+ "πŸš€ Process Batch",
646
+ type="primary",
647
+ use_container_width=True
648
+ )
649
+
650
+ with col3:
651
+ clear_batch_button = st.button(
652
+ "🧹 Clear Batch",
653
+ type="secondary",
654
+ use_container_width=True,
655
+ help="Clear batch results and reset model"
656
+ )
657
+
658
+ if clear_batch_button:
659
+ # Clear batch-specific states
660
+ st.session_state.batch_results = None
661
+ st.session_state.model_loaded = False
662
+ # Clear model loader state
663
+ self.model_loader.model = None
664
+ self.model_loader.tokenizer = None
665
+ self.model_loader.label_mappings = None
666
+ self.model_loader.classifier_pipeline = None
667
+ self.model_loader.current_model_type = None
668
+ # Clear cache
669
+ st.cache_resource.clear()
670
+ st.success("βœ… Cleared batch results and model cache!")
671
+ st.rerun()
672
+
673
+ if process_button:
674
+ texts = df[selected_column].astype(str).tolist()
675
+
676
+ st.markdown("### ⚑ Processing Batch...")
677
+ start_time = time.time()
678
+
679
+ try:
680
+ results = self.predict_batch_texts(texts)
681
+ total_time = time.time() - start_time
682
+
683
+ # Store results in session state
684
+ st.session_state.batch_results = {
685
+ 'original_df': df,
686
+ 'results': results,
687
+ 'selected_column': selected_column,
688
+ 'total_time': total_time
689
+ }
690
+
691
+ # Display batch results
692
+ self.display_batch_results(df, results, selected_column, total_time)
693
+
694
+ except Exception as e:
695
+ st.error(f"❌ Error during batch processing: {str(e)}")
696
+ st.info("πŸ’‘ Try clicking the 'Clear Batch' button to reset the model state.")
697
+
698
+ # Display previous batch results if available
699
+ elif st.session_state.batch_results:
700
+ st.info("πŸ“‹ Showing previous batch results. Upload new file to process again or click 'Clear Batch' to reset.")
701
+ batch_data = st.session_state.batch_results
702
+ self.display_batch_results(
703
+ batch_data['original_df'],
704
+ batch_data['results'],
705
+ batch_data['selected_column'],
706
+ batch_data['total_time']
707
+ )
708
+
709
+ except Exception as e:
710
+ st.error(f"Error reading CSV file: {str(e)}")
711
+
712
+ else:
713
+ # Show example CSV format
714
+ st.markdown("### πŸ“‹ Expected CSV Format")
715
+ example_df = pd.DataFrame({
716
+ 'id': [1, 2, 3],
717
+ 'text': [
718
+ 'Jalan di depan rumah saya rusak parah',
719
+ 'Pelayanan di kantor kelurahan lambat',
720
+ 'Lingkungan sekitar kotor dan tidak terawat'
721
+ ]
722
+ })
723
+ st.dataframe(example_df)
724
+
725
+ def display_batch_results(self, original_df: pd.DataFrame, results: List[Dict],
726
+ text_column: str, total_time: float):
727
+ """Display batch processing results"""
728
+ st.markdown("## πŸ“Š Batch Processing Results")
729
+
730
+ # Summary metrics
731
+ col1, col2, col3, col4 = st.columns(4)
732
+
733
+ with col1:
734
+ st.metric("πŸ“„ Total Texts", len(results))
735
+
736
+ with col2:
737
+ avg_confidence = np.mean([r['confidence'] for r in results])
738
+ st.metric("🎯 Avg Confidence", f"{avg_confidence:.2%}")
739
+
740
+ with col3:
741
+ st.metric("⏱️ Total Time", f"{total_time:.2f}s")
742
+
743
+ with col4:
744
+ st.metric("πŸš€ Speed", f"{len(results)/total_time:.1f} texts/sec")
745
+
746
+ # Create results dataframe
747
+ results_df = original_df.copy()
748
+ results_df['predicted_category'] = [r['predicted_category'] for r in results]
749
+ results_df['confidence'] = [r['confidence'] for r in results]
750
+ results_df['cleaned_text'] = [r['cleaned_text'] for r in results]
751
+
752
+ # Category distribution
753
+ st.markdown("### πŸ“ˆ Category Distribution")
754
+ category_counts = results_df['predicted_category'].value_counts()
755
+ fig = self.visualizer.plot_category_distribution(category_counts)
756
+ st.plotly_chart(fig, use_container_width=True)
757
+
758
+ # Results table
759
+ st.markdown("### πŸ“‹ Detailed Results")
760
+ display_df = results_df[[text_column, 'predicted_category', 'confidence']].copy()
761
+ display_df['confidence'] = display_df['confidence'].apply(lambda x: f"{x:.2%}")
762
+
763
+ st.dataframe(display_df, use_container_width=True)
764
+
765
+ # Download results
766
+ st.markdown("### πŸ’Ύ Download Results")
767
+
768
+ # Prepare Excel data with all predictions
769
+ excel_data = []
770
+ for i, result in enumerate(results):
771
+ row = original_df.iloc[i].to_dict()
772
+ row['predicted_category'] = result['predicted_category']
773
+ row['confidence'] = result['confidence']
774
+ row['cleaned_text'] = result['cleaned_text']
775
+
776
+ # Add top 3 predictions
777
+ top_3 = sorted(result['all_predictions'].items(), key=lambda x: x[1], reverse=True)[:3]
778
+ for j, (cat, conf) in enumerate(top_3, 1):
779
+ row[f'top_{j}_category'] = cat
780
+ row[f'top_{j}_confidence'] = conf
781
+
782
+ excel_data.append(row)
783
+
784
+ excel_df = pd.DataFrame(excel_data)
785
+
786
+ # Create Excel file
787
+ output = io.BytesIO()
788
+ with pd.ExcelWriter(output, engine='openpyxl') as writer:
789
+ excel_df.to_excel(writer, sheet_name='Results', index=False)
790
+
791
+ # Add summary sheet
792
+ summary_df = pd.DataFrame([
793
+ ['Total Texts Processed', len(results)],
794
+ ['Average Confidence', f"{avg_confidence:.2%}"],
795
+ ['Processing Time', f"{total_time:.2f} seconds"],
796
+ ['Model Used', st.session_state.model_type.replace('_', ' ').title()],
797
+ ['Processing Speed', f"{len(results)/total_time:.1f} texts/second"]
798
+ ], columns=['Metric', 'Value'])
799
+
800
+ summary_df.to_excel(writer, sheet_name='Summary', index=False)
801
+
802
+ # Download button
803
+ col1, col2, col3 = st.columns([2, 1, 2])
804
+ with col2:
805
+ st.download_button(
806
+ label="πŸ“₯ Download Excel Report",
807
+ data=output.getvalue(),
808
+ file_name=f"complaint_classification_results_{st.session_state.model_type}.xlsx",
809
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
810
+ use_container_width=True
811
+ )
812
+
813
+ def render_about_tab(self):
814
+ """Render about/help tab"""
815
+ st.header("ℹ️ About This Application")
816
+
817
+ st.markdown("""
818
+ ### 🎯 Purpose
819
+ This application automatically classifies government complaints using state-of-the-art
820
+ XLM-RoBERTa transformer models. It supports both Cross Entropy and Focal Loss variants
821
+ for handling imbalanced datasets.
822
+
823
+ ### πŸ”§ Technical Details
824
+ - **Model Architecture:** XLM-RoBERTa Base (Multi-lingual)
825
+ - **Framework:** Hugging Face Transformers + PyTorch
826
+ - **Preprocessing:** HTML cleaning, emoji removal, text normalization
827
+ - **Maximum Input Length:** 256 tokens
828
+ - **Languages Supported:** Indonesian, English, and more
829
+
830
+ ### πŸ“Š Model Comparison
831
+ - **Cross Entropy Loss:** Traditional classification loss with class weights
832
+ - **Focal Loss:** Specialized for imbalanced datasets, focuses on hard examples
833
+
834
+ ### πŸš€ Usage Guide
835
+
836
+ #### Single Text Analysis:
837
+ 1. Select your preferred model from the sidebar
838
+ 2. Enter text in the textarea
839
+ 3. Click "Analyze Text"
840
+ 4. View predictions and confidence scores
841
+
842
+ #### Batch Processing:
843
+ 1. Prepare a CSV file with text data
844
+ 2. Upload the file in the Batch Processing tab
845
+ 3. Select the text column to classify
846
+ 4. Click "Process Batch"
847
+ 5. Download results as Excel file
848
+
849
+ ### πŸ“ CSV Format for Batch Processing
850
+ Your CSV should contain at least one column with text data:
851
+ ```
852
+ id,text,other_columns...
853
+ 1,"Jalan rusak perlu diperbaiki",metadata
854
+ 2,"Pelayanan lambat di kantor",metadata
855
+ ```
856
+
857
+ ### ⚠️ Limitations
858
+ - Maximum text length: 256 tokens (approximately 200-300 words)
859
+ - Model performance depends on training data quality
860
+ - Processing time varies with text length and batch size
861
+
862
+ ### πŸ‘¨β€πŸ’» Credits
863
+ Based on research implementation by Farrikh Alzami using XLM-RoBERTa for
864
+ government complaint classification with focal loss optimization.
865
+ """)
866
+
867
+ def run(self):
868
+ """Main application runner"""
869
+ self.render_header()
870
+ self.render_sidebar()
871
+
872
+ # Main content tabs
873
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Single Text", "πŸ“Š Batch Processing", "ℹ️ About"])
874
+
875
+ with tab1:
876
+ self.render_single_text_tab()
877
+
878
+ with tab2:
879
+ self.render_batch_processing_tab()
880
+
881
+ with tab3:
882
+ self.render_about_tab()
883
+
884
+ def main():
885
+ """Main function"""
886
+ try:
887
+ app = StreamlitApp()
888
+ app.run()
889
+ except Exception as e:
890
+ st.error(f"Application error: {str(e)}")
891
+ st.info("Please ensure all model files are properly placed in the models/ directory.")
892
 
893
+ if __name__ == "__main__":
894
+ main()