MLOps-Platforms / app.py
songhieng's picture
Update app.py
c3d7a4c verified
"""
MLOps Training Platform - Streamlit Application
==================================================
A beginner-friendly web interface for training text classification models
with built-in system checks and model management.
Run with: streamlit run streamlit_app.py
"""
# CRITICAL: Set these environment variables FIRST, before any other imports
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['USE_TF'] = '0'
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, List
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.preprocessing import LabelEncoder
# Add src directory to path for imports
sys.path.insert(0, str(Path(__file__).parent / 'src'))
from mlops.config import (
TrainingConfig,
MODEL_ARCHITECTURES,
MODEL_SELECTION_GUIDE,
ClassificationType
)
from mlops.preprocessor import TextPreprocessor, DataValidator
from mlops.trainer import ModelTrainer
from mlops.evaluator import ModelEvaluator
from mlops.system_check import SystemChecker, get_system_summary
# ==================== Page Configuration ====================
st.set_page_config(
page_title="MLOps Training Platform",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded"
)
# ==================== Custom CSS ====================
st.markdown("""
<style>
/* Main styling */
.main-header {
font-size: 2.5rem;
font-weight: 700;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 0.5rem;
}
.sub-header {
font-size: 1.1rem;
color: #666;
margin-bottom: 2rem;
}
/* Info boxes */
.info-box {
background-color: #f0f7ff;
border-left: 4px solid #667eea;
padding: 1rem;
margin: 1rem 0;
border-radius: 0 8px 8px 0;
}
.warning-box {
background-color: #fff7e6;
border-left: 4px solid #fa8c16;
padding: 1rem;
margin: 1rem 0;
border-radius: 0 8px 8px 0;
}
.success-box {
background-color: #f6ffed;
border-left: 4px solid #52c41a;
padding: 1rem;
margin: 1rem 0;
border-radius: 0 8px 8px 0;
}
.error-box {
background-color: #fff1f0;
border-left: 4px solid #ff4d4f;
padding: 1rem;
margin: 1rem 0;
border-radius: 0 8px 8px 0;
}
/* Metric cards */
.metric-card {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 1.5rem;
border-radius: 10px;
color: white;
text-align: center;
}
/* Hide default elements */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
# ==================== Session State Initialization ====================
def init_session_state():
"""Initialize all session state variables."""
defaults = {
# Classification type selection
'classification_type': None,
'classification_type_selected': False,
# Prerequisites
'prerequisites_checked': False,
'cuda_status': None,
'env_status': None,
'models_downloaded': set(),
# Training state
'training_started': False,
'training_completed': False,
'training_progress': 0.0,
'training_logs': [],
'metrics_history': [],
'model_path': None,
# Data
'uploaded_data': None,
'preprocessed_data': None,
'label_encoder': None,
'label_classes': None,
# Evaluation
'evaluation_results': None,
# Config
'config': TrainingConfig(),
# Selected model
'selected_model': None
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
init_session_state()
# ==================== Helper Functions ====================
def add_log(message: str):
"""Add a log message with timestamp."""
timestamp = datetime.now().strftime("%H:%M:%S")
st.session_state.training_logs.append(f"[{timestamp}] {message}")
def create_info_box(text: str, box_type: str = "info"):
"""Create a styled info box."""
st.markdown(f'<div class="{box_type}-box">{text}</div>', unsafe_allow_html=True)
# ==================== Sidebar ====================
def render_sidebar():
"""Render the sidebar with navigation and status."""
with st.sidebar:
st.markdown('<h1 class="main-header">MLOps Platform</h1>', unsafe_allow_html=True)
st.markdown("---")
# Classification Type Status
st.subheader("Classification Type")
if st.session_state.classification_type_selected:
type_display = "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class"
st.success(f"Selected: {type_display}")
else:
st.warning("Not selected")
st.markdown("---")
# Prerequisites Status
st.subheader("Prerequisites")
if st.session_state.prerequisites_checked:
st.success("Checked")
# CUDA Status
if st.session_state.cuda_status:
cuda = st.session_state.cuda_status
if cuda['available']:
st.info(f"GPU: {cuda['devices'][0]['name']}")
else:
st.info("CPU Mode")
# Models downloaded
if st.session_state.models_downloaded:
st.info(f"Models: {len(st.session_state.models_downloaded)}")
else:
st.warning("Not checked")
st.markdown("---")
# Training Status
st.subheader("Training Status")
if st.session_state.training_completed:
st.success("Completed")
elif st.session_state.training_started:
st.info(f"In Progress ({st.session_state.training_progress:.0f}%)")
else:
st.info("Not started")
st.markdown("---")
# Quick Actions
st.subheader("Quick Actions")
if st.button("Reset All", width="stretch"):
for key in list(st.session_state.keys()):
del st.session_state[key]
init_session_state()
st.rerun()
render_sidebar()
# ==================== Main Content ====================
# Header
st.markdown('<h1 class="main-header">MLOps Training Platform</h1>', unsafe_allow_html=True)
st.markdown('<p class="sub-header">Train and evaluate text classification models with ease</p>', unsafe_allow_html=True)
# ==================== STEP 1: Classification Type Selection ====================
if not st.session_state.classification_type_selected:
st.markdown("## Step 1: Choose Classification Type")
create_info_box(
"<b>First, select your classification task type:</b><br><br>"
"• <b>Binary Classification:</b> Two classes (e.g., spam vs. not spam, positive vs. negative)<br>"
"• <b>Multi-class Classification:</b> More than two classes (e.g., categorize news into politics, sports, entertainment, etc.)",
"info"
)
col1, col2 = st.columns(2)
with col1:
st.markdown("### Binary Classification")
st.markdown("""
**Use when you have:**
- 2 categories/labels
- Yes/No questions
- Positive/Negative sentiment
**Examples:**
- Spam detection (spam/not spam)
- Sentiment analysis (positive/negative)
- Phishing detection (phishing/legitimate)
""")
if st.button("Select Binary Classification", width="stretch", type="primary"):
st.session_state.classification_type = ClassificationType.BINARY
st.session_state.classification_type_selected = True
st.session_state.config.num_labels = 2
add_log("Selected Binary Classification")
st.rerun()
with col2:
st.markdown("### Multi-class Classification")
st.markdown("""
**Use when you have:**
- 3+ categories/labels
- Multiple distinct classes
- Topic categorization
**Examples:**
- News categorization (politics/sports/tech/entertainment)
- Product classification (electronics/clothing/books/toys)
- Language detection (English/Chinese/Spanish/etc.)
""")
if st.button("Select Multi-class Classification", width="stretch"):
st.session_state.classification_type = ClassificationType.MULTICLASS
st.session_state.classification_type_selected = True
# Will set num_labels after data upload when we know the number of classes
add_log("Selected Multi-class Classification")
st.rerun()
st.stop() # Don't render rest of the app until classification type is selected
# ==================== TABS FOR REST OF WORKFLOW ====================
tab1, tab2, tab3, tab4, tab5 = st.tabs([
"Prerequisites",
"Upload Data",
"Configure Training",
"Train Model",
"Evaluate Model"
])
# ==================== TAB 1: Prerequisites ====================
with tab1:
st.markdown("## System Prerequisites")
create_info_box(
"<b>Important:</b> Complete all prerequisite checks before proceeding to training.<br>"
"This ensures your system is properly configured and all required models are downloaded.",
"warning"
)
# Initialize system checker
system_checker = SystemChecker(models_dir="models")
# ===== CUDA/GPU Check =====
st.markdown("### 1. CUDA/GPU Check")
col1, col2 = st.columns([3, 1])
with col1:
st.markdown("Check if CUDA-capable GPU is available for faster training.")
with col2:
if st.button("Check CUDA", width="stretch"):
with st.spinner("Checking CUDA availability..."):
cuda_status = system_checker.check_cuda()
st.session_state.cuda_status = cuda_status
add_log("CUDA check completed")
if st.session_state.cuda_status:
cuda = st.session_state.cuda_status
if cuda['available']:
st.success(f"CUDA Available - {cuda['device_count']} GPU(s) detected")
for device in cuda['devices']:
with st.expander(f"Device: {device['name']} Details"):
col1, col2, col3 = st.columns(3)
col1.metric("Memory", f"{device['memory_total']:.2f} GB")
col2.metric("Compute", device['compute_capability'])
col3.metric("CUDA Version", cuda['cuda_version'])
create_info_box(
"<b>Recommendation:</b> Your GPU is ready for training! "
"You can use any model from the list. XLM-RoBERTa and RoBERTa are recommended for best accuracy.",
"success"
)
else:
st.warning("No CUDA-capable GPU detected - Training will use CPU")
create_info_box(
"<b>Recommendation:</b> For CPU training, we recommend using <b>distilbert-base-multilingual-cased</b> "
"as it's significantly faster while maintaining good accuracy.",
"warning"
)
st.markdown("---")
# ===== Environment Check =====
st.markdown("### 2. Environment Check")
col1, col2 = st.columns([3, 1])
with col1:
st.markdown("Verify all required Python packages are installed with correct versions.")
with col2:
if st.button("Check Environment", width="stretch"):
with st.spinner("Checking environment..."):
env_status = system_checker.check_environment()
st.session_state.env_status = env_status
add_log("Environment check completed")
if st.session_state.env_status:
env = st.session_state.env_status
if env['all_satisfied']:
st.success("All required packages are installed")
else:
st.error(f"Missing packages: {', '.join(env['missing_packages'])}")
create_info_box(
f"<b>To install missing packages, run:</b><br>"
f"<code>pip install {' '.join(env['missing_packages'])}</code>",
"error"
)
with st.expander("View Package Details"):
package_df = pd.DataFrame([
{
'Package': pkg,
'Installed': info['installed'] or 'Not Installed',
'Required': info['required'],
'Status': 'OK' if info['satisfied'] else 'Missing'
}
for pkg, info in env['packages'].items()
])
st.dataframe(package_df, width="stretch", hide_index=True)
st.markdown("---")
# ===== Model Selection Guide =====
st.markdown("### 3. Model Selection Guide")
create_info_box(
"<b>How to choose the right model:</b><br><br>"
"Consider these factors:<br>"
"• <b>Language:</b> English only or multilingual?<br>"
"• <b>Hardware:</b> GPU available or CPU only?<br>"
"• <b>Speed vs Accuracy:</b> Need fast training or best accuracy?<br>"
"• <b>Task Type:</b> Binary or multi-class classification?",
"info"
)
# Display model comparison table
model_comparison = []
for model_id, model_info in MODEL_ARCHITECTURES.items():
model_comparison.append({
'Model': model_info['name'],
'Languages': ', '.join(model_info['languages']),
'Speed': model_info['speed'],
'Size': model_info['size'],
'Best For': model_info['best_use'],
'ID': model_id
})
model_df = pd.DataFrame(model_comparison)
st.dataframe(model_df, width="stretch", hide_index=True)
# Quick recommendations
st.markdown("#### Quick Recommendations:")
rec_col1, rec_col2 = st.columns(2)
with rec_col1:
st.markdown("**For GPU Training:**")
st.markdown("- Best: `xlm-roberta-base` (highest accuracy)")
st.markdown("- Fast: `roberta-base` (English only)")
with rec_col2:
st.markdown("**For CPU Training:**")
st.markdown("- Recommended: `distilbert-base-multilingual-cased`")
st.markdown("- Fastest training and good performance")
st.markdown("---")
# ===== Model Download =====
st.markdown("### 4. Download Models")
create_info_box(
"<b>Download models before training:</b><br>"
"Models will be downloaded to the <code>models/</code> directory. "
"This may take several minutes depending on your internet connection.",
"info"
)
# Model selection
selected_models = st.multiselect(
"Select models to download:",
options=list(MODEL_ARCHITECTURES.keys()),
format_func=lambda x: f"{MODEL_ARCHITECTURES[x]['name']} ({MODEL_ARCHITECTURES[x]['size']})",
help="Select one or more models to download. You can train with any downloaded model later."
)
col1, col2 = st.columns([3, 1])
with col2:
download_btn = st.button("Download Selected", width="stretch", type="primary", disabled=len(selected_models) == 0)
if download_btn:
progress_bar = st.progress(0)
status_text = st.empty()
for idx, model_id in enumerate(selected_models):
status_text.text(f"Downloading {model_id}... ({idx + 1}/{len(selected_models)})")
progress_bar.progress((idx) / len(selected_models))
success, path, message = system_checker.download_model(
model_id,
progress_callback=lambda msg, prog: None # Could add sub-progress here
)
if success:
st.session_state.models_downloaded.add(model_id)
add_log(f"Downloaded model: {model_id}")
else:
st.error(f"Failed to download {model_id}: {message}")
progress_bar.progress(1.0)
status_text.text("Download complete!")
time.sleep(1)
st.rerun()
# Show downloaded models
if st.session_state.models_downloaded:
st.markdown("#### Downloaded Models:")
for model_id in st.session_state.models_downloaded:
model_info = system_checker.get_model_info(model_id)
st.success(f"{MODEL_ARCHITECTURES[model_id]['name']} - {model_info['size_mb']:.0f} MB")
st.markdown("---")
# ===== Prerequisites Complete Button =====
can_proceed = (
st.session_state.cuda_status is not None and
st.session_state.env_status is not None and
st.session_state.env_status['all_satisfied'] and
len(st.session_state.models_downloaded) > 0
)
if can_proceed:
if st.button("Prerequisites Complete - Proceed to Data Upload", width="stretch", type="primary"):
st.session_state.prerequisites_checked = True
add_log("Prerequisites check completed successfully")
st.success("All prerequisites satisfied! You can now proceed to upload your data.")
time.sleep(1)
st.rerun()
else:
create_info_box(
"<b>Complete all checks above before proceeding:</b><br>"
"- CUDA Check<br>"
"- Environment Check (all packages installed)<br>"
"- Download at least one model",
"warning"
)
# ==================== TAB 2: Upload Data ====================
with tab2:
st.markdown("## Upload Training Data")
if not st.session_state.prerequisites_checked:
create_info_box(
"Please complete the <b>Prerequisites</b> tab first before uploading data.",
"warning"
)
st.stop()
create_info_box(
"<b>Data Format Requirements:</b><br>"
"• CSV file with at least two columns: text and label<br>"
"• Text column: Contains the text samples to classify<br>"
"• Label column: Contains the class labels (0/1 for binary, or class names for multi-class)<br>"
"• Minimum 20 samples recommended for training",
"info"
)
# File uploader
uploaded_file = st.file_uploader(
"Upload your CSV file",
type=['csv'],
help="Upload a CSV file with 'text' and 'label' columns"
)
if uploaded_file is not None:
try:
# Read data
df = pd.read_csv(uploaded_file)
st.session_state.uploaded_data = df
st.success(f"Uploaded {len(df)} samples")
# Validate data
validator = DataValidator()
is_valid, message = validator.validate_dataframe(df)
if is_valid:
st.success(f"Data validation passed: {message}")
# Show data preview
st.markdown("### Data Preview")
st.dataframe(df.head(10), width="stretch")
# Show statistics
col1, col2, col3 = st.columns(3)
col1.metric("Total Samples", len(df))
col2.metric("Unique Labels", df['label'].nunique())
col3.metric("Text Columns", len([c for c in df.columns if df[c].dtype == 'object']))
# Label distribution
st.markdown("### Label Distribution")
label_counts = df['label'].value_counts()
fig = px.bar(
x=label_counts.index.astype(str),
y=label_counts.values,
labels={'x': 'Label', 'y': 'Count'},
title='Number of samples per label'
)
st.plotly_chart(fig, width="stretch")
# Update num_labels for multi-class
if st.session_state.classification_type == ClassificationType.MULTICLASS:
num_classes = df['label'].nunique()
st.session_state.config.num_labels = num_classes
st.info(f"Detected {num_classes} classes for multi-class classification")
add_log(f"Uploaded data with {len(df)} samples and {df['label'].nunique()} labels")
else:
st.error(f"Data validation failed: {message}")
except Exception as e:
st.error(f"Error reading file: {str(e)}")
# ==================== TAB 3: Configure Training ====================
with tab3:
st.markdown("## Configure Training Parameters")
if st.session_state.uploaded_data is None:
create_info_box(
"Please upload your data in the <b>Upload Data</b> tab first.",
"warning"
)
st.stop()
create_info_box(
"<b>Configure your training settings:</b><br>"
"Adjust the parameters below based on your needs. Hover over ⓘ for explanations.",
"info"
)
# Model selection
st.markdown("### Model Selection")
available_models = list(st.session_state.models_downloaded)
if not available_models:
st.error("No models downloaded. Please download models in the Prerequisites tab.")
st.stop()
selected_model = st.selectbox(
"Choose model:",
options=available_models,
format_func=lambda x: f"{MODEL_ARCHITECTURES[x]['name']} - {MODEL_ARCHITECTURES[x]['best_use']}",
help="Select the model architecture to use for training"
)
st.session_state.selected_model = selected_model
st.session_state.config.model_name = selected_model
# Show model info
model_info = MODEL_ARCHITECTURES[selected_model]
with st.expander("ℹ️ Selected Model Information"):
st.markdown(f"**Name:** {model_info['name']}")
st.markdown(f"**Description:** {model_info['description']}")
st.markdown(f"**Best For:** {model_info['best_use']}")
st.markdown(f"**Speed:** {model_info['speed']}")
st.markdown(f"**Size:** {model_info['size']}")
st.markdown("---")
# Training parameters
st.markdown("### Training Parameters")
col1, col2 = st.columns(2)
with col1:
epochs = st.slider(
"Number of Epochs",
min_value=1,
max_value=20,
value=3,
help="Number of complete passes through the training dataset. More epochs = longer training but potentially better performance."
)
st.session_state.config.num_epochs = epochs
batch_size = st.select_slider(
"Batch Size",
options=[4, 8, 16, 32, 64],
value=16,
help="Number of samples processed together. Larger batches train faster but require more GPU memory."
)
st.session_state.config.batch_size = batch_size
learning_rate = st.select_slider(
"Learning Rate",
options=[1e-5, 2e-5, 3e-5, 5e-5, 1e-4],
value=2e-5,
format_func=lambda x: f"{x:.0e}",
help="Step size for model parameter updates. 2e-5 is a good default for BERT-like models."
)
st.session_state.config.learning_rate = learning_rate
with col2:
max_length = st.slider(
"Max Sequence Length",
min_value=128,
max_value=512,
value=128,
step=64,
help="Maximum length of input text in tokens. Longer sequences require more memory."
)
st.session_state.config.max_length = max_length
val_split = st.select_slider(
"Validation Split",
options=[0.1, 0.15, 0.2, 0.25, 0.3],
value=0.2,
format_func=lambda x: f"{x*100:.0f}%",
help="Percentage of data reserved for validation during training."
)
st.session_state.config.validation_split = val_split
st.session_state.config.train_split = 0.9 - val_split # Keep 0.1 for test
early_stopping = st.checkbox(
"Enable Early Stopping",
value=True,
help="Stop training automatically if validation performance stops improving."
)
if early_stopping:
patience = st.slider(
"Early Stopping Patience",
min_value=2,
max_value=5,
value=3,
help="Number of epochs to wait before stopping if no improvement."
)
st.markdown("---")
# Show configuration summary
st.markdown("### Configuration Summary")
config_summary = {
"Classification Type": "Binary" if st.session_state.classification_type == ClassificationType.BINARY else "Multi-class",
"Number of Labels": st.session_state.config.num_labels,
"Model": model_info['name'],
"Epochs": epochs,
"Batch Size": batch_size,
"Learning Rate": f"{learning_rate:.0e}",
"Max Length": max_length,
"Validation Split": f"{val_split*100:.0f}%"
}
summary_df = pd.DataFrame([
{"Parameter": k, "Value": str(v)}
for k, v in config_summary.items()
])
st.dataframe(summary_df, width="stretch", hide_index=True)
# ==================== TAB 4: Train Model ====================
with tab4:
st.markdown("## Train Your Model")
if st.session_state.uploaded_data is None:
create_info_box(
"Please complete previous steps first.",
"warning"
)
st.stop()
if not st.session_state.training_started:
create_info_box(
"<b>Ready to train!</b><br>"
f"Your {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']} model will be trained on {len(st.session_state.uploaded_data)} samples "
f"for {st.session_state.config.num_epochs} epochs.",
"info"
)
if st.button("Start Training", type="primary", width="stretch"):
st.session_state.training_started = True
st.rerun()
if st.session_state.training_started and not st.session_state.training_completed:
st.markdown("### Training in Progress...")
# Progress display
progress_bar = st.progress(0)
status_text = st.empty()
metrics_container = st.container()
try:
# Prepare data
status_text.text("Preparing data...")
df = st.session_state.uploaded_data
# Encode labels to integers
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(df['label'])
# Store label encoder for later use
st.session_state.label_encoder = label_encoder
st.session_state.label_classes = label_encoder.classes_.tolist()
# Initialize trainer with absolute path
import os
st.session_state.config.output_dir = os.path.abspath("trained_models")
trainer = ModelTrainer(config=st.session_state.config)
# Training progress callback - receives TrainingProgress object
def progress_callback(progress_obj):
if progress_obj.progress_percent > 0:
progress_bar.progress(progress_obj.progress_percent / 100.0)
status_text.text(f"Training: {progress_obj.progress_percent:.1f}% complete")
st.session_state.training_progress = progress_obj.progress_percent
# Update metrics display from latest metrics
if progress_obj.metrics_history:
latest_metrics = progress_obj.metrics_history[-1]
with metrics_container:
col1, col2, col3 = st.columns(3)
col1.metric("Epoch", f"{progress_obj.current_epoch}/{progress_obj.total_epochs}")
col2.metric("Train Loss", f"{latest_metrics.train_loss:.4f}")
if latest_metrics.eval_loss > 0:
col3.metric("Val Loss", f"{latest_metrics.eval_loss:.4f}")
# Train model
result = trainer.train(
texts=df['text'].tolist(),
labels=encoded_labels.tolist(),
progress_callback=progress_callback
)
# Check if training actually succeeded
if result.status == "failed":
raise Exception(result.error_message or "Training failed with unknown error")
if result.model_path is None:
raise Exception("Training completed but model path is None. Check logs for errors.")
# Training complete
st.session_state.training_completed = True
st.session_state.model_path = result.model_path
st.session_state.metrics_history = [m.to_dict() for m in result.metrics_history]
progress_bar.progress(1.0)
status_text.empty()
st.success("Training completed successfully!")
add_log(f"Training completed successfully. Model saved to: {result.model_path}")
# Show final metrics
if result.final_metrics:
st.markdown("### Final Training Metrics")
metrics = result.final_metrics.to_dict()
col1, col2, col3, col4 = st.columns(4)
col1.metric("Accuracy", f"{metrics.get('accuracy', 0):.2%}")
col2.metric("Precision", f"{metrics.get('precision', 0):.4f}")
col3.metric("Recall", f"{metrics.get('recall', 0):.4f}")
col4.metric("F1 Score", f"{metrics.get('f1', 0):.4f}")
time.sleep(2)
st.rerun()
except Exception as e:
import traceback
error_details = traceback.format_exc()
st.error(f"Training failed: {str(e)}")
with st.expander("Error Details"):
st.code(error_details)
st.session_state.training_started = False
add_log(f"Training failed: {str(e)}")
if st.session_state.training_completed:
st.success("Training completed!")
model_path_display = st.session_state.model_path if st.session_state.model_path else "Path not available"
create_info_box(
f"<b>Model trained successfully!</b><br>"
f"Model saved to: <code>{model_path_display}</code><br>"
"Proceed to the <b>Evaluate Model</b> tab to analyze performance.",
"success" if st.session_state.model_path else "warning"
)
# Show training logs
with st.expander("View Training Logs"):
for log in st.session_state.training_logs[-20:]: # Show last 20 logs
st.text(log)
# ==================== TAB 5: Evaluate Model ====================
with tab5:
st.markdown("## Evaluate Model Performance")
if not st.session_state.training_completed:
create_info_box(
"Please train a model first in the <b>Train Model</b> tab.",
"warning"
)
st.stop()
create_info_box(
"<b>Model Evaluation:</b><br>"
"Analyze your model's performance with detailed metrics and visualizations.",
"info"
)
if st.session_state.evaluation_results is None:
if st.button("Evaluate Model", type="primary", width="stretch"):
with st.spinner("Evaluating model..."):
try:
# Initialize evaluator and load model
evaluator = ModelEvaluator()
# Load the trained model
if not evaluator.load_model(st.session_state.model_path):
st.error("Failed to load model for evaluation")
st.stop()
# Prepare test data (use validation split from uploaded data)
df = st.session_state.uploaded_data
test_size = int(len(df) * st.session_state.config.validation_split)
test_df = df.tail(test_size)
# Encode labels using the same encoder from training
if 'label_encoder' in st.session_state:
test_labels_encoded = st.session_state.label_encoder.transform(test_df['label']).tolist()
else:
# Fallback: create new encoder if not available
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
test_labels_encoded = label_encoder.fit_transform(test_df['label']).tolist()
# Evaluate
results = evaluator.evaluate(
texts=test_df['text'].tolist(),
true_labels=test_labels_encoded,
batch_size=st.session_state.config.batch_size
)
st.session_state.evaluation_results = results
add_log("Model evaluation completed")
st.rerun()
except Exception as e:
st.error(f"Evaluation failed: {str(e)}")
if st.session_state.evaluation_results:
results = st.session_state.evaluation_results
# Overall metrics
st.markdown("### Overall Metrics")
col1, col2, col3, col4 = st.columns(4)
col1.metric("Accuracy", f"{results['accuracy']:.2%}")
col2.metric("Precision", f"{results['precision']:.4f}")
col3.metric("Recall", f"{results['recall']:.4f}")
col4.metric("F1 Score", f"{results['f1']:.4f}")
st.markdown("---")
# Confusion Matrix
st.markdown("### Confusion Matrix")
if 'confusion_matrix' in results:
cm = results['confusion_matrix']
# Create heatmap
fig = go.Figure(data=go.Heatmap(
z=cm,
x=[f"Predicted {i}" for i in range(len(cm))],
y=[f"True {i}" for i in range(len(cm))],
colorscale='Blues',
text=cm,
texttemplate="%{text}",
textfont={"size": 16}
))
fig.update_layout(
title="Confusion Matrix",
xaxis_title="Predicted Label",
yaxis_title="True Label",
height=500
)
st.plotly_chart(fig, width="stretch")
st.markdown("---")
# Classification Report
st.markdown("### Detailed Classification Report")
if 'classification_report' in results:
report = results['classification_report']
st.text(report)
# Download results
st.markdown("---")
st.markdown("### Download Results")
if st.button("Download Evaluation Report", width="stretch"):
# Create downloadable report
report_text = f"""
MLOps Training Platform - Evaluation Report
{'='*60}
Model: {MODEL_ARCHITECTURES[st.session_state.selected_model]['name']}
Classification Type: {'Binary' if st.session_state.classification_type == ClassificationType.BINARY else 'Multi-class'}
Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Overall Metrics:
- Accuracy: {results['accuracy']:.4f}
- Precision: {results['precision']:.4f}
- Recall: {results['recall']:.4f}
- F1 Score: {results['f1']:.4f}
Classification Report:
{results.get('classification_report', 'N/A')}
Training Configuration:
- Epochs: {st.session_state.config.num_epochs}
- Batch Size: {st.session_state.config.batch_size}
- Learning Rate: {st.session_state.config.learning_rate}
- Max Length: {st.session_state.config.max_length}
"""
st.download_button(
label="Download Text Report",
data=report_text,
file_name=f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
mime="text/plain"
)
# ==================== Footer ====================
st.markdown("---")
st.markdown(
"""
<div style='text-align: center; color: #666; padding: 2rem;'>
<p> MLOps Training Platform | Built with Streamlit & PyTorch</p>
<p>For help and documentation, check the README.md file</p>
</div>
""",
unsafe_allow_html=True
)