Spaces:
Sleeping
Sleeping
faizan commited on
Commit ·
33beec1
1
Parent(s): 52a2c0b
feat: implement comprehensive MLflow tracking with MLOps best practices
Browse files- Created MLflow setup module (scripts/mlflow_setup.py)
- Automatic experiment creation and organization
- System info, model params, data info logging
- Model registry integration
- Created MLflow-enabled training script (scripts/train_with_mlflow.py)
- Full hyperparameter tracking
- Per-epoch metrics logging (loss, accuracy, LR)
- Model versioning in registry
- Artifact tracking (checkpoints, metrics, reports)
- Classification report and confusion matrix logging
- Added MLflow UI launcher (scripts/launch_mlflow_ui.sh)
- Demo training run: 98.90% test accuracy in 5 epochs
- Created inference module for deployment (scripts/inference.py)
- Started Gradio app development (app.py)
MLflow UI: http://localhost:5000
- app.py +197 -0
- planning.md +2 -2
- results/mlflow_test_metrics.json +203 -0
- results/mlflow_training_history.json +37 -0
- scripts/inference.py +240 -0
- scripts/launch_mlflow_ui.sh +14 -0
- scripts/mlflow_setup.py +272 -0
- scripts/train_with_mlflow.py +365 -0
app.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio app for MNIST digit classification.
|
| 3 |
+
|
| 4 |
+
Interactive web interface for handwritten digit recognition using trained CNN model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from scripts.inference import DigitClassifier
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Initialize classifier
|
| 14 |
+
print("Loading model...")
|
| 15 |
+
classifier = DigitClassifier('models/best_model.pt')
|
| 16 |
+
print(f"Model loaded on {classifier.device}")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def predict_digit(image):
|
| 20 |
+
"""
|
| 21 |
+
Predict digit from user-drawn image.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Tuple of (predicted_digit, confidence_text, probability_dict)
|
| 28 |
+
"""
|
| 29 |
+
if image is None:
|
| 30 |
+
return "Please draw a digit", "", {}
|
| 31 |
+
|
| 32 |
+
# Handle different image formats from Gradio
|
| 33 |
+
if isinstance(image, dict):
|
| 34 |
+
# Sketchpad returns dict with 'composite' key
|
| 35 |
+
image = image.get('composite', image)
|
| 36 |
+
|
| 37 |
+
# Convert to PIL Image
|
| 38 |
+
if isinstance(image, np.ndarray):
|
| 39 |
+
# If RGB, convert to grayscale
|
| 40 |
+
if len(image.shape) == 3:
|
| 41 |
+
# Take only the drawn part (alpha channel if available)
|
| 42 |
+
if image.shape[2] == 4: # RGBA
|
| 43 |
+
image = image[:, :, 3] # Use alpha channel
|
| 44 |
+
else: # RGB
|
| 45 |
+
image = np.mean(image, axis=2).astype(np.uint8)
|
| 46 |
+
|
| 47 |
+
# Ensure values are in [0, 255]
|
| 48 |
+
if image.max() <= 1.0:
|
| 49 |
+
image = (image * 255).astype(np.uint8)
|
| 50 |
+
|
| 51 |
+
pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
|
| 52 |
+
else:
|
| 53 |
+
pil_image = image
|
| 54 |
+
|
| 55 |
+
# Get prediction
|
| 56 |
+
result = classifier.predict(pil_image)
|
| 57 |
+
|
| 58 |
+
# Format output
|
| 59 |
+
digit = result['digit']
|
| 60 |
+
confidence = result['confidence']
|
| 61 |
+
probabilities = result['probabilities']
|
| 62 |
+
|
| 63 |
+
# Create confidence text
|
| 64 |
+
confidence_text = f"Confidence: {confidence*100:.1f}%"
|
| 65 |
+
|
| 66 |
+
# Create probability dictionary for bar chart
|
| 67 |
+
prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
|
| 68 |
+
|
| 69 |
+
return digit, confidence_text, prob_dict
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Custom CSS for better styling
|
| 73 |
+
custom_css = """
|
| 74 |
+
.gradio-container {
|
| 75 |
+
font-family: 'Arial', sans-serif;
|
| 76 |
+
max-width: 900px;
|
| 77 |
+
margin: auto;
|
| 78 |
+
}
|
| 79 |
+
.title {
|
| 80 |
+
text-align: center;
|
| 81 |
+
color: #2c3e50;
|
| 82 |
+
}
|
| 83 |
+
.description {
|
| 84 |
+
text-align: center;
|
| 85 |
+
color: #7f8c8d;
|
| 86 |
+
margin-bottom: 20px;
|
| 87 |
+
}
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# Create Gradio interface
|
| 91 |
+
with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
| 92 |
+
gr.Markdown(
|
| 93 |
+
"""
|
| 94 |
+
# 🔢 Handwritten Digit Classifier
|
| 95 |
+
|
| 96 |
+
Draw a digit (0-9) in the box below and the AI will predict what it is!
|
| 97 |
+
|
| 98 |
+
This model uses a Convolutional Neural Network (CNN) trained on the MNIST dataset
|
| 99 |
+
with **99.17% accuracy** on 10,000 test images.
|
| 100 |
+
"""
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
with gr.Row():
|
| 104 |
+
with gr.Column(scale=1):
|
| 105 |
+
# Sketchpad for drawing
|
| 106 |
+
input_image = gr.Sketchpad(
|
| 107 |
+
label="Draw a digit here",
|
| 108 |
+
type="numpy",
|
| 109 |
+
image_mode="L",
|
| 110 |
+
brush=gr.Brush(default_size=5, colors=["#000000"], color_mode="fixed"),
|
| 111 |
+
height=280,
|
| 112 |
+
width=280
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Buttons
|
| 116 |
+
with gr.Row():
|
| 117 |
+
predict_btn = gr.Button("🔍 Predict", variant="primary", scale=2)
|
| 118 |
+
clear_btn = gr.ClearButton(components=[input_image], value="🗑️ Clear", scale=1)
|
| 119 |
+
|
| 120 |
+
with gr.Column(scale=1):
|
| 121 |
+
# Prediction output
|
| 122 |
+
output_digit = gr.Textbox(
|
| 123 |
+
label="Predicted Digit",
|
| 124 |
+
placeholder="Draw a digit to see prediction",
|
| 125 |
+
scale=1,
|
| 126 |
+
lines=1,
|
| 127 |
+
max_lines=1,
|
| 128 |
+
interactive=False
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
output_confidence = gr.Textbox(
|
| 132 |
+
label="Confidence",
|
| 133 |
+
placeholder="",
|
| 134 |
+
scale=1,
|
| 135 |
+
lines=1,
|
| 136 |
+
max_lines=1,
|
| 137 |
+
interactive=False
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Probability distribution
|
| 141 |
+
output_probs = gr.Label(
|
| 142 |
+
label="Probability Distribution",
|
| 143 |
+
num_top_classes=10
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Example images section
|
| 147 |
+
gr.Markdown("### 📝 Try these examples:")
|
| 148 |
+
gr.Examples(
|
| 149 |
+
examples=[
|
| 150 |
+
["examples/digit_0.png"] if __name__ != "__main__" else None,
|
| 151 |
+
],
|
| 152 |
+
inputs=input_image,
|
| 153 |
+
label="Example digits"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Model info
|
| 157 |
+
gr.Markdown(
|
| 158 |
+
"""
|
| 159 |
+
---
|
| 160 |
+
### 📊 Model Details
|
| 161 |
+
- **Architecture**: Convolutional Neural Network (CNN)
|
| 162 |
+
- **Parameters**: 421,066
|
| 163 |
+
- **Training**: MNIST dataset (60,000 images)
|
| 164 |
+
- **Test Accuracy**: 99.17%
|
| 165 |
+
- **Framework**: PyTorch 2.0.1
|
| 166 |
+
|
| 167 |
+
### 💡 Tips for best results:
|
| 168 |
+
- Draw the digit large and centered
|
| 169 |
+
- Use a thick brush stroke
|
| 170 |
+
- Draw in white on black background (like MNIST)
|
| 171 |
+
- Make sure the digit is clear and recognizable
|
| 172 |
+
"""
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Connect events
|
| 176 |
+
predict_btn.click(
|
| 177 |
+
fn=predict_digit,
|
| 178 |
+
inputs=input_image,
|
| 179 |
+
outputs=[output_digit, output_confidence, output_probs]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Also predict on sketchpad change (real-time prediction)
|
| 183 |
+
input_image.change(
|
| 184 |
+
fn=predict_digit,
|
| 185 |
+
inputs=input_image,
|
| 186 |
+
outputs=[output_digit, output_confidence, output_probs]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
# Launch the app
|
| 192 |
+
demo.launch(
|
| 193 |
+
server_name="0.0.0.0", # Allow external access
|
| 194 |
+
server_port=7860, # Standard Gradio port
|
| 195 |
+
share=False, # Set to True to create public link
|
| 196 |
+
show_error=True
|
| 197 |
+
)
|
planning.md
CHANGED
|
@@ -536,7 +536,7 @@ Apply augmentations **on-the-fly** during training (not pre-generate). Reasons:
|
|
| 536 |
|
| 537 |
> **Purpose:** Design, implement, and train CNN architecture with rigorous evaluation
|
| 538 |
|
| 539 |
-
**Status:** 🟡 IN PROGRESS (
|
| 540 |
**Prerequisites:** Phase 1 complete (data pipeline working)
|
| 541 |
**Estimated Time:** 8-10 hours
|
| 542 |
|
|
@@ -817,7 +817,7 @@ Apply augmentations **on-the-fly** during training (not pre-generate). Reasons:
|
|
| 817 |
---
|
| 818 |
|
| 819 |
### **Task 2.4:** Comprehensive model evaluation
|
| 820 |
-
**Status:**
|
| 821 |
**Priority:** HIGH
|
| 822 |
**Objective:** Compute all required metrics for report
|
| 823 |
|
|
|
|
| 536 |
|
| 537 |
> **Purpose:** Design, implement, and train CNN architecture with rigorous evaluation
|
| 538 |
|
| 539 |
+
**Status:** 🟡 IN PROGRESS (4/6 tasks complete)
|
| 540 |
**Prerequisites:** Phase 1 complete (data pipeline working)
|
| 541 |
**Estimated Time:** 8-10 hours
|
| 542 |
|
|
|
|
| 817 |
---
|
| 818 |
|
| 819 |
### **Task 2.4:** Comprehensive model evaluation
|
| 820 |
+
**Status:** ✅ COMPLETE (integrated in training script)
|
| 821 |
**Priority:** HIGH
|
| 822 |
**Objective:** Compute all required metrics for report
|
| 823 |
|
results/mlflow_test_metrics.json
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"test_accuracy": 98.9,
|
| 3 |
+
"test_precision": 0.9890924021865366,
|
| 4 |
+
"test_recall": 0.9889686019869461,
|
| 5 |
+
"test_f1_score": 0.989004598229335,
|
| 6 |
+
"classification_report": {
|
| 7 |
+
"0": {
|
| 8 |
+
"precision": 0.9959141981613892,
|
| 9 |
+
"recall": 0.9948979591836735,
|
| 10 |
+
"f1-score": 0.9954058192955589,
|
| 11 |
+
"support": 980.0
|
| 12 |
+
},
|
| 13 |
+
"1": {
|
| 14 |
+
"precision": 0.9973404255319149,
|
| 15 |
+
"recall": 0.9911894273127754,
|
| 16 |
+
"f1-score": 0.9942554131683606,
|
| 17 |
+
"support": 1135.0
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"precision": 0.9753320683111955,
|
| 21 |
+
"recall": 0.9961240310077519,
|
| 22 |
+
"f1-score": 0.9856184084372004,
|
| 23 |
+
"support": 1032.0
|
| 24 |
+
},
|
| 25 |
+
"3": {
|
| 26 |
+
"precision": 0.9862475442043221,
|
| 27 |
+
"recall": 0.994059405940594,
|
| 28 |
+
"f1-score": 0.9901380670611439,
|
| 29 |
+
"support": 1010.0
|
| 30 |
+
},
|
| 31 |
+
"4": {
|
| 32 |
+
"precision": 0.9808853118712274,
|
| 33 |
+
"recall": 0.9928716904276986,
|
| 34 |
+
"f1-score": 0.9868421052631581,
|
| 35 |
+
"support": 982.0
|
| 36 |
+
},
|
| 37 |
+
"5": {
|
| 38 |
+
"precision": 0.9932508436445444,
|
| 39 |
+
"recall": 0.9899103139013453,
|
| 40 |
+
"f1-score": 0.9915777653003931,
|
| 41 |
+
"support": 892.0
|
| 42 |
+
},
|
| 43 |
+
"6": {
|
| 44 |
+
"precision": 0.9957939011566772,
|
| 45 |
+
"recall": 0.988517745302714,
|
| 46 |
+
"f1-score": 0.9921424829753797,
|
| 47 |
+
"support": 958.0
|
| 48 |
+
},
|
| 49 |
+
"7": {
|
| 50 |
+
"precision": 0.9863680623174295,
|
| 51 |
+
"recall": 0.9854085603112841,
|
| 52 |
+
"f1-score": 0.9858880778588808,
|
| 53 |
+
"support": 1028.0
|
| 54 |
+
},
|
| 55 |
+
"8": {
|
| 56 |
+
"precision": 0.9947916666666666,
|
| 57 |
+
"recall": 0.9804928131416838,
|
| 58 |
+
"f1-score": 0.9875904860392969,
|
| 59 |
+
"support": 974.0
|
| 60 |
+
},
|
| 61 |
+
"9": {
|
| 62 |
+
"precision": 0.985,
|
| 63 |
+
"recall": 0.9762140733399405,
|
| 64 |
+
"f1-score": 0.9805873568939771,
|
| 65 |
+
"support": 1009.0
|
| 66 |
+
},
|
| 67 |
+
"accuracy": 0.989,
|
| 68 |
+
"macro avg": {
|
| 69 |
+
"precision": 0.9890924021865366,
|
| 70 |
+
"recall": 0.9889686019869461,
|
| 71 |
+
"f1-score": 0.989004598229335,
|
| 72 |
+
"support": 10000.0
|
| 73 |
+
},
|
| 74 |
+
"weighted avg": {
|
| 75 |
+
"precision": 0.989058814881263,
|
| 76 |
+
"recall": 0.989,
|
| 77 |
+
"f1-score": 0.9890032775348696,
|
| 78 |
+
"support": 10000.0
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
"confusion_matrix": [
|
| 82 |
+
[
|
| 83 |
+
975,
|
| 84 |
+
0,
|
| 85 |
+
2,
|
| 86 |
+
0,
|
| 87 |
+
1,
|
| 88 |
+
0,
|
| 89 |
+
1,
|
| 90 |
+
1,
|
| 91 |
+
0,
|
| 92 |
+
0
|
| 93 |
+
],
|
| 94 |
+
[
|
| 95 |
+
0,
|
| 96 |
+
1125,
|
| 97 |
+
6,
|
| 98 |
+
0,
|
| 99 |
+
1,
|
| 100 |
+
0,
|
| 101 |
+
1,
|
| 102 |
+
2,
|
| 103 |
+
0,
|
| 104 |
+
0
|
| 105 |
+
],
|
| 106 |
+
[
|
| 107 |
+
1,
|
| 108 |
+
0,
|
| 109 |
+
1028,
|
| 110 |
+
0,
|
| 111 |
+
1,
|
| 112 |
+
0,
|
| 113 |
+
0,
|
| 114 |
+
2,
|
| 115 |
+
0,
|
| 116 |
+
0
|
| 117 |
+
],
|
| 118 |
+
[
|
| 119 |
+
0,
|
| 120 |
+
0,
|
| 121 |
+
1,
|
| 122 |
+
1004,
|
| 123 |
+
0,
|
| 124 |
+
1,
|
| 125 |
+
0,
|
| 126 |
+
2,
|
| 127 |
+
1,
|
| 128 |
+
1
|
| 129 |
+
],
|
| 130 |
+
[
|
| 131 |
+
0,
|
| 132 |
+
0,
|
| 133 |
+
0,
|
| 134 |
+
0,
|
| 135 |
+
975,
|
| 136 |
+
0,
|
| 137 |
+
0,
|
| 138 |
+
0,
|
| 139 |
+
1,
|
| 140 |
+
6
|
| 141 |
+
],
|
| 142 |
+
[
|
| 143 |
+
0,
|
| 144 |
+
0,
|
| 145 |
+
0,
|
| 146 |
+
7,
|
| 147 |
+
0,
|
| 148 |
+
883,
|
| 149 |
+
1,
|
| 150 |
+
1,
|
| 151 |
+
0,
|
| 152 |
+
0
|
| 153 |
+
],
|
| 154 |
+
[
|
| 155 |
+
1,
|
| 156 |
+
2,
|
| 157 |
+
2,
|
| 158 |
+
0,
|
| 159 |
+
4,
|
| 160 |
+
2,
|
| 161 |
+
947,
|
| 162 |
+
0,
|
| 163 |
+
0,
|
| 164 |
+
0
|
| 165 |
+
],
|
| 166 |
+
[
|
| 167 |
+
0,
|
| 168 |
+
1,
|
| 169 |
+
8,
|
| 170 |
+
0,
|
| 171 |
+
3,
|
| 172 |
+
0,
|
| 173 |
+
0,
|
| 174 |
+
1013,
|
| 175 |
+
1,
|
| 176 |
+
2
|
| 177 |
+
],
|
| 178 |
+
[
|
| 179 |
+
1,
|
| 180 |
+
0,
|
| 181 |
+
7,
|
| 182 |
+
0,
|
| 183 |
+
1,
|
| 184 |
+
2,
|
| 185 |
+
1,
|
| 186 |
+
1,
|
| 187 |
+
955,
|
| 188 |
+
6
|
| 189 |
+
],
|
| 190 |
+
[
|
| 191 |
+
1,
|
| 192 |
+
0,
|
| 193 |
+
0,
|
| 194 |
+
7,
|
| 195 |
+
8,
|
| 196 |
+
1,
|
| 197 |
+
0,
|
| 198 |
+
5,
|
| 199 |
+
2,
|
| 200 |
+
985
|
| 201 |
+
]
|
| 202 |
+
]
|
| 203 |
+
}
|
results/mlflow_training_history.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_loss": [
|
| 3 |
+
0.623640251447595,
|
| 4 |
+
0.25853686837229406,
|
| 5 |
+
0.20275674461196627,
|
| 6 |
+
0.16697734022463953,
|
| 7 |
+
0.15257755666153233
|
| 8 |
+
],
|
| 9 |
+
"train_accuracy": [
|
| 10 |
+
79.71565839788215,
|
| 11 |
+
92.16001568781253,
|
| 12 |
+
93.85233846455534,
|
| 13 |
+
94.94460241200117,
|
| 14 |
+
95.43092460045102
|
| 15 |
+
],
|
| 16 |
+
"val_loss": [
|
| 17 |
+
0.091112698253297,
|
| 18 |
+
0.06359739341700436,
|
| 19 |
+
0.0590373575881434,
|
| 20 |
+
0.05036303665776281,
|
| 21 |
+
0.05376061073117163
|
| 22 |
+
],
|
| 23 |
+
"val_accuracy": [
|
| 24 |
+
97.22376457523598,
|
| 25 |
+
98.06774014436424,
|
| 26 |
+
98.33425874514158,
|
| 27 |
+
98.75624652970572,
|
| 28 |
+
98.5341476957246
|
| 29 |
+
],
|
| 30 |
+
"learning_rate": [
|
| 31 |
+
0.001,
|
| 32 |
+
0.001,
|
| 33 |
+
0.001,
|
| 34 |
+
0.001,
|
| 35 |
+
0.001
|
| 36 |
+
]
|
| 37 |
+
}
|
scripts/inference.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference module for MNIST digit classification.
|
| 3 |
+
|
| 4 |
+
Provides a clean API for making predictions with the trained model.
|
| 5 |
+
Handles image preprocessing and returns predictions with confidence scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Union, Dict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DigitClassifier:
|
| 16 |
+
"""Production inference wrapper for MNIST digit classifier."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, model_path: str, device: str = None):
|
| 19 |
+
"""
|
| 20 |
+
Initialize the digit classifier.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model_path: Path to model checkpoint (.pt file)
|
| 24 |
+
device: Device to run inference on ('cuda' or 'cpu').
|
| 25 |
+
If None, auto-detects CUDA availability.
|
| 26 |
+
"""
|
| 27 |
+
if device is None:
|
| 28 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 29 |
+
else:
|
| 30 |
+
self.device = device
|
| 31 |
+
|
| 32 |
+
self.model_path = Path(model_path)
|
| 33 |
+
if not self.model_path.exists():
|
| 34 |
+
raise FileNotFoundError(f"Model not found at {model_path}")
|
| 35 |
+
|
| 36 |
+
self.model = self._load_model()
|
| 37 |
+
self.model.eval()
|
| 38 |
+
|
| 39 |
+
# Normalization values (same as training)
|
| 40 |
+
self.mean = 0.1307
|
| 41 |
+
self.std = 0.3081
|
| 42 |
+
|
| 43 |
+
def _load_model(self) -> torch.nn.Module:
|
| 44 |
+
"""Load model from checkpoint."""
|
| 45 |
+
from scripts.models import BaselineCNN
|
| 46 |
+
|
| 47 |
+
model = BaselineCNN()
|
| 48 |
+
|
| 49 |
+
# Load checkpoint
|
| 50 |
+
checkpoint = torch.load(self.model_path, map_location=self.device)
|
| 51 |
+
|
| 52 |
+
# Handle different checkpoint formats
|
| 53 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 54 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 55 |
+
else:
|
| 56 |
+
model.load_state_dict(checkpoint)
|
| 57 |
+
|
| 58 |
+
return model.to(self.device)
|
| 59 |
+
|
| 60 |
+
def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Preprocess image for model input.
|
| 63 |
+
|
| 64 |
+
Handles:
|
| 65 |
+
- RGB to grayscale conversion
|
| 66 |
+
- Resizing to 28x28
|
| 67 |
+
- Normalization
|
| 68 |
+
- Inversion if needed (white digit on black background)
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
image: PIL Image or numpy array
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Preprocessed tensor of shape (1, 1, 28, 28)
|
| 75 |
+
"""
|
| 76 |
+
# Convert numpy array to PIL Image if needed
|
| 77 |
+
if isinstance(image, np.ndarray):
|
| 78 |
+
image = Image.fromarray(image)
|
| 79 |
+
|
| 80 |
+
# Convert to grayscale if RGB
|
| 81 |
+
if image.mode != 'L':
|
| 82 |
+
image = image.convert('L')
|
| 83 |
+
|
| 84 |
+
# Resize to 28x28 if needed
|
| 85 |
+
if image.size != (28, 28):
|
| 86 |
+
image = image.resize((28, 28), Image.Resampling.LANCZOS)
|
| 87 |
+
|
| 88 |
+
# Convert to numpy array
|
| 89 |
+
img_array = np.array(image).astype(np.float32)
|
| 90 |
+
|
| 91 |
+
# Normalize to [0, 1]
|
| 92 |
+
img_array = img_array / 255.0
|
| 93 |
+
|
| 94 |
+
# Check if inversion is needed (MNIST is white digit on black background)
|
| 95 |
+
# If most pixels are bright, it's likely a black digit on white background
|
| 96 |
+
if img_array.mean() > 0.5:
|
| 97 |
+
img_array = 1.0 - img_array
|
| 98 |
+
|
| 99 |
+
# Apply normalization (same as training)
|
| 100 |
+
img_array = (img_array - self.mean) / self.std
|
| 101 |
+
|
| 102 |
+
# Convert to tensor and add batch and channel dimensions
|
| 103 |
+
img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
|
| 104 |
+
|
| 105 |
+
return img_tensor.to(self.device)
|
| 106 |
+
|
| 107 |
+
def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
|
| 108 |
+
"""
|
| 109 |
+
Predict digit from image.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
image: PIL Image or numpy array containing digit
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dictionary with:
|
| 116 |
+
- digit: Predicted digit (0-9)
|
| 117 |
+
- confidence: Confidence score (0-1)
|
| 118 |
+
- probabilities: List of probabilities for each digit
|
| 119 |
+
"""
|
| 120 |
+
img_tensor = self.preprocess(image)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
outputs = self.model(img_tensor)
|
| 124 |
+
probabilities = torch.softmax(outputs, dim=1)[0]
|
| 125 |
+
confidence, predicted = torch.max(probabilities, dim=0)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
'digit': int(predicted.item()),
|
| 129 |
+
'confidence': float(confidence.item()),
|
| 130 |
+
'probabilities': probabilities.cpu().numpy().tolist()
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def predict_batch(self, images: list) -> list:
|
| 134 |
+
"""
|
| 135 |
+
Predict digits for a batch of images.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
images: List of PIL Images or numpy arrays
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List of prediction dictionaries
|
| 142 |
+
"""
|
| 143 |
+
return [self.predict(img) for img in images]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_inference():
|
| 147 |
+
"""Test inference module with sample images."""
|
| 148 |
+
import sys
|
| 149 |
+
from pathlib import Path
|
| 150 |
+
|
| 151 |
+
# Add project root to path
|
| 152 |
+
project_root = Path(__file__).parent.parent
|
| 153 |
+
sys.path.insert(0, str(project_root))
|
| 154 |
+
|
| 155 |
+
from scripts.data_loader import MnistDataloader
|
| 156 |
+
|
| 157 |
+
print("Testing Inference Module")
|
| 158 |
+
print("=" * 50)
|
| 159 |
+
|
| 160 |
+
# Check if model exists
|
| 161 |
+
model_path = project_root / 'models' / 'best_model.pt'
|
| 162 |
+
if not model_path.exists():
|
| 163 |
+
print(f"Error: Model not found at {model_path}")
|
| 164 |
+
print("Please train a model first.")
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
# Load MNIST test data
|
| 168 |
+
data_path = project_root / 'data' / 'raw'
|
| 169 |
+
loader = MnistDataloader(
|
| 170 |
+
training_images_filepath=str(data_path / 'train-images.idx3-ubyte'),
|
| 171 |
+
training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'),
|
| 172 |
+
test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'),
|
| 173 |
+
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
|
| 174 |
+
)
|
| 175 |
+
_, (x_test, y_test) = loader.load_data()
|
| 176 |
+
|
| 177 |
+
# Initialize classifier
|
| 178 |
+
print(f"\n1. Loading model from: {model_path}")
|
| 179 |
+
classifier = DigitClassifier(str(model_path))
|
| 180 |
+
print(f" Device: {classifier.device}")
|
| 181 |
+
|
| 182 |
+
# Test on a few images
|
| 183 |
+
print("\n2. Testing predictions on 10 random test images:")
|
| 184 |
+
print("-" * 50)
|
| 185 |
+
|
| 186 |
+
indices = np.random.choice(len(x_test), 10, replace=False)
|
| 187 |
+
correct = 0
|
| 188 |
+
|
| 189 |
+
for i, idx in enumerate(indices, 1):
|
| 190 |
+
image = x_test[idx]
|
| 191 |
+
true_label = y_test[idx]
|
| 192 |
+
|
| 193 |
+
# Convert list to numpy array if needed
|
| 194 |
+
if isinstance(image, list):
|
| 195 |
+
image = np.array(image)
|
| 196 |
+
|
| 197 |
+
# Convert to PIL Image
|
| 198 |
+
img = Image.fromarray(image.astype(np.uint8), mode='L')
|
| 199 |
+
|
| 200 |
+
# Predict
|
| 201 |
+
result = classifier.predict(img)
|
| 202 |
+
|
| 203 |
+
is_correct = result['digit'] == true_label
|
| 204 |
+
correct += is_correct
|
| 205 |
+
|
| 206 |
+
print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
|
| 207 |
+
f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
|
| 208 |
+
|
| 209 |
+
accuracy = correct / len(indices) * 100
|
| 210 |
+
print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
|
| 211 |
+
|
| 212 |
+
# Test edge cases
|
| 213 |
+
print("\n3. Testing edge cases:")
|
| 214 |
+
print("-" * 50)
|
| 215 |
+
|
| 216 |
+
# Blank image
|
| 217 |
+
blank = np.zeros((28, 28), dtype=np.uint8)
|
| 218 |
+
blank_img = Image.fromarray(blank, mode='L')
|
| 219 |
+
result = classifier.predict(blank_img)
|
| 220 |
+
print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 221 |
+
|
| 222 |
+
# All white image
|
| 223 |
+
white = np.ones((28, 28), dtype=np.uint8) * 255
|
| 224 |
+
white_img = Image.fromarray(white, mode='L')
|
| 225 |
+
result = classifier.predict(white_img)
|
| 226 |
+
print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 227 |
+
|
| 228 |
+
# Different size image
|
| 229 |
+
test_img = x_test[0]
|
| 230 |
+
if isinstance(test_img, list):
|
| 231 |
+
test_img = np.array(test_img)
|
| 232 |
+
large = Image.fromarray(test_img.astype(np.uint8), mode='L').resize((56, 56))
|
| 233 |
+
result = classifier.predict(large)
|
| 234 |
+
print(f" Resized image (56x56): Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 235 |
+
|
| 236 |
+
print("\n✓ Inference module test complete!")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == '__main__':
|
| 240 |
+
test_inference()
|
scripts/launch_mlflow_ui.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Launch MLflow UI dashboard
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# ./scripts/launch_mlflow_ui.sh
|
| 6 |
+
#
|
| 7 |
+
# Access at: http://localhost:5000
|
| 8 |
+
|
| 9 |
+
echo "Starting MLflow UI..."
|
| 10 |
+
echo "Access dashboard at: http://localhost:5000"
|
| 11 |
+
echo "Press Ctrl+C to stop"
|
| 12 |
+
echo ""
|
| 13 |
+
|
| 14 |
+
mlflow ui --backend-store-uri file:./mlruns --host 0.0.0.0 --port 5000
|
scripts/mlflow_setup.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLflow Setup and Configuration
|
| 3 |
+
|
| 4 |
+
Utilities for MLflow experiment tracking with MLOps best practices:
|
| 5 |
+
- Automatic experiment naming and organization
|
| 6 |
+
- Parameter and metric logging
|
| 7 |
+
- Model registry integration
|
| 8 |
+
- Artifact tracking
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import mlflow
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Optional, Dict, Any
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# MLflow configuration
|
| 18 |
+
MLFLOW_TRACKING_URI = "file:./mlruns"
|
| 19 |
+
DEFAULT_EXPERIMENT_NAME = "mnist-digit-classification"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def setup_mlflow(
|
| 23 |
+
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
|
| 24 |
+
tracking_uri: Optional[str] = None
|
| 25 |
+
) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Setup MLflow tracking with best practices.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
experiment_name: Name of the experiment
|
| 31 |
+
tracking_uri: MLflow tracking URI (default: local ./mlruns)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
experiment_id: MLflow experiment ID
|
| 35 |
+
"""
|
| 36 |
+
# Set tracking URI
|
| 37 |
+
if tracking_uri is None:
|
| 38 |
+
tracking_uri = MLFLOW_TRACKING_URI
|
| 39 |
+
|
| 40 |
+
mlflow.set_tracking_uri(tracking_uri)
|
| 41 |
+
|
| 42 |
+
# Create or get experiment
|
| 43 |
+
try:
|
| 44 |
+
experiment = mlflow.get_experiment_by_name(experiment_name)
|
| 45 |
+
if experiment is None:
|
| 46 |
+
experiment_id = mlflow.create_experiment(
|
| 47 |
+
experiment_name,
|
| 48 |
+
tags={
|
| 49 |
+
"project": "mnist-classification",
|
| 50 |
+
"framework": "pytorch",
|
| 51 |
+
"model_type": "cnn"
|
| 52 |
+
}
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
experiment_id = experiment.experiment_id
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Warning: Could not create experiment: {e}")
|
| 58 |
+
experiment_id = "0" # Default experiment
|
| 59 |
+
|
| 60 |
+
mlflow.set_experiment(experiment_name)
|
| 61 |
+
|
| 62 |
+
print(f"MLflow tracking URI: {tracking_uri}")
|
| 63 |
+
print(f"Experiment: {experiment_name} (ID: {experiment_id})")
|
| 64 |
+
|
| 65 |
+
return experiment_id
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
|
| 69 |
+
"""
|
| 70 |
+
Log model parameters to MLflow.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
model: PyTorch model
|
| 74 |
+
prefix: Prefix for parameter names
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Dictionary of logged parameters
|
| 78 |
+
"""
|
| 79 |
+
from scripts.models import count_parameters
|
| 80 |
+
|
| 81 |
+
params = {
|
| 82 |
+
f"{prefix}_name": model.__class__.__name__,
|
| 83 |
+
f"{prefix}_total_params": count_parameters(model),
|
| 84 |
+
f"{prefix}_trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
mlflow.log_params(params)
|
| 88 |
+
return params
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def log_training_config(config: Dict[str, Any]) -> None:
|
| 92 |
+
"""
|
| 93 |
+
Log training configuration to MLflow.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
config: Dictionary of training hyperparameters
|
| 97 |
+
"""
|
| 98 |
+
# Flatten nested config if needed
|
| 99 |
+
flat_config = {}
|
| 100 |
+
for key, value in config.items():
|
| 101 |
+
if isinstance(value, dict):
|
| 102 |
+
for subkey, subvalue in value.items():
|
| 103 |
+
flat_config[f"{key}_{subkey}"] = subvalue
|
| 104 |
+
else:
|
| 105 |
+
flat_config[key] = value
|
| 106 |
+
|
| 107 |
+
mlflow.log_params(flat_config)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def log_data_info(
|
| 111 |
+
train_size: int,
|
| 112 |
+
val_size: int,
|
| 113 |
+
test_size: int,
|
| 114 |
+
num_classes: int = 10,
|
| 115 |
+
augmentation: bool = False
|
| 116 |
+
) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Log dataset information to MLflow.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
train_size: Number of training samples
|
| 122 |
+
val_size: Number of validation samples
|
| 123 |
+
test_size: Number of test samples
|
| 124 |
+
num_classes: Number of classes
|
| 125 |
+
augmentation: Whether data augmentation is used
|
| 126 |
+
"""
|
| 127 |
+
mlflow.log_params({
|
| 128 |
+
"data_train_size": train_size,
|
| 129 |
+
"data_val_size": val_size,
|
| 130 |
+
"data_test_size": test_size,
|
| 131 |
+
"data_num_classes": num_classes,
|
| 132 |
+
"data_augmentation": augmentation
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def log_system_info() -> Dict[str, Any]:
|
| 137 |
+
"""
|
| 138 |
+
Log system information to MLflow.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Dictionary of system information
|
| 142 |
+
"""
|
| 143 |
+
import torch
|
| 144 |
+
import platform
|
| 145 |
+
|
| 146 |
+
system_info = {
|
| 147 |
+
"system_platform": platform.system(),
|
| 148 |
+
"system_python_version": platform.python_version(),
|
| 149 |
+
"system_pytorch_version": torch.__version__,
|
| 150 |
+
"system_cuda_available": torch.cuda.is_available(),
|
| 151 |
+
"system_cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
|
| 152 |
+
"system_device": "cuda" if torch.cuda.is_available() else "cpu"
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
if torch.cuda.is_available():
|
| 156 |
+
system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
|
| 157 |
+
system_info["system_gpu_count"] = torch.cuda.device_count()
|
| 158 |
+
|
| 159 |
+
mlflow.log_params(system_info)
|
| 160 |
+
return system_info
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
|
| 164 |
+
"""
|
| 165 |
+
Log metrics for a specific epoch.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
metrics: Dictionary of metric names and values
|
| 169 |
+
step: Epoch number
|
| 170 |
+
"""
|
| 171 |
+
mlflow.log_metrics(metrics, step=step)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
|
| 175 |
+
"""
|
| 176 |
+
Log a file or directory as an artifact.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
path: Path to file or directory
|
| 180 |
+
artifact_path: Optional artifact path in MLflow
|
| 181 |
+
"""
|
| 182 |
+
if Path(path).exists():
|
| 183 |
+
mlflow.log_artifact(path, artifact_path=artifact_path)
|
| 184 |
+
else:
|
| 185 |
+
print(f"Warning: Artifact not found: {path}")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def log_model_to_registry(
|
| 189 |
+
model: Any,
|
| 190 |
+
model_name: str,
|
| 191 |
+
artifact_path: str = "model",
|
| 192 |
+
registered_model_name: Optional[str] = None
|
| 193 |
+
) -> None:
|
| 194 |
+
"""
|
| 195 |
+
Log model to MLflow with model registry integration.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
model: PyTorch model
|
| 199 |
+
model_name: Name for the model artifact
|
| 200 |
+
artifact_path: Artifact path in MLflow
|
| 201 |
+
registered_model_name: Name for model registry (optional)
|
| 202 |
+
"""
|
| 203 |
+
import torch
|
| 204 |
+
|
| 205 |
+
# Log model
|
| 206 |
+
mlflow.pytorch.log_model(
|
| 207 |
+
pytorch_model=model,
|
| 208 |
+
artifact_path=artifact_path,
|
| 209 |
+
registered_model_name=registered_model_name
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_or_create_run(
|
| 214 |
+
run_name: Optional[str] = None,
|
| 215 |
+
tags: Optional[Dict[str, str]] = None
|
| 216 |
+
) -> mlflow.ActiveRun:
|
| 217 |
+
"""
|
| 218 |
+
Get existing run or create a new one.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
run_name: Name for the run
|
| 222 |
+
tags: Tags for the run
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
MLflow active run context
|
| 226 |
+
"""
|
| 227 |
+
return mlflow.start_run(run_name=run_name, tags=tags)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def end_run() -> None:
|
| 231 |
+
"""End the current MLflow run."""
|
| 232 |
+
mlflow.end_run()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def test_mlflow_setup():
|
| 236 |
+
"""Test MLflow setup and basic logging."""
|
| 237 |
+
print("Testing MLflow Setup")
|
| 238 |
+
print("=" * 50)
|
| 239 |
+
|
| 240 |
+
# Setup MLflow
|
| 241 |
+
exp_id = setup_mlflow("test-experiment")
|
| 242 |
+
|
| 243 |
+
# Test logging
|
| 244 |
+
with mlflow.start_run(run_name="test-run"):
|
| 245 |
+
# Log parameters
|
| 246 |
+
mlflow.log_params({
|
| 247 |
+
"learning_rate": 0.001,
|
| 248 |
+
"batch_size": 64,
|
| 249 |
+
"epochs": 10
|
| 250 |
+
})
|
| 251 |
+
|
| 252 |
+
# Log metrics
|
| 253 |
+
for epoch in range(3):
|
| 254 |
+
mlflow.log_metrics({
|
| 255 |
+
"train_loss": 0.5 - epoch * 0.1,
|
| 256 |
+
"val_loss": 0.6 - epoch * 0.1,
|
| 257 |
+
"train_accuracy": 0.8 + epoch * 0.05,
|
| 258 |
+
"val_accuracy": 0.75 + epoch * 0.05
|
| 259 |
+
}, step=epoch)
|
| 260 |
+
|
| 261 |
+
# Log system info
|
| 262 |
+
system_info = log_system_info()
|
| 263 |
+
print("\nSystem Info:")
|
| 264 |
+
for key, value in system_info.items():
|
| 265 |
+
print(f" {key}: {value}")
|
| 266 |
+
|
| 267 |
+
print("\n✓ MLflow test complete!")
|
| 268 |
+
print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
test_mlflow_setup()
|
scripts/train_with_mlflow.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLflow-Enabled Training Script for MNIST CNN
|
| 3 |
+
|
| 4 |
+
Full training script with comprehensive MLflow tracking:
|
| 5 |
+
- Hyperparameters and model architecture
|
| 6 |
+
- Per-epoch metrics (loss, accuracy, learning rate)
|
| 7 |
+
- System information and environment
|
| 8 |
+
- Model artifacts and checkpoints
|
| 9 |
+
- Training visualizations
|
| 10 |
+
- Confusion matrix and classification report
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python scripts/train_with_mlflow.py --epochs 20 --lr 0.001 --augment
|
| 14 |
+
python scripts/train_with_mlflow.py --help
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
import json
|
| 23 |
+
import sys
|
| 24 |
+
import numpy as np
|
| 25 |
+
import mlflow
|
| 26 |
+
|
| 27 |
+
# Add project root to path
|
| 28 |
+
project_root = Path(__file__).parent.parent
|
| 29 |
+
sys.path.insert(0, str(project_root))
|
| 30 |
+
|
| 31 |
+
from scripts.models import BaselineCNN, count_parameters
|
| 32 |
+
from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val
|
| 33 |
+
from scripts.train import train_epoch, validate, evaluate_model, save_training_history
|
| 34 |
+
from scripts.data_loader import MnistDataloader
|
| 35 |
+
from scripts.augmentation import get_train_augmentation
|
| 36 |
+
from scripts.mlflow_setup import (
|
| 37 |
+
setup_mlflow, log_model_params, log_training_config,
|
| 38 |
+
log_data_info, log_system_info, log_metrics_epoch,
|
| 39 |
+
log_artifact_path, log_model_to_registry
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def train_with_mlflow(
|
| 44 |
+
model: nn.Module,
|
| 45 |
+
train_loader: torch.utils.data.DataLoader,
|
| 46 |
+
val_loader: torch.utils.data.DataLoader,
|
| 47 |
+
test_loader: torch.utils.data.DataLoader,
|
| 48 |
+
config: dict,
|
| 49 |
+
run_name: str = None
|
| 50 |
+
) -> dict:
|
| 51 |
+
"""
|
| 52 |
+
Train model with full MLflow tracking.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model: PyTorch model to train
|
| 56 |
+
train_loader: Training data loader
|
| 57 |
+
val_loader: Validation data loader
|
| 58 |
+
test_loader: Test data loader
|
| 59 |
+
config: Training configuration dictionary
|
| 60 |
+
run_name: Optional name for MLflow run
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Training history dictionary
|
| 64 |
+
"""
|
| 65 |
+
device = config['device']
|
| 66 |
+
num_epochs = config['num_epochs']
|
| 67 |
+
learning_rate = config['learning_rate']
|
| 68 |
+
|
| 69 |
+
# Setup MLflow
|
| 70 |
+
setup_mlflow("mnist-digit-classification")
|
| 71 |
+
|
| 72 |
+
# Start MLflow run
|
| 73 |
+
with mlflow.start_run(run_name=run_name):
|
| 74 |
+
print("\n" + "="*70)
|
| 75 |
+
print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
|
| 76 |
+
print("="*70 + "\n")
|
| 77 |
+
|
| 78 |
+
# Log all configuration
|
| 79 |
+
print("Logging configuration to MLflow...")
|
| 80 |
+
log_training_config(config)
|
| 81 |
+
log_model_params(model)
|
| 82 |
+
log_data_info(
|
| 83 |
+
train_size=len(train_loader.dataset),
|
| 84 |
+
val_size=len(val_loader.dataset),
|
| 85 |
+
test_size=len(test_loader.dataset),
|
| 86 |
+
num_classes=10,
|
| 87 |
+
augmentation=config.get('augmentation', False)
|
| 88 |
+
)
|
| 89 |
+
log_system_info()
|
| 90 |
+
|
| 91 |
+
# Log model architecture as text
|
| 92 |
+
total_params, trainable_params = count_parameters(model)
|
| 93 |
+
model_summary = f"""
|
| 94 |
+
Model: {model.__class__.__name__}
|
| 95 |
+
Total Parameters: {total_params:,}
|
| 96 |
+
Trainable Parameters: {trainable_params:,}
|
| 97 |
+
Device: {device}
|
| 98 |
+
|
| 99 |
+
Architecture:
|
| 100 |
+
{str(model)}
|
| 101 |
+
"""
|
| 102 |
+
mlflow.log_text(model_summary, "model_architecture.txt")
|
| 103 |
+
|
| 104 |
+
# Setup training
|
| 105 |
+
criterion = nn.CrossEntropyLoss()
|
| 106 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 107 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 108 |
+
optimizer, mode='min', patience=3, factor=0.5, verbose=True
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Training history
|
| 112 |
+
history = {
|
| 113 |
+
'train_loss': [],
|
| 114 |
+
'train_accuracy': [],
|
| 115 |
+
'val_loss': [],
|
| 116 |
+
'val_accuracy': [],
|
| 117 |
+
'learning_rate': []
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
best_val_loss = float('inf')
|
| 121 |
+
patience = 5
|
| 122 |
+
patience_counter = 0
|
| 123 |
+
|
| 124 |
+
print(f"\nStarting training for {num_epochs} epochs...")
|
| 125 |
+
print(f"Device: {device}")
|
| 126 |
+
total_p, _ = count_parameters(model)
|
| 127 |
+
print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
|
| 128 |
+
print("-" * 70)
|
| 129 |
+
|
| 130 |
+
for epoch in range(num_epochs):
|
| 131 |
+
# Train
|
| 132 |
+
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
|
| 133 |
+
|
| 134 |
+
# Validate
|
| 135 |
+
val_metrics = validate(model, val_loader, criterion, device)
|
| 136 |
+
|
| 137 |
+
# Get current learning rate
|
| 138 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 139 |
+
|
| 140 |
+
# Update scheduler
|
| 141 |
+
scheduler.step(val_metrics['loss'])
|
| 142 |
+
|
| 143 |
+
# Save history
|
| 144 |
+
history['train_loss'].append(train_metrics['loss'])
|
| 145 |
+
history['train_accuracy'].append(train_metrics['accuracy'])
|
| 146 |
+
history['val_loss'].append(val_metrics['loss'])
|
| 147 |
+
history['val_accuracy'].append(val_metrics['accuracy'])
|
| 148 |
+
history['learning_rate'].append(current_lr)
|
| 149 |
+
|
| 150 |
+
# Log metrics to MLflow
|
| 151 |
+
mlflow_metrics = {
|
| 152 |
+
'train_loss': train_metrics['loss'],
|
| 153 |
+
'train_accuracy': train_metrics['accuracy'],
|
| 154 |
+
'val_loss': val_metrics['loss'],
|
| 155 |
+
'val_accuracy': val_metrics['accuracy'],
|
| 156 |
+
'learning_rate': current_lr,
|
| 157 |
+
'epoch': epoch + 1
|
| 158 |
+
}
|
| 159 |
+
log_metrics_epoch(mlflow_metrics, step=epoch)
|
| 160 |
+
|
| 161 |
+
# Print progress
|
| 162 |
+
print(f"Epoch {epoch+1}/{num_epochs} | "
|
| 163 |
+
f"Train Loss: {train_metrics['loss']:.4f} ({train_metrics['accuracy']:.2f}%) | "
|
| 164 |
+
f"Val Loss: {val_metrics['loss']:.4f} ({val_metrics['accuracy']:.2f}%) | "
|
| 165 |
+
f"LR: {current_lr:.6f}")
|
| 166 |
+
|
| 167 |
+
# Save best model
|
| 168 |
+
if val_metrics['loss'] < best_val_loss:
|
| 169 |
+
best_val_loss = val_metrics['loss']
|
| 170 |
+
best_epoch = epoch + 1
|
| 171 |
+
patience_counter = 0
|
| 172 |
+
|
| 173 |
+
# Save checkpoint
|
| 174 |
+
checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
|
| 175 |
+
torch.save({
|
| 176 |
+
'epoch': epoch,
|
| 177 |
+
'model_state_dict': model.state_dict(),
|
| 178 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 179 |
+
'train_loss': train_metrics['loss'],
|
| 180 |
+
'val_loss': val_metrics['loss'],
|
| 181 |
+
'val_accuracy': val_metrics['accuracy'],
|
| 182 |
+
}, checkpoint_path)
|
| 183 |
+
|
| 184 |
+
print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
|
| 185 |
+
|
| 186 |
+
# Log model to MLflow
|
| 187 |
+
mlflow.pytorch.log_model(
|
| 188 |
+
model,
|
| 189 |
+
"model",
|
| 190 |
+
registered_model_name="mnist-cnn-baseline"
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
patience_counter += 1
|
| 194 |
+
|
| 195 |
+
# Early stopping
|
| 196 |
+
if patience_counter >= patience:
|
| 197 |
+
print(f"\nEarly stopping triggered after {epoch+1} epochs")
|
| 198 |
+
mlflow.log_param("early_stopped", True)
|
| 199 |
+
mlflow.log_param("early_stop_epoch", epoch + 1)
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
print("-" * 70)
|
| 203 |
+
print(f"\nTraining complete!")
|
| 204 |
+
print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
|
| 205 |
+
|
| 206 |
+
# Log best metrics
|
| 207 |
+
mlflow.log_metrics({
|
| 208 |
+
'best_epoch': best_epoch,
|
| 209 |
+
'best_val_loss': best_val_loss,
|
| 210 |
+
'final_train_loss': history['train_loss'][-1],
|
| 211 |
+
'final_val_loss': history['val_loss'][-1]
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
# Evaluate on test set
|
| 215 |
+
print("\nEvaluating on test set...")
|
| 216 |
+
test_metrics = evaluate_model(model, test_loader, device)
|
| 217 |
+
|
| 218 |
+
test_accuracy = test_metrics['accuracy']
|
| 219 |
+
test_report = test_metrics['classification_report']
|
| 220 |
+
|
| 221 |
+
# Extract macro average metrics
|
| 222 |
+
test_precision = test_report['macro avg']['precision']
|
| 223 |
+
test_recall = test_report['macro avg']['recall']
|
| 224 |
+
test_f1_score = test_report['macro avg']['f1-score']
|
| 225 |
+
|
| 226 |
+
print(f"Test Accuracy: {test_accuracy:.2f}%")
|
| 227 |
+
print(f"Test Precision: {test_precision:.4f}")
|
| 228 |
+
print(f"Test Recall: {test_recall:.4f}")
|
| 229 |
+
print(f"Test F1-Score: {test_f1_score:.4f}")
|
| 230 |
+
|
| 231 |
+
# Log test metrics to MLflow
|
| 232 |
+
mlflow.log_metrics({
|
| 233 |
+
'test_accuracy': test_accuracy,
|
| 234 |
+
'test_precision': test_precision,
|
| 235 |
+
'test_recall': test_recall,
|
| 236 |
+
'test_f1_score': test_f1_score
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
# Save and log artifacts
|
| 240 |
+
print("\nSaving artifacts...")
|
| 241 |
+
|
| 242 |
+
# Save history
|
| 243 |
+
history_path = project_root / 'results' / 'mlflow_training_history.json'
|
| 244 |
+
history_path.parent.mkdir(exist_ok=True)
|
| 245 |
+
save_training_history(history, history_path)
|
| 246 |
+
log_artifact_path(str(history_path))
|
| 247 |
+
|
| 248 |
+
# Save test metrics
|
| 249 |
+
metrics_to_save = {
|
| 250 |
+
'test_accuracy': test_accuracy,
|
| 251 |
+
'test_precision': test_precision,
|
| 252 |
+
'test_recall': test_recall,
|
| 253 |
+
'test_f1_score': test_f1_score,
|
| 254 |
+
'classification_report': test_report,
|
| 255 |
+
'confusion_matrix': test_metrics['confusion_matrix'].tolist()
|
| 256 |
+
}
|
| 257 |
+
metrics_path = project_root / 'results' / 'mlflow_test_metrics.json'
|
| 258 |
+
with open(metrics_path, 'w') as f:
|
| 259 |
+
json.dump(metrics_to_save, f, indent=2)
|
| 260 |
+
log_artifact_path(str(metrics_path))
|
| 261 |
+
|
| 262 |
+
# Save model checkpoint
|
| 263 |
+
log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
|
| 264 |
+
|
| 265 |
+
# Log confusion matrix as JSON
|
| 266 |
+
conf_matrix_dict = {
|
| 267 |
+
f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
|
| 268 |
+
for i in range(len(test_metrics['confusion_matrix']))
|
| 269 |
+
}
|
| 270 |
+
mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
|
| 271 |
+
|
| 272 |
+
# Log classification report
|
| 273 |
+
mlflow.log_dict(test_report, "classification_report.json")
|
| 274 |
+
|
| 275 |
+
print(f"\n✓ All artifacts logged to MLflow")
|
| 276 |
+
print(f"View results: mlflow ui --backend-store-uri file:./mlruns")
|
| 277 |
+
|
| 278 |
+
return history
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main():
|
| 282 |
+
parser = argparse.ArgumentParser(description='Train MNIST CNN with MLflow tracking')
|
| 283 |
+
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs (default: 20)')
|
| 284 |
+
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
|
| 285 |
+
parser.add_argument('--batch-size', type=int, default=64, help='Batch size (default: 64)')
|
| 286 |
+
parser.add_argument('--augment', action='store_true', help='Use data augmentation')
|
| 287 |
+
parser.add_argument('--run-name', type=str, default=None, help='MLflow run name')
|
| 288 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed (default: 42)')
|
| 289 |
+
|
| 290 |
+
args = parser.parse_args()
|
| 291 |
+
|
| 292 |
+
# Set random seeds
|
| 293 |
+
torch.manual_seed(args.seed)
|
| 294 |
+
np.random.seed(args.seed)
|
| 295 |
+
if torch.cuda.is_available():
|
| 296 |
+
torch.cuda.manual_seed(args.seed)
|
| 297 |
+
|
| 298 |
+
# Configuration
|
| 299 |
+
config = {
|
| 300 |
+
'num_epochs': args.epochs,
|
| 301 |
+
'learning_rate': args.lr,
|
| 302 |
+
'batch_size': args.batch_size,
|
| 303 |
+
'augmentation': args.augment,
|
| 304 |
+
'random_seed': args.seed,
|
| 305 |
+
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
|
| 306 |
+
'optimizer': 'Adam',
|
| 307 |
+
'scheduler': 'ReduceLROnPlateau',
|
| 308 |
+
'early_stopping_patience': 5
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
print("Training Configuration:")
|
| 312 |
+
print(json.dumps(config, indent=2))
|
| 313 |
+
|
| 314 |
+
# Load MNIST data
|
| 315 |
+
print("\nLoading MNIST data...")
|
| 316 |
+
data_path = project_root / 'data' / 'raw'
|
| 317 |
+
loader = MnistDataloader(
|
| 318 |
+
training_images_filepath=str(data_path / 'train-images.idx3-ubyte'),
|
| 319 |
+
training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'),
|
| 320 |
+
test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'),
|
| 321 |
+
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
|
| 322 |
+
)
|
| 323 |
+
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 324 |
+
|
| 325 |
+
# Split train/val
|
| 326 |
+
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 327 |
+
x_train, y_train, val_split=0.15, random_seed=args.seed
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Create datasets with optional augmentation
|
| 331 |
+
augmentation = get_train_augmentation() if args.augment else None
|
| 332 |
+
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
|
| 333 |
+
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 334 |
+
test_dataset = MnistDataset(x_test, y_test, transform=None)
|
| 335 |
+
|
| 336 |
+
# Create data loaders
|
| 337 |
+
train_loader, val_loader = create_dataloaders(
|
| 338 |
+
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
|
| 339 |
+
)
|
| 340 |
+
test_loader = torch.utils.data.DataLoader(
|
| 341 |
+
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
print(f"Train: {len(train_loader.dataset)} samples")
|
| 345 |
+
print(f"Val: {len(val_loader.dataset)} samples")
|
| 346 |
+
print(f"Test: {len(test_loader.dataset)} samples")
|
| 347 |
+
|
| 348 |
+
# Create model
|
| 349 |
+
model = BaselineCNN().to(config['device'])
|
| 350 |
+
|
| 351 |
+
# Train with MLflow
|
| 352 |
+
history = train_with_mlflow(
|
| 353 |
+
model, train_loader, val_loader, test_loader,
|
| 354 |
+
config, run_name=args.run_name
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
print("\n" + "="*70)
|
| 358 |
+
print("Training complete! View MLflow dashboard:")
|
| 359 |
+
print(" ./scripts/launch_mlflow_ui.sh")
|
| 360 |
+
print(" or: mlflow ui --backend-store-uri file:./mlruns")
|
| 361 |
+
print("="*70)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
if __name__ == '__main__':
|
| 365 |
+
main()
|