Spaces:
Sleeping
Sleeping
| """ | |
| MLOps Training Platform - Streamlit Application | |
| ================================================== | |
| A beginner-friendly web interface for training text classification models | |
| with built-in system checks and model management. | |
| Run with: streamlit run streamlit_app.py | |
| """ | |
| # CRITICAL: Set these environment variables FIRST, before any other imports | |
| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| os.environ['TRANSFORMERS_NO_TF'] = '1' | |
| os.environ['USE_TF'] = '0' | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional, List | |
| import streamlit as st | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from sklearn.preprocessing import LabelEncoder | |
| # Add src directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent / 'src')) | |
| from mlops.config import ( | |
| TrainingConfig, | |
| MODEL_ARCHITECTURES, | |
| MODEL_SELECTION_GUIDE, | |
| ClassificationType | |
| ) | |
| from mlops.preprocessor import TextPreprocessor, DataValidator | |
| from mlops.trainer import ModelTrainer | |
| from mlops.evaluator import ModelEvaluator | |
| from mlops.system_check import SystemChecker, get_system_summary | |
| # ==================== Page Configuration ==================== | |
| st.set_page_config( | |
| page_title="MLOps Training Platform", | |
| page_icon="📊", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # ==================== Custom CSS ==================== | |
| st.markdown(""" | |
| <style> | |
| /* Main styling */ | |
| .main-header { | |
| font-size: 2.5rem; | |
| font-weight: 700; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| margin-bottom: 0.5rem; | |
| } | |
| .sub-header { | |
| font-size: 1.1rem; | |
| color: #666; | |
| margin-bottom: 2rem; | |
| } | |
| /* Info boxes */ | |
| .info-box { | |
| background-color: #f0f7ff; | |
| border-left: 4px solid #667eea; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 0 8px 8px 0; | |
| } | |
| .warning-box { | |
| background-color: #fff7e6; | |
| border-left: 4px solid #fa8c16; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 0 8px 8px 0; | |
| } | |
| .success-box { | |
| background-color: #f6ffed; | |
| border-left: 4px solid #52c41a; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 0 8px 8px 0; | |
| } | |
| .error-box { | |
| background-color: #fff1f0; | |
| border-left: 4px solid #ff4d4f; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| border-radius: 0 8px 8px 0; | |
| } | |
| /* Metric cards */ | |
| .metric-card { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| color: white; | |
| text-align: center; | |
| } | |
| /* Hide default elements */ | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ==================== Session State Initialization ==================== | |
| def init_session_state(): | |
| """Initialize all session state variables.""" | |
| defaults = { | |
| # Classification type selection | |
| 'classification_type': None, | |
| 'classification_type_selected': False, | |
| # Prerequisites | |
| 'prerequisites_checked': False, | |
| 'cuda_status': None, | |
| 'env_status': None, | |
| 'models_downloaded': set(), | |
| # Training state | |
| 'training_started': False, | |
| 'training_completed': False, | |
| 'training_progress': 0.0, | |
| 'training_logs': [], | |
| 'metrics_history': [], | |
| 'model_path': None, | |
| # Data | |
| 'uploaded_data': None, | |
| 'preprocessed_data': None, | |
| 'label_encoder': None, | |
| 'label_classes': None, | |
| # Evaluation | |
| 'evaluation_results': None, | |
| # Config | |
| 'config': TrainingConfig(), | |
| # Selected model | |
| 'selected_model': None | |
| } | |
| for key, value in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| init_session_state() | |
| # ==================== Helper Functions ==================== | |
| def add_log(message: str): | |
| """Add a log message with timestamp.""" | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| st.session_state.training_logs.append(f"[{timestamp}] {message}") | |
| def create_info_box(text: str, box_type: str = "info"): | |
| """Create a styled info box.""" | |
| st.markdown(f'<div class="{box_type}-box">{text}</div>', unsafe_allow_html=True) | |
| # ==================== Sidebar ==================== | |
| def render_sidebar(): | |
| """Render the sidebar with navigation and status.""" | |
| with st.sidebar: | |
| st.markdown('<h1 class="main-header">MLOps Platform</h1>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| # Classification Type Status | |
| st.subheader("Classification Type") | |
| if st.session_state.classification_type_selected: | |
| type_display = "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class" | |
| st.success(f"Selected: {type_display}") | |
| else: | |
| st.warning("Not selected") | |
| st.markdown("---") | |
| # Prerequisites Status | |
| st.subheader("Prerequisites") | |
| if st.session_state.prerequisites_checked: | |
| st.success("Checked") | |
| # CUDA Status | |
| if st.session_state.cuda_status: | |
| cuda = st.session_state.cuda_status | |
| if cuda['available']: | |
| st.info(f"GPU: {cuda['devices'][0]['name']}") | |
| else: | |
| st.info("CPU Mode") | |
| # Models downloaded | |
| if st.session_state.models_downloaded: | |
| st.info(f"Models: {len(st.session_state.models_downloaded)}") | |
| else: | |
| st.warning("Not checked") | |
| st.markdown("---") | |
| # Training Status | |
| st.subheader("Training Status") | |
| if st.session_state.training_completed: | |
| st.success("Completed") | |
| elif st.session_state.training_started: | |
| st.info(f"In Progress ({st.session_state.training_progress:.0f}%)") | |
| else: | |
| st.info("Not started") | |
| st.markdown("---") | |
| # Quick Actions | |
| st.subheader("Quick Actions") | |
| if st.button("Reset All", width="stretch"): | |
| for key in list(st.session_state.keys()): | |
| del st.session_state[key] | |
| init_session_state() | |
| st.rerun() | |
| render_sidebar() | |
| # ==================== Main Content ==================== | |
| # Header | |
| st.markdown('<h1 class="main-header">MLOps Training Platform</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="sub-header">Train and evaluate text classification models with ease</p>', unsafe_allow_html=True) | |
| # ==================== STEP 1: Classification Type Selection ==================== | |
| if not st.session_state.classification_type_selected: | |
| st.markdown("## Step 1: Choose Classification Type") | |
| create_info_box( | |
| "<b>First, select your classification task type:</b><br><br>" | |
| "• <b>Binary Classification:</b> Two classes (e.g., spam vs. not spam, positive vs. negative)<br>" | |
| "• <b>Multi-class Classification:</b> More than two classes (e.g., categorize news into politics, sports, entertainment, etc.)", | |
| "info" | |
| ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("### Binary Classification") | |
| st.markdown(""" | |
| **Use when you have:** | |
| - 2 categories/labels | |
| - Yes/No questions | |
| - Positive/Negative sentiment | |
| **Examples:** | |
| - Spam detection (spam/not spam) | |
| - Sentiment analysis (positive/negative) | |
| - Phishing detection (phishing/legitimate) | |
| """) | |
| if st.button("Select Binary Classification", width="stretch", type="primary"): | |
| st.session_state.classification_type = ClassificationType.BINARY | |
| st.session_state.classification_type_selected = True | |
| st.session_state.config.num_labels = 2 | |
| add_log("Selected Binary Classification") | |
| st.rerun() | |
| with col2: | |
| st.markdown("### Multi-class Classification") | |
| st.markdown(""" | |
| **Use when you have:** | |
| - 3+ categories/labels | |
| - Multiple distinct classes | |
| - Topic categorization | |
| **Examples:** | |
| - News categorization (politics/sports/tech/entertainment) | |
| - Product classification (electronics/clothing/books/toys) | |
| - Language detection (English/Chinese/Spanish/etc.) | |
| """) | |
| if st.button("Select Multi-class Classification", width="stretch"): | |
| st.session_state.classification_type = ClassificationType.MULTICLASS | |
| st.session_state.classification_type_selected = True | |
| # Will set num_labels after data upload when we know the number of classes | |
| add_log("Selected Multi-class Classification") | |
| st.rerun() | |
| st.stop() # Don't render rest of the app until classification type is selected | |
| # ==================== TABS FOR REST OF WORKFLOW ==================== | |
| tab1, tab2, tab3, tab4, tab5 = st.tabs([ | |
| "Prerequisites", | |
| "Upload Data", | |
| "Configure Training", | |
| "Train Model", | |
| "Evaluate Model" | |
| ]) | |
| # ==================== TAB 1: Prerequisites ==================== | |
| with tab1: | |
| st.markdown("## System Prerequisites") | |
| create_info_box( | |
| "<b>Important:</b> Complete all prerequisite checks before proceeding to training.<br>" | |
| "This ensures your system is properly configured and all required models are downloaded.", | |
| "warning" | |
| ) | |
| # Initialize system checker | |
| system_checker = SystemChecker(models_dir="models") | |
| # ===== CUDA/GPU Check ===== | |
| st.markdown("### 1. CUDA/GPU Check") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.markdown("Check if CUDA-capable GPU is available for faster training.") | |
| with col2: | |
| if st.button("Check CUDA", width="stretch"): | |
| with st.spinner("Checking CUDA availability..."): | |
| cuda_status = system_checker.check_cuda() | |
| st.session_state.cuda_status = cuda_status | |
| add_log("CUDA check completed") | |
| if st.session_state.cuda_status: | |
| cuda = st.session_state.cuda_status | |
| if cuda['available']: | |
| st.success(f"CUDA Available - {cuda['device_count']} GPU(s) detected") | |
| for device in cuda['devices']: | |
| with st.expander(f"Device: {device['name']} Details"): | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Memory", f"{device['memory_total']:.2f} GB") | |
| col2.metric("Compute", device['compute_capability']) | |
| col3.metric("CUDA Version", cuda['cuda_version']) | |
| create_info_box( | |
| "<b>Recommendation:</b> Your GPU is ready for training! " | |
| "You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.", | |
| "success" | |
| ) | |
| else: | |
| st.warning("No CUDA-capable GPU detected - Training will use CPU") | |
| create_info_box( | |
| "<b>Recommendation:</b> For CPU training, we recommend using <b>distilbert-base-multilingual-cased</b> " | |
| "as it's significantly faster while maintaining good accuracy.", | |
| "warning" | |
| ) | |
| st.markdown("---") | |
| # ===== Environment Check ===== | |
| st.markdown("### 2. Environment Check") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.markdown("Verify all required Python packages are installed with correct versions.") | |
| with col2: | |
| if st.button("Check Environment", width="stretch"): | |
| with st.spinner("Checking environment..."): | |
| env_status = system_checker.check_environment() | |
| st.session_state.env_status = env_status | |
| add_log("Environment check completed") | |
| if st.session_state.env_status: | |
| env = st.session_state.env_status | |
| if env['all_satisfied']: | |
| st.success("All required packages are installed") | |
| else: | |
| st.error(f"Missing packages: {', '.join(env['missing_packages'])}") | |
| create_info_box( | |
| f"<b>To install missing packages, run:</b><br>" | |
| f"<code>pip install {' '.join(env['missing_packages'])}</code>", | |
| "error" | |
| ) | |
| with st.expander("View Package Details"): | |
| package_df = pd.DataFrame([ | |
| { | |
| 'Package': pkg, | |
| 'Installed': info['installed'] or 'Not Installed', | |
| 'Required': info['required'], | |
| 'Status': 'OK' if info['satisfied'] else 'Missing' | |
| } | |
| for pkg, info in env['packages'].items() | |
| ]) | |
| st.dataframe(package_df, width="stretch", hide_index=True) | |
| st.markdown("---") | |
| # ===== Model Selection Guide ===== | |
| st.markdown("### 3. Model Selection Guide") | |
| create_info_box( | |
| "<b>How to choose the right model:</b><br><br>" | |
| "Consider these factors:<br>" | |
| "• <b>Language:</b> English only or multilingual?<br>" | |
| "• <b>Hardware:</b> GPU available or CPU only?<br>" | |
| "• <b>Speed vs Accuracy:</b> Need fast training or best accuracy?<br>" | |
| "• <b>Task Type:</b> Binary or multi-class classification?", | |
| "info" | |
| ) | |
| # Display model comparison table | |
| model_comparison = [] | |
| for model_id, model_info in MODEL_ARCHITECTURES.items(): | |
| model_comparison.append({ | |
| 'Model': model_info['name'], | |
| 'Languages': ', '.join(model_info['languages']), | |
| 'Speed': model_info['speed'], | |
| 'Size': model_info['size'], | |
| 'Best For': model_info['best_use'], | |
| 'ID': model_id | |
| }) | |
| model_df = pd.DataFrame(model_comparison) | |
| st.dataframe(model_df, width="stretch", hide_index=True) | |
| # Quick recommendations | |
| st.markdown("#### Quick Recommendations:") | |
| rec_col1, rec_col2 = st.columns(2) | |
| with rec_col1: | |
| st.markdown("**For GPU Training:**") | |
| st.markdown("- Best: `xlm-roberta-base` (highest accuracy)") | |
| st.markdown("- Fast: `roberta-base` (English only)") | |
| with rec_col2: | |
| st.markdown("**For CPU Training:**") | |
| st.markdown("- Recommended: `distilbert-base-multilingual-cased`") | |
| st.markdown("- Fastest training and good performance") | |
| st.markdown("---") | |
| # ===== Model Download ===== | |
| st.markdown("### 4. Download Models") | |
| create_info_box( | |
| "<b>Download models before training:</b><br>" | |
| "Models will be downloaded to the <code>models/</code> directory. " | |
| "This may take several minutes depending on your internet connection.", | |
| "info" | |
| ) | |
| # Model selection | |
| selected_models = st.multiselect( | |
| "Select models to download:", | |
| options=list(MODEL_ARCHITECTURES.keys()), | |
| format_func=lambda x: f"{MODEL_ARCHITECTURES[x]['name']} ({MODEL_ARCHITECTURES[x]['size']})", | |
| help="Select one or more models to download. You can train with any downloaded model later." | |
| ) | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| download_btn = st.button("Download Selected", width="stretch", type="primary", disabled=len(selected_models) == 0) | |
| if download_btn: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| for idx, model_id in enumerate(selected_models): | |
| status_text.text(f"Downloading {model_id}... ({idx + 1}/{len(selected_models)})") | |
| progress_bar.progress((idx) / len(selected_models)) | |
| success, path, message = system_checker.download_model( | |
| model_id, | |
| progress_callback=lambda msg, prog: None # Could add sub-progress here | |
| ) | |
| if success: | |
| st.session_state.models_downloaded.add(model_id) | |
| add_log(f"Downloaded model: {model_id}") | |
| else: | |
| st.error(f"Failed to download {model_id}: {message}") | |
| progress_bar.progress(1.0) | |
| status_text.text("Download complete!") | |
| time.sleep(1) | |
| st.rerun() | |
| # Show downloaded models | |
| if st.session_state.models_downloaded: | |
| st.markdown("#### Downloaded Models:") | |
| for model_id in st.session_state.models_downloaded: | |
| model_info = system_checker.get_model_info(model_id) | |
| st.success(f"{MODEL_ARCHITECTURES[model_id]['name']} - {model_info['size_mb']:.0f} MB") | |
| st.markdown("---") | |
| # ===== Prerequisites Complete Button ===== | |
| can_proceed = ( | |
| st.session_state.cuda_status is not None and | |
| st.session_state.env_status is not None and | |
| st.session_state.env_status['all_satisfied'] and | |
| len(st.session_state.models_downloaded) > 0 | |
| ) | |
| if can_proceed: | |
| if st.button("Prerequisites Complete - Proceed to Data Upload", width="stretch", type="primary"): | |
| st.session_state.prerequisites_checked = True | |
| add_log("Prerequisites check completed successfully") | |
| st.success("All prerequisites satisfied! You can now proceed to upload your data.") | |
| time.sleep(1) | |
| st.rerun() | |
| else: | |
| create_info_box( | |
| "<b>Complete all checks above before proceeding:</b><br>" | |
| "- CUDA Check<br>" | |
| "- Environment Check (all packages installed)<br>" | |
| "- Download at least one model", | |
| "warning" | |
| ) | |
| # ==================== TAB 2: Upload Data ==================== | |
| with tab2: | |
| st.markdown("## Upload Training Data") | |
| if not st.session_state.prerequisites_checked: | |
| create_info_box( | |
| "Please complete the <b>Prerequisites</b> tab first before uploading data.", | |
| "warning" | |
| ) | |
| st.stop() | |
| create_info_box( | |
| "<b>Data Format Requirements:</b><br>" | |
| "• CSV file with at least two columns: text and label<br>" | |
| "• Text column: Contains the text samples to classify<br>" | |
| "• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>" | |
| "• Minimum 20 samples recommended for training", | |
| "info" | |
| ) | |
| # File uploader | |
| uploaded_file = st.file_uploader( | |
| "Upload your CSV file", | |
| type=['csv'], | |
| help="Upload a CSV file with 'text' and 'label' columns" | |
| ) | |
| if uploaded_file is not None: | |
| try: | |
| # Read data | |
| df = pd.read_csv(uploaded_file) | |
| st.session_state.uploaded_data = df | |
| st.success(f"Uploaded {len(df)} samples") | |
| # Validate data | |
| validator = DataValidator() | |
| is_valid, message = validator.validate_dataframe(df) | |
| if is_valid: | |
| st.success(f"Data validation passed: {message}") | |
| # Show data preview | |
| st.markdown("### Data Preview") | |
| st.dataframe(df.head(10), width="stretch") | |
| # Show statistics | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Total Samples", len(df)) | |
| col2.metric("Unique Labels", df['label'].nunique()) | |
| col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object'])) | |
| # Label distribution | |
| st.markdown("### Label Distribution") | |
| label_counts = df['label'].value_counts() | |
| fig = px.bar( | |
| x=label_counts.index.astype(str), | |
| y=label_counts.values, | |
| labels={'x': 'Label', 'y': 'Count'}, | |
| title='Number of samples per label' | |
| ) | |
| st.plotly_chart(fig, width="stretch") | |
| # Update num_labels for multi-class | |
| if st.session_state.classification_type == ClassificationType.MULTICLASS: | |
| num_classes = df['label'].nunique() | |
| st.session_state.config.num_labels = num_classes | |
| st.info(f"Detected {num_classes} classes for multi-class classification") | |
| add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels") | |
| else: | |
| st.error(f"Data validation failed: {message}") | |
| except Exception as e: | |
| st.error(f"Error reading file: {str(e)}") | |
| # ==================== TAB 3: Configure Training ==================== | |
| with tab3: | |
| st.markdown("## Configure Training Parameters") | |
| if st.session_state.uploaded_data is None: | |
| create_info_box( | |
| "Please upload your data in the <b>Upload Data</b> tab first.", | |
| "warning" | |
| ) | |
| st.stop() | |
| create_info_box( | |
| "<b>Configure your training settings:</b><br>" | |
| "Adjust the parameters below based on your needs. Hover over ⓘ for explanations.", | |
| "info" | |
| ) | |
| # Model selection | |
| st.markdown("### Model Selection") | |
| available_models = list(st.session_state.models_downloaded) | |
| if not available_models: | |
| st.error("No models downloaded. Please download models in the Prerequisites tab.") | |
| st.stop() | |
| selected_model = st.selectbox( | |
| "Choose model:", | |
| options=available_models, | |
| format_func=lambda x: f"{MODEL_ARCHITECTURES[x]['name']} - {MODEL_ARCHITECTURES[x]['best_use']}", | |
| help="Select the model architecture to use for training" | |
| ) | |
| st.session_state.selected_model = selected_model | |
| st.session_state.config.model_name = selected_model | |
| # Show model info | |
| model_info = MODEL_ARCHITECTURES[selected_model] | |
| with st.expander("ℹ️ Selected Model Information"): | |
| st.markdown(f"**Name:** {model_info['name']}") | |
| st.markdown(f"**Description:** {model_info['description']}") | |
| st.markdown(f"**Best For:** {model_info['best_use']}") | |
| st.markdown(f"**Speed:** {model_info['speed']}") | |
| st.markdown(f"**Size:** {model_info['size']}") | |
| st.markdown("---") | |
| # Training parameters | |
| st.markdown("### Training Parameters") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| epochs = st.slider( | |
| "Number of Epochs", | |
| min_value=1, | |
| max_value=20, | |
| value=3, | |
| help="Number of complete passes through the training dataset. More epochs = longer training but potentially better performance." | |
| ) | |
| st.session_state.config.num_epochs = epochs | |
| batch_size = st.select_slider( | |
| "Batch Size", | |
| options=[4, 8, 16, 32, 64], | |
| value=16, | |
| help="Number of samples processed together. Larger batches train faster but require more GPU memory." | |
| ) | |
| st.session_state.config.batch_size = batch_size | |
| learning_rate = st.select_slider( | |
| "Learning Rate", | |
| options=[1e-5, 2e-5, 3e-5, 5e-5, 1e-4], | |
| value=2e-5, | |
| format_func=lambda x: f"{x:.0e}", | |
| help="Step size for model parameter updates. 2e-5 is a good default for BERT-like models." | |
| ) | |
| st.session_state.config.learning_rate = learning_rate | |
| with col2: | |
| max_length = st.slider( | |
| "Max Sequence Length", | |
| min_value=128, | |
| max_value=512, | |
| value=128, | |
| step=64, | |
| help="Maximum length of input text in tokens. Longer sequences require more memory." | |
| ) | |
| st.session_state.config.max_length = max_length | |
| val_split = st.select_slider( | |
| "Validation Split", | |
| options=[0.1, 0.15, 0.2, 0.25, 0.3], | |
| value=0.2, | |
| format_func=lambda x: f"{x*100:.0f}%", | |
| help="Percentage of data reserved for validation during training." | |
| ) | |
| st.session_state.config.validation_split = val_split | |
| st.session_state.config.train_split = 0.9 - val_split # Keep 0.1 for test | |
| early_stopping = st.checkbox( | |
| "Enable Early Stopping", | |
| value=True, | |
| help="Stop training automatically if validation performance stops improving." | |
| ) | |
| if early_stopping: | |
| patience = st.slider( | |
| "Early Stopping Patience", | |
| min_value=2, | |
| max_value=5, | |
| value=3, | |
| help="Number of epochs to wait before stopping if no improvement." | |
| ) | |
| st.markdown("---") | |
| # Show configuration summary | |
| st.markdown("### Configuration Summary") | |
| config_summary = { | |
| "Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class", | |
| "Number of Labels": st.session_state.config.num_labels, | |
| "Model": model_info['name'], | |
| "Epochs": epochs, | |
| "Batch Size": batch_size, | |
| "Learning Rate": f"{learning_rate:.0e}", | |
| "Max Length": max_length, | |
| "Validation Split": f"{val_split*100:.0f}%" | |
| } | |
| summary_df = pd.DataFrame([ | |
| {"Parameter": k, "Value": str(v)} | |
| for k, v in config_summary.items() | |
| ]) | |
| st.dataframe(summary_df, width="stretch", hide_index=True) | |
| # ==================== TAB 4: Train Model ==================== | |
| with tab4: | |
| st.markdown("## Train Your Model") | |
| if st.session_state.uploaded_data is None: | |
| create_info_box( | |
| "Please complete previous steps first.", | |
| "warning" | |
| ) | |
| st.stop() | |
| if not st.session_state.training_started: | |
| create_info_box( | |
| "<b>Ready to train!</b><br>" | |
| f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples " | |
| f"for {st.session_state.config.num_epochs} epochs.", | |
| "info" | |
| ) | |
| if st.button("Start Training", type="primary", width="stretch"): | |
| st.session_state.training_started = True | |
| st.rerun() | |
| if st.session_state.training_started and not st.session_state.training_completed: | |
| st.markdown("### Training in Progress...") | |
| # Progress display | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| metrics_container = st.container() | |
| try: | |
| # Prepare data | |
| status_text.text("Preparing data...") | |
| df = st.session_state.uploaded_data | |
| # Encode labels to integers | |
| label_encoder = LabelEncoder() | |
| encoded_labels = label_encoder.fit_transform(df['label']) | |
| # Store label encoder for later use | |
| st.session_state.label_encoder = label_encoder | |
| st.session_state.label_classes = label_encoder.classes_.tolist() | |
| # Initialize trainer with absolute path | |
| import os | |
| st.session_state.config.output_dir = os.path.abspath("trained_models") | |
| trainer = ModelTrainer(config=st.session_state.config) | |
| # Training progress callback - receives TrainingProgress object | |
| def progress_callback(progress_obj): | |
| if progress_obj.progress_percent > 0: | |
| progress_bar.progress(progress_obj.progress_percent / 100.0) | |
| status_text.text(f"Training: {progress_obj.progress_percent:.1f}% complete") | |
| st.session_state.training_progress = progress_obj.progress_percent | |
| # Update metrics display from latest metrics | |
| if progress_obj.metrics_history: | |
| latest_metrics = progress_obj.metrics_history[-1] | |
| with metrics_container: | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Epoch", f"{progress_obj.current_epoch}/{progress_obj.total_epochs}") | |
| col2.metric("Train Loss", f"{latest_metrics.train_loss:.4f}") | |
| if latest_metrics.eval_loss > 0: | |
| col3.metric("Val Loss", f"{latest_metrics.eval_loss:.4f}") | |
| # Train model | |
| result = trainer.train( | |
| texts=df['text'].tolist(), | |
| labels=encoded_labels.tolist(), | |
| progress_callback=progress_callback | |
| ) | |
| # Check if training actually succeeded | |
| if result.status == "failed": | |
| raise Exception(result.error_message or "Training failed with unknown error") | |
| if result.model_path is None: | |
| raise Exception("Training completed but model path is None. Check logs for errors.") | |
| # Training complete | |
| st.session_state.training_completed = True | |
| st.session_state.model_path = result.model_path | |
| st.session_state.metrics_history = [m.to_dict() for m in result.metrics_history] | |
| progress_bar.progress(1.0) | |
| status_text.empty() | |
| st.success("Training completed successfully!") | |
| add_log(f"Training completed successfully. Model saved to: {result.model_path}") | |
| # Show final metrics | |
| if result.final_metrics: | |
| st.markdown("### Final Training Metrics") | |
| metrics = result.final_metrics.to_dict() | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Accuracy", f"{metrics.get('accuracy', 0):.2%}") | |
| col2.metric("Precision", f"{metrics.get('precision', 0):.4f}") | |
| col3.metric("Recall", f"{metrics.get('recall', 0):.4f}") | |
| col4.metric("F1 Score", f"{metrics.get('f1', 0):.4f}") | |
| time.sleep(2) | |
| st.rerun() | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| st.error(f"Training failed: {str(e)}") | |
| with st.expander("Error Details"): | |
| st.code(error_details) | |
| st.session_state.training_started = False | |
| add_log(f"Training failed: {str(e)}") | |
| if st.session_state.training_completed: | |
| st.success("Training completed!") | |
| model_path_display = st.session_state.model_path if st.session_state.model_path else "Path not available" | |
| create_info_box( | |
| f"<b>Model trained successfully!</b><br>" | |
| f"Model saved to: <code>{model_path_display}</code><br>" | |
| "Proceed to the <b>Evaluate Model</b> tab to analyze performance.", | |
| "success" if st.session_state.model_path else "warning" | |
| ) | |
| # Show training logs | |
| with st.expander("View Training Logs"): | |
| for log in st.session_state.training_logs[-20:]: # Show last 20 logs | |
| st.text(log) | |
| # ==================== TAB 5: Evaluate Model ==================== | |
| with tab5: | |
| st.markdown("## Evaluate Model Performance") | |
| if not st.session_state.training_completed: | |
| create_info_box( | |
| "Please train a model first in the <b>Train Model</b> tab.", | |
| "warning" | |
| ) | |
| st.stop() | |
| create_info_box( | |
| "<b>Model Evaluation:</b><br>" | |
| "Analyze your model's performance with detailed metrics and visualizations.", | |
| "info" | |
| ) | |
| if st.session_state.evaluation_results is None: | |
| if st.button("Evaluate Model", type="primary", width="stretch"): | |
| with st.spinner("Evaluating model..."): | |
| try: | |
| # Initialize evaluator and load model | |
| evaluator = ModelEvaluator() | |
| # Load the trained model | |
| if not evaluator.load_model(st.session_state.model_path): | |
| st.error("Failed to load model for evaluation") | |
| st.stop() | |
| # Prepare test data (use validation split from uploaded data) | |
| df = st.session_state.uploaded_data | |
| test_size = int(len(df) * st.session_state.config.validation_split) | |
| test_df = df.tail(test_size) | |
| # Encode labels using the same encoder from training | |
| if 'label_encoder' in st.session_state: | |
| test_labels_encoded = st.session_state.label_encoder.transform(test_df['label']).tolist() | |
| else: | |
| # Fallback: create new encoder if not available | |
| from sklearn.preprocessing import LabelEncoder | |
| label_encoder = LabelEncoder() | |
| test_labels_encoded = label_encoder.fit_transform(test_df['label']).tolist() | |
| # Evaluate | |
| results = evaluator.evaluate( | |
| texts=test_df['text'].tolist(), | |
| true_labels=test_labels_encoded, | |
| batch_size=st.session_state.config.batch_size | |
| ) | |
| st.session_state.evaluation_results = results | |
| add_log("Model evaluation completed") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Evaluation failed: {str(e)}") | |
| if st.session_state.evaluation_results: | |
| results = st.session_state.evaluation_results | |
| # Overall metrics | |
| st.markdown("### Overall Metrics") | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Accuracy", f"{results['accuracy']:.2%}") | |
| col2.metric("Precision", f"{results['precision']:.4f}") | |
| col3.metric("Recall", f"{results['recall']:.4f}") | |
| col4.metric("F1 Score", f"{results['f1']:.4f}") | |
| st.markdown("---") | |
| # Confusion Matrix | |
| st.markdown("### Confusion Matrix") | |
| if 'confusion_matrix' in results: | |
| cm = results['confusion_matrix'] | |
| # Create heatmap | |
| fig = go.Figure(data=go.Heatmap( | |
| z=cm, | |
| x=[f"Predicted {i}" for i in range(len(cm))], | |
| y=[f"True {i}" for i in range(len(cm))], | |
| colorscale='Blues', | |
| text=cm, | |
| texttemplate="%{text}", | |
| textfont={"size": 16} | |
| )) | |
| fig.update_layout( | |
| title="Confusion Matrix", | |
| xaxis_title="Predicted Label", | |
| yaxis_title="True Label", | |
| height=500 | |
| ) | |
| st.plotly_chart(fig, width="stretch") | |
| st.markdown("---") | |
| # Classification Report | |
| st.markdown("### Detailed Classification Report") | |
| if 'classification_report' in results: | |
| report = results['classification_report'] | |
| st.text(report) | |
| # Download results | |
| st.markdown("---") | |
| st.markdown("### Download Results") | |
| if st.button("Download Evaluation Report", width="stretch"): | |
| # Create downloadable report | |
| report_text = f""" | |
| MLOps Training Platform - Evaluation Report | |
| {'='*60} | |
| Model: {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} | |
| Classification Type: {'Binary' if st.session_state.classification_type == ClassificationType.BINARY else 'Multi-class'} | |
| Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| Overall Metrics: | |
| - Accuracy: {results['accuracy']:.4f} | |
| - Precision: {results['precision']:.4f} | |
| - Recall: {results['recall']:.4f} | |
| - F1 Score: {results['f1']:.4f} | |
| Classification Report: | |
| {results.get('classification_report', 'N/A')} | |
| Training Configuration: | |
| - Epochs: {st.session_state.config.num_epochs} | |
| - Batch Size: {st.session_state.config.batch_size} | |
| - Learning Rate: {st.session_state.config.learning_rate} | |
| - Max Length: {st.session_state.config.max_length} | |
| """ | |
| st.download_button( | |
| label="Download Text Report", | |
| data=report_text, | |
| file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", | |
| mime="text/plain" | |
| ) | |
| # ==================== Footer ==================== | |
| st.markdown("---") | |
| st.markdown( | |
| """ | |
| <div style='text-align: center; color: #666; padding: 2rem;'> | |
| <p> MLOps Training Platform | Built with Streamlit & PyTorch</p> | |
| <p>For help and documentation, check the README.md file</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |