songhieng commited on
Commit
4e25137
·
verified ·
1 Parent(s): 9a534da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1013 -0
app.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLOps Training Platform - Streamlit Application
3
+ ==================================================
4
+
5
+ A beginner-friendly web interface for training text classification models
6
+ with built-in system checks and model management.
7
+
8
+ Run with: streamlit run streamlit_app.py
9
+ """
10
+
11
+ # CRITICAL: Set these environment variables FIRST, before any other imports
12
+ import os
13
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
15
+ os.environ['TRANSFORMERS_NO_TF'] = '1'
16
+ os.environ['USE_TF'] = '0'
17
+
18
+ import sys
19
+ import time
20
+ from datetime import datetime
21
+ from pathlib import Path
22
+ from typing import Optional, List
23
+
24
+ import streamlit as st
25
+ import pandas as pd
26
+ import plotly.express as px
27
+ import plotly.graph_objects as go
28
+
29
+ # Add src directory to path for imports
30
+ sys.path.insert(0, str(Path(__file__).parent / 'src'))
31
+
32
+ from mlops.config import (
33
+ TrainingConfig,
34
+ MODEL_ARCHITECTURES,
35
+ MODEL_SELECTION_GUIDE,
36
+ ClassificationType
37
+ )
38
+ from mlops.preprocessor import TextPreprocessor, DataValidator
39
+ from mlops.trainer import ModelTrainer
40
+ from mlops.evaluator import ModelEvaluator
41
+ from mlops.system_check import SystemChecker, get_system_summary
42
+
43
+ # ==================== Page Configuration ====================
44
+
45
+ st.set_page_config(
46
+ page_title="MLOps Training Platform",
47
+ page_icon="🤖",
48
+ layout="wide",
49
+ initial_sidebar_state="expanded"
50
+ )
51
+
52
+ # ==================== Custom CSS ====================
53
+
54
+ st.markdown("""
55
+ <style>
56
+ /* Main styling */
57
+ .main-header {
58
+ font-size: 2.5rem;
59
+ font-weight: 700;
60
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
61
+ -webkit-background-clip: text;
62
+ -webkit-text-fill-color: transparent;
63
+ margin-bottom: 0.5rem;
64
+ }
65
+
66
+ .sub-header {
67
+ font-size: 1.1rem;
68
+ color: #666;
69
+ margin-bottom: 2rem;
70
+ }
71
+
72
+ /* Info boxes */
73
+ .info-box {
74
+ background-color: #f0f7ff;
75
+ border-left: 4px solid #667eea;
76
+ padding: 1rem;
77
+ margin: 1rem 0;
78
+ border-radius: 0 8px 8px 0;
79
+ }
80
+
81
+ .warning-box {
82
+ background-color: #fff7e6;
83
+ border-left: 4px solid #fa8c16;
84
+ padding: 1rem;
85
+ margin: 1rem 0;
86
+ border-radius: 0 8px 8px 0;
87
+ }
88
+
89
+ .success-box {
90
+ background-color: #f6ffed;
91
+ border-left: 4px solid #52c41a;
92
+ padding: 1rem;
93
+ margin: 1rem 0;
94
+ border-radius: 0 8px 8px 0;
95
+ }
96
+
97
+ .error-box {
98
+ background-color: #fff1f0;
99
+ border-left: 4px solid #ff4d4f;
100
+ padding: 1rem;
101
+ margin: 1rem 0;
102
+ border-radius: 0 8px 8px 0;
103
+ }
104
+
105
+ /* Metric cards */
106
+ .metric-card {
107
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
108
+ padding: 1.5rem;
109
+ border-radius: 10px;
110
+ color: white;
111
+ text-align: center;
112
+ }
113
+
114
+ /* Hide default elements */
115
+ #MainMenu {visibility: hidden;}
116
+ footer {visibility: hidden;}
117
+ </style>
118
+ """, unsafe_allow_html=True)
119
+
120
+ # ==================== Session State Initialization ====================
121
+
122
+ def init_session_state():
123
+ """Initialize all session state variables."""
124
+ defaults = {
125
+ # Classification type selection
126
+ 'classification_type': None,
127
+ 'classification_type_selected': False,
128
+
129
+ # Prerequisites
130
+ 'prerequisites_checked': False,
131
+ 'cuda_status': None,
132
+ 'env_status': None,
133
+ 'models_downloaded': set(),
134
+
135
+ # Training state
136
+ 'training_started': False,
137
+ 'training_completed': False,
138
+ 'training_progress': 0.0,
139
+ 'training_logs': [],
140
+ 'metrics_history': [],
141
+ 'model_path': None,
142
+
143
+ # Data
144
+ 'uploaded_data': None,
145
+ 'preprocessed_data': None,
146
+
147
+ # Evaluation
148
+ 'evaluation_results': None,
149
+
150
+ # Config
151
+ 'config': TrainingConfig(),
152
+
153
+ # Selected model
154
+ 'selected_model': None
155
+ }
156
+
157
+ for key, value in defaults.items():
158
+ if key not in st.session_state:
159
+ st.session_state[key] = value
160
+
161
+ init_session_state()
162
+
163
+ # ==================== Helper Functions ====================
164
+
165
+ def add_log(message: str):
166
+ """Add a log message with timestamp."""
167
+ timestamp = datetime.now().strftime("%H:%M:%S")
168
+ st.session_state.training_logs.append(f"[{timestamp}] {message}")
169
+
170
+ def create_info_box(text: str, box_type: str = "info"):
171
+ """Create a styled info box."""
172
+ st.markdown(f'<div class="{box_type}-box">{text}</div>', unsafe_allow_html=True)
173
+
174
+ # ==================== Sidebar ====================
175
+
176
+ def render_sidebar():
177
+ """Render the sidebar with navigation and status."""
178
+ with st.sidebar:
179
+ st.markdown('<h1 class="main-header">🤖 MLOps Platform</h1>', unsafe_allow_html=True)
180
+ st.markdown("---")
181
+
182
+ # Classification Type Status
183
+ st.subheader("📋 Classification Type")
184
+ if st.session_state.classification_type_selected:
185
+ type_display = "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class"
186
+ st.success(f"✅ {type_display}")
187
+ else:
188
+ st.warning("⚠️ Not selected")
189
+
190
+ st.markdown("---")
191
+
192
+ # Prerequisites Status
193
+ st.subheader("🔧 Prerequisites")
194
+
195
+ if st.session_state.prerequisites_checked:
196
+ st.success("✅ Checked")
197
+
198
+ # CUDA Status
199
+ if st.session_state.cuda_status:
200
+ cuda = st.session_state.cuda_status
201
+ if cuda['available']:
202
+ st.info(f"🎮 GPU: {cuda['devices'][0]['name']}")
203
+ else:
204
+ st.info("💻 CPU Mode")
205
+
206
+ # Models downloaded
207
+ if st.session_state.models_downloaded:
208
+ st.info(f"📦 Models: {len(st.session_state.models_downloaded)}")
209
+ else:
210
+ st.warning("⚠️ Not checked")
211
+
212
+ st.markdown("---")
213
+
214
+ # Training Status
215
+ st.subheader("🎯 Training Status")
216
+ if st.session_state.training_completed:
217
+ st.success("✅ Completed")
218
+ elif st.session_state.training_started:
219
+ st.info(f"⏳ In Progress ({st.session_state.training_progress:.0f}%)")
220
+ else:
221
+ st.info("💤 Not started")
222
+
223
+ st.markdown("---")
224
+
225
+ # Quick Actions
226
+ st.subheader("⚡ Quick Actions")
227
+ if st.button("🔄 Reset All", width="stretch"):
228
+ for key in list(st.session_state.keys()):
229
+ del st.session_state[key]
230
+ init_session_state()
231
+ st.rerun()
232
+
233
+ render_sidebar()
234
+
235
+ # ==================== Main Content ====================
236
+
237
+ # Header
238
+ st.markdown('<h1 class="main-header">🤖 MLOps Training Platform</h1>', unsafe_allow_html=True)
239
+ st.markdown('<p class="sub-header">Train and evaluate text classification models with ease</p>', unsafe_allow_html=True)
240
+
241
+ # ==================== STEP 1: Classification Type Selection ====================
242
+
243
+ if not st.session_state.classification_type_selected:
244
+ st.markdown("## 📋 Step 1: Choose Classification Type")
245
+
246
+ create_info_box(
247
+ "🎯 <b>First, select your classification task type:</b><br><br>"
248
+ "• <b>Binary Classification:</b> Two classes (e.g., spam vs. not spam, positive vs. negative)<br>"
249
+ "• <b>Multi-class Classification:</b> More than two classes (e.g., categorize news into politics, sports, entertainment, etc.)",
250
+ "info"
251
+ )
252
+
253
+ col1, col2 = st.columns(2)
254
+
255
+ with col1:
256
+ st.markdown("### 🔵 Binary Classification")
257
+ st.markdown("""
258
+ **Use when you have:**
259
+ - 2 categories/labels
260
+ - Yes/No questions
261
+ - Positive/Negative sentiment
262
+
263
+ **Examples:**
264
+ - Spam detection (spam/not spam)
265
+ - Sentiment analysis (positive/negative)
266
+ - Phishing detection (phishing/legitimate)
267
+ """)
268
+
269
+ if st.button("Select Binary Classification", width="stretch", type="primary"):
270
+ st.session_state.classification_type = ClassificationType.BINARY
271
+ st.session_state.classification_type_selected = True
272
+ st.session_state.config.num_labels = 2
273
+ add_log("Selected Binary Classification")
274
+ st.rerun()
275
+
276
+ with col2:
277
+ st.markdown("### 🌈 Multi-class Classification")
278
+ st.markdown("""
279
+ **Use when you have:**
280
+ - 3+ categories/labels
281
+ - Multiple distinct classes
282
+ - Topic categorization
283
+
284
+ **Examples:**
285
+ - News categorization (politics/sports/tech/entertainment)
286
+ - Product classification (electronics/clothing/books/toys)
287
+ - Language detection (English/Chinese/Spanish/etc.)
288
+ """)
289
+
290
+ if st.button("Select Multi-class Classification", width="stretch"):
291
+ st.session_state.classification_type = ClassificationType.MULTICLASS
292
+ st.session_state.classification_type_selected = True
293
+ # Will set num_labels after data upload when we know the number of classes
294
+ add_log("Selected Multi-class Classification")
295
+ st.rerun()
296
+
297
+ st.stop() # Don't render rest of the app until classification type is selected
298
+
299
+ # ==================== TABS FOR REST OF WORKFLOW ====================
300
+
301
+ tab1, tab2, tab3, tab4, tab5 = st.tabs([
302
+ "🔧 Prerequisites",
303
+ "📤 Upload Data",
304
+ "⚙️ Configure Training",
305
+ "🎯 Train Model",
306
+ "📊 Evaluate Model"
307
+ ])
308
+
309
+ # ==================== TAB 1: Prerequisites ====================
310
+
311
+ with tab1:
312
+ st.markdown("## 🔧 System Prerequisites")
313
+
314
+ create_info_box(
315
+ "⚠️ <b>Important:</b> Complete all prerequisite checks before proceeding to training.<br>"
316
+ "This ensures your system is properly configured and all required models are downloaded.",
317
+ "warning"
318
+ )
319
+
320
+ # Initialize system checker
321
+ system_checker = SystemChecker(models_dir="models")
322
+
323
+ # ===== CUDA/GPU Check =====
324
+ st.markdown("### 🎮 1. CUDA/GPU Check")
325
+
326
+ col1, col2 = st.columns([3, 1])
327
+ with col1:
328
+ st.markdown("Check if CUDA-capable GPU is available for faster training.")
329
+ with col2:
330
+ if st.button("🔍 Check CUDA", width="stretch"):
331
+ with st.spinner("Checking CUDA availability..."):
332
+ cuda_status = system_checker.check_cuda()
333
+ st.session_state.cuda_status = cuda_status
334
+ add_log("CUDA check completed")
335
+
336
+ if st.session_state.cuda_status:
337
+ cuda = st.session_state.cuda_status
338
+
339
+ if cuda['available']:
340
+ st.success(f"✅ CUDA Available - {cuda['device_count']} GPU(s) detected")
341
+
342
+ for device in cuda['devices']:
343
+ with st.expander(f"📊 {device['name']} Details"):
344
+ col1, col2, col3 = st.columns(3)
345
+ col1.metric("Memory", f"{device['memory_total']:.2f} GB")
346
+ col2.metric("Compute", device['compute_capability'])
347
+ col3.metric("CUDA Version", cuda['cuda_version'])
348
+
349
+ create_info_box(
350
+ "💡 <b>Recommendation:</b> Your GPU is ready for training! "
351
+ "You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.",
352
+ "success"
353
+ )
354
+ else:
355
+ st.warning("⚠️ No CUDA-capable GPU detected - Training will use CPU")
356
+ create_info_box(
357
+ "💡 <b>Recommendation:</b> For CPU training, we recommend using <b>distilbert-base-multilingual-cased</b> "
358
+ "as it's significantly faster while maintaining good accuracy.",
359
+ "warning"
360
+ )
361
+
362
+ st.markdown("---")
363
+
364
+ # ===== Environment Check =====
365
+ st.markdown("### 🐍 2. Environment Check")
366
+
367
+ col1, col2 = st.columns([3, 1])
368
+ with col1:
369
+ st.markdown("Verify all required Python packages are installed with correct versions.")
370
+ with col2:
371
+ if st.button("🔍 Check Environment", width="stretch"):
372
+ with st.spinner("Checking environment..."):
373
+ env_status = system_checker.check_environment()
374
+ st.session_state.env_status = env_status
375
+ add_log("Environment check completed")
376
+
377
+ if st.session_state.env_status:
378
+ env = st.session_state.env_status
379
+
380
+ if env['all_satisfied']:
381
+ st.success("✅ All required packages are installed")
382
+ else:
383
+ st.error(f"❌ Missing packages: {', '.join(env['missing_packages'])}")
384
+ create_info_box(
385
+ f"<b>To install missing packages, run:</b><br>"
386
+ f"<code>pip install {' '.join(env['missing_packages'])}</code>",
387
+ "error"
388
+ )
389
+
390
+ with st.expander("📦 View Package Details"):
391
+ package_df = pd.DataFrame([
392
+ {
393
+ 'Package': pkg,
394
+ 'Installed': info['installed'] or 'Not Installed',
395
+ 'Required': info['required'],
396
+ 'Status': '✅' if info['satisfied'] else '❌'
397
+ }
398
+ for pkg, info in env['packages'].items()
399
+ ])
400
+ st.dataframe(package_df, width="stretch", hide_index=True)
401
+
402
+ st.markdown("---")
403
+
404
+ # ===== Model Selection Guide =====
405
+ st.markdown("### 📚 3. Model Selection Guide")
406
+
407
+ create_info_box(
408
+ "📖 <b>How to choose the right model:</b><br><br>"
409
+ "Consider these factors:<br>"
410
+ "• <b>Language:</b> English only or multilingual?<br>"
411
+ "• <b>Hardware:</b> GPU available or CPU only?<br>"
412
+ "• <b>Speed vs Accuracy:</b> Need fast training or best accuracy?<br>"
413
+ "• <b>Task Type:</b> Binary or multi-class classification?",
414
+ "info"
415
+ )
416
+
417
+ # Display model comparison table
418
+ model_comparison = []
419
+ for model_id, model_info in MODEL_ARCHITECTURES.items():
420
+ model_comparison.append({
421
+ 'Model': model_info['name'],
422
+ 'Languages': ', '.join(model_info['languages']),
423
+ 'Speed': model_info['speed'],
424
+ 'Size': model_info['size'],
425
+ 'Best For': model_info['best_use'],
426
+ 'ID': model_id
427
+ })
428
+
429
+ model_df = pd.DataFrame(model_comparison)
430
+ st.dataframe(model_df, width="stretch", hide_index=True)
431
+
432
+ # Quick recommendations
433
+ st.markdown("#### 💡 Quick Recommendations:")
434
+
435
+ rec_col1, rec_col2 = st.columns(2)
436
+
437
+ with rec_col1:
438
+ st.markdown("**For GPU Training:**")
439
+ st.markdown("- 🏆 Best: `xlm-roberta-base` (highest accuracy)")
440
+ st.markdown("- ⚡ Fast: `roberta-base` (English only)")
441
+
442
+ with rec_col2:
443
+ st.markdown("**For CPU Training:**")
444
+ st.markdown("- 🎯 Recommended: `distilbert-base-multilingual-cased`")
445
+ st.markdown("- 💨 Fastest training and good performance")
446
+
447
+ st.markdown("---")
448
+
449
+ # ===== Model Download =====
450
+ st.markdown("### 📥 4. Download Models")
451
+
452
+ create_info_box(
453
+ "⬇️ <b>Download models before training:</b><br>"
454
+ "Models will be downloaded to the <code>models/</code> directory. "
455
+ "This may take several minutes depending on your internet connection.",
456
+ "info"
457
+ )
458
+
459
+ # Model selection
460
+ selected_models = st.multiselect(
461
+ "Select models to download:",
462
+ options=list(MODEL_ARCHITECTURES.keys()),
463
+ format_func=lambda x: f"{MODEL_ARCHITECTURES[x]['name']} ({MODEL_ARCHITECTURES[x]['size']})",
464
+ help="Select one or more models to download. You can train with any downloaded model later."
465
+ )
466
+
467
+ col1, col2 = st.columns([3, 1])
468
+ with col2:
469
+ download_btn = st.button("⬇️ Download Selected", width="stretch", type="primary", disabled=len(selected_models) == 0)
470
+
471
+ if download_btn:
472
+ progress_bar = st.progress(0)
473
+ status_text = st.empty()
474
+
475
+ for idx, model_id in enumerate(selected_models):
476
+ status_text.text(f"Downloading {model_id}... ({idx + 1}/{len(selected_models)})")
477
+ progress_bar.progress((idx) / len(selected_models))
478
+
479
+ success, path, message = system_checker.download_model(
480
+ model_id,
481
+ progress_callback=lambda msg, prog: None # Could add sub-progress here
482
+ )
483
+
484
+ if success:
485
+ st.session_state.models_downloaded.add(model_id)
486
+ add_log(f"Downloaded model: {model_id}")
487
+ else:
488
+ st.error(f"Failed to download {model_id}: {message}")
489
+
490
+ progress_bar.progress(1.0)
491
+ status_text.text("✅ Download complete!")
492
+ time.sleep(1)
493
+ st.rerun()
494
+
495
+ # Show downloaded models
496
+ if st.session_state.models_downloaded:
497
+ st.markdown("#### ✅ Downloaded Models:")
498
+ for model_id in st.session_state.models_downloaded:
499
+ model_info = system_checker.get_model_info(model_id)
500
+ st.success(f"📦 {MODEL_ARCHITECTURES[model_id]['name']} - {model_info['size_mb']:.0f} MB")
501
+
502
+ st.markdown("---")
503
+
504
+ # ===== Prerequisites Complete Button =====
505
+ can_proceed = (
506
+ st.session_state.cuda_status is not None and
507
+ st.session_state.env_status is not None and
508
+ st.session_state.env_status['all_satisfied'] and
509
+ len(st.session_state.models_downloaded) > 0
510
+ )
511
+
512
+ if can_proceed:
513
+ if st.button("✅ Prerequisites Complete - Proceed to Data Upload", width="stretch", type="primary"):
514
+ st.session_state.prerequisites_checked = True
515
+ add_log("Prerequisites check completed successfully")
516
+ st.success("🎉 All prerequisites satisfied! You can now proceed to upload your data.")
517
+ time.sleep(1)
518
+ st.rerun()
519
+ else:
520
+ create_info_box(
521
+ "⏳ <b>Complete all checks above before proceeding:</b><br>"
522
+ "✓ CUDA Check<br>"
523
+ "✓ Environment Check (all packages installed)<br>"
524
+ "✓ Download at least one model",
525
+ "warning"
526
+ )
527
+
528
+ # ==================== TAB 2: Upload Data ====================
529
+
530
+ with tab2:
531
+ st.markdown("## 📤 Upload Training Data")
532
+
533
+ if not st.session_state.prerequisites_checked:
534
+ create_info_box(
535
+ "⚠️ Please complete the <b>Prerequisites</b> tab first before uploading data.",
536
+ "warning"
537
+ )
538
+ st.stop()
539
+
540
+ create_info_box(
541
+ "📄 <b>Data Format Requirements:</b><br>"
542
+ "• CSV file with at least two columns: text and label<br>"
543
+ "• Text column: Contains the text samples to classify<br>"
544
+ "• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>"
545
+ "• Minimum 20 samples recommended for training",
546
+ "info"
547
+ )
548
+
549
+ # File uploader
550
+ uploaded_file = st.file_uploader(
551
+ "Upload your CSV file",
552
+ type=['csv'],
553
+ help="Upload a CSV file with 'text' and 'label' columns"
554
+ )
555
+
556
+ if uploaded_file is not None:
557
+ try:
558
+ # Read data
559
+ df = pd.read_csv(uploaded_file)
560
+ st.session_state.uploaded_data = df
561
+
562
+ st.success(f"✅ Uploaded {len(df)} samples")
563
+
564
+ # Validate data
565
+ validator = DataValidator()
566
+ is_valid, message = validator.validate_dataframe(df)
567
+
568
+ if is_valid:
569
+ st.success(f"✅ Data validation passed: {message}")
570
+
571
+ # Show data preview
572
+ st.markdown("### 📊 Data Preview")
573
+ st.dataframe(df.head(10), width="stretch")
574
+
575
+ # Show statistics
576
+ col1, col2, col3 = st.columns(3)
577
+ col1.metric("Total Samples", len(df))
578
+ col2.metric("Unique Labels", df['label'].nunique())
579
+ col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object']))
580
+
581
+ # Label distribution
582
+ st.markdown("### 📈 Label Distribution")
583
+ label_counts = df['label'].value_counts()
584
+ fig = px.bar(
585
+ x=label_counts.index.astype(str),
586
+ y=label_counts.values,
587
+ labels={'x': 'Label', 'y': 'Count'},
588
+ title='Number of samples per label'
589
+ )
590
+ st.plotly_chart(fig, width="stretch")
591
+
592
+ # Update num_labels for multi-class
593
+ if st.session_state.classification_type == ClassificationType.MULTICLASS:
594
+ num_classes = df['label'].nunique()
595
+ st.session_state.config.num_labels = num_classes
596
+ st.info(f"ℹ️ Detected {num_classes} classes for multi-class classification")
597
+
598
+ add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels")
599
+
600
+ else:
601
+ st.error(f"❌ Data validation failed: {message}")
602
+
603
+ except Exception as e:
604
+ st.error(f"Error reading file: {str(e)}")
605
+
606
+ # ==================== TAB 3: Configure Training ====================
607
+
608
+ with tab3:
609
+ st.markdown("## ⚙️ Configure Training Parameters")
610
+
611
+ if st.session_state.uploaded_data is None:
612
+ create_info_box(
613
+ "⚠️ Please upload your data in the <b>Upload Data</b> tab first.",
614
+ "warning"
615
+ )
616
+ st.stop()
617
+
618
+ create_info_box(
619
+ "🎛️ <b>Configure your training settings:</b><br>"
620
+ "Adjust the parameters below based on your needs. Hover over ⓘ for explanations.",
621
+ "info"
622
+ )
623
+
624
+ # Model selection
625
+ st.markdown("### 🤖 Model Selection")
626
+
627
+ available_models = list(st.session_state.models_downloaded)
628
+
629
+ if not available_models:
630
+ st.error("❌ No models downloaded. Please download models in the Prerequisites tab.")
631
+ st.stop()
632
+
633
+ selected_model = st.selectbox(
634
+ "Choose model:",
635
+ options=available_models,
636
+ format_func=lambda x: f"{MODEL_ARCHITECTURES[x]['name']} - {MODEL_ARCHITECTURES[x]['best_use']}",
637
+ help="Select the model architecture to use for training"
638
+ )
639
+
640
+ st.session_state.selected_model = selected_model
641
+ st.session_state.config.model_name = selected_model
642
+
643
+ # Show model info
644
+ model_info = MODEL_ARCHITECTURES[selected_model]
645
+ with st.expander("ℹ️ Selected Model Information"):
646
+ st.markdown(f"**Name:** {model_info['name']}")
647
+ st.markdown(f"**Description:** {model_info['description']}")
648
+ st.markdown(f"**Best For:** {model_info['best_use']}")
649
+ st.markdown(f"**Speed:** {model_info['speed']}")
650
+ st.markdown(f"**Size:** {model_info['size']}")
651
+
652
+ st.markdown("---")
653
+
654
+ # Training parameters
655
+ st.markdown("### 🎯 Training Parameters")
656
+
657
+ col1, col2 = st.columns(2)
658
+
659
+ with col1:
660
+ epochs = st.slider(
661
+ "Number of Epochs",
662
+ min_value=1,
663
+ max_value=20,
664
+ value=3,
665
+ help="Number of complete passes through the training dataset. More epochs = longer training but potentially better performance."
666
+ )
667
+ st.session_state.config.num_epochs = epochs
668
+
669
+ batch_size = st.select_slider(
670
+ "Batch Size",
671
+ options=[4, 8, 16, 32, 64],
672
+ value=16,
673
+ help="Number of samples processed together. Larger batches train faster but require more GPU memory."
674
+ )
675
+ st.session_state.config.batch_size = batch_size
676
+
677
+ learning_rate = st.select_slider(
678
+ "Learning Rate",
679
+ options=[1e-5, 2e-5, 3e-5, 5e-5, 1e-4],
680
+ value=2e-5,
681
+ format_func=lambda x: f"{x:.0e}",
682
+ help="Step size for model parameter updates. 2e-5 is a good default for BERT-like models."
683
+ )
684
+ st.session_state.config.learning_rate = learning_rate
685
+
686
+ with col2:
687
+ max_length = st.slider(
688
+ "Max Sequence Length",
689
+ min_value=128,
690
+ max_value=512,
691
+ value=128,
692
+ step=64,
693
+ help="Maximum length of input text in tokens. Longer sequences require more memory."
694
+ )
695
+ st.session_state.config.max_length = max_length
696
+
697
+ val_split = st.select_slider(
698
+ "Validation Split",
699
+ options=[0.1, 0.15, 0.2, 0.25, 0.3],
700
+ value=0.2,
701
+ format_func=lambda x: f"{x*100:.0f}%",
702
+ help="Percentage of data reserved for validation during training."
703
+ )
704
+ st.session_state.config.validation_split = val_split
705
+ st.session_state.config.train_split = 0.9 - val_split # Keep 0.1 for test
706
+
707
+ early_stopping = st.checkbox(
708
+ "Enable Early Stopping",
709
+ value=True,
710
+ help="Stop training automatically if validation performance stops improving."
711
+ )
712
+
713
+ if early_stopping:
714
+ patience = st.slider(
715
+ "Early Stopping Patience",
716
+ min_value=2,
717
+ max_value=5,
718
+ value=3,
719
+ help="Number of epochs to wait before stopping if no improvement."
720
+ )
721
+
722
+ st.markdown("---")
723
+
724
+ # Show configuration summary
725
+ st.markdown("### 📋 Configuration Summary")
726
+
727
+ config_summary = {
728
+ "Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class",
729
+ "Number of Labels": st.session_state.config.num_labels,
730
+ "Model": model_info['name'],
731
+ "Epochs": epochs,
732
+ "Batch Size": batch_size,
733
+ "Learning Rate": f"{learning_rate:.0e}",
734
+ "Max Length": max_length,
735
+ "Validation Split": f"{val_split*100:.0f}%"
736
+ }
737
+
738
+ summary_df = pd.DataFrame([
739
+ {"Parameter": k, "Value": str(v)}
740
+ for k, v in config_summary.items()
741
+ ])
742
+ st.dataframe(summary_df, width="stretch", hide_index=True)
743
+
744
+ # ==================== TAB 4: Train Model ====================
745
+
746
+ with tab4:
747
+ st.markdown("## 🎯 Train Your Model")
748
+
749
+ if st.session_state.uploaded_data is None:
750
+ create_info_box(
751
+ "⚠️ Please complete previous steps first.",
752
+ "warning"
753
+ )
754
+ st.stop()
755
+
756
+ if not st.session_state.training_started:
757
+ create_info_box(
758
+ "🚀 <b>Ready to train!</b><br>"
759
+ f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples "
760
+ f"for {st.session_state.config.num_epochs} epochs.",
761
+ "info"
762
+ )
763
+
764
+ if st.button("🚀 Start Training", type="primary", width="stretch"):
765
+ st.session_state.training_started = True
766
+ st.rerun()
767
+
768
+ if st.session_state.training_started and not st.session_state.training_completed:
769
+ st.markdown("### ⏳ Training in Progress...")
770
+
771
+ # Progress display
772
+ progress_bar = st.progress(0)
773
+ status_text = st.empty()
774
+ metrics_container = st.container()
775
+
776
+ try:
777
+ # Prepare data
778
+ status_text.text("Preparing data...")
779
+ df = st.session_state.uploaded_data
780
+
781
+ # Initialize trainer with absolute path
782
+ import os
783
+ st.session_state.config.output_dir = os.path.abspath("trained_models")
784
+ trainer = ModelTrainer(config=st.session_state.config)
785
+
786
+ # Training progress callback - receives TrainingProgress object
787
+ def progress_callback(progress_obj):
788
+ if progress_obj.progress_percent > 0:
789
+ progress_bar.progress(progress_obj.progress_percent / 100.0)
790
+
791
+ status_text.text(f"Training: {progress_obj.progress_percent:.1f}% complete")
792
+ st.session_state.training_progress = progress_obj.progress_percent
793
+
794
+ # Update metrics display from latest metrics
795
+ if progress_obj.metrics_history:
796
+ latest_metrics = progress_obj.metrics_history[-1]
797
+ with metrics_container:
798
+ col1, col2, col3 = st.columns(3)
799
+ col1.metric("Epoch", f"{progress_obj.current_epoch}/{progress_obj.total_epochs}")
800
+ col2.metric("Train Loss", f"{latest_metrics.train_loss:.4f}")
801
+ if latest_metrics.eval_loss > 0:
802
+ col3.metric("Val Loss", f"{latest_metrics.eval_loss:.4f}")
803
+
804
+ # Train model
805
+ result = trainer.train(
806
+ texts=df['text'].tolist(),
807
+ labels=df['label'].tolist(),
808
+ progress_callback=progress_callback
809
+ )
810
+
811
+ # Check if training actually succeeded
812
+ if result.status == "failed":
813
+ raise Exception(result.error_message or "Training failed with unknown error")
814
+
815
+ if result.model_path is None:
816
+ raise Exception("Training completed but model path is None. Check logs for errors.")
817
+
818
+ # Training complete
819
+ st.session_state.training_completed = True
820
+ st.session_state.model_path = result.model_path
821
+ st.session_state.metrics_history = [m.to_dict() for m in result.metrics_history]
822
+
823
+ progress_bar.progress(1.0)
824
+ status_text.empty()
825
+
826
+ st.success("🎉 Training completed successfully!")
827
+ add_log(f"Training completed successfully. Model saved to: {result.model_path}")
828
+
829
+ # Show final metrics
830
+ if result.final_metrics:
831
+ st.markdown("### 📊 Final Training Metrics")
832
+ metrics = result.final_metrics.to_dict()
833
+
834
+ col1, col2, col3, col4 = st.columns(4)
835
+ col1.metric("Accuracy", f"{metrics.get('accuracy', 0):.2%}")
836
+ col2.metric("Precision", f"{metrics.get('precision', 0):.4f}")
837
+ col3.metric("Recall", f"{metrics.get('recall', 0):.4f}")
838
+ col4.metric("F1 Score", f"{metrics.get('f1', 0):.4f}")
839
+
840
+ time.sleep(2)
841
+ st.rerun()
842
+
843
+ except Exception as e:
844
+ import traceback
845
+ error_details = traceback.format_exc()
846
+ st.error(f"❌ Training failed: {str(e)}")
847
+ with st.expander("🔍 Error Details"):
848
+ st.code(error_details)
849
+ st.session_state.training_started = False
850
+ add_log(f"Training failed: {str(e)}")
851
+
852
+ if st.session_state.training_completed:
853
+ st.success("✅ Training completed!")
854
+
855
+ model_path_display = st.session_state.model_path if st.session_state.model_path else "⚠️ Path not available"
856
+
857
+ create_info_box(
858
+ f"🎉 <b>Model trained successfully!</b><br>"
859
+ f"Model saved to: <code>{model_path_display}</code><br>"
860
+ "Proceed to the <b>Evaluate Model</b> tab to analyze performance.",
861
+ "success" if st.session_state.model_path else "warning"
862
+ )
863
+
864
+ # Show training logs
865
+ with st.expander("📜 View Training Logs"):
866
+ for log in st.session_state.training_logs[-20:]: # Show last 20 logs
867
+ st.text(log)
868
+
869
+ # ==================== TAB 5: Evaluate Model ====================
870
+
871
+ with tab5:
872
+ st.markdown("## 📊 Evaluate Model Performance")
873
+
874
+ if not st.session_state.training_completed:
875
+ create_info_box(
876
+ "⚠️ Please train a model first in the <b>Train Model</b> tab.",
877
+ "warning"
878
+ )
879
+ st.stop()
880
+
881
+ create_info_box(
882
+ "📈 <b>Model Evaluation:</b><br>"
883
+ "Analyze your model's performance with detailed metrics and visualizations.",
884
+ "info"
885
+ )
886
+
887
+ if st.session_state.evaluation_results is None:
888
+ if st.button("🔍 Evaluate Model", type="primary", width="stretch"):
889
+ with st.spinner("Evaluating model..."):
890
+ try:
891
+ # Initialize evaluator
892
+ evaluator = ModelEvaluator(
893
+ model_path=st.session_state.model_path,
894
+ use_cuda=st.session_state.cuda_status['available'] if st.session_state.cuda_status else False
895
+ )
896
+
897
+ # Prepare test data (use validation split from uploaded data)
898
+ df = st.session_state.uploaded_data
899
+ test_size = int(len(df) * st.session_state.config.validation_split)
900
+ test_df = df.tail(test_size)
901
+
902
+ # Evaluate
903
+ results = evaluator.evaluate(
904
+ texts=test_df['text'].tolist(),
905
+ true_labels=test_df['label'].tolist(),
906
+ batch_size=st.session_state.config.batch_size
907
+ )
908
+
909
+ st.session_state.evaluation_results = results
910
+ add_log("Model evaluation completed")
911
+ st.rerun()
912
+
913
+ except Exception as e:
914
+ st.error(f"Evaluation failed: {str(e)}")
915
+
916
+ if st.session_state.evaluation_results:
917
+ results = st.session_state.evaluation_results
918
+
919
+ # Overall metrics
920
+ st.markdown("### 📊 Overall Metrics")
921
+
922
+ col1, col2, col3, col4 = st.columns(4)
923
+ col1.metric("Accuracy", f"{results['accuracy']:.2%}")
924
+ col2.metric("Precision", f"{results['precision']:.4f}")
925
+ col3.metric("Recall", f"{results['recall']:.4f}")
926
+ col4.metric("F1 Score", f"{results['f1']:.4f}")
927
+
928
+ st.markdown("---")
929
+
930
+ # Confusion Matrix
931
+ st.markdown("### 🔢 Confusion Matrix")
932
+
933
+ if 'confusion_matrix' in results:
934
+ cm = results['confusion_matrix']
935
+
936
+ # Create heatmap
937
+ fig = go.Figure(data=go.Heatmap(
938
+ z=cm,
939
+ x=[f"Predicted {i}" for i in range(len(cm))],
940
+ y=[f"True {i}" for i in range(len(cm))],
941
+ colorscale='Blues',
942
+ text=cm,
943
+ texttemplate="%{text}",
944
+ textfont={"size": 16}
945
+ ))
946
+
947
+ fig.update_layout(
948
+ title="Confusion Matrix",
949
+ xaxis_title="Predicted Label",
950
+ yaxis_title="True Label",
951
+ height=500
952
+ )
953
+
954
+ st.plotly_chart(fig, width="stretch")
955
+
956
+ st.markdown("---")
957
+
958
+ # Classification Report
959
+ st.markdown("### 📋 Detailed Classification Report")
960
+
961
+ if 'classification_report' in results:
962
+ report = results['classification_report']
963
+ st.text(report)
964
+
965
+ # Download results
966
+ st.markdown("---")
967
+ st.markdown("### 💾 Download Results")
968
+
969
+ if st.button("📥 Download Evaluation Report", width="stretch"):
970
+ # Create downloadable report
971
+ report_text = f"""
972
+ MLOps Training Platform - Evaluation Report
973
+ {'='*60}
974
+
975
+ Model: {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']}
976
+ Classification Type: {'Binary' if st.session_state.classification_type == ClassificationType.BINARY else 'Multi-class'}
977
+ Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
978
+
979
+ Overall Metrics:
980
+ - Accuracy: {results['accuracy']:.4f}
981
+ - Precision: {results['precision']:.4f}
982
+ - Recall: {results['recall']:.4f}
983
+ - F1 Score: {results['f1']:.4f}
984
+
985
+ Classification Report:
986
+ {results.get('classification_report', 'N/A')}
987
+
988
+ Training Configuration:
989
+ - Epochs: {st.session_state.config.num_epochs}
990
+ - Batch Size: {st.session_state.config.batch_size}
991
+ - Learning Rate: {st.session_state.config.learning_rate}
992
+ - Max Length: {st.session_state.config.max_length}
993
+ """
994
+
995
+ st.download_button(
996
+ label="📄 Download Text Report",
997
+ data=report_text,
998
+ file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
999
+ mime="text/plain"
1000
+ )
1001
+
1002
+ # ==================== Footer ====================
1003
+
1004
+ st.markdown("---")
1005
+ st.markdown(
1006
+ """
1007
+ <div style='text-align: center; color: #666; padding: 2rem;'>
1008
+ <p> MLOps Training Platform | Built with Streamlit & PyTorch</p>
1009
+ <p>For help and documentation, check the README.md file</p>
1010
+ </div>
1011
+ """,
1012
+ unsafe_allow_html=True
1013
+ )