songhieng commited on
Commit
f31a765
·
verified ·
1 Parent(s): bf40bd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -75
app.py CHANGED
@@ -25,6 +25,7 @@ import streamlit as st
25
  import pandas as pd
26
  import plotly.express as px
27
  import plotly.graph_objects as go
 
28
 
29
  # Add src directory to path for imports
30
  sys.path.insert(0, str(Path(__file__).parent / 'src'))
@@ -143,6 +144,8 @@ def init_session_state():
143
  # Data
144
  'uploaded_data': None,
145
  'preprocessed_data': None,
 
 
146
 
147
  # Evaluation
148
  'evaluation_results': None,
@@ -309,10 +312,10 @@ tab1, tab2, tab3, tab4, tab5 = st.tabs([
309
  # ==================== TAB 1: Prerequisites ====================
310
 
311
  with tab1:
312
- st.markdown("## 🔧 System Prerequisites")
313
 
314
  create_info_box(
315
- "⚠️ <b>Important:</b> Complete all prerequisite checks before proceeding to training.<br>"
316
  "This ensures your system is properly configured and all required models are downloaded.",
317
  "warning"
318
  )
@@ -321,13 +324,13 @@ with tab1:
321
  system_checker = SystemChecker(models_dir="models")
322
 
323
  # ===== CUDA/GPU Check =====
324
- st.markdown("### 🎮 1. CUDA/GPU Check")
325
 
326
  col1, col2 = st.columns([3, 1])
327
  with col1:
328
  st.markdown("Check if CUDA-capable GPU is available for faster training.")
329
  with col2:
330
- if st.button("🔍 Check CUDA", width="stretch"):
331
  with st.spinner("Checking CUDA availability..."):
332
  cuda_status = system_checker.check_cuda()
333
  st.session_state.cuda_status = cuda_status
@@ -337,24 +340,24 @@ with tab1:
337
  cuda = st.session_state.cuda_status
338
 
339
  if cuda['available']:
340
- st.success(f"CUDA Available - {cuda['device_count']} GPU(s) detected")
341
 
342
  for device in cuda['devices']:
343
- with st.expander(f"📊 {device['name']} Details"):
344
  col1, col2, col3 = st.columns(3)
345
  col1.metric("Memory", f"{device['memory_total']:.2f} GB")
346
  col2.metric("Compute", device['compute_capability'])
347
  col3.metric("CUDA Version", cuda['cuda_version'])
348
 
349
  create_info_box(
350
- "💡 <b>Recommendation:</b> Your GPU is ready for training! "
351
  "You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.",
352
  "success"
353
  )
354
  else:
355
- st.warning("⚠️ No CUDA-capable GPU detected - Training will use CPU")
356
  create_info_box(
357
- "💡 <b>Recommendation:</b> For CPU training, we recommend using <b>distilbert-base-multilingual-cased</b> "
358
  "as it's significantly faster while maintaining good accuracy.",
359
  "warning"
360
  )
@@ -362,13 +365,13 @@ with tab1:
362
  st.markdown("---")
363
 
364
  # ===== Environment Check =====
365
- st.markdown("### 🐍 2. Environment Check")
366
 
367
  col1, col2 = st.columns([3, 1])
368
  with col1:
369
  st.markdown("Verify all required Python packages are installed with correct versions.")
370
  with col2:
371
- if st.button("🔍 Check Environment", width="stretch"):
372
  with st.spinner("Checking environment..."):
373
  env_status = system_checker.check_environment()
374
  st.session_state.env_status = env_status
@@ -378,22 +381,22 @@ with tab1:
378
  env = st.session_state.env_status
379
 
380
  if env['all_satisfied']:
381
- st.success("All required packages are installed")
382
  else:
383
- st.error(f"Missing packages: {', '.join(env['missing_packages'])}")
384
  create_info_box(
385
  f"<b>To install missing packages, run:</b><br>"
386
  f"<code>pip install {' '.join(env['missing_packages'])}</code>",
387
  "error"
388
  )
389
 
390
- with st.expander("📦 View Package Details"):
391
  package_df = pd.DataFrame([
392
  {
393
  'Package': pkg,
394
  'Installed': info['installed'] or 'Not Installed',
395
  'Required': info['required'],
396
- 'Status': '' if info['satisfied'] else ''
397
  }
398
  for pkg, info in env['packages'].items()
399
  ])
@@ -402,10 +405,10 @@ with tab1:
402
  st.markdown("---")
403
 
404
  # ===== Model Selection Guide =====
405
- st.markdown("### 📚 3. Model Selection Guide")
406
 
407
  create_info_box(
408
- "📖 <b>How to choose the right model:</b><br><br>"
409
  "Consider these factors:<br>"
410
  "• <b>Language:</b> English only or multilingual?<br>"
411
  "• <b>Hardware:</b> GPU available or CPU only?<br>"
@@ -430,27 +433,27 @@ with tab1:
430
  st.dataframe(model_df, width="stretch", hide_index=True)
431
 
432
  # Quick recommendations
433
- st.markdown("#### 💡 Quick Recommendations:")
434
 
435
  rec_col1, rec_col2 = st.columns(2)
436
 
437
  with rec_col1:
438
  st.markdown("**For GPU Training:**")
439
- st.markdown("- 🏆 Best: `xlm-roberta-base` (highest accuracy)")
440
- st.markdown("- Fast: `roberta-base` (English only)")
441
 
442
  with rec_col2:
443
  st.markdown("**For CPU Training:**")
444
- st.markdown("- 🎯 Recommended: `distilbert-base-multilingual-cased`")
445
- st.markdown("- 💨 Fastest training and good performance")
446
 
447
  st.markdown("---")
448
 
449
  # ===== Model Download =====
450
- st.markdown("### 📥 4. Download Models")
451
 
452
  create_info_box(
453
- "⬇️ <b>Download models before training:</b><br>"
454
  "Models will be downloaded to the <code>models/</code> directory. "
455
  "This may take several minutes depending on your internet connection.",
456
  "info"
@@ -466,7 +469,7 @@ with tab1:
466
 
467
  col1, col2 = st.columns([3, 1])
468
  with col2:
469
- download_btn = st.button("⬇️ Download Selected", width="stretch", type="primary", disabled=len(selected_models) == 0)
470
 
471
  if download_btn:
472
  progress_bar = st.progress(0)
@@ -488,16 +491,16 @@ with tab1:
488
  st.error(f"Failed to download {model_id}: {message}")
489
 
490
  progress_bar.progress(1.0)
491
- status_text.text("Download complete!")
492
  time.sleep(1)
493
  st.rerun()
494
 
495
  # Show downloaded models
496
  if st.session_state.models_downloaded:
497
- st.markdown("#### Downloaded Models:")
498
  for model_id in st.session_state.models_downloaded:
499
  model_info = system_checker.get_model_info(model_id)
500
- st.success(f"📦 {MODEL_ARCHITECTURES[model_id]['name']} - {model_info['size_mb']:.0f} MB")
501
 
502
  st.markdown("---")
503
 
@@ -510,35 +513,35 @@ with tab1:
510
  )
511
 
512
  if can_proceed:
513
- if st.button("Prerequisites Complete - Proceed to Data Upload", width="stretch", type="primary"):
514
  st.session_state.prerequisites_checked = True
515
  add_log("Prerequisites check completed successfully")
516
- st.success("🎉 All prerequisites satisfied! You can now proceed to upload your data.")
517
  time.sleep(1)
518
  st.rerun()
519
  else:
520
  create_info_box(
521
- "<b>Complete all checks above before proceeding:</b><br>"
522
- " CUDA Check<br>"
523
- " Environment Check (all packages installed)<br>"
524
- " Download at least one model",
525
  "warning"
526
  )
527
 
528
  # ==================== TAB 2: Upload Data ====================
529
 
530
  with tab2:
531
- st.markdown("## 📤 Upload Training Data")
532
 
533
  if not st.session_state.prerequisites_checked:
534
  create_info_box(
535
- "⚠️ Please complete the <b>Prerequisites</b> tab first before uploading data.",
536
  "warning"
537
  )
538
  st.stop()
539
 
540
  create_info_box(
541
- "📄 <b>Data Format Requirements:</b><br>"
542
  "• CSV file with at least two columns: text and label<br>"
543
  "• Text column: Contains the text samples to classify<br>"
544
  "• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>"
@@ -559,17 +562,17 @@ with tab2:
559
  df = pd.read_csv(uploaded_file)
560
  st.session_state.uploaded_data = df
561
 
562
- st.success(f"Uploaded {len(df)} samples")
563
 
564
  # Validate data
565
  validator = DataValidator()
566
  is_valid, message = validator.validate_dataframe(df)
567
 
568
  if is_valid:
569
- st.success(f"Data validation passed: {message}")
570
 
571
  # Show data preview
572
- st.markdown("### 📊 Data Preview")
573
  st.dataframe(df.head(10), width="stretch")
574
 
575
  # Show statistics
@@ -579,7 +582,7 @@ with tab2:
579
  col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object']))
580
 
581
  # Label distribution
582
- st.markdown("### 📈 Label Distribution")
583
  label_counts = df['label'].value_counts()
584
  fig = px.bar(
585
  x=label_counts.index.astype(str),
@@ -593,12 +596,12 @@ with tab2:
593
  if st.session_state.classification_type == ClassificationType.MULTICLASS:
594
  num_classes = df['label'].nunique()
595
  st.session_state.config.num_labels = num_classes
596
- st.info(f"ℹ️ Detected {num_classes} classes for multi-class classification")
597
 
598
  add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels")
599
 
600
  else:
601
- st.error(f"Data validation failed: {message}")
602
 
603
  except Exception as e:
604
  st.error(f"Error reading file: {str(e)}")
@@ -606,28 +609,28 @@ with tab2:
606
  # ==================== TAB 3: Configure Training ====================
607
 
608
  with tab3:
609
- st.markdown("## ⚙️ Configure Training Parameters")
610
 
611
  if st.session_state.uploaded_data is None:
612
  create_info_box(
613
- "⚠️ Please upload your data in the <b>Upload Data</b> tab first.",
614
  "warning"
615
  )
616
  st.stop()
617
 
618
  create_info_box(
619
- "🎛️ <b>Configure your training settings:</b><br>"
620
  "Adjust the parameters below based on your needs. Hover over ⓘ for explanations.",
621
  "info"
622
  )
623
 
624
  # Model selection
625
- st.markdown("### 🤖 Model Selection")
626
 
627
  available_models = list(st.session_state.models_downloaded)
628
 
629
  if not available_models:
630
- st.error("No models downloaded. Please download models in the Prerequisites tab.")
631
  st.stop()
632
 
633
  selected_model = st.selectbox(
@@ -652,7 +655,7 @@ with tab3:
652
  st.markdown("---")
653
 
654
  # Training parameters
655
- st.markdown("### 🎯 Training Parameters")
656
 
657
  col1, col2 = st.columns(2)
658
 
@@ -722,7 +725,7 @@ with tab3:
722
  st.markdown("---")
723
 
724
  # Show configuration summary
725
- st.markdown("### 📋 Configuration Summary")
726
 
727
  config_summary = {
728
  "Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class",
@@ -744,29 +747,29 @@ with tab3:
744
  # ==================== TAB 4: Train Model ====================
745
 
746
  with tab4:
747
- st.markdown("## 🎯 Train Your Model")
748
 
749
  if st.session_state.uploaded_data is None:
750
  create_info_box(
751
- "⚠️ Please complete previous steps first.",
752
  "warning"
753
  )
754
  st.stop()
755
 
756
  if not st.session_state.training_started:
757
  create_info_box(
758
- "🚀 <b>Ready to train!</b><br>"
759
  f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples "
760
  f"for {st.session_state.config.num_epochs} epochs.",
761
  "info"
762
  )
763
 
764
- if st.button("🚀 Start Training", type="primary", width="stretch"):
765
  st.session_state.training_started = True
766
  st.rerun()
767
 
768
  if st.session_state.training_started and not st.session_state.training_completed:
769
- st.markdown("### Training in Progress...")
770
 
771
  # Progress display
772
  progress_bar = st.progress(0)
@@ -778,6 +781,14 @@ with tab4:
778
  status_text.text("Preparing data...")
779
  df = st.session_state.uploaded_data
780
 
 
 
 
 
 
 
 
 
781
  # Initialize trainer with absolute path
782
  import os
783
  st.session_state.config.output_dir = os.path.abspath("trained_models")
@@ -804,7 +815,7 @@ with tab4:
804
  # Train model
805
  result = trainer.train(
806
  texts=df['text'].tolist(),
807
- labels=df['label'].tolist(),
808
  progress_callback=progress_callback
809
  )
810
 
@@ -823,12 +834,12 @@ with tab4:
823
  progress_bar.progress(1.0)
824
  status_text.empty()
825
 
826
- st.success("🎉 Training completed successfully!")
827
  add_log(f"Training completed successfully. Model saved to: {result.model_path}")
828
 
829
  # Show final metrics
830
  if result.final_metrics:
831
- st.markdown("### 📊 Final Training Metrics")
832
  metrics = result.final_metrics.to_dict()
833
 
834
  col1, col2, col3, col4 = st.columns(4)
@@ -843,49 +854,49 @@ with tab4:
843
  except Exception as e:
844
  import traceback
845
  error_details = traceback.format_exc()
846
- st.error(f"Training failed: {str(e)}")
847
- with st.expander("🔍 Error Details"):
848
  st.code(error_details)
849
  st.session_state.training_started = False
850
  add_log(f"Training failed: {str(e)}")
851
 
852
  if st.session_state.training_completed:
853
- st.success("Training completed!")
854
 
855
- model_path_display = st.session_state.model_path if st.session_state.model_path else "⚠️ Path not available"
856
 
857
  create_info_box(
858
- f"🎉 <b>Model trained successfully!</b><br>"
859
  f"Model saved to: <code>{model_path_display}</code><br>"
860
  "Proceed to the <b>Evaluate Model</b> tab to analyze performance.",
861
  "success" if st.session_state.model_path else "warning"
862
  )
863
 
864
  # Show training logs
865
- with st.expander("📜 View Training Logs"):
866
  for log in st.session_state.training_logs[-20:]: # Show last 20 logs
867
  st.text(log)
868
 
869
  # ==================== TAB 5: Evaluate Model ====================
870
 
871
  with tab5:
872
- st.markdown("## 📊 Evaluate Model Performance")
873
 
874
  if not st.session_state.training_completed:
875
  create_info_box(
876
- "⚠️ Please train a model first in the <b>Train Model</b> tab.",
877
  "warning"
878
  )
879
  st.stop()
880
 
881
  create_info_box(
882
- "📈 <b>Model Evaluation:</b><br>"
883
  "Analyze your model's performance with detailed metrics and visualizations.",
884
  "info"
885
  )
886
 
887
  if st.session_state.evaluation_results is None:
888
- if st.button("🔍 Evaluate Model", type="primary", width="stretch"):
889
  with st.spinner("Evaluating model..."):
890
  try:
891
  # Initialize evaluator
@@ -899,10 +910,19 @@ with tab5:
899
  test_size = int(len(df) * st.session_state.config.validation_split)
900
  test_df = df.tail(test_size)
901
 
 
 
 
 
 
 
 
 
 
902
  # Evaluate
903
  results = evaluator.evaluate(
904
  texts=test_df['text'].tolist(),
905
- true_labels=test_df['label'].tolist(),
906
  batch_size=st.session_state.config.batch_size
907
  )
908
 
@@ -917,7 +937,7 @@ with tab5:
917
  results = st.session_state.evaluation_results
918
 
919
  # Overall metrics
920
- st.markdown("### 📊 Overall Metrics")
921
 
922
  col1, col2, col3, col4 = st.columns(4)
923
  col1.metric("Accuracy", f"{results['accuracy']:.2%}")
@@ -928,7 +948,7 @@ with tab5:
928
  st.markdown("---")
929
 
930
  # Confusion Matrix
931
- st.markdown("### 🔢 Confusion Matrix")
932
 
933
  if 'confusion_matrix' in results:
934
  cm = results['confusion_matrix']
@@ -956,7 +976,7 @@ with tab5:
956
  st.markdown("---")
957
 
958
  # Classification Report
959
- st.markdown("### 📋 Detailed Classification Report")
960
 
961
  if 'classification_report' in results:
962
  report = results['classification_report']
@@ -964,9 +984,9 @@ with tab5:
964
 
965
  # Download results
966
  st.markdown("---")
967
- st.markdown("### 💾 Download Results")
968
 
969
- if st.button("📥 Download Evaluation Report", width="stretch"):
970
  # Create downloadable report
971
  report_text = f"""
972
  MLOps Training Platform - Evaluation Report
@@ -993,7 +1013,7 @@ Training Configuration:
993
  """
994
 
995
  st.download_button(
996
- label="📄 Download Text Report",
997
  data=report_text,
998
  file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
999
  mime="text/plain"
 
25
  import pandas as pd
26
  import plotly.express as px
27
  import plotly.graph_objects as go
28
+ from sklearn.preprocessing import LabelEncoder
29
 
30
  # Add src directory to path for imports
31
  sys.path.insert(0, str(Path(__file__).parent / 'src'))
 
144
  # Data
145
  'uploaded_data': None,
146
  'preprocessed_data': None,
147
+ 'label_encoder': None,
148
+ 'label_classes': None,
149
 
150
  # Evaluation
151
  'evaluation_results': None,
 
312
  # ==================== TAB 1: Prerequisites ====================
313
 
314
  with tab1:
315
+ st.markdown("## System Prerequisites")
316
 
317
  create_info_box(
318
+ "<b>Important:</b> Complete all prerequisite checks before proceeding to training.<br>"
319
  "This ensures your system is properly configured and all required models are downloaded.",
320
  "warning"
321
  )
 
324
  system_checker = SystemChecker(models_dir="models")
325
 
326
  # ===== CUDA/GPU Check =====
327
+ st.markdown("### 1. CUDA/GPU Check")
328
 
329
  col1, col2 = st.columns([3, 1])
330
  with col1:
331
  st.markdown("Check if CUDA-capable GPU is available for faster training.")
332
  with col2:
333
+ if st.button("Check CUDA", width="stretch"):
334
  with st.spinner("Checking CUDA availability..."):
335
  cuda_status = system_checker.check_cuda()
336
  st.session_state.cuda_status = cuda_status
 
340
  cuda = st.session_state.cuda_status
341
 
342
  if cuda['available']:
343
+ st.success(f"CUDA Available - {cuda['device_count']} GPU(s) detected")
344
 
345
  for device in cuda['devices']:
346
+ with st.expander(f"Device: {device['name']} Details"):
347
  col1, col2, col3 = st.columns(3)
348
  col1.metric("Memory", f"{device['memory_total']:.2f} GB")
349
  col2.metric("Compute", device['compute_capability'])
350
  col3.metric("CUDA Version", cuda['cuda_version'])
351
 
352
  create_info_box(
353
+ "<b>Recommendation:</b> Your GPU is ready for training! "
354
  "You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.",
355
  "success"
356
  )
357
  else:
358
+ st.warning("No CUDA-capable GPU detected - Training will use CPU")
359
  create_info_box(
360
+ "<b>Recommendation:</b> For CPU training, we recommend using <b>distilbert-base-multilingual-cased</b> "
361
  "as it's significantly faster while maintaining good accuracy.",
362
  "warning"
363
  )
 
365
  st.markdown("---")
366
 
367
  # ===== Environment Check =====
368
+ st.markdown("### 2. Environment Check")
369
 
370
  col1, col2 = st.columns([3, 1])
371
  with col1:
372
  st.markdown("Verify all required Python packages are installed with correct versions.")
373
  with col2:
374
+ if st.button("Check Environment", width="stretch"):
375
  with st.spinner("Checking environment..."):
376
  env_status = system_checker.check_environment()
377
  st.session_state.env_status = env_status
 
381
  env = st.session_state.env_status
382
 
383
  if env['all_satisfied']:
384
+ st.success("All required packages are installed")
385
  else:
386
+ st.error(f"Missing packages: {', '.join(env['missing_packages'])}")
387
  create_info_box(
388
  f"<b>To install missing packages, run:</b><br>"
389
  f"<code>pip install {' '.join(env['missing_packages'])}</code>",
390
  "error"
391
  )
392
 
393
+ with st.expander("View Package Details"):
394
  package_df = pd.DataFrame([
395
  {
396
  'Package': pkg,
397
  'Installed': info['installed'] or 'Not Installed',
398
  'Required': info['required'],
399
+ 'Status': 'OK' if info['satisfied'] else 'Missing'
400
  }
401
  for pkg, info in env['packages'].items()
402
  ])
 
405
  st.markdown("---")
406
 
407
  # ===== Model Selection Guide =====
408
+ st.markdown("### 3. Model Selection Guide")
409
 
410
  create_info_box(
411
+ "<b>How to choose the right model:</b><br><br>"
412
  "Consider these factors:<br>"
413
  "• <b>Language:</b> English only or multilingual?<br>"
414
  "• <b>Hardware:</b> GPU available or CPU only?<br>"
 
433
  st.dataframe(model_df, width="stretch", hide_index=True)
434
 
435
  # Quick recommendations
436
+ st.markdown("#### Quick Recommendations:")
437
 
438
  rec_col1, rec_col2 = st.columns(2)
439
 
440
  with rec_col1:
441
  st.markdown("**For GPU Training:**")
442
+ st.markdown("- Best: `xlm-roberta-base` (highest accuracy)")
443
+ st.markdown("- Fast: `roberta-base` (English only)")
444
 
445
  with rec_col2:
446
  st.markdown("**For CPU Training:**")
447
+ st.markdown("- Recommended: `distilbert-base-multilingual-cased`")
448
+ st.markdown("- Fastest training and good performance")
449
 
450
  st.markdown("---")
451
 
452
  # ===== Model Download =====
453
+ st.markdown("### 4. Download Models")
454
 
455
  create_info_box(
456
+ "<b>Download models before training:</b><br>"
457
  "Models will be downloaded to the <code>models/</code> directory. "
458
  "This may take several minutes depending on your internet connection.",
459
  "info"
 
469
 
470
  col1, col2 = st.columns([3, 1])
471
  with col2:
472
+ download_btn = st.button("Download Selected", width="stretch", type="primary", disabled=len(selected_models) == 0)
473
 
474
  if download_btn:
475
  progress_bar = st.progress(0)
 
491
  st.error(f"Failed to download {model_id}: {message}")
492
 
493
  progress_bar.progress(1.0)
494
+ status_text.text("Download complete!")
495
  time.sleep(1)
496
  st.rerun()
497
 
498
  # Show downloaded models
499
  if st.session_state.models_downloaded:
500
+ st.markdown("#### Downloaded Models:")
501
  for model_id in st.session_state.models_downloaded:
502
  model_info = system_checker.get_model_info(model_id)
503
+ st.success(f"{MODEL_ARCHITECTURES[model_id]['name']} - {model_info['size_mb']:.0f} MB")
504
 
505
  st.markdown("---")
506
 
 
513
  )
514
 
515
  if can_proceed:
516
+ if st.button("Prerequisites Complete - Proceed to Data Upload", width="stretch", type="primary"):
517
  st.session_state.prerequisites_checked = True
518
  add_log("Prerequisites check completed successfully")
519
+ st.success("All prerequisites satisfied! You can now proceed to upload your data.")
520
  time.sleep(1)
521
  st.rerun()
522
  else:
523
  create_info_box(
524
+ "<b>Complete all checks above before proceeding:</b><br>"
525
+ "- CUDA Check<br>"
526
+ "- Environment Check (all packages installed)<br>"
527
+ "- Download at least one model",
528
  "warning"
529
  )
530
 
531
  # ==================== TAB 2: Upload Data ====================
532
 
533
  with tab2:
534
+ st.markdown("## Upload Training Data")
535
 
536
  if not st.session_state.prerequisites_checked:
537
  create_info_box(
538
+ "Please complete the <b>Prerequisites</b> tab first before uploading data.",
539
  "warning"
540
  )
541
  st.stop()
542
 
543
  create_info_box(
544
+ "<b>Data Format Requirements:</b><br>"
545
  "• CSV file with at least two columns: text and label<br>"
546
  "• Text column: Contains the text samples to classify<br>"
547
  "• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>"
 
562
  df = pd.read_csv(uploaded_file)
563
  st.session_state.uploaded_data = df
564
 
565
+ st.success(f"Uploaded {len(df)} samples")
566
 
567
  # Validate data
568
  validator = DataValidator()
569
  is_valid, message = validator.validate_dataframe(df)
570
 
571
  if is_valid:
572
+ st.success(f"Data validation passed: {message}")
573
 
574
  # Show data preview
575
+ st.markdown("### Data Preview")
576
  st.dataframe(df.head(10), width="stretch")
577
 
578
  # Show statistics
 
582
  col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object']))
583
 
584
  # Label distribution
585
+ st.markdown("### Label Distribution")
586
  label_counts = df['label'].value_counts()
587
  fig = px.bar(
588
  x=label_counts.index.astype(str),
 
596
  if st.session_state.classification_type == ClassificationType.MULTICLASS:
597
  num_classes = df['label'].nunique()
598
  st.session_state.config.num_labels = num_classes
599
+ st.info(f"Detected {num_classes} classes for multi-class classification")
600
 
601
  add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels")
602
 
603
  else:
604
+ st.error(f"Data validation failed: {message}")
605
 
606
  except Exception as e:
607
  st.error(f"Error reading file: {str(e)}")
 
609
  # ==================== TAB 3: Configure Training ====================
610
 
611
  with tab3:
612
+ st.markdown("## Configure Training Parameters")
613
 
614
  if st.session_state.uploaded_data is None:
615
  create_info_box(
616
+ "Please upload your data in the <b>Upload Data</b> tab first.",
617
  "warning"
618
  )
619
  st.stop()
620
 
621
  create_info_box(
622
+ "<b>Configure your training settings:</b><br>"
623
  "Adjust the parameters below based on your needs. Hover over ⓘ for explanations.",
624
  "info"
625
  )
626
 
627
  # Model selection
628
+ st.markdown("### Model Selection")
629
 
630
  available_models = list(st.session_state.models_downloaded)
631
 
632
  if not available_models:
633
+ st.error("No models downloaded. Please download models in the Prerequisites tab.")
634
  st.stop()
635
 
636
  selected_model = st.selectbox(
 
655
  st.markdown("---")
656
 
657
  # Training parameters
658
+ st.markdown("### Training Parameters")
659
 
660
  col1, col2 = st.columns(2)
661
 
 
725
  st.markdown("---")
726
 
727
  # Show configuration summary
728
+ st.markdown("### Configuration Summary")
729
 
730
  config_summary = {
731
  "Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class",
 
747
  # ==================== TAB 4: Train Model ====================
748
 
749
  with tab4:
750
+ st.markdown("## Train Your Model")
751
 
752
  if st.session_state.uploaded_data is None:
753
  create_info_box(
754
+ "Please complete previous steps first.",
755
  "warning"
756
  )
757
  st.stop()
758
 
759
  if not st.session_state.training_started:
760
  create_info_box(
761
+ "<b>Ready to train!</b><br>"
762
  f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples "
763
  f"for {st.session_state.config.num_epochs} epochs.",
764
  "info"
765
  )
766
 
767
+ if st.button("Start Training", type="primary", width="stretch"):
768
  st.session_state.training_started = True
769
  st.rerun()
770
 
771
  if st.session_state.training_started and not st.session_state.training_completed:
772
+ st.markdown("### Training in Progress...")
773
 
774
  # Progress display
775
  progress_bar = st.progress(0)
 
781
  status_text.text("Preparing data...")
782
  df = st.session_state.uploaded_data
783
 
784
+ # Encode labels to integers
785
+ label_encoder = LabelEncoder()
786
+ encoded_labels = label_encoder.fit_transform(df['label'])
787
+
788
+ # Store label encoder for later use
789
+ st.session_state.label_encoder = label_encoder
790
+ st.session_state.label_classes = label_encoder.classes_.tolist()
791
+
792
  # Initialize trainer with absolute path
793
  import os
794
  st.session_state.config.output_dir = os.path.abspath("trained_models")
 
815
  # Train model
816
  result = trainer.train(
817
  texts=df['text'].tolist(),
818
+ labels=encoded_labels.tolist(),
819
  progress_callback=progress_callback
820
  )
821
 
 
834
  progress_bar.progress(1.0)
835
  status_text.empty()
836
 
837
+ st.success("Training completed successfully!")
838
  add_log(f"Training completed successfully. Model saved to: {result.model_path}")
839
 
840
  # Show final metrics
841
  if result.final_metrics:
842
+ st.markdown("### Final Training Metrics")
843
  metrics = result.final_metrics.to_dict()
844
 
845
  col1, col2, col3, col4 = st.columns(4)
 
854
  except Exception as e:
855
  import traceback
856
  error_details = traceback.format_exc()
857
+ st.error(f"Training failed: {str(e)}")
858
+ with st.expander("Error Details"):
859
  st.code(error_details)
860
  st.session_state.training_started = False
861
  add_log(f"Training failed: {str(e)}")
862
 
863
  if st.session_state.training_completed:
864
+ st.success("Training completed!")
865
 
866
+ model_path_display = st.session_state.model_path if st.session_state.model_path else "Path not available"
867
 
868
  create_info_box(
869
+ f"<b>Model trained successfully!</b><br>"
870
  f"Model saved to: <code>{model_path_display}</code><br>"
871
  "Proceed to the <b>Evaluate Model</b> tab to analyze performance.",
872
  "success" if st.session_state.model_path else "warning"
873
  )
874
 
875
  # Show training logs
876
+ with st.expander("View Training Logs"):
877
  for log in st.session_state.training_logs[-20:]: # Show last 20 logs
878
  st.text(log)
879
 
880
  # ==================== TAB 5: Evaluate Model ====================
881
 
882
  with tab5:
883
+ st.markdown("## Evaluate Model Performance")
884
 
885
  if not st.session_state.training_completed:
886
  create_info_box(
887
+ "Please train a model first in the <b>Train Model</b> tab.",
888
  "warning"
889
  )
890
  st.stop()
891
 
892
  create_info_box(
893
+ "<b>Model Evaluation:</b><br>"
894
  "Analyze your model's performance with detailed metrics and visualizations.",
895
  "info"
896
  )
897
 
898
  if st.session_state.evaluation_results is None:
899
+ if st.button("Evaluate Model", type="primary", width="stretch"):
900
  with st.spinner("Evaluating model..."):
901
  try:
902
  # Initialize evaluator
 
910
  test_size = int(len(df) * st.session_state.config.validation_split)
911
  test_df = df.tail(test_size)
912
 
913
+ # Encode labels using the same encoder from training
914
+ if 'label_encoder' in st.session_state:
915
+ test_labels_encoded = st.session_state.label_encoder.transform(test_df['label']).tolist()
916
+ else:
917
+ # Fallback: create new encoder if not available
918
+ from sklearn.preprocessing import LabelEncoder
919
+ label_encoder = LabelEncoder()
920
+ test_labels_encoded = label_encoder.fit_transform(test_df['label']).tolist()
921
+
922
  # Evaluate
923
  results = evaluator.evaluate(
924
  texts=test_df['text'].tolist(),
925
+ true_labels=test_labels_encoded,
926
  batch_size=st.session_state.config.batch_size
927
  )
928
 
 
937
  results = st.session_state.evaluation_results
938
 
939
  # Overall metrics
940
+ st.markdown("### Overall Metrics")
941
 
942
  col1, col2, col3, col4 = st.columns(4)
943
  col1.metric("Accuracy", f"{results['accuracy']:.2%}")
 
948
  st.markdown("---")
949
 
950
  # Confusion Matrix
951
+ st.markdown("### Confusion Matrix")
952
 
953
  if 'confusion_matrix' in results:
954
  cm = results['confusion_matrix']
 
976
  st.markdown("---")
977
 
978
  # Classification Report
979
+ st.markdown("### Detailed Classification Report")
980
 
981
  if 'classification_report' in results:
982
  report = results['classification_report']
 
984
 
985
  # Download results
986
  st.markdown("---")
987
+ st.markdown("### Download Results")
988
 
989
+ if st.button("Download Evaluation Report", width="stretch"):
990
  # Create downloadable report
991
  report_text = f"""
992
  MLOps Training Platform - Evaluation Report
 
1013
  """
1014
 
1015
  st.download_button(
1016
+ label="Download Text Report",
1017
  data=report_text,
1018
  file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
1019
  mime="text/plain"