faizan commited on
Commit
33beec1
·
1 Parent(s): 52a2c0b

feat: implement comprehensive MLflow tracking with MLOps best practices

Browse files

- Created MLflow setup module (scripts/mlflow_setup.py)
- Automatic experiment creation and organization
- System info, model params, data info logging
- Model registry integration

- Created MLflow-enabled training script (scripts/train_with_mlflow.py)
- Full hyperparameter tracking
- Per-epoch metrics logging (loss, accuracy, LR)
- Model versioning in registry
- Artifact tracking (checkpoints, metrics, reports)
- Classification report and confusion matrix logging

- Added MLflow UI launcher (scripts/launch_mlflow_ui.sh)
- Demo training run: 98.90% test accuracy in 5 epochs
- Created inference module for deployment (scripts/inference.py)
- Started Gradio app development (app.py)

MLflow UI: http://localhost:5000

app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for MNIST digit classification.
3
+
4
+ Interactive web interface for handwritten digit recognition using trained CNN model.
5
+ """
6
+
7
+ import gradio as gr
8
+ from scripts.inference import DigitClassifier
9
+ from PIL import Image
10
+ import numpy as np
11
+
12
+
13
+ # Initialize classifier
14
+ print("Loading model...")
15
+ classifier = DigitClassifier('models/best_model.pt')
16
+ print(f"Model loaded on {classifier.device}")
17
+
18
+
19
+ def predict_digit(image):
20
+ """
21
+ Predict digit from user-drawn image.
22
+
23
+ Args:
24
+ image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
25
+
26
+ Returns:
27
+ Tuple of (predicted_digit, confidence_text, probability_dict)
28
+ """
29
+ if image is None:
30
+ return "Please draw a digit", "", {}
31
+
32
+ # Handle different image formats from Gradio
33
+ if isinstance(image, dict):
34
+ # Sketchpad returns dict with 'composite' key
35
+ image = image.get('composite', image)
36
+
37
+ # Convert to PIL Image
38
+ if isinstance(image, np.ndarray):
39
+ # If RGB, convert to grayscale
40
+ if len(image.shape) == 3:
41
+ # Take only the drawn part (alpha channel if available)
42
+ if image.shape[2] == 4: # RGBA
43
+ image = image[:, :, 3] # Use alpha channel
44
+ else: # RGB
45
+ image = np.mean(image, axis=2).astype(np.uint8)
46
+
47
+ # Ensure values are in [0, 255]
48
+ if image.max() <= 1.0:
49
+ image = (image * 255).astype(np.uint8)
50
+
51
+ pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
52
+ else:
53
+ pil_image = image
54
+
55
+ # Get prediction
56
+ result = classifier.predict(pil_image)
57
+
58
+ # Format output
59
+ digit = result['digit']
60
+ confidence = result['confidence']
61
+ probabilities = result['probabilities']
62
+
63
+ # Create confidence text
64
+ confidence_text = f"Confidence: {confidence*100:.1f}%"
65
+
66
+ # Create probability dictionary for bar chart
67
+ prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
68
+
69
+ return digit, confidence_text, prob_dict
70
+
71
+
72
+ # Custom CSS for better styling
73
+ custom_css = """
74
+ .gradio-container {
75
+ font-family: 'Arial', sans-serif;
76
+ max-width: 900px;
77
+ margin: auto;
78
+ }
79
+ .title {
80
+ text-align: center;
81
+ color: #2c3e50;
82
+ }
83
+ .description {
84
+ text-align: center;
85
+ color: #7f8c8d;
86
+ margin-bottom: 20px;
87
+ }
88
+ """
89
+
90
+ # Create Gradio interface
91
+ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
92
+ gr.Markdown(
93
+ """
94
+ # 🔢 Handwritten Digit Classifier
95
+
96
+ Draw a digit (0-9) in the box below and the AI will predict what it is!
97
+
98
+ This model uses a Convolutional Neural Network (CNN) trained on the MNIST dataset
99
+ with **99.17% accuracy** on 10,000 test images.
100
+ """
101
+ )
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=1):
105
+ # Sketchpad for drawing
106
+ input_image = gr.Sketchpad(
107
+ label="Draw a digit here",
108
+ type="numpy",
109
+ image_mode="L",
110
+ brush=gr.Brush(default_size=5, colors=["#000000"], color_mode="fixed"),
111
+ height=280,
112
+ width=280
113
+ )
114
+
115
+ # Buttons
116
+ with gr.Row():
117
+ predict_btn = gr.Button("🔍 Predict", variant="primary", scale=2)
118
+ clear_btn = gr.ClearButton(components=[input_image], value="🗑️ Clear", scale=1)
119
+
120
+ with gr.Column(scale=1):
121
+ # Prediction output
122
+ output_digit = gr.Textbox(
123
+ label="Predicted Digit",
124
+ placeholder="Draw a digit to see prediction",
125
+ scale=1,
126
+ lines=1,
127
+ max_lines=1,
128
+ interactive=False
129
+ )
130
+
131
+ output_confidence = gr.Textbox(
132
+ label="Confidence",
133
+ placeholder="",
134
+ scale=1,
135
+ lines=1,
136
+ max_lines=1,
137
+ interactive=False
138
+ )
139
+
140
+ # Probability distribution
141
+ output_probs = gr.Label(
142
+ label="Probability Distribution",
143
+ num_top_classes=10
144
+ )
145
+
146
+ # Example images section
147
+ gr.Markdown("### 📝 Try these examples:")
148
+ gr.Examples(
149
+ examples=[
150
+ ["examples/digit_0.png"] if __name__ != "__main__" else None,
151
+ ],
152
+ inputs=input_image,
153
+ label="Example digits"
154
+ )
155
+
156
+ # Model info
157
+ gr.Markdown(
158
+ """
159
+ ---
160
+ ### 📊 Model Details
161
+ - **Architecture**: Convolutional Neural Network (CNN)
162
+ - **Parameters**: 421,066
163
+ - **Training**: MNIST dataset (60,000 images)
164
+ - **Test Accuracy**: 99.17%
165
+ - **Framework**: PyTorch 2.0.1
166
+
167
+ ### 💡 Tips for best results:
168
+ - Draw the digit large and centered
169
+ - Use a thick brush stroke
170
+ - Draw in white on black background (like MNIST)
171
+ - Make sure the digit is clear and recognizable
172
+ """
173
+ )
174
+
175
+ # Connect events
176
+ predict_btn.click(
177
+ fn=predict_digit,
178
+ inputs=input_image,
179
+ outputs=[output_digit, output_confidence, output_probs]
180
+ )
181
+
182
+ # Also predict on sketchpad change (real-time prediction)
183
+ input_image.change(
184
+ fn=predict_digit,
185
+ inputs=input_image,
186
+ outputs=[output_digit, output_confidence, output_probs]
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ # Launch the app
192
+ demo.launch(
193
+ server_name="0.0.0.0", # Allow external access
194
+ server_port=7860, # Standard Gradio port
195
+ share=False, # Set to True to create public link
196
+ show_error=True
197
+ )
planning.md CHANGED
@@ -536,7 +536,7 @@ Apply augmentations **on-the-fly** during training (not pre-generate). Reasons:
536
 
537
  > **Purpose:** Design, implement, and train CNN architecture with rigorous evaluation
538
 
539
- **Status:** 🟡 IN PROGRESS (3/6 tasks complete)
540
  **Prerequisites:** Phase 1 complete (data pipeline working)
541
  **Estimated Time:** 8-10 hours
542
 
@@ -817,7 +817,7 @@ Apply augmentations **on-the-fly** during training (not pre-generate). Reasons:
817
  ---
818
 
819
  ### **Task 2.4:** Comprehensive model evaluation
820
- **Status:** NOT STARTED
821
  **Priority:** HIGH
822
  **Objective:** Compute all required metrics for report
823
 
 
536
 
537
  > **Purpose:** Design, implement, and train CNN architecture with rigorous evaluation
538
 
539
+ **Status:** 🟡 IN PROGRESS (4/6 tasks complete)
540
  **Prerequisites:** Phase 1 complete (data pipeline working)
541
  **Estimated Time:** 8-10 hours
542
 
 
817
  ---
818
 
819
  ### **Task 2.4:** Comprehensive model evaluation
820
+ **Status:** COMPLETE (integrated in training script)
821
  **Priority:** HIGH
822
  **Objective:** Compute all required metrics for report
823
 
results/mlflow_test_metrics.json ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_accuracy": 98.9,
3
+ "test_precision": 0.9890924021865366,
4
+ "test_recall": 0.9889686019869461,
5
+ "test_f1_score": 0.989004598229335,
6
+ "classification_report": {
7
+ "0": {
8
+ "precision": 0.9959141981613892,
9
+ "recall": 0.9948979591836735,
10
+ "f1-score": 0.9954058192955589,
11
+ "support": 980.0
12
+ },
13
+ "1": {
14
+ "precision": 0.9973404255319149,
15
+ "recall": 0.9911894273127754,
16
+ "f1-score": 0.9942554131683606,
17
+ "support": 1135.0
18
+ },
19
+ "2": {
20
+ "precision": 0.9753320683111955,
21
+ "recall": 0.9961240310077519,
22
+ "f1-score": 0.9856184084372004,
23
+ "support": 1032.0
24
+ },
25
+ "3": {
26
+ "precision": 0.9862475442043221,
27
+ "recall": 0.994059405940594,
28
+ "f1-score": 0.9901380670611439,
29
+ "support": 1010.0
30
+ },
31
+ "4": {
32
+ "precision": 0.9808853118712274,
33
+ "recall": 0.9928716904276986,
34
+ "f1-score": 0.9868421052631581,
35
+ "support": 982.0
36
+ },
37
+ "5": {
38
+ "precision": 0.9932508436445444,
39
+ "recall": 0.9899103139013453,
40
+ "f1-score": 0.9915777653003931,
41
+ "support": 892.0
42
+ },
43
+ "6": {
44
+ "precision": 0.9957939011566772,
45
+ "recall": 0.988517745302714,
46
+ "f1-score": 0.9921424829753797,
47
+ "support": 958.0
48
+ },
49
+ "7": {
50
+ "precision": 0.9863680623174295,
51
+ "recall": 0.9854085603112841,
52
+ "f1-score": 0.9858880778588808,
53
+ "support": 1028.0
54
+ },
55
+ "8": {
56
+ "precision": 0.9947916666666666,
57
+ "recall": 0.9804928131416838,
58
+ "f1-score": 0.9875904860392969,
59
+ "support": 974.0
60
+ },
61
+ "9": {
62
+ "precision": 0.985,
63
+ "recall": 0.9762140733399405,
64
+ "f1-score": 0.9805873568939771,
65
+ "support": 1009.0
66
+ },
67
+ "accuracy": 0.989,
68
+ "macro avg": {
69
+ "precision": 0.9890924021865366,
70
+ "recall": 0.9889686019869461,
71
+ "f1-score": 0.989004598229335,
72
+ "support": 10000.0
73
+ },
74
+ "weighted avg": {
75
+ "precision": 0.989058814881263,
76
+ "recall": 0.989,
77
+ "f1-score": 0.9890032775348696,
78
+ "support": 10000.0
79
+ }
80
+ },
81
+ "confusion_matrix": [
82
+ [
83
+ 975,
84
+ 0,
85
+ 2,
86
+ 0,
87
+ 1,
88
+ 0,
89
+ 1,
90
+ 1,
91
+ 0,
92
+ 0
93
+ ],
94
+ [
95
+ 0,
96
+ 1125,
97
+ 6,
98
+ 0,
99
+ 1,
100
+ 0,
101
+ 1,
102
+ 2,
103
+ 0,
104
+ 0
105
+ ],
106
+ [
107
+ 1,
108
+ 0,
109
+ 1028,
110
+ 0,
111
+ 1,
112
+ 0,
113
+ 0,
114
+ 2,
115
+ 0,
116
+ 0
117
+ ],
118
+ [
119
+ 0,
120
+ 0,
121
+ 1,
122
+ 1004,
123
+ 0,
124
+ 1,
125
+ 0,
126
+ 2,
127
+ 1,
128
+ 1
129
+ ],
130
+ [
131
+ 0,
132
+ 0,
133
+ 0,
134
+ 0,
135
+ 975,
136
+ 0,
137
+ 0,
138
+ 0,
139
+ 1,
140
+ 6
141
+ ],
142
+ [
143
+ 0,
144
+ 0,
145
+ 0,
146
+ 7,
147
+ 0,
148
+ 883,
149
+ 1,
150
+ 1,
151
+ 0,
152
+ 0
153
+ ],
154
+ [
155
+ 1,
156
+ 2,
157
+ 2,
158
+ 0,
159
+ 4,
160
+ 2,
161
+ 947,
162
+ 0,
163
+ 0,
164
+ 0
165
+ ],
166
+ [
167
+ 0,
168
+ 1,
169
+ 8,
170
+ 0,
171
+ 3,
172
+ 0,
173
+ 0,
174
+ 1013,
175
+ 1,
176
+ 2
177
+ ],
178
+ [
179
+ 1,
180
+ 0,
181
+ 7,
182
+ 0,
183
+ 1,
184
+ 2,
185
+ 1,
186
+ 1,
187
+ 955,
188
+ 6
189
+ ],
190
+ [
191
+ 1,
192
+ 0,
193
+ 0,
194
+ 7,
195
+ 8,
196
+ 1,
197
+ 0,
198
+ 5,
199
+ 2,
200
+ 985
201
+ ]
202
+ ]
203
+ }
results/mlflow_training_history.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_loss": [
3
+ 0.623640251447595,
4
+ 0.25853686837229406,
5
+ 0.20275674461196627,
6
+ 0.16697734022463953,
7
+ 0.15257755666153233
8
+ ],
9
+ "train_accuracy": [
10
+ 79.71565839788215,
11
+ 92.16001568781253,
12
+ 93.85233846455534,
13
+ 94.94460241200117,
14
+ 95.43092460045102
15
+ ],
16
+ "val_loss": [
17
+ 0.091112698253297,
18
+ 0.06359739341700436,
19
+ 0.0590373575881434,
20
+ 0.05036303665776281,
21
+ 0.05376061073117163
22
+ ],
23
+ "val_accuracy": [
24
+ 97.22376457523598,
25
+ 98.06774014436424,
26
+ 98.33425874514158,
27
+ 98.75624652970572,
28
+ 98.5341476957246
29
+ ],
30
+ "learning_rate": [
31
+ 0.001,
32
+ 0.001,
33
+ 0.001,
34
+ 0.001,
35
+ 0.001
36
+ ]
37
+ }
scripts/inference.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference module for MNIST digit classification.
3
+
4
+ Provides a clean API for making predictions with the trained model.
5
+ Handles image preprocessing and returns predictions with confidence scores.
6
+ """
7
+
8
+ import torch
9
+ from PIL import Image
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from typing import Union, Dict
13
+
14
+
15
+ class DigitClassifier:
16
+ """Production inference wrapper for MNIST digit classifier."""
17
+
18
+ def __init__(self, model_path: str, device: str = None):
19
+ """
20
+ Initialize the digit classifier.
21
+
22
+ Args:
23
+ model_path: Path to model checkpoint (.pt file)
24
+ device: Device to run inference on ('cuda' or 'cpu').
25
+ If None, auto-detects CUDA availability.
26
+ """
27
+ if device is None:
28
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+ else:
30
+ self.device = device
31
+
32
+ self.model_path = Path(model_path)
33
+ if not self.model_path.exists():
34
+ raise FileNotFoundError(f"Model not found at {model_path}")
35
+
36
+ self.model = self._load_model()
37
+ self.model.eval()
38
+
39
+ # Normalization values (same as training)
40
+ self.mean = 0.1307
41
+ self.std = 0.3081
42
+
43
+ def _load_model(self) -> torch.nn.Module:
44
+ """Load model from checkpoint."""
45
+ from scripts.models import BaselineCNN
46
+
47
+ model = BaselineCNN()
48
+
49
+ # Load checkpoint
50
+ checkpoint = torch.load(self.model_path, map_location=self.device)
51
+
52
+ # Handle different checkpoint formats
53
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
54
+ model.load_state_dict(checkpoint['model_state_dict'])
55
+ else:
56
+ model.load_state_dict(checkpoint)
57
+
58
+ return model.to(self.device)
59
+
60
+ def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
61
+ """
62
+ Preprocess image for model input.
63
+
64
+ Handles:
65
+ - RGB to grayscale conversion
66
+ - Resizing to 28x28
67
+ - Normalization
68
+ - Inversion if needed (white digit on black background)
69
+
70
+ Args:
71
+ image: PIL Image or numpy array
72
+
73
+ Returns:
74
+ Preprocessed tensor of shape (1, 1, 28, 28)
75
+ """
76
+ # Convert numpy array to PIL Image if needed
77
+ if isinstance(image, np.ndarray):
78
+ image = Image.fromarray(image)
79
+
80
+ # Convert to grayscale if RGB
81
+ if image.mode != 'L':
82
+ image = image.convert('L')
83
+
84
+ # Resize to 28x28 if needed
85
+ if image.size != (28, 28):
86
+ image = image.resize((28, 28), Image.Resampling.LANCZOS)
87
+
88
+ # Convert to numpy array
89
+ img_array = np.array(image).astype(np.float32)
90
+
91
+ # Normalize to [0, 1]
92
+ img_array = img_array / 255.0
93
+
94
+ # Check if inversion is needed (MNIST is white digit on black background)
95
+ # If most pixels are bright, it's likely a black digit on white background
96
+ if img_array.mean() > 0.5:
97
+ img_array = 1.0 - img_array
98
+
99
+ # Apply normalization (same as training)
100
+ img_array = (img_array - self.mean) / self.std
101
+
102
+ # Convert to tensor and add batch and channel dimensions
103
+ img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
104
+
105
+ return img_tensor.to(self.device)
106
+
107
+ def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
108
+ """
109
+ Predict digit from image.
110
+
111
+ Args:
112
+ image: PIL Image or numpy array containing digit
113
+
114
+ Returns:
115
+ Dictionary with:
116
+ - digit: Predicted digit (0-9)
117
+ - confidence: Confidence score (0-1)
118
+ - probabilities: List of probabilities for each digit
119
+ """
120
+ img_tensor = self.preprocess(image)
121
+
122
+ with torch.no_grad():
123
+ outputs = self.model(img_tensor)
124
+ probabilities = torch.softmax(outputs, dim=1)[0]
125
+ confidence, predicted = torch.max(probabilities, dim=0)
126
+
127
+ return {
128
+ 'digit': int(predicted.item()),
129
+ 'confidence': float(confidence.item()),
130
+ 'probabilities': probabilities.cpu().numpy().tolist()
131
+ }
132
+
133
+ def predict_batch(self, images: list) -> list:
134
+ """
135
+ Predict digits for a batch of images.
136
+
137
+ Args:
138
+ images: List of PIL Images or numpy arrays
139
+
140
+ Returns:
141
+ List of prediction dictionaries
142
+ """
143
+ return [self.predict(img) for img in images]
144
+
145
+
146
+ def test_inference():
147
+ """Test inference module with sample images."""
148
+ import sys
149
+ from pathlib import Path
150
+
151
+ # Add project root to path
152
+ project_root = Path(__file__).parent.parent
153
+ sys.path.insert(0, str(project_root))
154
+
155
+ from scripts.data_loader import MnistDataloader
156
+
157
+ print("Testing Inference Module")
158
+ print("=" * 50)
159
+
160
+ # Check if model exists
161
+ model_path = project_root / 'models' / 'best_model.pt'
162
+ if not model_path.exists():
163
+ print(f"Error: Model not found at {model_path}")
164
+ print("Please train a model first.")
165
+ return
166
+
167
+ # Load MNIST test data
168
+ data_path = project_root / 'data' / 'raw'
169
+ loader = MnistDataloader(
170
+ training_images_filepath=str(data_path / 'train-images.idx3-ubyte'),
171
+ training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'),
172
+ test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'),
173
+ test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
174
+ )
175
+ _, (x_test, y_test) = loader.load_data()
176
+
177
+ # Initialize classifier
178
+ print(f"\n1. Loading model from: {model_path}")
179
+ classifier = DigitClassifier(str(model_path))
180
+ print(f" Device: {classifier.device}")
181
+
182
+ # Test on a few images
183
+ print("\n2. Testing predictions on 10 random test images:")
184
+ print("-" * 50)
185
+
186
+ indices = np.random.choice(len(x_test), 10, replace=False)
187
+ correct = 0
188
+
189
+ for i, idx in enumerate(indices, 1):
190
+ image = x_test[idx]
191
+ true_label = y_test[idx]
192
+
193
+ # Convert list to numpy array if needed
194
+ if isinstance(image, list):
195
+ image = np.array(image)
196
+
197
+ # Convert to PIL Image
198
+ img = Image.fromarray(image.astype(np.uint8), mode='L')
199
+
200
+ # Predict
201
+ result = classifier.predict(img)
202
+
203
+ is_correct = result['digit'] == true_label
204
+ correct += is_correct
205
+
206
+ print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
207
+ f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
208
+
209
+ accuracy = correct / len(indices) * 100
210
+ print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
211
+
212
+ # Test edge cases
213
+ print("\n3. Testing edge cases:")
214
+ print("-" * 50)
215
+
216
+ # Blank image
217
+ blank = np.zeros((28, 28), dtype=np.uint8)
218
+ blank_img = Image.fromarray(blank, mode='L')
219
+ result = classifier.predict(blank_img)
220
+ print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
221
+
222
+ # All white image
223
+ white = np.ones((28, 28), dtype=np.uint8) * 255
224
+ white_img = Image.fromarray(white, mode='L')
225
+ result = classifier.predict(white_img)
226
+ print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
227
+
228
+ # Different size image
229
+ test_img = x_test[0]
230
+ if isinstance(test_img, list):
231
+ test_img = np.array(test_img)
232
+ large = Image.fromarray(test_img.astype(np.uint8), mode='L').resize((56, 56))
233
+ result = classifier.predict(large)
234
+ print(f" Resized image (56x56): Pred={result['digit']}, Conf={result['confidence']:.4f}")
235
+
236
+ print("\n✓ Inference module test complete!")
237
+
238
+
239
+ if __name__ == '__main__':
240
+ test_inference()
scripts/launch_mlflow_ui.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Launch MLflow UI dashboard
3
+ #
4
+ # Usage:
5
+ # ./scripts/launch_mlflow_ui.sh
6
+ #
7
+ # Access at: http://localhost:5000
8
+
9
+ echo "Starting MLflow UI..."
10
+ echo "Access dashboard at: http://localhost:5000"
11
+ echo "Press Ctrl+C to stop"
12
+ echo ""
13
+
14
+ mlflow ui --backend-store-uri file:./mlruns --host 0.0.0.0 --port 5000
scripts/mlflow_setup.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLflow Setup and Configuration
3
+
4
+ Utilities for MLflow experiment tracking with MLOps best practices:
5
+ - Automatic experiment naming and organization
6
+ - Parameter and metric logging
7
+ - Model registry integration
8
+ - Artifact tracking
9
+ """
10
+
11
+ import mlflow
12
+ from pathlib import Path
13
+ from typing import Optional, Dict, Any
14
+ import os
15
+
16
+
17
+ # MLflow configuration
18
+ MLFLOW_TRACKING_URI = "file:./mlruns"
19
+ DEFAULT_EXPERIMENT_NAME = "mnist-digit-classification"
20
+
21
+
22
+ def setup_mlflow(
23
+ experiment_name: str = DEFAULT_EXPERIMENT_NAME,
24
+ tracking_uri: Optional[str] = None
25
+ ) -> str:
26
+ """
27
+ Setup MLflow tracking with best practices.
28
+
29
+ Args:
30
+ experiment_name: Name of the experiment
31
+ tracking_uri: MLflow tracking URI (default: local ./mlruns)
32
+
33
+ Returns:
34
+ experiment_id: MLflow experiment ID
35
+ """
36
+ # Set tracking URI
37
+ if tracking_uri is None:
38
+ tracking_uri = MLFLOW_TRACKING_URI
39
+
40
+ mlflow.set_tracking_uri(tracking_uri)
41
+
42
+ # Create or get experiment
43
+ try:
44
+ experiment = mlflow.get_experiment_by_name(experiment_name)
45
+ if experiment is None:
46
+ experiment_id = mlflow.create_experiment(
47
+ experiment_name,
48
+ tags={
49
+ "project": "mnist-classification",
50
+ "framework": "pytorch",
51
+ "model_type": "cnn"
52
+ }
53
+ )
54
+ else:
55
+ experiment_id = experiment.experiment_id
56
+ except Exception as e:
57
+ print(f"Warning: Could not create experiment: {e}")
58
+ experiment_id = "0" # Default experiment
59
+
60
+ mlflow.set_experiment(experiment_name)
61
+
62
+ print(f"MLflow tracking URI: {tracking_uri}")
63
+ print(f"Experiment: {experiment_name} (ID: {experiment_id})")
64
+
65
+ return experiment_id
66
+
67
+
68
+ def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
69
+ """
70
+ Log model parameters to MLflow.
71
+
72
+ Args:
73
+ model: PyTorch model
74
+ prefix: Prefix for parameter names
75
+
76
+ Returns:
77
+ Dictionary of logged parameters
78
+ """
79
+ from scripts.models import count_parameters
80
+
81
+ params = {
82
+ f"{prefix}_name": model.__class__.__name__,
83
+ f"{prefix}_total_params": count_parameters(model),
84
+ f"{prefix}_trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad)
85
+ }
86
+
87
+ mlflow.log_params(params)
88
+ return params
89
+
90
+
91
+ def log_training_config(config: Dict[str, Any]) -> None:
92
+ """
93
+ Log training configuration to MLflow.
94
+
95
+ Args:
96
+ config: Dictionary of training hyperparameters
97
+ """
98
+ # Flatten nested config if needed
99
+ flat_config = {}
100
+ for key, value in config.items():
101
+ if isinstance(value, dict):
102
+ for subkey, subvalue in value.items():
103
+ flat_config[f"{key}_{subkey}"] = subvalue
104
+ else:
105
+ flat_config[key] = value
106
+
107
+ mlflow.log_params(flat_config)
108
+
109
+
110
+ def log_data_info(
111
+ train_size: int,
112
+ val_size: int,
113
+ test_size: int,
114
+ num_classes: int = 10,
115
+ augmentation: bool = False
116
+ ) -> None:
117
+ """
118
+ Log dataset information to MLflow.
119
+
120
+ Args:
121
+ train_size: Number of training samples
122
+ val_size: Number of validation samples
123
+ test_size: Number of test samples
124
+ num_classes: Number of classes
125
+ augmentation: Whether data augmentation is used
126
+ """
127
+ mlflow.log_params({
128
+ "data_train_size": train_size,
129
+ "data_val_size": val_size,
130
+ "data_test_size": test_size,
131
+ "data_num_classes": num_classes,
132
+ "data_augmentation": augmentation
133
+ })
134
+
135
+
136
+ def log_system_info() -> Dict[str, Any]:
137
+ """
138
+ Log system information to MLflow.
139
+
140
+ Returns:
141
+ Dictionary of system information
142
+ """
143
+ import torch
144
+ import platform
145
+
146
+ system_info = {
147
+ "system_platform": platform.system(),
148
+ "system_python_version": platform.python_version(),
149
+ "system_pytorch_version": torch.__version__,
150
+ "system_cuda_available": torch.cuda.is_available(),
151
+ "system_cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
152
+ "system_device": "cuda" if torch.cuda.is_available() else "cpu"
153
+ }
154
+
155
+ if torch.cuda.is_available():
156
+ system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
157
+ system_info["system_gpu_count"] = torch.cuda.device_count()
158
+
159
+ mlflow.log_params(system_info)
160
+ return system_info
161
+
162
+
163
+ def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
164
+ """
165
+ Log metrics for a specific epoch.
166
+
167
+ Args:
168
+ metrics: Dictionary of metric names and values
169
+ step: Epoch number
170
+ """
171
+ mlflow.log_metrics(metrics, step=step)
172
+
173
+
174
+ def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
175
+ """
176
+ Log a file or directory as an artifact.
177
+
178
+ Args:
179
+ path: Path to file or directory
180
+ artifact_path: Optional artifact path in MLflow
181
+ """
182
+ if Path(path).exists():
183
+ mlflow.log_artifact(path, artifact_path=artifact_path)
184
+ else:
185
+ print(f"Warning: Artifact not found: {path}")
186
+
187
+
188
+ def log_model_to_registry(
189
+ model: Any,
190
+ model_name: str,
191
+ artifact_path: str = "model",
192
+ registered_model_name: Optional[str] = None
193
+ ) -> None:
194
+ """
195
+ Log model to MLflow with model registry integration.
196
+
197
+ Args:
198
+ model: PyTorch model
199
+ model_name: Name for the model artifact
200
+ artifact_path: Artifact path in MLflow
201
+ registered_model_name: Name for model registry (optional)
202
+ """
203
+ import torch
204
+
205
+ # Log model
206
+ mlflow.pytorch.log_model(
207
+ pytorch_model=model,
208
+ artifact_path=artifact_path,
209
+ registered_model_name=registered_model_name
210
+ )
211
+
212
+
213
+ def get_or_create_run(
214
+ run_name: Optional[str] = None,
215
+ tags: Optional[Dict[str, str]] = None
216
+ ) -> mlflow.ActiveRun:
217
+ """
218
+ Get existing run or create a new one.
219
+
220
+ Args:
221
+ run_name: Name for the run
222
+ tags: Tags for the run
223
+
224
+ Returns:
225
+ MLflow active run context
226
+ """
227
+ return mlflow.start_run(run_name=run_name, tags=tags)
228
+
229
+
230
+ def end_run() -> None:
231
+ """End the current MLflow run."""
232
+ mlflow.end_run()
233
+
234
+
235
+ def test_mlflow_setup():
236
+ """Test MLflow setup and basic logging."""
237
+ print("Testing MLflow Setup")
238
+ print("=" * 50)
239
+
240
+ # Setup MLflow
241
+ exp_id = setup_mlflow("test-experiment")
242
+
243
+ # Test logging
244
+ with mlflow.start_run(run_name="test-run"):
245
+ # Log parameters
246
+ mlflow.log_params({
247
+ "learning_rate": 0.001,
248
+ "batch_size": 64,
249
+ "epochs": 10
250
+ })
251
+
252
+ # Log metrics
253
+ for epoch in range(3):
254
+ mlflow.log_metrics({
255
+ "train_loss": 0.5 - epoch * 0.1,
256
+ "val_loss": 0.6 - epoch * 0.1,
257
+ "train_accuracy": 0.8 + epoch * 0.05,
258
+ "val_accuracy": 0.75 + epoch * 0.05
259
+ }, step=epoch)
260
+
261
+ # Log system info
262
+ system_info = log_system_info()
263
+ print("\nSystem Info:")
264
+ for key, value in system_info.items():
265
+ print(f" {key}: {value}")
266
+
267
+ print("\n✓ MLflow test complete!")
268
+ print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ test_mlflow_setup()
scripts/train_with_mlflow.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLflow-Enabled Training Script for MNIST CNN
3
+
4
+ Full training script with comprehensive MLflow tracking:
5
+ - Hyperparameters and model architecture
6
+ - Per-epoch metrics (loss, accuracy, learning rate)
7
+ - System information and environment
8
+ - Model artifacts and checkpoints
9
+ - Training visualizations
10
+ - Confusion matrix and classification report
11
+
12
+ Usage:
13
+ python scripts/train_with_mlflow.py --epochs 20 --lr 0.001 --augment
14
+ python scripts/train_with_mlflow.py --help
15
+ """
16
+
17
+ import argparse
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from pathlib import Path
22
+ import json
23
+ import sys
24
+ import numpy as np
25
+ import mlflow
26
+
27
+ # Add project root to path
28
+ project_root = Path(__file__).parent.parent
29
+ sys.path.insert(0, str(project_root))
30
+
31
+ from scripts.models import BaselineCNN, count_parameters
32
+ from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val
33
+ from scripts.train import train_epoch, validate, evaluate_model, save_training_history
34
+ from scripts.data_loader import MnistDataloader
35
+ from scripts.augmentation import get_train_augmentation
36
+ from scripts.mlflow_setup import (
37
+ setup_mlflow, log_model_params, log_training_config,
38
+ log_data_info, log_system_info, log_metrics_epoch,
39
+ log_artifact_path, log_model_to_registry
40
+ )
41
+
42
+
43
+ def train_with_mlflow(
44
+ model: nn.Module,
45
+ train_loader: torch.utils.data.DataLoader,
46
+ val_loader: torch.utils.data.DataLoader,
47
+ test_loader: torch.utils.data.DataLoader,
48
+ config: dict,
49
+ run_name: str = None
50
+ ) -> dict:
51
+ """
52
+ Train model with full MLflow tracking.
53
+
54
+ Args:
55
+ model: PyTorch model to train
56
+ train_loader: Training data loader
57
+ val_loader: Validation data loader
58
+ test_loader: Test data loader
59
+ config: Training configuration dictionary
60
+ run_name: Optional name for MLflow run
61
+
62
+ Returns:
63
+ Training history dictionary
64
+ """
65
+ device = config['device']
66
+ num_epochs = config['num_epochs']
67
+ learning_rate = config['learning_rate']
68
+
69
+ # Setup MLflow
70
+ setup_mlflow("mnist-digit-classification")
71
+
72
+ # Start MLflow run
73
+ with mlflow.start_run(run_name=run_name):
74
+ print("\n" + "="*70)
75
+ print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
76
+ print("="*70 + "\n")
77
+
78
+ # Log all configuration
79
+ print("Logging configuration to MLflow...")
80
+ log_training_config(config)
81
+ log_model_params(model)
82
+ log_data_info(
83
+ train_size=len(train_loader.dataset),
84
+ val_size=len(val_loader.dataset),
85
+ test_size=len(test_loader.dataset),
86
+ num_classes=10,
87
+ augmentation=config.get('augmentation', False)
88
+ )
89
+ log_system_info()
90
+
91
+ # Log model architecture as text
92
+ total_params, trainable_params = count_parameters(model)
93
+ model_summary = f"""
94
+ Model: {model.__class__.__name__}
95
+ Total Parameters: {total_params:,}
96
+ Trainable Parameters: {trainable_params:,}
97
+ Device: {device}
98
+
99
+ Architecture:
100
+ {str(model)}
101
+ """
102
+ mlflow.log_text(model_summary, "model_architecture.txt")
103
+
104
+ # Setup training
105
+ criterion = nn.CrossEntropyLoss()
106
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
107
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
108
+ optimizer, mode='min', patience=3, factor=0.5, verbose=True
109
+ )
110
+
111
+ # Training history
112
+ history = {
113
+ 'train_loss': [],
114
+ 'train_accuracy': [],
115
+ 'val_loss': [],
116
+ 'val_accuracy': [],
117
+ 'learning_rate': []
118
+ }
119
+
120
+ best_val_loss = float('inf')
121
+ patience = 5
122
+ patience_counter = 0
123
+
124
+ print(f"\nStarting training for {num_epochs} epochs...")
125
+ print(f"Device: {device}")
126
+ total_p, _ = count_parameters(model)
127
+ print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
128
+ print("-" * 70)
129
+
130
+ for epoch in range(num_epochs):
131
+ # Train
132
+ train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
133
+
134
+ # Validate
135
+ val_metrics = validate(model, val_loader, criterion, device)
136
+
137
+ # Get current learning rate
138
+ current_lr = optimizer.param_groups[0]['lr']
139
+
140
+ # Update scheduler
141
+ scheduler.step(val_metrics['loss'])
142
+
143
+ # Save history
144
+ history['train_loss'].append(train_metrics['loss'])
145
+ history['train_accuracy'].append(train_metrics['accuracy'])
146
+ history['val_loss'].append(val_metrics['loss'])
147
+ history['val_accuracy'].append(val_metrics['accuracy'])
148
+ history['learning_rate'].append(current_lr)
149
+
150
+ # Log metrics to MLflow
151
+ mlflow_metrics = {
152
+ 'train_loss': train_metrics['loss'],
153
+ 'train_accuracy': train_metrics['accuracy'],
154
+ 'val_loss': val_metrics['loss'],
155
+ 'val_accuracy': val_metrics['accuracy'],
156
+ 'learning_rate': current_lr,
157
+ 'epoch': epoch + 1
158
+ }
159
+ log_metrics_epoch(mlflow_metrics, step=epoch)
160
+
161
+ # Print progress
162
+ print(f"Epoch {epoch+1}/{num_epochs} | "
163
+ f"Train Loss: {train_metrics['loss']:.4f} ({train_metrics['accuracy']:.2f}%) | "
164
+ f"Val Loss: {val_metrics['loss']:.4f} ({val_metrics['accuracy']:.2f}%) | "
165
+ f"LR: {current_lr:.6f}")
166
+
167
+ # Save best model
168
+ if val_metrics['loss'] < best_val_loss:
169
+ best_val_loss = val_metrics['loss']
170
+ best_epoch = epoch + 1
171
+ patience_counter = 0
172
+
173
+ # Save checkpoint
174
+ checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
175
+ torch.save({
176
+ 'epoch': epoch,
177
+ 'model_state_dict': model.state_dict(),
178
+ 'optimizer_state_dict': optimizer.state_dict(),
179
+ 'train_loss': train_metrics['loss'],
180
+ 'val_loss': val_metrics['loss'],
181
+ 'val_accuracy': val_metrics['accuracy'],
182
+ }, checkpoint_path)
183
+
184
+ print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
185
+
186
+ # Log model to MLflow
187
+ mlflow.pytorch.log_model(
188
+ model,
189
+ "model",
190
+ registered_model_name="mnist-cnn-baseline"
191
+ )
192
+ else:
193
+ patience_counter += 1
194
+
195
+ # Early stopping
196
+ if patience_counter >= patience:
197
+ print(f"\nEarly stopping triggered after {epoch+1} epochs")
198
+ mlflow.log_param("early_stopped", True)
199
+ mlflow.log_param("early_stop_epoch", epoch + 1)
200
+ break
201
+
202
+ print("-" * 70)
203
+ print(f"\nTraining complete!")
204
+ print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
205
+
206
+ # Log best metrics
207
+ mlflow.log_metrics({
208
+ 'best_epoch': best_epoch,
209
+ 'best_val_loss': best_val_loss,
210
+ 'final_train_loss': history['train_loss'][-1],
211
+ 'final_val_loss': history['val_loss'][-1]
212
+ })
213
+
214
+ # Evaluate on test set
215
+ print("\nEvaluating on test set...")
216
+ test_metrics = evaluate_model(model, test_loader, device)
217
+
218
+ test_accuracy = test_metrics['accuracy']
219
+ test_report = test_metrics['classification_report']
220
+
221
+ # Extract macro average metrics
222
+ test_precision = test_report['macro avg']['precision']
223
+ test_recall = test_report['macro avg']['recall']
224
+ test_f1_score = test_report['macro avg']['f1-score']
225
+
226
+ print(f"Test Accuracy: {test_accuracy:.2f}%")
227
+ print(f"Test Precision: {test_precision:.4f}")
228
+ print(f"Test Recall: {test_recall:.4f}")
229
+ print(f"Test F1-Score: {test_f1_score:.4f}")
230
+
231
+ # Log test metrics to MLflow
232
+ mlflow.log_metrics({
233
+ 'test_accuracy': test_accuracy,
234
+ 'test_precision': test_precision,
235
+ 'test_recall': test_recall,
236
+ 'test_f1_score': test_f1_score
237
+ })
238
+
239
+ # Save and log artifacts
240
+ print("\nSaving artifacts...")
241
+
242
+ # Save history
243
+ history_path = project_root / 'results' / 'mlflow_training_history.json'
244
+ history_path.parent.mkdir(exist_ok=True)
245
+ save_training_history(history, history_path)
246
+ log_artifact_path(str(history_path))
247
+
248
+ # Save test metrics
249
+ metrics_to_save = {
250
+ 'test_accuracy': test_accuracy,
251
+ 'test_precision': test_precision,
252
+ 'test_recall': test_recall,
253
+ 'test_f1_score': test_f1_score,
254
+ 'classification_report': test_report,
255
+ 'confusion_matrix': test_metrics['confusion_matrix'].tolist()
256
+ }
257
+ metrics_path = project_root / 'results' / 'mlflow_test_metrics.json'
258
+ with open(metrics_path, 'w') as f:
259
+ json.dump(metrics_to_save, f, indent=2)
260
+ log_artifact_path(str(metrics_path))
261
+
262
+ # Save model checkpoint
263
+ log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
264
+
265
+ # Log confusion matrix as JSON
266
+ conf_matrix_dict = {
267
+ f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
268
+ for i in range(len(test_metrics['confusion_matrix']))
269
+ }
270
+ mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
271
+
272
+ # Log classification report
273
+ mlflow.log_dict(test_report, "classification_report.json")
274
+
275
+ print(f"\n✓ All artifacts logged to MLflow")
276
+ print(f"View results: mlflow ui --backend-store-uri file:./mlruns")
277
+
278
+ return history
279
+
280
+
281
+ def main():
282
+ parser = argparse.ArgumentParser(description='Train MNIST CNN with MLflow tracking')
283
+ parser.add_argument('--epochs', type=int, default=20, help='Number of epochs (default: 20)')
284
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
285
+ parser.add_argument('--batch-size', type=int, default=64, help='Batch size (default: 64)')
286
+ parser.add_argument('--augment', action='store_true', help='Use data augmentation')
287
+ parser.add_argument('--run-name', type=str, default=None, help='MLflow run name')
288
+ parser.add_argument('--seed', type=int, default=42, help='Random seed (default: 42)')
289
+
290
+ args = parser.parse_args()
291
+
292
+ # Set random seeds
293
+ torch.manual_seed(args.seed)
294
+ np.random.seed(args.seed)
295
+ if torch.cuda.is_available():
296
+ torch.cuda.manual_seed(args.seed)
297
+
298
+ # Configuration
299
+ config = {
300
+ 'num_epochs': args.epochs,
301
+ 'learning_rate': args.lr,
302
+ 'batch_size': args.batch_size,
303
+ 'augmentation': args.augment,
304
+ 'random_seed': args.seed,
305
+ 'device': 'cuda' if torch.cuda.is_available() else 'cpu',
306
+ 'optimizer': 'Adam',
307
+ 'scheduler': 'ReduceLROnPlateau',
308
+ 'early_stopping_patience': 5
309
+ }
310
+
311
+ print("Training Configuration:")
312
+ print(json.dumps(config, indent=2))
313
+
314
+ # Load MNIST data
315
+ print("\nLoading MNIST data...")
316
+ data_path = project_root / 'data' / 'raw'
317
+ loader = MnistDataloader(
318
+ training_images_filepath=str(data_path / 'train-images.idx3-ubyte'),
319
+ training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'),
320
+ test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'),
321
+ test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
322
+ )
323
+ (x_train, y_train), (x_test, y_test) = loader.load_data()
324
+
325
+ # Split train/val
326
+ (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
327
+ x_train, y_train, val_split=0.15, random_seed=args.seed
328
+ )
329
+
330
+ # Create datasets with optional augmentation
331
+ augmentation = get_train_augmentation() if args.augment else None
332
+ train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
333
+ val_dataset = MnistDataset(x_val, y_val, transform=None)
334
+ test_dataset = MnistDataset(x_test, y_test, transform=None)
335
+
336
+ # Create data loaders
337
+ train_loader, val_loader = create_dataloaders(
338
+ train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
339
+ )
340
+ test_loader = torch.utils.data.DataLoader(
341
+ test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
342
+ )
343
+
344
+ print(f"Train: {len(train_loader.dataset)} samples")
345
+ print(f"Val: {len(val_loader.dataset)} samples")
346
+ print(f"Test: {len(test_loader.dataset)} samples")
347
+
348
+ # Create model
349
+ model = BaselineCNN().to(config['device'])
350
+
351
+ # Train with MLflow
352
+ history = train_with_mlflow(
353
+ model, train_loader, val_loader, test_loader,
354
+ config, run_name=args.run_name
355
+ )
356
+
357
+ print("\n" + "="*70)
358
+ print("Training complete! View MLflow dashboard:")
359
+ print(" ./scripts/launch_mlflow_ui.sh")
360
+ print(" or: mlflow ui --backend-store-uri file:./mlruns")
361
+ print("="*70)
362
+
363
+
364
+ if __name__ == '__main__':
365
+ main()