Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,7 @@ import streamlit as st
|
|
| 25 |
import pandas as pd
|
| 26 |
import plotly.express as px
|
| 27 |
import plotly.graph_objects as go
|
|
|
|
| 28 |
|
| 29 |
# Add src directory to path for imports
|
| 30 |
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
|
@@ -143,6 +144,8 @@ def init_session_state():
|
|
| 143 |
# Data
|
| 144 |
'uploaded_data': None,
|
| 145 |
'preprocessed_data': None,
|
|
|
|
|
|
|
| 146 |
|
| 147 |
# Evaluation
|
| 148 |
'evaluation_results': None,
|
|
@@ -309,10 +312,10 @@ tab1, tab2, tab3, tab4, tab5 = st.tabs([
|
|
| 309 |
# ==================== TAB 1: Prerequisites ====================
|
| 310 |
|
| 311 |
with tab1:
|
| 312 |
-
st.markdown("##
|
| 313 |
|
| 314 |
create_info_box(
|
| 315 |
-
"
|
| 316 |
"This ensures your system is properly configured and all required models are downloaded.",
|
| 317 |
"warning"
|
| 318 |
)
|
|
@@ -321,13 +324,13 @@ with tab1:
|
|
| 321 |
system_checker = SystemChecker(models_dir="models")
|
| 322 |
|
| 323 |
# ===== CUDA/GPU Check =====
|
| 324 |
-
st.markdown("###
|
| 325 |
|
| 326 |
col1, col2 = st.columns([3, 1])
|
| 327 |
with col1:
|
| 328 |
st.markdown("Check if CUDA-capable GPU is available for faster training.")
|
| 329 |
with col2:
|
| 330 |
-
if st.button("
|
| 331 |
with st.spinner("Checking CUDA availability..."):
|
| 332 |
cuda_status = system_checker.check_cuda()
|
| 333 |
st.session_state.cuda_status = cuda_status
|
|
@@ -337,24 +340,24 @@ with tab1:
|
|
| 337 |
cuda = st.session_state.cuda_status
|
| 338 |
|
| 339 |
if cuda['available']:
|
| 340 |
-
st.success(f"
|
| 341 |
|
| 342 |
for device in cuda['devices']:
|
| 343 |
-
with st.expander(f"
|
| 344 |
col1, col2, col3 = st.columns(3)
|
| 345 |
col1.metric("Memory", f"{device['memory_total']:.2f} GB")
|
| 346 |
col2.metric("Compute", device['compute_capability'])
|
| 347 |
col3.metric("CUDA Version", cuda['cuda_version'])
|
| 348 |
|
| 349 |
create_info_box(
|
| 350 |
-
"
|
| 351 |
"You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.",
|
| 352 |
"success"
|
| 353 |
)
|
| 354 |
else:
|
| 355 |
-
st.warning("
|
| 356 |
create_info_box(
|
| 357 |
-
"
|
| 358 |
"as it's significantly faster while maintaining good accuracy.",
|
| 359 |
"warning"
|
| 360 |
)
|
|
@@ -362,13 +365,13 @@ with tab1:
|
|
| 362 |
st.markdown("---")
|
| 363 |
|
| 364 |
# ===== Environment Check =====
|
| 365 |
-
st.markdown("###
|
| 366 |
|
| 367 |
col1, col2 = st.columns([3, 1])
|
| 368 |
with col1:
|
| 369 |
st.markdown("Verify all required Python packages are installed with correct versions.")
|
| 370 |
with col2:
|
| 371 |
-
if st.button("
|
| 372 |
with st.spinner("Checking environment..."):
|
| 373 |
env_status = system_checker.check_environment()
|
| 374 |
st.session_state.env_status = env_status
|
|
@@ -378,22 +381,22 @@ with tab1:
|
|
| 378 |
env = st.session_state.env_status
|
| 379 |
|
| 380 |
if env['all_satisfied']:
|
| 381 |
-
st.success("
|
| 382 |
else:
|
| 383 |
-
st.error(f"
|
| 384 |
create_info_box(
|
| 385 |
f"<b>To install missing packages, run:</b><br>"
|
| 386 |
f"<code>pip install {' '.join(env['missing_packages'])}</code>",
|
| 387 |
"error"
|
| 388 |
)
|
| 389 |
|
| 390 |
-
with st.expander("
|
| 391 |
package_df = pd.DataFrame([
|
| 392 |
{
|
| 393 |
'Package': pkg,
|
| 394 |
'Installed': info['installed'] or 'Not Installed',
|
| 395 |
'Required': info['required'],
|
| 396 |
-
'Status': '
|
| 397 |
}
|
| 398 |
for pkg, info in env['packages'].items()
|
| 399 |
])
|
|
@@ -402,10 +405,10 @@ with tab1:
|
|
| 402 |
st.markdown("---")
|
| 403 |
|
| 404 |
# ===== Model Selection Guide =====
|
| 405 |
-
st.markdown("###
|
| 406 |
|
| 407 |
create_info_box(
|
| 408 |
-
"
|
| 409 |
"Consider these factors:<br>"
|
| 410 |
"• <b>Language:</b> English only or multilingual?<br>"
|
| 411 |
"• <b>Hardware:</b> GPU available or CPU only?<br>"
|
|
@@ -430,27 +433,27 @@ with tab1:
|
|
| 430 |
st.dataframe(model_df, width="stretch", hide_index=True)
|
| 431 |
|
| 432 |
# Quick recommendations
|
| 433 |
-
st.markdown("####
|
| 434 |
|
| 435 |
rec_col1, rec_col2 = st.columns(2)
|
| 436 |
|
| 437 |
with rec_col1:
|
| 438 |
st.markdown("**For GPU Training:**")
|
| 439 |
-
st.markdown("-
|
| 440 |
-
st.markdown("-
|
| 441 |
|
| 442 |
with rec_col2:
|
| 443 |
st.markdown("**For CPU Training:**")
|
| 444 |
-
st.markdown("-
|
| 445 |
-
st.markdown("-
|
| 446 |
|
| 447 |
st.markdown("---")
|
| 448 |
|
| 449 |
# ===== Model Download =====
|
| 450 |
-
st.markdown("###
|
| 451 |
|
| 452 |
create_info_box(
|
| 453 |
-
"
|
| 454 |
"Models will be downloaded to the <code>models/</code> directory. "
|
| 455 |
"This may take several minutes depending on your internet connection.",
|
| 456 |
"info"
|
|
@@ -466,7 +469,7 @@ with tab1:
|
|
| 466 |
|
| 467 |
col1, col2 = st.columns([3, 1])
|
| 468 |
with col2:
|
| 469 |
-
download_btn = st.button("
|
| 470 |
|
| 471 |
if download_btn:
|
| 472 |
progress_bar = st.progress(0)
|
|
@@ -488,16 +491,16 @@ with tab1:
|
|
| 488 |
st.error(f"Failed to download {model_id}: {message}")
|
| 489 |
|
| 490 |
progress_bar.progress(1.0)
|
| 491 |
-
status_text.text("
|
| 492 |
time.sleep(1)
|
| 493 |
st.rerun()
|
| 494 |
|
| 495 |
# Show downloaded models
|
| 496 |
if st.session_state.models_downloaded:
|
| 497 |
-
st.markdown("####
|
| 498 |
for model_id in st.session_state.models_downloaded:
|
| 499 |
model_info = system_checker.get_model_info(model_id)
|
| 500 |
-
st.success(f"
|
| 501 |
|
| 502 |
st.markdown("---")
|
| 503 |
|
|
@@ -510,35 +513,35 @@ with tab1:
|
|
| 510 |
)
|
| 511 |
|
| 512 |
if can_proceed:
|
| 513 |
-
if st.button("
|
| 514 |
st.session_state.prerequisites_checked = True
|
| 515 |
add_log("Prerequisites check completed successfully")
|
| 516 |
-
st.success("
|
| 517 |
time.sleep(1)
|
| 518 |
st.rerun()
|
| 519 |
else:
|
| 520 |
create_info_box(
|
| 521 |
-
"
|
| 522 |
-
"
|
| 523 |
-
"
|
| 524 |
-
"
|
| 525 |
"warning"
|
| 526 |
)
|
| 527 |
|
| 528 |
# ==================== TAB 2: Upload Data ====================
|
| 529 |
|
| 530 |
with tab2:
|
| 531 |
-
st.markdown("##
|
| 532 |
|
| 533 |
if not st.session_state.prerequisites_checked:
|
| 534 |
create_info_box(
|
| 535 |
-
"
|
| 536 |
"warning"
|
| 537 |
)
|
| 538 |
st.stop()
|
| 539 |
|
| 540 |
create_info_box(
|
| 541 |
-
"
|
| 542 |
"• CSV file with at least two columns: text and label<br>"
|
| 543 |
"• Text column: Contains the text samples to classify<br>"
|
| 544 |
"• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>"
|
|
@@ -559,17 +562,17 @@ with tab2:
|
|
| 559 |
df = pd.read_csv(uploaded_file)
|
| 560 |
st.session_state.uploaded_data = df
|
| 561 |
|
| 562 |
-
st.success(f"
|
| 563 |
|
| 564 |
# Validate data
|
| 565 |
validator = DataValidator()
|
| 566 |
is_valid, message = validator.validate_dataframe(df)
|
| 567 |
|
| 568 |
if is_valid:
|
| 569 |
-
st.success(f"
|
| 570 |
|
| 571 |
# Show data preview
|
| 572 |
-
st.markdown("###
|
| 573 |
st.dataframe(df.head(10), width="stretch")
|
| 574 |
|
| 575 |
# Show statistics
|
|
@@ -579,7 +582,7 @@ with tab2:
|
|
| 579 |
col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object']))
|
| 580 |
|
| 581 |
# Label distribution
|
| 582 |
-
st.markdown("###
|
| 583 |
label_counts = df['label'].value_counts()
|
| 584 |
fig = px.bar(
|
| 585 |
x=label_counts.index.astype(str),
|
|
@@ -593,12 +596,12 @@ with tab2:
|
|
| 593 |
if st.session_state.classification_type == ClassificationType.MULTICLASS:
|
| 594 |
num_classes = df['label'].nunique()
|
| 595 |
st.session_state.config.num_labels = num_classes
|
| 596 |
-
st.info(f"
|
| 597 |
|
| 598 |
add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels")
|
| 599 |
|
| 600 |
else:
|
| 601 |
-
st.error(f"
|
| 602 |
|
| 603 |
except Exception as e:
|
| 604 |
st.error(f"Error reading file: {str(e)}")
|
|
@@ -606,28 +609,28 @@ with tab2:
|
|
| 606 |
# ==================== TAB 3: Configure Training ====================
|
| 607 |
|
| 608 |
with tab3:
|
| 609 |
-
st.markdown("##
|
| 610 |
|
| 611 |
if st.session_state.uploaded_data is None:
|
| 612 |
create_info_box(
|
| 613 |
-
"
|
| 614 |
"warning"
|
| 615 |
)
|
| 616 |
st.stop()
|
| 617 |
|
| 618 |
create_info_box(
|
| 619 |
-
"
|
| 620 |
"Adjust the parameters below based on your needs. Hover over ⓘ for explanations.",
|
| 621 |
"info"
|
| 622 |
)
|
| 623 |
|
| 624 |
# Model selection
|
| 625 |
-
st.markdown("###
|
| 626 |
|
| 627 |
available_models = list(st.session_state.models_downloaded)
|
| 628 |
|
| 629 |
if not available_models:
|
| 630 |
-
st.error("
|
| 631 |
st.stop()
|
| 632 |
|
| 633 |
selected_model = st.selectbox(
|
|
@@ -652,7 +655,7 @@ with tab3:
|
|
| 652 |
st.markdown("---")
|
| 653 |
|
| 654 |
# Training parameters
|
| 655 |
-
st.markdown("###
|
| 656 |
|
| 657 |
col1, col2 = st.columns(2)
|
| 658 |
|
|
@@ -722,7 +725,7 @@ with tab3:
|
|
| 722 |
st.markdown("---")
|
| 723 |
|
| 724 |
# Show configuration summary
|
| 725 |
-
st.markdown("###
|
| 726 |
|
| 727 |
config_summary = {
|
| 728 |
"Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class",
|
|
@@ -744,29 +747,29 @@ with tab3:
|
|
| 744 |
# ==================== TAB 4: Train Model ====================
|
| 745 |
|
| 746 |
with tab4:
|
| 747 |
-
st.markdown("##
|
| 748 |
|
| 749 |
if st.session_state.uploaded_data is None:
|
| 750 |
create_info_box(
|
| 751 |
-
"
|
| 752 |
"warning"
|
| 753 |
)
|
| 754 |
st.stop()
|
| 755 |
|
| 756 |
if not st.session_state.training_started:
|
| 757 |
create_info_box(
|
| 758 |
-
"
|
| 759 |
f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples "
|
| 760 |
f"for {st.session_state.config.num_epochs} epochs.",
|
| 761 |
"info"
|
| 762 |
)
|
| 763 |
|
| 764 |
-
if st.button("
|
| 765 |
st.session_state.training_started = True
|
| 766 |
st.rerun()
|
| 767 |
|
| 768 |
if st.session_state.training_started and not st.session_state.training_completed:
|
| 769 |
-
st.markdown("###
|
| 770 |
|
| 771 |
# Progress display
|
| 772 |
progress_bar = st.progress(0)
|
|
@@ -778,6 +781,14 @@ with tab4:
|
|
| 778 |
status_text.text("Preparing data...")
|
| 779 |
df = st.session_state.uploaded_data
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
# Initialize trainer with absolute path
|
| 782 |
import os
|
| 783 |
st.session_state.config.output_dir = os.path.abspath("trained_models")
|
|
@@ -804,7 +815,7 @@ with tab4:
|
|
| 804 |
# Train model
|
| 805 |
result = trainer.train(
|
| 806 |
texts=df['text'].tolist(),
|
| 807 |
-
labels=
|
| 808 |
progress_callback=progress_callback
|
| 809 |
)
|
| 810 |
|
|
@@ -823,12 +834,12 @@ with tab4:
|
|
| 823 |
progress_bar.progress(1.0)
|
| 824 |
status_text.empty()
|
| 825 |
|
| 826 |
-
st.success("
|
| 827 |
add_log(f"Training completed successfully. Model saved to: {result.model_path}")
|
| 828 |
|
| 829 |
# Show final metrics
|
| 830 |
if result.final_metrics:
|
| 831 |
-
st.markdown("###
|
| 832 |
metrics = result.final_metrics.to_dict()
|
| 833 |
|
| 834 |
col1, col2, col3, col4 = st.columns(4)
|
|
@@ -843,49 +854,49 @@ with tab4:
|
|
| 843 |
except Exception as e:
|
| 844 |
import traceback
|
| 845 |
error_details = traceback.format_exc()
|
| 846 |
-
st.error(f"
|
| 847 |
-
with st.expander("
|
| 848 |
st.code(error_details)
|
| 849 |
st.session_state.training_started = False
|
| 850 |
add_log(f"Training failed: {str(e)}")
|
| 851 |
|
| 852 |
if st.session_state.training_completed:
|
| 853 |
-
st.success("
|
| 854 |
|
| 855 |
-
model_path_display = st.session_state.model_path if st.session_state.model_path else "
|
| 856 |
|
| 857 |
create_info_box(
|
| 858 |
-
f"
|
| 859 |
f"Model saved to: <code>{model_path_display}</code><br>"
|
| 860 |
"Proceed to the <b>Evaluate Model</b> tab to analyze performance.",
|
| 861 |
"success" if st.session_state.model_path else "warning"
|
| 862 |
)
|
| 863 |
|
| 864 |
# Show training logs
|
| 865 |
-
with st.expander("
|
| 866 |
for log in st.session_state.training_logs[-20:]: # Show last 20 logs
|
| 867 |
st.text(log)
|
| 868 |
|
| 869 |
# ==================== TAB 5: Evaluate Model ====================
|
| 870 |
|
| 871 |
with tab5:
|
| 872 |
-
st.markdown("##
|
| 873 |
|
| 874 |
if not st.session_state.training_completed:
|
| 875 |
create_info_box(
|
| 876 |
-
"
|
| 877 |
"warning"
|
| 878 |
)
|
| 879 |
st.stop()
|
| 880 |
|
| 881 |
create_info_box(
|
| 882 |
-
"
|
| 883 |
"Analyze your model's performance with detailed metrics and visualizations.",
|
| 884 |
"info"
|
| 885 |
)
|
| 886 |
|
| 887 |
if st.session_state.evaluation_results is None:
|
| 888 |
-
if st.button("
|
| 889 |
with st.spinner("Evaluating model..."):
|
| 890 |
try:
|
| 891 |
# Initialize evaluator
|
|
@@ -899,10 +910,19 @@ with tab5:
|
|
| 899 |
test_size = int(len(df) * st.session_state.config.validation_split)
|
| 900 |
test_df = df.tail(test_size)
|
| 901 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
# Evaluate
|
| 903 |
results = evaluator.evaluate(
|
| 904 |
texts=test_df['text'].tolist(),
|
| 905 |
-
true_labels=
|
| 906 |
batch_size=st.session_state.config.batch_size
|
| 907 |
)
|
| 908 |
|
|
@@ -917,7 +937,7 @@ with tab5:
|
|
| 917 |
results = st.session_state.evaluation_results
|
| 918 |
|
| 919 |
# Overall metrics
|
| 920 |
-
st.markdown("###
|
| 921 |
|
| 922 |
col1, col2, col3, col4 = st.columns(4)
|
| 923 |
col1.metric("Accuracy", f"{results['accuracy']:.2%}")
|
|
@@ -928,7 +948,7 @@ with tab5:
|
|
| 928 |
st.markdown("---")
|
| 929 |
|
| 930 |
# Confusion Matrix
|
| 931 |
-
st.markdown("###
|
| 932 |
|
| 933 |
if 'confusion_matrix' in results:
|
| 934 |
cm = results['confusion_matrix']
|
|
@@ -956,7 +976,7 @@ with tab5:
|
|
| 956 |
st.markdown("---")
|
| 957 |
|
| 958 |
# Classification Report
|
| 959 |
-
st.markdown("###
|
| 960 |
|
| 961 |
if 'classification_report' in results:
|
| 962 |
report = results['classification_report']
|
|
@@ -964,9 +984,9 @@ with tab5:
|
|
| 964 |
|
| 965 |
# Download results
|
| 966 |
st.markdown("---")
|
| 967 |
-
st.markdown("###
|
| 968 |
|
| 969 |
-
if st.button("
|
| 970 |
# Create downloadable report
|
| 971 |
report_text = f"""
|
| 972 |
MLOps Training Platform - Evaluation Report
|
|
@@ -993,7 +1013,7 @@ Training Configuration:
|
|
| 993 |
"""
|
| 994 |
|
| 995 |
st.download_button(
|
| 996 |
-
label="
|
| 997 |
data=report_text,
|
| 998 |
file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
|
| 999 |
mime="text/plain"
|
|
|
|
| 25 |
import pandas as pd
|
| 26 |
import plotly.express as px
|
| 27 |
import plotly.graph_objects as go
|
| 28 |
+
from sklearn.preprocessing import LabelEncoder
|
| 29 |
|
| 30 |
# Add src directory to path for imports
|
| 31 |
sys.path.insert(0, str(Path(__file__).parent / 'src'))
|
|
|
|
| 144 |
# Data
|
| 145 |
'uploaded_data': None,
|
| 146 |
'preprocessed_data': None,
|
| 147 |
+
'label_encoder': None,
|
| 148 |
+
'label_classes': None,
|
| 149 |
|
| 150 |
# Evaluation
|
| 151 |
'evaluation_results': None,
|
|
|
|
| 312 |
# ==================== TAB 1: Prerequisites ====================
|
| 313 |
|
| 314 |
with tab1:
|
| 315 |
+
st.markdown("## System Prerequisites")
|
| 316 |
|
| 317 |
create_info_box(
|
| 318 |
+
"<b>Important:</b> Complete all prerequisite checks before proceeding to training.<br>"
|
| 319 |
"This ensures your system is properly configured and all required models are downloaded.",
|
| 320 |
"warning"
|
| 321 |
)
|
|
|
|
| 324 |
system_checker = SystemChecker(models_dir="models")
|
| 325 |
|
| 326 |
# ===== CUDA/GPU Check =====
|
| 327 |
+
st.markdown("### 1. CUDA/GPU Check")
|
| 328 |
|
| 329 |
col1, col2 = st.columns([3, 1])
|
| 330 |
with col1:
|
| 331 |
st.markdown("Check if CUDA-capable GPU is available for faster training.")
|
| 332 |
with col2:
|
| 333 |
+
if st.button("Check CUDA", width="stretch"):
|
| 334 |
with st.spinner("Checking CUDA availability..."):
|
| 335 |
cuda_status = system_checker.check_cuda()
|
| 336 |
st.session_state.cuda_status = cuda_status
|
|
|
|
| 340 |
cuda = st.session_state.cuda_status
|
| 341 |
|
| 342 |
if cuda['available']:
|
| 343 |
+
st.success(f"CUDA Available - {cuda['device_count']} GPU(s) detected")
|
| 344 |
|
| 345 |
for device in cuda['devices']:
|
| 346 |
+
with st.expander(f"Device: {device['name']} Details"):
|
| 347 |
col1, col2, col3 = st.columns(3)
|
| 348 |
col1.metric("Memory", f"{device['memory_total']:.2f} GB")
|
| 349 |
col2.metric("Compute", device['compute_capability'])
|
| 350 |
col3.metric("CUDA Version", cuda['cuda_version'])
|
| 351 |
|
| 352 |
create_info_box(
|
| 353 |
+
"<b>Recommendation:</b> Your GPU is ready for training! "
|
| 354 |
"You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.",
|
| 355 |
"success"
|
| 356 |
)
|
| 357 |
else:
|
| 358 |
+
st.warning("No CUDA-capable GPU detected - Training will use CPU")
|
| 359 |
create_info_box(
|
| 360 |
+
"<b>Recommendation:</b> For CPU training, we recommend using <b>distilbert-base-multilingual-cased</b> "
|
| 361 |
"as it's significantly faster while maintaining good accuracy.",
|
| 362 |
"warning"
|
| 363 |
)
|
|
|
|
| 365 |
st.markdown("---")
|
| 366 |
|
| 367 |
# ===== Environment Check =====
|
| 368 |
+
st.markdown("### 2. Environment Check")
|
| 369 |
|
| 370 |
col1, col2 = st.columns([3, 1])
|
| 371 |
with col1:
|
| 372 |
st.markdown("Verify all required Python packages are installed with correct versions.")
|
| 373 |
with col2:
|
| 374 |
+
if st.button("Check Environment", width="stretch"):
|
| 375 |
with st.spinner("Checking environment..."):
|
| 376 |
env_status = system_checker.check_environment()
|
| 377 |
st.session_state.env_status = env_status
|
|
|
|
| 381 |
env = st.session_state.env_status
|
| 382 |
|
| 383 |
if env['all_satisfied']:
|
| 384 |
+
st.success("All required packages are installed")
|
| 385 |
else:
|
| 386 |
+
st.error(f"Missing packages: {', '.join(env['missing_packages'])}")
|
| 387 |
create_info_box(
|
| 388 |
f"<b>To install missing packages, run:</b><br>"
|
| 389 |
f"<code>pip install {' '.join(env['missing_packages'])}</code>",
|
| 390 |
"error"
|
| 391 |
)
|
| 392 |
|
| 393 |
+
with st.expander("View Package Details"):
|
| 394 |
package_df = pd.DataFrame([
|
| 395 |
{
|
| 396 |
'Package': pkg,
|
| 397 |
'Installed': info['installed'] or 'Not Installed',
|
| 398 |
'Required': info['required'],
|
| 399 |
+
'Status': 'OK' if info['satisfied'] else 'Missing'
|
| 400 |
}
|
| 401 |
for pkg, info in env['packages'].items()
|
| 402 |
])
|
|
|
|
| 405 |
st.markdown("---")
|
| 406 |
|
| 407 |
# ===== Model Selection Guide =====
|
| 408 |
+
st.markdown("### 3. Model Selection Guide")
|
| 409 |
|
| 410 |
create_info_box(
|
| 411 |
+
"<b>How to choose the right model:</b><br><br>"
|
| 412 |
"Consider these factors:<br>"
|
| 413 |
"• <b>Language:</b> English only or multilingual?<br>"
|
| 414 |
"• <b>Hardware:</b> GPU available or CPU only?<br>"
|
|
|
|
| 433 |
st.dataframe(model_df, width="stretch", hide_index=True)
|
| 434 |
|
| 435 |
# Quick recommendations
|
| 436 |
+
st.markdown("#### Quick Recommendations:")
|
| 437 |
|
| 438 |
rec_col1, rec_col2 = st.columns(2)
|
| 439 |
|
| 440 |
with rec_col1:
|
| 441 |
st.markdown("**For GPU Training:**")
|
| 442 |
+
st.markdown("- Best: `xlm-roberta-base` (highest accuracy)")
|
| 443 |
+
st.markdown("- Fast: `roberta-base` (English only)")
|
| 444 |
|
| 445 |
with rec_col2:
|
| 446 |
st.markdown("**For CPU Training:**")
|
| 447 |
+
st.markdown("- Recommended: `distilbert-base-multilingual-cased`")
|
| 448 |
+
st.markdown("- Fastest training and good performance")
|
| 449 |
|
| 450 |
st.markdown("---")
|
| 451 |
|
| 452 |
# ===== Model Download =====
|
| 453 |
+
st.markdown("### 4. Download Models")
|
| 454 |
|
| 455 |
create_info_box(
|
| 456 |
+
"<b>Download models before training:</b><br>"
|
| 457 |
"Models will be downloaded to the <code>models/</code> directory. "
|
| 458 |
"This may take several minutes depending on your internet connection.",
|
| 459 |
"info"
|
|
|
|
| 469 |
|
| 470 |
col1, col2 = st.columns([3, 1])
|
| 471 |
with col2:
|
| 472 |
+
download_btn = st.button("Download Selected", width="stretch", type="primary", disabled=len(selected_models) == 0)
|
| 473 |
|
| 474 |
if download_btn:
|
| 475 |
progress_bar = st.progress(0)
|
|
|
|
| 491 |
st.error(f"Failed to download {model_id}: {message}")
|
| 492 |
|
| 493 |
progress_bar.progress(1.0)
|
| 494 |
+
status_text.text("Download complete!")
|
| 495 |
time.sleep(1)
|
| 496 |
st.rerun()
|
| 497 |
|
| 498 |
# Show downloaded models
|
| 499 |
if st.session_state.models_downloaded:
|
| 500 |
+
st.markdown("#### Downloaded Models:")
|
| 501 |
for model_id in st.session_state.models_downloaded:
|
| 502 |
model_info = system_checker.get_model_info(model_id)
|
| 503 |
+
st.success(f"{MODEL_ARCHITECTURES[model_id]['name']} - {model_info['size_mb']:.0f} MB")
|
| 504 |
|
| 505 |
st.markdown("---")
|
| 506 |
|
|
|
|
| 513 |
)
|
| 514 |
|
| 515 |
if can_proceed:
|
| 516 |
+
if st.button("Prerequisites Complete - Proceed to Data Upload", width="stretch", type="primary"):
|
| 517 |
st.session_state.prerequisites_checked = True
|
| 518 |
add_log("Prerequisites check completed successfully")
|
| 519 |
+
st.success("All prerequisites satisfied! You can now proceed to upload your data.")
|
| 520 |
time.sleep(1)
|
| 521 |
st.rerun()
|
| 522 |
else:
|
| 523 |
create_info_box(
|
| 524 |
+
"<b>Complete all checks above before proceeding:</b><br>"
|
| 525 |
+
"- CUDA Check<br>"
|
| 526 |
+
"- Environment Check (all packages installed)<br>"
|
| 527 |
+
"- Download at least one model",
|
| 528 |
"warning"
|
| 529 |
)
|
| 530 |
|
| 531 |
# ==================== TAB 2: Upload Data ====================
|
| 532 |
|
| 533 |
with tab2:
|
| 534 |
+
st.markdown("## Upload Training Data")
|
| 535 |
|
| 536 |
if not st.session_state.prerequisites_checked:
|
| 537 |
create_info_box(
|
| 538 |
+
"Please complete the <b>Prerequisites</b> tab first before uploading data.",
|
| 539 |
"warning"
|
| 540 |
)
|
| 541 |
st.stop()
|
| 542 |
|
| 543 |
create_info_box(
|
| 544 |
+
"<b>Data Format Requirements:</b><br>"
|
| 545 |
"• CSV file with at least two columns: text and label<br>"
|
| 546 |
"• Text column: Contains the text samples to classify<br>"
|
| 547 |
"• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>"
|
|
|
|
| 562 |
df = pd.read_csv(uploaded_file)
|
| 563 |
st.session_state.uploaded_data = df
|
| 564 |
|
| 565 |
+
st.success(f"Uploaded {len(df)} samples")
|
| 566 |
|
| 567 |
# Validate data
|
| 568 |
validator = DataValidator()
|
| 569 |
is_valid, message = validator.validate_dataframe(df)
|
| 570 |
|
| 571 |
if is_valid:
|
| 572 |
+
st.success(f"Data validation passed: {message}")
|
| 573 |
|
| 574 |
# Show data preview
|
| 575 |
+
st.markdown("### Data Preview")
|
| 576 |
st.dataframe(df.head(10), width="stretch")
|
| 577 |
|
| 578 |
# Show statistics
|
|
|
|
| 582 |
col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object']))
|
| 583 |
|
| 584 |
# Label distribution
|
| 585 |
+
st.markdown("### Label Distribution")
|
| 586 |
label_counts = df['label'].value_counts()
|
| 587 |
fig = px.bar(
|
| 588 |
x=label_counts.index.astype(str),
|
|
|
|
| 596 |
if st.session_state.classification_type == ClassificationType.MULTICLASS:
|
| 597 |
num_classes = df['label'].nunique()
|
| 598 |
st.session_state.config.num_labels = num_classes
|
| 599 |
+
st.info(f"Detected {num_classes} classes for multi-class classification")
|
| 600 |
|
| 601 |
add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels")
|
| 602 |
|
| 603 |
else:
|
| 604 |
+
st.error(f"Data validation failed: {message}")
|
| 605 |
|
| 606 |
except Exception as e:
|
| 607 |
st.error(f"Error reading file: {str(e)}")
|
|
|
|
| 609 |
# ==================== TAB 3: Configure Training ====================
|
| 610 |
|
| 611 |
with tab3:
|
| 612 |
+
st.markdown("## Configure Training Parameters")
|
| 613 |
|
| 614 |
if st.session_state.uploaded_data is None:
|
| 615 |
create_info_box(
|
| 616 |
+
"Please upload your data in the <b>Upload Data</b> tab first.",
|
| 617 |
"warning"
|
| 618 |
)
|
| 619 |
st.stop()
|
| 620 |
|
| 621 |
create_info_box(
|
| 622 |
+
"<b>Configure your training settings:</b><br>"
|
| 623 |
"Adjust the parameters below based on your needs. Hover over ⓘ for explanations.",
|
| 624 |
"info"
|
| 625 |
)
|
| 626 |
|
| 627 |
# Model selection
|
| 628 |
+
st.markdown("### Model Selection")
|
| 629 |
|
| 630 |
available_models = list(st.session_state.models_downloaded)
|
| 631 |
|
| 632 |
if not available_models:
|
| 633 |
+
st.error("No models downloaded. Please download models in the Prerequisites tab.")
|
| 634 |
st.stop()
|
| 635 |
|
| 636 |
selected_model = st.selectbox(
|
|
|
|
| 655 |
st.markdown("---")
|
| 656 |
|
| 657 |
# Training parameters
|
| 658 |
+
st.markdown("### Training Parameters")
|
| 659 |
|
| 660 |
col1, col2 = st.columns(2)
|
| 661 |
|
|
|
|
| 725 |
st.markdown("---")
|
| 726 |
|
| 727 |
# Show configuration summary
|
| 728 |
+
st.markdown("### Configuration Summary")
|
| 729 |
|
| 730 |
config_summary = {
|
| 731 |
"Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class",
|
|
|
|
| 747 |
# ==================== TAB 4: Train Model ====================
|
| 748 |
|
| 749 |
with tab4:
|
| 750 |
+
st.markdown("## Train Your Model")
|
| 751 |
|
| 752 |
if st.session_state.uploaded_data is None:
|
| 753 |
create_info_box(
|
| 754 |
+
"Please complete previous steps first.",
|
| 755 |
"warning"
|
| 756 |
)
|
| 757 |
st.stop()
|
| 758 |
|
| 759 |
if not st.session_state.training_started:
|
| 760 |
create_info_box(
|
| 761 |
+
"<b>Ready to train!</b><br>"
|
| 762 |
f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples "
|
| 763 |
f"for {st.session_state.config.num_epochs} epochs.",
|
| 764 |
"info"
|
| 765 |
)
|
| 766 |
|
| 767 |
+
if st.button("Start Training", type="primary", width="stretch"):
|
| 768 |
st.session_state.training_started = True
|
| 769 |
st.rerun()
|
| 770 |
|
| 771 |
if st.session_state.training_started and not st.session_state.training_completed:
|
| 772 |
+
st.markdown("### Training in Progress...")
|
| 773 |
|
| 774 |
# Progress display
|
| 775 |
progress_bar = st.progress(0)
|
|
|
|
| 781 |
status_text.text("Preparing data...")
|
| 782 |
df = st.session_state.uploaded_data
|
| 783 |
|
| 784 |
+
# Encode labels to integers
|
| 785 |
+
label_encoder = LabelEncoder()
|
| 786 |
+
encoded_labels = label_encoder.fit_transform(df['label'])
|
| 787 |
+
|
| 788 |
+
# Store label encoder for later use
|
| 789 |
+
st.session_state.label_encoder = label_encoder
|
| 790 |
+
st.session_state.label_classes = label_encoder.classes_.tolist()
|
| 791 |
+
|
| 792 |
# Initialize trainer with absolute path
|
| 793 |
import os
|
| 794 |
st.session_state.config.output_dir = os.path.abspath("trained_models")
|
|
|
|
| 815 |
# Train model
|
| 816 |
result = trainer.train(
|
| 817 |
texts=df['text'].tolist(),
|
| 818 |
+
labels=encoded_labels.tolist(),
|
| 819 |
progress_callback=progress_callback
|
| 820 |
)
|
| 821 |
|
|
|
|
| 834 |
progress_bar.progress(1.0)
|
| 835 |
status_text.empty()
|
| 836 |
|
| 837 |
+
st.success("Training completed successfully!")
|
| 838 |
add_log(f"Training completed successfully. Model saved to: {result.model_path}")
|
| 839 |
|
| 840 |
# Show final metrics
|
| 841 |
if result.final_metrics:
|
| 842 |
+
st.markdown("### Final Training Metrics")
|
| 843 |
metrics = result.final_metrics.to_dict()
|
| 844 |
|
| 845 |
col1, col2, col3, col4 = st.columns(4)
|
|
|
|
| 854 |
except Exception as e:
|
| 855 |
import traceback
|
| 856 |
error_details = traceback.format_exc()
|
| 857 |
+
st.error(f"Training failed: {str(e)}")
|
| 858 |
+
with st.expander("Error Details"):
|
| 859 |
st.code(error_details)
|
| 860 |
st.session_state.training_started = False
|
| 861 |
add_log(f"Training failed: {str(e)}")
|
| 862 |
|
| 863 |
if st.session_state.training_completed:
|
| 864 |
+
st.success("Training completed!")
|
| 865 |
|
| 866 |
+
model_path_display = st.session_state.model_path if st.session_state.model_path else "Path not available"
|
| 867 |
|
| 868 |
create_info_box(
|
| 869 |
+
f"<b>Model trained successfully!</b><br>"
|
| 870 |
f"Model saved to: <code>{model_path_display}</code><br>"
|
| 871 |
"Proceed to the <b>Evaluate Model</b> tab to analyze performance.",
|
| 872 |
"success" if st.session_state.model_path else "warning"
|
| 873 |
)
|
| 874 |
|
| 875 |
# Show training logs
|
| 876 |
+
with st.expander("View Training Logs"):
|
| 877 |
for log in st.session_state.training_logs[-20:]: # Show last 20 logs
|
| 878 |
st.text(log)
|
| 879 |
|
| 880 |
# ==================== TAB 5: Evaluate Model ====================
|
| 881 |
|
| 882 |
with tab5:
|
| 883 |
+
st.markdown("## Evaluate Model Performance")
|
| 884 |
|
| 885 |
if not st.session_state.training_completed:
|
| 886 |
create_info_box(
|
| 887 |
+
"Please train a model first in the <b>Train Model</b> tab.",
|
| 888 |
"warning"
|
| 889 |
)
|
| 890 |
st.stop()
|
| 891 |
|
| 892 |
create_info_box(
|
| 893 |
+
"<b>Model Evaluation:</b><br>"
|
| 894 |
"Analyze your model's performance with detailed metrics and visualizations.",
|
| 895 |
"info"
|
| 896 |
)
|
| 897 |
|
| 898 |
if st.session_state.evaluation_results is None:
|
| 899 |
+
if st.button("Evaluate Model", type="primary", width="stretch"):
|
| 900 |
with st.spinner("Evaluating model..."):
|
| 901 |
try:
|
| 902 |
# Initialize evaluator
|
|
|
|
| 910 |
test_size = int(len(df) * st.session_state.config.validation_split)
|
| 911 |
test_df = df.tail(test_size)
|
| 912 |
|
| 913 |
+
# Encode labels using the same encoder from training
|
| 914 |
+
if 'label_encoder' in st.session_state:
|
| 915 |
+
test_labels_encoded = st.session_state.label_encoder.transform(test_df['label']).tolist()
|
| 916 |
+
else:
|
| 917 |
+
# Fallback: create new encoder if not available
|
| 918 |
+
from sklearn.preprocessing import LabelEncoder
|
| 919 |
+
label_encoder = LabelEncoder()
|
| 920 |
+
test_labels_encoded = label_encoder.fit_transform(test_df['label']).tolist()
|
| 921 |
+
|
| 922 |
# Evaluate
|
| 923 |
results = evaluator.evaluate(
|
| 924 |
texts=test_df['text'].tolist(),
|
| 925 |
+
true_labels=test_labels_encoded,
|
| 926 |
batch_size=st.session_state.config.batch_size
|
| 927 |
)
|
| 928 |
|
|
|
|
| 937 |
results = st.session_state.evaluation_results
|
| 938 |
|
| 939 |
# Overall metrics
|
| 940 |
+
st.markdown("### Overall Metrics")
|
| 941 |
|
| 942 |
col1, col2, col3, col4 = st.columns(4)
|
| 943 |
col1.metric("Accuracy", f"{results['accuracy']:.2%}")
|
|
|
|
| 948 |
st.markdown("---")
|
| 949 |
|
| 950 |
# Confusion Matrix
|
| 951 |
+
st.markdown("### Confusion Matrix")
|
| 952 |
|
| 953 |
if 'confusion_matrix' in results:
|
| 954 |
cm = results['confusion_matrix']
|
|
|
|
| 976 |
st.markdown("---")
|
| 977 |
|
| 978 |
# Classification Report
|
| 979 |
+
st.markdown("### Detailed Classification Report")
|
| 980 |
|
| 981 |
if 'classification_report' in results:
|
| 982 |
report = results['classification_report']
|
|
|
|
| 984 |
|
| 985 |
# Download results
|
| 986 |
st.markdown("---")
|
| 987 |
+
st.markdown("### Download Results")
|
| 988 |
|
| 989 |
+
if st.button("Download Evaluation Report", width="stretch"):
|
| 990 |
# Create downloadable report
|
| 991 |
report_text = f"""
|
| 992 |
MLOps Training Platform - Evaluation Report
|
|
|
|
| 1013 |
"""
|
| 1014 |
|
| 1015 |
st.download_button(
|
| 1016 |
+
label="Download Text Report",
|
| 1017 |
data=report_text,
|
| 1018 |
file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
|
| 1019 |
mime="text/plain"
|