Spaces:
Sleeping
Sleeping
faizan commited on
Commit ·
e77a25a
1
Parent(s): 1c6a6f3
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
Browse files- Added ruff.toml config with E402 ignore for sys.path test pattern
- Fixed 3 E501 line length violations manually
- Auto-fixed 468 W293/W291 whitespace errors
- Codebase now passes all linting checks (0 errors)
- Ready for deployment (Phase 3)
- app.py +29 -25
- ruff.toml +13 -0
- scripts/augmentation.py +19 -19
- scripts/data_loader.py +17 -17
- scripts/data_quality.py +54 -54
- scripts/inference.py +57 -53
- scripts/mlflow_setup.py +32 -30
- scripts/models.py +61 -57
- scripts/preprocessing.py +35 -35
- scripts/test_data_loader.py +13 -13
- scripts/test_data_quality.py +7 -7
- scripts/test_preprocessing.py +30 -28
- scripts/test_train.py +15 -14
- scripts/train.py +46 -46
- scripts/train_baseline.py +75 -44
- scripts/train_with_mlflow.py +87 -61
app.py
CHANGED
|
@@ -19,21 +19,21 @@ print(f"Model loaded on {classifier.device}")
|
|
| 19 |
def predict_digit(image):
|
| 20 |
"""
|
| 21 |
Predict digit from user-drawn image.
|
| 22 |
-
|
| 23 |
Args:
|
| 24 |
image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
|
| 25 |
-
|
| 26 |
Returns:
|
| 27 |
Tuple of (predicted_digit, confidence_text, probability_dict)
|
| 28 |
"""
|
| 29 |
if image is None:
|
| 30 |
return "Please draw a digit", "", {}
|
| 31 |
-
|
| 32 |
# Handle different image formats from Gradio
|
| 33 |
if isinstance(image, dict):
|
| 34 |
# Sketchpad returns dict with 'composite' key
|
| 35 |
image = image.get('composite', image)
|
| 36 |
-
|
| 37 |
# Convert to PIL Image
|
| 38 |
if isinstance(image, np.ndarray):
|
| 39 |
# If RGB, convert to grayscale
|
|
@@ -43,29 +43,29 @@ def predict_digit(image):
|
|
| 43 |
image = image[:, :, 3] # Use alpha channel
|
| 44 |
else: # RGB
|
| 45 |
image = np.mean(image, axis=2).astype(np.uint8)
|
| 46 |
-
|
| 47 |
# Ensure values are in [0, 255]
|
| 48 |
if image.max() <= 1.0:
|
| 49 |
image = (image * 255).astype(np.uint8)
|
| 50 |
-
|
| 51 |
pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
|
| 52 |
else:
|
| 53 |
pil_image = image
|
| 54 |
-
|
| 55 |
# Get prediction
|
| 56 |
result = classifier.predict(pil_image)
|
| 57 |
-
|
| 58 |
# Format output
|
| 59 |
digit = result['digit']
|
| 60 |
confidence = result['confidence']
|
| 61 |
probabilities = result['probabilities']
|
| 62 |
-
|
| 63 |
# Create confidence text
|
| 64 |
confidence_text = f"Confidence: {confidence*100:.1f}%"
|
| 65 |
-
|
| 66 |
# Create probability dictionary for bar chart
|
| 67 |
prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
|
| 68 |
-
|
| 69 |
return digit, confidence_text, prob_dict
|
| 70 |
|
| 71 |
|
|
@@ -92,15 +92,15 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 92 |
gr.Markdown(
|
| 93 |
"""
|
| 94 |
# 🔢 Handwritten Digit Classifier
|
| 95 |
-
|
| 96 |
Draw a digit (0-9) in the box below and the AI will predict
|
| 97 |
what it is!
|
| 98 |
-
|
| 99 |
This model uses a Convolutional Neural Network (CNN) trained on
|
| 100 |
the MNIST dataset with **99.17% accuracy** on 10,000 test images.
|
| 101 |
"""
|
| 102 |
)
|
| 103 |
-
|
| 104 |
with gr.Row():
|
| 105 |
with gr.Column(scale=1):
|
| 106 |
# Sketchpad for drawing
|
|
@@ -116,12 +116,16 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 116 |
height=280,
|
| 117 |
width=280
|
| 118 |
)
|
| 119 |
-
|
| 120 |
# Buttons
|
| 121 |
with gr.Row():
|
| 122 |
-
predict_btn = gr.Button(
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
with gr.Column(scale=1):
|
| 126 |
# Prediction output
|
| 127 |
output_digit = gr.Textbox(
|
|
@@ -132,7 +136,7 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 132 |
max_lines=1,
|
| 133 |
interactive=False
|
| 134 |
)
|
| 135 |
-
|
| 136 |
output_confidence = gr.Textbox(
|
| 137 |
label="Confidence",
|
| 138 |
placeholder="",
|
|
@@ -141,13 +145,13 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 141 |
max_lines=1,
|
| 142 |
interactive=False
|
| 143 |
)
|
| 144 |
-
|
| 145 |
# Probability distribution
|
| 146 |
output_probs = gr.Label(
|
| 147 |
label="Probability Distribution",
|
| 148 |
num_top_classes=10
|
| 149 |
)
|
| 150 |
-
|
| 151 |
# Example images section
|
| 152 |
gr.Markdown("### 📝 Try these examples:")
|
| 153 |
gr.Examples(
|
|
@@ -157,7 +161,7 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 157 |
inputs=input_image,
|
| 158 |
label="Example digits"
|
| 159 |
)
|
| 160 |
-
|
| 161 |
# Model info
|
| 162 |
gr.Markdown(
|
| 163 |
"""
|
|
@@ -168,7 +172,7 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 168 |
- **Training**: MNIST dataset (60,000 images)
|
| 169 |
- **Test Accuracy**: 99.17%
|
| 170 |
- **Framework**: PyTorch 2.0.1
|
| 171 |
-
|
| 172 |
### 💡 Tips for best results:
|
| 173 |
- Draw the digit large and centered
|
| 174 |
- Use a thick brush stroke
|
|
@@ -176,14 +180,14 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
|
|
| 176 |
- Make sure the digit is clear and recognizable
|
| 177 |
"""
|
| 178 |
)
|
| 179 |
-
|
| 180 |
# Connect events
|
| 181 |
predict_btn.click(
|
| 182 |
fn=predict_digit,
|
| 183 |
inputs=input_image,
|
| 184 |
outputs=[output_digit, output_confidence, output_probs]
|
| 185 |
)
|
| 186 |
-
|
| 187 |
# Also predict on sketchpad change (real-time prediction)
|
| 188 |
input_image.change(
|
| 189 |
fn=predict_digit,
|
|
|
|
| 19 |
def predict_digit(image):
|
| 20 |
"""
|
| 21 |
Predict digit from user-drawn image.
|
| 22 |
+
|
| 23 |
Args:
|
| 24 |
image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
|
| 25 |
+
|
| 26 |
Returns:
|
| 27 |
Tuple of (predicted_digit, confidence_text, probability_dict)
|
| 28 |
"""
|
| 29 |
if image is None:
|
| 30 |
return "Please draw a digit", "", {}
|
| 31 |
+
|
| 32 |
# Handle different image formats from Gradio
|
| 33 |
if isinstance(image, dict):
|
| 34 |
# Sketchpad returns dict with 'composite' key
|
| 35 |
image = image.get('composite', image)
|
| 36 |
+
|
| 37 |
# Convert to PIL Image
|
| 38 |
if isinstance(image, np.ndarray):
|
| 39 |
# If RGB, convert to grayscale
|
|
|
|
| 43 |
image = image[:, :, 3] # Use alpha channel
|
| 44 |
else: # RGB
|
| 45 |
image = np.mean(image, axis=2).astype(np.uint8)
|
| 46 |
+
|
| 47 |
# Ensure values are in [0, 255]
|
| 48 |
if image.max() <= 1.0:
|
| 49 |
image = (image * 255).astype(np.uint8)
|
| 50 |
+
|
| 51 |
pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
|
| 52 |
else:
|
| 53 |
pil_image = image
|
| 54 |
+
|
| 55 |
# Get prediction
|
| 56 |
result = classifier.predict(pil_image)
|
| 57 |
+
|
| 58 |
# Format output
|
| 59 |
digit = result['digit']
|
| 60 |
confidence = result['confidence']
|
| 61 |
probabilities = result['probabilities']
|
| 62 |
+
|
| 63 |
# Create confidence text
|
| 64 |
confidence_text = f"Confidence: {confidence*100:.1f}%"
|
| 65 |
+
|
| 66 |
# Create probability dictionary for bar chart
|
| 67 |
prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
|
| 68 |
+
|
| 69 |
return digit, confidence_text, prob_dict
|
| 70 |
|
| 71 |
|
|
|
|
| 92 |
gr.Markdown(
|
| 93 |
"""
|
| 94 |
# 🔢 Handwritten Digit Classifier
|
| 95 |
+
|
| 96 |
Draw a digit (0-9) in the box below and the AI will predict
|
| 97 |
what it is!
|
| 98 |
+
|
| 99 |
This model uses a Convolutional Neural Network (CNN) trained on
|
| 100 |
the MNIST dataset with **99.17% accuracy** on 10,000 test images.
|
| 101 |
"""
|
| 102 |
)
|
| 103 |
+
|
| 104 |
with gr.Row():
|
| 105 |
with gr.Column(scale=1):
|
| 106 |
# Sketchpad for drawing
|
|
|
|
| 116 |
height=280,
|
| 117 |
width=280
|
| 118 |
)
|
| 119 |
+
|
| 120 |
# Buttons
|
| 121 |
with gr.Row():
|
| 122 |
+
predict_btn = gr.Button(
|
| 123 |
+
"🔍 Predict", variant="primary", scale=2
|
| 124 |
+
)
|
| 125 |
+
clear_btn = gr.ClearButton(
|
| 126 |
+
components=[input_image], value="🗑️ Clear", scale=1
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
with gr.Column(scale=1):
|
| 130 |
# Prediction output
|
| 131 |
output_digit = gr.Textbox(
|
|
|
|
| 136 |
max_lines=1,
|
| 137 |
interactive=False
|
| 138 |
)
|
| 139 |
+
|
| 140 |
output_confidence = gr.Textbox(
|
| 141 |
label="Confidence",
|
| 142 |
placeholder="",
|
|
|
|
| 145 |
max_lines=1,
|
| 146 |
interactive=False
|
| 147 |
)
|
| 148 |
+
|
| 149 |
# Probability distribution
|
| 150 |
output_probs = gr.Label(
|
| 151 |
label="Probability Distribution",
|
| 152 |
num_top_classes=10
|
| 153 |
)
|
| 154 |
+
|
| 155 |
# Example images section
|
| 156 |
gr.Markdown("### 📝 Try these examples:")
|
| 157 |
gr.Examples(
|
|
|
|
| 161 |
inputs=input_image,
|
| 162 |
label="Example digits"
|
| 163 |
)
|
| 164 |
+
|
| 165 |
# Model info
|
| 166 |
gr.Markdown(
|
| 167 |
"""
|
|
|
|
| 172 |
- **Training**: MNIST dataset (60,000 images)
|
| 173 |
- **Test Accuracy**: 99.17%
|
| 174 |
- **Framework**: PyTorch 2.0.1
|
| 175 |
+
|
| 176 |
### 💡 Tips for best results:
|
| 177 |
- Draw the digit large and centered
|
| 178 |
- Use a thick brush stroke
|
|
|
|
| 180 |
- Make sure the digit is clear and recognizable
|
| 181 |
"""
|
| 182 |
)
|
| 183 |
+
|
| 184 |
# Connect events
|
| 185 |
predict_btn.click(
|
| 186 |
fn=predict_digit,
|
| 187 |
inputs=input_image,
|
| 188 |
outputs=[output_digit, output_confidence, output_probs]
|
| 189 |
)
|
| 190 |
+
|
| 191 |
# Also predict on sketchpad change (real-time prediction)
|
| 192 |
input_image.change(
|
| 193 |
fn=predict_digit,
|
ruff.toml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ruff configuration for MNIST project
|
| 2 |
+
|
| 3 |
+
select = ["E", "F", "W"]
|
| 4 |
+
ignore = []
|
| 5 |
+
|
| 6 |
+
# Line length
|
| 7 |
+
line-length = 88
|
| 8 |
+
|
| 9 |
+
[per-file-ignores]
|
| 10 |
+
# Ignore E402 (module level import not at top) for scripts that modify sys.path
|
| 11 |
+
"scripts/test_*.py" = ["E402"]
|
| 12 |
+
"scripts/train_*.py" = ["E402"]
|
| 13 |
+
"scripts/inference.py" = ["E402"]
|
scripts/augmentation.py
CHANGED
|
@@ -11,10 +11,10 @@ These augmentations are applied on-the-fly during training for infinite variatio
|
|
| 11 |
Usage:
|
| 12 |
from scripts.augmentation import get_train_augmentation, get_val_augmentation
|
| 13 |
from scripts.preprocessing import MnistDataset
|
| 14 |
-
|
| 15 |
# Training with augmentation
|
| 16 |
train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation())
|
| 17 |
-
|
| 18 |
# Validation/test without augmentation
|
| 19 |
val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation())
|
| 20 |
"""
|
|
@@ -26,14 +26,14 @@ import torch
|
|
| 26 |
def get_train_augmentation():
|
| 27 |
"""
|
| 28 |
Get augmentation pipeline for training data.
|
| 29 |
-
|
| 30 |
Applies realistic transformations that preserve digit readability:
|
| 31 |
- Rotation: ±15° (typical handwriting angle variation)
|
| 32 |
- Translation: ±10% (off-center digits)
|
| 33 |
- Scaling: 90-110% (size variation)
|
| 34 |
-
|
| 35 |
Note: Normalization happens in MnistDataset, not here.
|
| 36 |
-
|
| 37 |
Returns:
|
| 38 |
torchvision.transforms.Compose: Composition of augmentation transforms
|
| 39 |
"""
|
|
@@ -44,7 +44,7 @@ def get_train_augmentation():
|
|
| 44 |
interpolation=transforms.InterpolationMode.BILINEAR,
|
| 45 |
fill=0 # Fill with black (background)
|
| 46 |
),
|
| 47 |
-
|
| 48 |
# Random translation and scaling (no additional rotation)
|
| 49 |
transforms.RandomAffine(
|
| 50 |
degrees=0, # No rotation here (already done above)
|
|
@@ -59,10 +59,10 @@ def get_train_augmentation():
|
|
| 59 |
def get_val_augmentation():
|
| 60 |
"""
|
| 61 |
Get augmentation pipeline for validation/test data.
|
| 62 |
-
|
| 63 |
No augmentation is applied - returns identity transform.
|
| 64 |
This ensures fair evaluation on original unmodified data.
|
| 65 |
-
|
| 66 |
Returns:
|
| 67 |
None (no transforms)
|
| 68 |
"""
|
|
@@ -72,12 +72,12 @@ def get_val_augmentation():
|
|
| 72 |
def get_mild_augmentation():
|
| 73 |
"""
|
| 74 |
Get milder augmentation pipeline (conservative settings).
|
| 75 |
-
|
| 76 |
Use this if standard augmentation is too aggressive:
|
| 77 |
- Rotation: ±10° (reduced from ±15°)
|
| 78 |
- Translation: ±5% (reduced from ±10%)
|
| 79 |
- Scaling: 95-105% (reduced from 90-110%)
|
| 80 |
-
|
| 81 |
Returns:
|
| 82 |
torchvision.transforms.Compose: Mild augmentation transforms
|
| 83 |
"""
|
|
@@ -100,13 +100,13 @@ def get_mild_augmentation():
|
|
| 100 |
def get_aggressive_augmentation():
|
| 101 |
"""
|
| 102 |
Get aggressive augmentation pipeline (stronger settings).
|
| 103 |
-
|
| 104 |
Use with caution - may distort digits beyond recognition:
|
| 105 |
- Rotation: ±20°
|
| 106 |
- Translation: ±15%
|
| 107 |
- Scaling: 80-120%
|
| 108 |
- Elastic deformation (optional, commented out)
|
| 109 |
-
|
| 110 |
Returns:
|
| 111 |
torchvision.transforms.Compose: Aggressive augmentation transforms
|
| 112 |
"""
|
|
@@ -131,14 +131,14 @@ def get_aggressive_augmentation():
|
|
| 131 |
def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9):
|
| 132 |
"""
|
| 133 |
Apply augmentation multiple times to visualize variations.
|
| 134 |
-
|
| 135 |
Useful for debugging and understanding augmentation effects.
|
| 136 |
-
|
| 137 |
Args:
|
| 138 |
image: Single image tensor (1, 28, 28)
|
| 139 |
transform: Augmentation transform to apply
|
| 140 |
num_samples: Number of augmented versions to generate
|
| 141 |
-
|
| 142 |
Returns:
|
| 143 |
list: List of augmented image tensors
|
| 144 |
"""
|
|
@@ -149,7 +149,7 @@ def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9
|
|
| 149 |
else:
|
| 150 |
aug_img = image
|
| 151 |
augmented_images.append(aug_img)
|
| 152 |
-
|
| 153 |
return augmented_images
|
| 154 |
|
| 155 |
|
|
@@ -165,10 +165,10 @@ AUGMENTATION_PRESETS = {
|
|
| 165 |
def get_augmentation_by_name(preset_name: str = 'standard'):
|
| 166 |
"""
|
| 167 |
Get augmentation pipeline by preset name.
|
| 168 |
-
|
| 169 |
Args:
|
| 170 |
preset_name: One of ['none', 'mild', 'standard', 'aggressive']
|
| 171 |
-
|
| 172 |
Returns:
|
| 173 |
Augmentation transform or None
|
| 174 |
"""
|
|
@@ -177,6 +177,6 @@ def get_augmentation_by_name(preset_name: str = 'standard'):
|
|
| 177 |
f"Unknown preset '{preset_name}'. "
|
| 178 |
f"Choose from: {list(AUGMENTATION_PRESETS.keys())}"
|
| 179 |
)
|
| 180 |
-
|
| 181 |
preset = AUGMENTATION_PRESETS[preset_name]
|
| 182 |
return preset() if callable(preset) else preset
|
|
|
|
| 11 |
Usage:
|
| 12 |
from scripts.augmentation import get_train_augmentation, get_val_augmentation
|
| 13 |
from scripts.preprocessing import MnistDataset
|
| 14 |
+
|
| 15 |
# Training with augmentation
|
| 16 |
train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation())
|
| 17 |
+
|
| 18 |
# Validation/test without augmentation
|
| 19 |
val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation())
|
| 20 |
"""
|
|
|
|
| 26 |
def get_train_augmentation():
|
| 27 |
"""
|
| 28 |
Get augmentation pipeline for training data.
|
| 29 |
+
|
| 30 |
Applies realistic transformations that preserve digit readability:
|
| 31 |
- Rotation: ±15° (typical handwriting angle variation)
|
| 32 |
- Translation: ±10% (off-center digits)
|
| 33 |
- Scaling: 90-110% (size variation)
|
| 34 |
+
|
| 35 |
Note: Normalization happens in MnistDataset, not here.
|
| 36 |
+
|
| 37 |
Returns:
|
| 38 |
torchvision.transforms.Compose: Composition of augmentation transforms
|
| 39 |
"""
|
|
|
|
| 44 |
interpolation=transforms.InterpolationMode.BILINEAR,
|
| 45 |
fill=0 # Fill with black (background)
|
| 46 |
),
|
| 47 |
+
|
| 48 |
# Random translation and scaling (no additional rotation)
|
| 49 |
transforms.RandomAffine(
|
| 50 |
degrees=0, # No rotation here (already done above)
|
|
|
|
| 59 |
def get_val_augmentation():
|
| 60 |
"""
|
| 61 |
Get augmentation pipeline for validation/test data.
|
| 62 |
+
|
| 63 |
No augmentation is applied - returns identity transform.
|
| 64 |
This ensures fair evaluation on original unmodified data.
|
| 65 |
+
|
| 66 |
Returns:
|
| 67 |
None (no transforms)
|
| 68 |
"""
|
|
|
|
| 72 |
def get_mild_augmentation():
|
| 73 |
"""
|
| 74 |
Get milder augmentation pipeline (conservative settings).
|
| 75 |
+
|
| 76 |
Use this if standard augmentation is too aggressive:
|
| 77 |
- Rotation: ±10° (reduced from ±15°)
|
| 78 |
- Translation: ±5% (reduced from ±10%)
|
| 79 |
- Scaling: 95-105% (reduced from 90-110%)
|
| 80 |
+
|
| 81 |
Returns:
|
| 82 |
torchvision.transforms.Compose: Mild augmentation transforms
|
| 83 |
"""
|
|
|
|
| 100 |
def get_aggressive_augmentation():
|
| 101 |
"""
|
| 102 |
Get aggressive augmentation pipeline (stronger settings).
|
| 103 |
+
|
| 104 |
Use with caution - may distort digits beyond recognition:
|
| 105 |
- Rotation: ±20°
|
| 106 |
- Translation: ±15%
|
| 107 |
- Scaling: 80-120%
|
| 108 |
- Elastic deformation (optional, commented out)
|
| 109 |
+
|
| 110 |
Returns:
|
| 111 |
torchvision.transforms.Compose: Aggressive augmentation transforms
|
| 112 |
"""
|
|
|
|
| 131 |
def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9):
|
| 132 |
"""
|
| 133 |
Apply augmentation multiple times to visualize variations.
|
| 134 |
+
|
| 135 |
Useful for debugging and understanding augmentation effects.
|
| 136 |
+
|
| 137 |
Args:
|
| 138 |
image: Single image tensor (1, 28, 28)
|
| 139 |
transform: Augmentation transform to apply
|
| 140 |
num_samples: Number of augmented versions to generate
|
| 141 |
+
|
| 142 |
Returns:
|
| 143 |
list: List of augmented image tensors
|
| 144 |
"""
|
|
|
|
| 149 |
else:
|
| 150 |
aug_img = image
|
| 151 |
augmented_images.append(aug_img)
|
| 152 |
+
|
| 153 |
return augmented_images
|
| 154 |
|
| 155 |
|
|
|
|
| 165 |
def get_augmentation_by_name(preset_name: str = 'standard'):
|
| 166 |
"""
|
| 167 |
Get augmentation pipeline by preset name.
|
| 168 |
+
|
| 169 |
Args:
|
| 170 |
preset_name: One of ['none', 'mild', 'standard', 'aggressive']
|
| 171 |
+
|
| 172 |
Returns:
|
| 173 |
Augmentation transform or None
|
| 174 |
"""
|
|
|
|
| 177 |
f"Unknown preset '{preset_name}'. "
|
| 178 |
f"Choose from: {list(AUGMENTATION_PRESETS.keys())}"
|
| 179 |
)
|
| 180 |
+
|
| 181 |
preset = AUGMENTATION_PRESETS[preset_name]
|
| 182 |
return preset() if callable(preset) else preset
|
scripts/data_loader.py
CHANGED
|
@@ -15,17 +15,17 @@ from numpy.typing import NDArray
|
|
| 15 |
class MnistDataloader:
|
| 16 |
"""
|
| 17 |
Load MNIST handwritten digit dataset from IDX binary files.
|
| 18 |
-
|
| 19 |
The MNIST dataset uses a custom IDX binary format with magic numbers
|
| 20 |
to identify image (2051) and label (2049) files.
|
| 21 |
-
|
| 22 |
Attributes:
|
| 23 |
training_images_filepath: Path to training images IDX file
|
| 24 |
training_labels_filepath: Path to training labels IDX file
|
| 25 |
test_images_filepath: Path to test images IDX file
|
| 26 |
test_labels_filepath: Path to test labels IDX file
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
def __init__(
|
| 30 |
self,
|
| 31 |
training_images_filepath: str,
|
|
@@ -35,13 +35,13 @@ class MnistDataloader:
|
|
| 35 |
) -> None:
|
| 36 |
"""
|
| 37 |
Initialize MNIST data loader with file paths.
|
| 38 |
-
|
| 39 |
Args:
|
| 40 |
training_images_filepath: Path to training images (.idx3-ubyte)
|
| 41 |
training_labels_filepath: Path to training labels (.idx1-ubyte)
|
| 42 |
test_images_filepath: Path to test images (.idx3-ubyte)
|
| 43 |
test_labels_filepath: Path to test labels (.idx1-ubyte)
|
| 44 |
-
|
| 45 |
Raises:
|
| 46 |
FileNotFoundError: If any of the specified files don't exist
|
| 47 |
"""
|
|
@@ -49,7 +49,7 @@ class MnistDataloader:
|
|
| 49 |
self.training_labels_filepath = training_labels_filepath
|
| 50 |
self.test_images_filepath = test_images_filepath
|
| 51 |
self.test_labels_filepath = test_labels_filepath
|
| 52 |
-
|
| 53 |
# Verify files exist
|
| 54 |
for filepath in [
|
| 55 |
training_images_filepath,
|
|
@@ -59,7 +59,7 @@ class MnistDataloader:
|
|
| 59 |
]:
|
| 60 |
if not Path(filepath).exists():
|
| 61 |
raise FileNotFoundError(f"MNIST data file not found: {filepath}")
|
| 62 |
-
|
| 63 |
def read_images_labels(
|
| 64 |
self,
|
| 65 |
images_filepath: str,
|
|
@@ -67,16 +67,16 @@ class MnistDataloader:
|
|
| 67 |
) -> Tuple[List[NDArray[np.uint8]], List[int]]:
|
| 68 |
"""
|
| 69 |
Read images and labels from IDX binary files.
|
| 70 |
-
|
| 71 |
Args:
|
| 72 |
images_filepath: Path to images IDX file
|
| 73 |
labels_filepath: Path to labels IDX file
|
| 74 |
-
|
| 75 |
Returns:
|
| 76 |
Tuple of (images, labels) where:
|
| 77 |
- images: List of 28x28 numpy arrays (uint8)
|
| 78 |
- labels: List of integer labels (0-9)
|
| 79 |
-
|
| 80 |
Raises:
|
| 81 |
ValueError: If magic numbers don't match expected values
|
| 82 |
"""
|
|
@@ -90,7 +90,7 @@ class MnistDataloader:
|
|
| 90 |
f'Expected 2049, got {magic}'
|
| 91 |
)
|
| 92 |
labels = array("B", file.read())
|
| 93 |
-
|
| 94 |
# Read images
|
| 95 |
with open(images_filepath, 'rb') as file:
|
| 96 |
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
|
|
@@ -100,12 +100,12 @@ class MnistDataloader:
|
|
| 100 |
f'Expected 2051, got {magic}'
|
| 101 |
)
|
| 102 |
image_data = array("B", file.read())
|
| 103 |
-
|
| 104 |
# Convert to list of 28x28 arrays
|
| 105 |
images = []
|
| 106 |
for i in range(size):
|
| 107 |
images.append([0] * rows * cols)
|
| 108 |
-
|
| 109 |
for i in range(size):
|
| 110 |
img = np.array(
|
| 111 |
image_data[i * rows * cols:(i + 1) * rows * cols],
|
|
@@ -113,23 +113,23 @@ class MnistDataloader:
|
|
| 113 |
)
|
| 114 |
img = img.reshape(rows, cols)
|
| 115 |
images[i][:] = img
|
| 116 |
-
|
| 117 |
return images, list(labels)
|
| 118 |
-
|
| 119 |
def load_data(self) -> Tuple[
|
| 120 |
Tuple[List[NDArray[np.uint8]], List[int]],
|
| 121 |
Tuple[List[NDArray[np.uint8]], List[int]]
|
| 122 |
]:
|
| 123 |
"""
|
| 124 |
Load complete MNIST dataset (training and test sets).
|
| 125 |
-
|
| 126 |
Returns:
|
| 127 |
Tuple of ((x_train, y_train), (x_test, y_test)) where:
|
| 128 |
- x_train: 60,000 training images (28x28 uint8 arrays)
|
| 129 |
- y_train: 60,000 training labels (0-9)
|
| 130 |
- x_test: 10,000 test images (28x28 uint8 arrays)
|
| 131 |
- y_test: 10,000 test labels (0-9)
|
| 132 |
-
|
| 133 |
Example:
|
| 134 |
>>> loader = MnistDataloader(
|
| 135 |
... 'data/raw/train-images.idx3-ubyte',
|
|
|
|
| 15 |
class MnistDataloader:
|
| 16 |
"""
|
| 17 |
Load MNIST handwritten digit dataset from IDX binary files.
|
| 18 |
+
|
| 19 |
The MNIST dataset uses a custom IDX binary format with magic numbers
|
| 20 |
to identify image (2051) and label (2049) files.
|
| 21 |
+
|
| 22 |
Attributes:
|
| 23 |
training_images_filepath: Path to training images IDX file
|
| 24 |
training_labels_filepath: Path to training labels IDX file
|
| 25 |
test_images_filepath: Path to test images IDX file
|
| 26 |
test_labels_filepath: Path to test labels IDX file
|
| 27 |
"""
|
| 28 |
+
|
| 29 |
def __init__(
|
| 30 |
self,
|
| 31 |
training_images_filepath: str,
|
|
|
|
| 35 |
) -> None:
|
| 36 |
"""
|
| 37 |
Initialize MNIST data loader with file paths.
|
| 38 |
+
|
| 39 |
Args:
|
| 40 |
training_images_filepath: Path to training images (.idx3-ubyte)
|
| 41 |
training_labels_filepath: Path to training labels (.idx1-ubyte)
|
| 42 |
test_images_filepath: Path to test images (.idx3-ubyte)
|
| 43 |
test_labels_filepath: Path to test labels (.idx1-ubyte)
|
| 44 |
+
|
| 45 |
Raises:
|
| 46 |
FileNotFoundError: If any of the specified files don't exist
|
| 47 |
"""
|
|
|
|
| 49 |
self.training_labels_filepath = training_labels_filepath
|
| 50 |
self.test_images_filepath = test_images_filepath
|
| 51 |
self.test_labels_filepath = test_labels_filepath
|
| 52 |
+
|
| 53 |
# Verify files exist
|
| 54 |
for filepath in [
|
| 55 |
training_images_filepath,
|
|
|
|
| 59 |
]:
|
| 60 |
if not Path(filepath).exists():
|
| 61 |
raise FileNotFoundError(f"MNIST data file not found: {filepath}")
|
| 62 |
+
|
| 63 |
def read_images_labels(
|
| 64 |
self,
|
| 65 |
images_filepath: str,
|
|
|
|
| 67 |
) -> Tuple[List[NDArray[np.uint8]], List[int]]:
|
| 68 |
"""
|
| 69 |
Read images and labels from IDX binary files.
|
| 70 |
+
|
| 71 |
Args:
|
| 72 |
images_filepath: Path to images IDX file
|
| 73 |
labels_filepath: Path to labels IDX file
|
| 74 |
+
|
| 75 |
Returns:
|
| 76 |
Tuple of (images, labels) where:
|
| 77 |
- images: List of 28x28 numpy arrays (uint8)
|
| 78 |
- labels: List of integer labels (0-9)
|
| 79 |
+
|
| 80 |
Raises:
|
| 81 |
ValueError: If magic numbers don't match expected values
|
| 82 |
"""
|
|
|
|
| 90 |
f'Expected 2049, got {magic}'
|
| 91 |
)
|
| 92 |
labels = array("B", file.read())
|
| 93 |
+
|
| 94 |
# Read images
|
| 95 |
with open(images_filepath, 'rb') as file:
|
| 96 |
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
|
|
|
|
| 100 |
f'Expected 2051, got {magic}'
|
| 101 |
)
|
| 102 |
image_data = array("B", file.read())
|
| 103 |
+
|
| 104 |
# Convert to list of 28x28 arrays
|
| 105 |
images = []
|
| 106 |
for i in range(size):
|
| 107 |
images.append([0] * rows * cols)
|
| 108 |
+
|
| 109 |
for i in range(size):
|
| 110 |
img = np.array(
|
| 111 |
image_data[i * rows * cols:(i + 1) * rows * cols],
|
|
|
|
| 113 |
)
|
| 114 |
img = img.reshape(rows, cols)
|
| 115 |
images[i][:] = img
|
| 116 |
+
|
| 117 |
return images, list(labels)
|
| 118 |
+
|
| 119 |
def load_data(self) -> Tuple[
|
| 120 |
Tuple[List[NDArray[np.uint8]], List[int]],
|
| 121 |
Tuple[List[NDArray[np.uint8]], List[int]]
|
| 122 |
]:
|
| 123 |
"""
|
| 124 |
Load complete MNIST dataset (training and test sets).
|
| 125 |
+
|
| 126 |
Returns:
|
| 127 |
Tuple of ((x_train, y_train), (x_test, y_test)) where:
|
| 128 |
- x_train: 60,000 training images (28x28 uint8 arrays)
|
| 129 |
- y_train: 60,000 training labels (0-9)
|
| 130 |
- x_test: 10,000 test images (28x28 uint8 arrays)
|
| 131 |
- y_test: 10,000 test labels (0-9)
|
| 132 |
+
|
| 133 |
Example:
|
| 134 |
>>> loader = MnistDataloader(
|
| 135 |
... 'data/raw/train-images.idx3-ubyte',
|
scripts/data_quality.py
CHANGED
|
@@ -10,9 +10,9 @@ This module provides functions to systematically check MNIST dataset quality:
|
|
| 10 |
|
| 11 |
Usage:
|
| 12 |
from scripts.data_quality import generate_quality_report
|
| 13 |
-
|
| 14 |
report = generate_quality_report(
|
| 15 |
-
(x_train, y_train),
|
| 16 |
(x_test, y_test)
|
| 17 |
)
|
| 18 |
"""
|
|
@@ -24,16 +24,16 @@ from collections import Counter
|
|
| 24 |
|
| 25 |
|
| 26 |
def check_missing_values(
|
| 27 |
-
images: List[NDArray[np.uint8]],
|
| 28 |
labels: List[int]
|
| 29 |
) -> Dict[str, Any]:
|
| 30 |
"""
|
| 31 |
Check for NaN or missing values in images and labels.
|
| 32 |
-
|
| 33 |
Args:
|
| 34 |
images: List of image arrays (each 28x28)
|
| 35 |
labels: List of integer labels (0-9)
|
| 36 |
-
|
| 37 |
Returns:
|
| 38 |
dict: Contains 'has_missing_values', 'missing_count', 'details'
|
| 39 |
"""
|
|
@@ -43,12 +43,12 @@ def check_missing_values(
|
|
| 43 |
img_array = np.array(img)
|
| 44 |
if np.isnan(img_array).any():
|
| 45 |
images_with_nan.append(idx)
|
| 46 |
-
|
| 47 |
# Check labels for None
|
| 48 |
labels_with_none = [idx for idx, label in enumerate(labels) if label is None]
|
| 49 |
-
|
| 50 |
has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0
|
| 51 |
-
|
| 52 |
return {
|
| 53 |
'has_missing_values': has_missing,
|
| 54 |
'missing_count': len(images_with_nan) + len(labels_with_none),
|
|
@@ -66,25 +66,25 @@ def check_missing_values(
|
|
| 66 |
def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
| 67 |
"""
|
| 68 |
Identify pixels outside valid range [0, 255] for uint8 images.
|
| 69 |
-
|
| 70 |
Args:
|
| 71 |
images: List of image arrays (each 28x28)
|
| 72 |
-
|
| 73 |
Returns:
|
| 74 |
dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details'
|
| 75 |
"""
|
| 76 |
outlier_images = []
|
| 77 |
pixel_min = 255
|
| 78 |
pixel_max = 0
|
| 79 |
-
|
| 80 |
for idx, img in enumerate(images):
|
| 81 |
img_array = np.array(img)
|
| 82 |
img_min = img_array.min()
|
| 83 |
img_max = img_array.max()
|
| 84 |
-
|
| 85 |
pixel_min = min(pixel_min, img_min)
|
| 86 |
pixel_max = max(pixel_max, img_max)
|
| 87 |
-
|
| 88 |
# Check for values outside [0, 255]
|
| 89 |
if img_min < 0 or img_max > 255:
|
| 90 |
outlier_images.append({
|
|
@@ -92,7 +92,7 @@ def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
|
| 92 |
'min': int(img_min),
|
| 93 |
'max': int(img_max)
|
| 94 |
})
|
| 95 |
-
|
| 96 |
return {
|
| 97 |
'has_outliers': len(outlier_images) > 0,
|
| 98 |
'outlier_count': len(outlier_images),
|
|
@@ -109,44 +109,44 @@ def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
|
| 109 |
def check_class_balance(labels: List[int]) -> Dict[str, Any]:
|
| 110 |
"""
|
| 111 |
Compute samples per class and calculate imbalance ratio.
|
| 112 |
-
|
| 113 |
Imbalance ratio = max_count / min_count
|
| 114 |
A ratio < 1.2 indicates good balance (< 20% difference)
|
| 115 |
-
|
| 116 |
Args:
|
| 117 |
labels: List of integer labels (0-9)
|
| 118 |
-
|
| 119 |
Returns:
|
| 120 |
dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details'
|
| 121 |
"""
|
| 122 |
class_counts = Counter(labels)
|
| 123 |
-
|
| 124 |
# Ensure all 10 digits present
|
| 125 |
for digit in range(10):
|
| 126 |
if digit not in class_counts:
|
| 127 |
class_counts[digit] = 0
|
| 128 |
-
|
| 129 |
counts = list(class_counts.values())
|
| 130 |
max_count = max(counts)
|
| 131 |
min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero
|
| 132 |
-
|
| 133 |
imbalance_ratio = max_count / min_count
|
| 134 |
is_balanced = imbalance_ratio < 1.2 # Less than 20% difference
|
| 135 |
-
|
| 136 |
# Per-class percentages
|
| 137 |
total = len(labels)
|
| 138 |
class_percentages = {
|
| 139 |
-
digit: (count / total) * 100
|
| 140 |
for digit, count in class_counts.items()
|
| 141 |
}
|
| 142 |
-
|
| 143 |
return {
|
| 144 |
'is_balanced': is_balanced,
|
| 145 |
'imbalance_ratio': round(imbalance_ratio, 3),
|
| 146 |
'threshold': 1.2,
|
| 147 |
'class_counts': dict(sorted(class_counts.items())),
|
| 148 |
'class_percentages': {
|
| 149 |
-
k: round(v, 2)
|
| 150 |
for k, v in sorted(class_percentages.items())
|
| 151 |
},
|
| 152 |
'details': {
|
|
@@ -162,16 +162,16 @@ def check_class_balance(labels: List[int]) -> Dict[str, Any]:
|
|
| 162 |
def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
| 163 |
"""
|
| 164 |
Verify all images are 28x28 as expected for MNIST.
|
| 165 |
-
|
| 166 |
Args:
|
| 167 |
images: List of image arrays
|
| 168 |
-
|
| 169 |
Returns:
|
| 170 |
dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details'
|
| 171 |
"""
|
| 172 |
expected_shape = (28, 28)
|
| 173 |
invalid_images = []
|
| 174 |
-
|
| 175 |
for idx, img in enumerate(images):
|
| 176 |
img_array = np.array(img)
|
| 177 |
if img_array.shape != expected_shape:
|
|
@@ -179,7 +179,7 @@ def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
|
| 179 |
'index': idx,
|
| 180 |
'shape': img_array.shape
|
| 181 |
})
|
| 182 |
-
|
| 183 |
return {
|
| 184 |
'all_correct_shape': len(invalid_images) == 0,
|
| 185 |
'expected_shape': expected_shape,
|
|
@@ -194,16 +194,16 @@ def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
|
| 194 |
def check_label_validity(labels: List[int]) -> Dict[str, Any]:
|
| 195 |
"""
|
| 196 |
Verify all labels are valid integers in range [0, 9].
|
| 197 |
-
|
| 198 |
Args:
|
| 199 |
labels: List of labels
|
| 200 |
-
|
| 201 |
Returns:
|
| 202 |
dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details'
|
| 203 |
"""
|
| 204 |
valid_range = set(range(10))
|
| 205 |
invalid_labels = []
|
| 206 |
-
|
| 207 |
for idx, label in enumerate(labels):
|
| 208 |
if not isinstance(label, int) or label not in valid_range:
|
| 209 |
invalid_labels.append({
|
|
@@ -211,9 +211,9 @@ def check_label_validity(labels: List[int]) -> Dict[str, Any]:
|
|
| 211 |
'value': label,
|
| 212 |
'type': type(label).__name__
|
| 213 |
})
|
| 214 |
-
|
| 215 |
unique_labels = sorted(set(labels))
|
| 216 |
-
|
| 217 |
return {
|
| 218 |
'all_valid': len(invalid_labels) == 0,
|
| 219 |
'expected_range': [0, 9],
|
|
@@ -232,17 +232,17 @@ def generate_quality_report(
|
|
| 232 |
) -> Dict[str, Any]:
|
| 233 |
"""
|
| 234 |
Run all quality checks on training and test sets.
|
| 235 |
-
|
| 236 |
Args:
|
| 237 |
train_data: Tuple of (train_images, train_labels)
|
| 238 |
test_data: Tuple of (test_images, test_labels)
|
| 239 |
-
|
| 240 |
Returns:
|
| 241 |
dict: Comprehensive quality report with all check results
|
| 242 |
"""
|
| 243 |
x_train, y_train = train_data
|
| 244 |
x_test, y_test = test_data
|
| 245 |
-
|
| 246 |
report = {
|
| 247 |
'dataset_info': {
|
| 248 |
'train_samples': len(x_train),
|
|
@@ -264,7 +264,7 @@ def generate_quality_report(
|
|
| 264 |
'label_validity': check_label_validity(y_test)
|
| 265 |
}
|
| 266 |
}
|
| 267 |
-
|
| 268 |
# Overall quality assessment
|
| 269 |
all_checks_pass = (
|
| 270 |
not report['training_set']['missing_values']['has_missing_values'] and
|
|
@@ -276,7 +276,7 @@ def generate_quality_report(
|
|
| 276 |
report['test_set']['image_dimensions']['all_correct_shape'] and
|
| 277 |
report['test_set']['label_validity']['all_valid']
|
| 278 |
)
|
| 279 |
-
|
| 280 |
report['summary'] = {
|
| 281 |
'all_checks_pass': all_checks_pass,
|
| 282 |
'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND',
|
|
@@ -284,22 +284,22 @@ def generate_quality_report(
|
|
| 284 |
'test_balanced': report['test_set']['class_balance']['is_balanced'],
|
| 285 |
'recommendations': _generate_recommendations(report)
|
| 286 |
}
|
| 287 |
-
|
| 288 |
return report
|
| 289 |
|
| 290 |
|
| 291 |
def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
| 292 |
"""
|
| 293 |
Generate recommendations based on quality check results.
|
| 294 |
-
|
| 295 |
Args:
|
| 296 |
report: Quality report dictionary
|
| 297 |
-
|
| 298 |
Returns:
|
| 299 |
list: List of recommendation strings
|
| 300 |
"""
|
| 301 |
recommendations = []
|
| 302 |
-
|
| 303 |
# Check missing values
|
| 304 |
if report['training_set']['missing_values']['has_missing_values']:
|
| 305 |
recommendations.append(
|
|
@@ -309,7 +309,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
|
| 309 |
recommendations.append(
|
| 310 |
"Remove or impute samples with missing values in test set"
|
| 311 |
)
|
| 312 |
-
|
| 313 |
# Check outliers
|
| 314 |
if report['training_set']['outliers']['has_outliers']:
|
| 315 |
recommendations.append(
|
|
@@ -319,7 +319,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
|
| 319 |
recommendations.append(
|
| 320 |
"Clip or remove test images with pixel values outside [0, 255]"
|
| 321 |
)
|
| 322 |
-
|
| 323 |
# Check class balance
|
| 324 |
train_imbalance = report['training_set']['class_balance']['imbalance_ratio']
|
| 325 |
if train_imbalance >= 1.5:
|
|
@@ -332,7 +332,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
|
| 332 |
f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). "
|
| 333 |
"Monitor per-class performance during training."
|
| 334 |
)
|
| 335 |
-
|
| 336 |
# Check dimensions
|
| 337 |
if not report['training_set']['image_dimensions']['all_correct_shape']:
|
| 338 |
recommendations.append(
|
|
@@ -342,7 +342,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
|
| 342 |
recommendations.append(
|
| 343 |
"Resize or remove test images with incorrect dimensions"
|
| 344 |
)
|
| 345 |
-
|
| 346 |
# Check labels
|
| 347 |
if not report['training_set']['label_validity']['all_valid']:
|
| 348 |
recommendations.append(
|
|
@@ -352,20 +352,20 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
|
| 352 |
recommendations.append(
|
| 353 |
"Remove or correct test samples with invalid labels"
|
| 354 |
)
|
| 355 |
-
|
| 356 |
# If all checks pass
|
| 357 |
if not recommendations:
|
| 358 |
recommendations.append(
|
| 359 |
"Dataset is high quality - proceed with preprocessing and normalization"
|
| 360 |
)
|
| 361 |
-
|
| 362 |
return recommendations
|
| 363 |
|
| 364 |
|
| 365 |
def print_quality_summary(report: Dict[str, Any]) -> None:
|
| 366 |
"""
|
| 367 |
Print a human-readable summary of the quality report.
|
| 368 |
-
|
| 369 |
Args:
|
| 370 |
report: Quality report dictionary from generate_quality_report()
|
| 371 |
"""
|
|
@@ -373,7 +373,7 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
|
|
| 373 |
print("MNIST DATASET QUALITY REPORT")
|
| 374 |
print("=" * 60)
|
| 375 |
print()
|
| 376 |
-
|
| 377 |
# Dataset info
|
| 378 |
info = report['dataset_info']
|
| 379 |
print("Dataset Size:")
|
|
@@ -381,7 +381,7 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
|
|
| 381 |
print(f" Test: {info['test_samples']:,} samples")
|
| 382 |
print(f" Total: {info['total_samples']:,} samples")
|
| 383 |
print()
|
| 384 |
-
|
| 385 |
# Training set checks
|
| 386 |
print("Training Set Quality Checks:")
|
| 387 |
train = report['training_set']
|
|
@@ -394,7 +394,7 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
|
|
| 394 |
train['class_balance']['is_balanced']
|
| 395 |
)
|
| 396 |
print()
|
| 397 |
-
|
| 398 |
# Test set checks
|
| 399 |
print("Test Set Quality Checks:")
|
| 400 |
test = report['test_set']
|
|
@@ -407,14 +407,14 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
|
|
| 407 |
test['class_balance']['is_balanced']
|
| 408 |
)
|
| 409 |
print()
|
| 410 |
-
|
| 411 |
# Overall summary
|
| 412 |
summary = report['summary']
|
| 413 |
print("=" * 60)
|
| 414 |
print(f"Overall Quality: {summary['quality_rating']}")
|
| 415 |
print("=" * 60)
|
| 416 |
print()
|
| 417 |
-
|
| 418 |
# Recommendations
|
| 419 |
print("Recommendations:")
|
| 420 |
for i, rec in enumerate(summary['recommendations'], 1):
|
|
|
|
| 10 |
|
| 11 |
Usage:
|
| 12 |
from scripts.data_quality import generate_quality_report
|
| 13 |
+
|
| 14 |
report = generate_quality_report(
|
| 15 |
+
(x_train, y_train),
|
| 16 |
(x_test, y_test)
|
| 17 |
)
|
| 18 |
"""
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def check_missing_values(
|
| 27 |
+
images: List[NDArray[np.uint8]],
|
| 28 |
labels: List[int]
|
| 29 |
) -> Dict[str, Any]:
|
| 30 |
"""
|
| 31 |
Check for NaN or missing values in images and labels.
|
| 32 |
+
|
| 33 |
Args:
|
| 34 |
images: List of image arrays (each 28x28)
|
| 35 |
labels: List of integer labels (0-9)
|
| 36 |
+
|
| 37 |
Returns:
|
| 38 |
dict: Contains 'has_missing_values', 'missing_count', 'details'
|
| 39 |
"""
|
|
|
|
| 43 |
img_array = np.array(img)
|
| 44 |
if np.isnan(img_array).any():
|
| 45 |
images_with_nan.append(idx)
|
| 46 |
+
|
| 47 |
# Check labels for None
|
| 48 |
labels_with_none = [idx for idx, label in enumerate(labels) if label is None]
|
| 49 |
+
|
| 50 |
has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0
|
| 51 |
+
|
| 52 |
return {
|
| 53 |
'has_missing_values': has_missing,
|
| 54 |
'missing_count': len(images_with_nan) + len(labels_with_none),
|
|
|
|
| 66 |
def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
| 67 |
"""
|
| 68 |
Identify pixels outside valid range [0, 255] for uint8 images.
|
| 69 |
+
|
| 70 |
Args:
|
| 71 |
images: List of image arrays (each 28x28)
|
| 72 |
+
|
| 73 |
Returns:
|
| 74 |
dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details'
|
| 75 |
"""
|
| 76 |
outlier_images = []
|
| 77 |
pixel_min = 255
|
| 78 |
pixel_max = 0
|
| 79 |
+
|
| 80 |
for idx, img in enumerate(images):
|
| 81 |
img_array = np.array(img)
|
| 82 |
img_min = img_array.min()
|
| 83 |
img_max = img_array.max()
|
| 84 |
+
|
| 85 |
pixel_min = min(pixel_min, img_min)
|
| 86 |
pixel_max = max(pixel_max, img_max)
|
| 87 |
+
|
| 88 |
# Check for values outside [0, 255]
|
| 89 |
if img_min < 0 or img_max > 255:
|
| 90 |
outlier_images.append({
|
|
|
|
| 92 |
'min': int(img_min),
|
| 93 |
'max': int(img_max)
|
| 94 |
})
|
| 95 |
+
|
| 96 |
return {
|
| 97 |
'has_outliers': len(outlier_images) > 0,
|
| 98 |
'outlier_count': len(outlier_images),
|
|
|
|
| 109 |
def check_class_balance(labels: List[int]) -> Dict[str, Any]:
|
| 110 |
"""
|
| 111 |
Compute samples per class and calculate imbalance ratio.
|
| 112 |
+
|
| 113 |
Imbalance ratio = max_count / min_count
|
| 114 |
A ratio < 1.2 indicates good balance (< 20% difference)
|
| 115 |
+
|
| 116 |
Args:
|
| 117 |
labels: List of integer labels (0-9)
|
| 118 |
+
|
| 119 |
Returns:
|
| 120 |
dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details'
|
| 121 |
"""
|
| 122 |
class_counts = Counter(labels)
|
| 123 |
+
|
| 124 |
# Ensure all 10 digits present
|
| 125 |
for digit in range(10):
|
| 126 |
if digit not in class_counts:
|
| 127 |
class_counts[digit] = 0
|
| 128 |
+
|
| 129 |
counts = list(class_counts.values())
|
| 130 |
max_count = max(counts)
|
| 131 |
min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero
|
| 132 |
+
|
| 133 |
imbalance_ratio = max_count / min_count
|
| 134 |
is_balanced = imbalance_ratio < 1.2 # Less than 20% difference
|
| 135 |
+
|
| 136 |
# Per-class percentages
|
| 137 |
total = len(labels)
|
| 138 |
class_percentages = {
|
| 139 |
+
digit: (count / total) * 100
|
| 140 |
for digit, count in class_counts.items()
|
| 141 |
}
|
| 142 |
+
|
| 143 |
return {
|
| 144 |
'is_balanced': is_balanced,
|
| 145 |
'imbalance_ratio': round(imbalance_ratio, 3),
|
| 146 |
'threshold': 1.2,
|
| 147 |
'class_counts': dict(sorted(class_counts.items())),
|
| 148 |
'class_percentages': {
|
| 149 |
+
k: round(v, 2)
|
| 150 |
for k, v in sorted(class_percentages.items())
|
| 151 |
},
|
| 152 |
'details': {
|
|
|
|
| 162 |
def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
|
| 163 |
"""
|
| 164 |
Verify all images are 28x28 as expected for MNIST.
|
| 165 |
+
|
| 166 |
Args:
|
| 167 |
images: List of image arrays
|
| 168 |
+
|
| 169 |
Returns:
|
| 170 |
dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details'
|
| 171 |
"""
|
| 172 |
expected_shape = (28, 28)
|
| 173 |
invalid_images = []
|
| 174 |
+
|
| 175 |
for idx, img in enumerate(images):
|
| 176 |
img_array = np.array(img)
|
| 177 |
if img_array.shape != expected_shape:
|
|
|
|
| 179 |
'index': idx,
|
| 180 |
'shape': img_array.shape
|
| 181 |
})
|
| 182 |
+
|
| 183 |
return {
|
| 184 |
'all_correct_shape': len(invalid_images) == 0,
|
| 185 |
'expected_shape': expected_shape,
|
|
|
|
| 194 |
def check_label_validity(labels: List[int]) -> Dict[str, Any]:
|
| 195 |
"""
|
| 196 |
Verify all labels are valid integers in range [0, 9].
|
| 197 |
+
|
| 198 |
Args:
|
| 199 |
labels: List of labels
|
| 200 |
+
|
| 201 |
Returns:
|
| 202 |
dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details'
|
| 203 |
"""
|
| 204 |
valid_range = set(range(10))
|
| 205 |
invalid_labels = []
|
| 206 |
+
|
| 207 |
for idx, label in enumerate(labels):
|
| 208 |
if not isinstance(label, int) or label not in valid_range:
|
| 209 |
invalid_labels.append({
|
|
|
|
| 211 |
'value': label,
|
| 212 |
'type': type(label).__name__
|
| 213 |
})
|
| 214 |
+
|
| 215 |
unique_labels = sorted(set(labels))
|
| 216 |
+
|
| 217 |
return {
|
| 218 |
'all_valid': len(invalid_labels) == 0,
|
| 219 |
'expected_range': [0, 9],
|
|
|
|
| 232 |
) -> Dict[str, Any]:
|
| 233 |
"""
|
| 234 |
Run all quality checks on training and test sets.
|
| 235 |
+
|
| 236 |
Args:
|
| 237 |
train_data: Tuple of (train_images, train_labels)
|
| 238 |
test_data: Tuple of (test_images, test_labels)
|
| 239 |
+
|
| 240 |
Returns:
|
| 241 |
dict: Comprehensive quality report with all check results
|
| 242 |
"""
|
| 243 |
x_train, y_train = train_data
|
| 244 |
x_test, y_test = test_data
|
| 245 |
+
|
| 246 |
report = {
|
| 247 |
'dataset_info': {
|
| 248 |
'train_samples': len(x_train),
|
|
|
|
| 264 |
'label_validity': check_label_validity(y_test)
|
| 265 |
}
|
| 266 |
}
|
| 267 |
+
|
| 268 |
# Overall quality assessment
|
| 269 |
all_checks_pass = (
|
| 270 |
not report['training_set']['missing_values']['has_missing_values'] and
|
|
|
|
| 276 |
report['test_set']['image_dimensions']['all_correct_shape'] and
|
| 277 |
report['test_set']['label_validity']['all_valid']
|
| 278 |
)
|
| 279 |
+
|
| 280 |
report['summary'] = {
|
| 281 |
'all_checks_pass': all_checks_pass,
|
| 282 |
'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND',
|
|
|
|
| 284 |
'test_balanced': report['test_set']['class_balance']['is_balanced'],
|
| 285 |
'recommendations': _generate_recommendations(report)
|
| 286 |
}
|
| 287 |
+
|
| 288 |
return report
|
| 289 |
|
| 290 |
|
| 291 |
def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
|
| 292 |
"""
|
| 293 |
Generate recommendations based on quality check results.
|
| 294 |
+
|
| 295 |
Args:
|
| 296 |
report: Quality report dictionary
|
| 297 |
+
|
| 298 |
Returns:
|
| 299 |
list: List of recommendation strings
|
| 300 |
"""
|
| 301 |
recommendations = []
|
| 302 |
+
|
| 303 |
# Check missing values
|
| 304 |
if report['training_set']['missing_values']['has_missing_values']:
|
| 305 |
recommendations.append(
|
|
|
|
| 309 |
recommendations.append(
|
| 310 |
"Remove or impute samples with missing values in test set"
|
| 311 |
)
|
| 312 |
+
|
| 313 |
# Check outliers
|
| 314 |
if report['training_set']['outliers']['has_outliers']:
|
| 315 |
recommendations.append(
|
|
|
|
| 319 |
recommendations.append(
|
| 320 |
"Clip or remove test images with pixel values outside [0, 255]"
|
| 321 |
)
|
| 322 |
+
|
| 323 |
# Check class balance
|
| 324 |
train_imbalance = report['training_set']['class_balance']['imbalance_ratio']
|
| 325 |
if train_imbalance >= 1.5:
|
|
|
|
| 332 |
f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). "
|
| 333 |
"Monitor per-class performance during training."
|
| 334 |
)
|
| 335 |
+
|
| 336 |
# Check dimensions
|
| 337 |
if not report['training_set']['image_dimensions']['all_correct_shape']:
|
| 338 |
recommendations.append(
|
|
|
|
| 342 |
recommendations.append(
|
| 343 |
"Resize or remove test images with incorrect dimensions"
|
| 344 |
)
|
| 345 |
+
|
| 346 |
# Check labels
|
| 347 |
if not report['training_set']['label_validity']['all_valid']:
|
| 348 |
recommendations.append(
|
|
|
|
| 352 |
recommendations.append(
|
| 353 |
"Remove or correct test samples with invalid labels"
|
| 354 |
)
|
| 355 |
+
|
| 356 |
# If all checks pass
|
| 357 |
if not recommendations:
|
| 358 |
recommendations.append(
|
| 359 |
"Dataset is high quality - proceed with preprocessing and normalization"
|
| 360 |
)
|
| 361 |
+
|
| 362 |
return recommendations
|
| 363 |
|
| 364 |
|
| 365 |
def print_quality_summary(report: Dict[str, Any]) -> None:
|
| 366 |
"""
|
| 367 |
Print a human-readable summary of the quality report.
|
| 368 |
+
|
| 369 |
Args:
|
| 370 |
report: Quality report dictionary from generate_quality_report()
|
| 371 |
"""
|
|
|
|
| 373 |
print("MNIST DATASET QUALITY REPORT")
|
| 374 |
print("=" * 60)
|
| 375 |
print()
|
| 376 |
+
|
| 377 |
# Dataset info
|
| 378 |
info = report['dataset_info']
|
| 379 |
print("Dataset Size:")
|
|
|
|
| 381 |
print(f" Test: {info['test_samples']:,} samples")
|
| 382 |
print(f" Total: {info['total_samples']:,} samples")
|
| 383 |
print()
|
| 384 |
+
|
| 385 |
# Training set checks
|
| 386 |
print("Training Set Quality Checks:")
|
| 387 |
train = report['training_set']
|
|
|
|
| 394 |
train['class_balance']['is_balanced']
|
| 395 |
)
|
| 396 |
print()
|
| 397 |
+
|
| 398 |
# Test set checks
|
| 399 |
print("Test Set Quality Checks:")
|
| 400 |
test = report['test_set']
|
|
|
|
| 407 |
test['class_balance']['is_balanced']
|
| 408 |
)
|
| 409 |
print()
|
| 410 |
+
|
| 411 |
# Overall summary
|
| 412 |
summary = report['summary']
|
| 413 |
print("=" * 60)
|
| 414 |
print(f"Overall Quality: {summary['quality_rating']}")
|
| 415 |
print("=" * 60)
|
| 416 |
print()
|
| 417 |
+
|
| 418 |
# Recommendations
|
| 419 |
print("Recommendations:")
|
| 420 |
for i, rec in enumerate(summary['recommendations'], 1):
|
scripts/inference.py
CHANGED
|
@@ -14,103 +14,103 @@ from typing import Union, Dict
|
|
| 14 |
|
| 15 |
class DigitClassifier:
|
| 16 |
"""Production inference wrapper for MNIST digit classifier."""
|
| 17 |
-
|
| 18 |
def __init__(self, model_path: str, device: str = None):
|
| 19 |
"""
|
| 20 |
Initialize the digit classifier.
|
| 21 |
-
|
| 22 |
Args:
|
| 23 |
model_path: Path to model checkpoint (.pt file)
|
| 24 |
-
device: Device to run inference on ('cuda' or 'cpu').
|
| 25 |
If None, auto-detects CUDA availability.
|
| 26 |
"""
|
| 27 |
if device is None:
|
| 28 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 29 |
else:
|
| 30 |
self.device = device
|
| 31 |
-
|
| 32 |
self.model_path = Path(model_path)
|
| 33 |
if not self.model_path.exists():
|
| 34 |
raise FileNotFoundError(f"Model not found at {model_path}")
|
| 35 |
-
|
| 36 |
self.model = self._load_model()
|
| 37 |
self.model.eval()
|
| 38 |
-
|
| 39 |
# Normalization values (same as training)
|
| 40 |
self.mean = 0.1307
|
| 41 |
self.std = 0.3081
|
| 42 |
-
|
| 43 |
def _load_model(self) -> torch.nn.Module:
|
| 44 |
"""Load model from checkpoint."""
|
| 45 |
from scripts.models import BaselineCNN
|
| 46 |
-
|
| 47 |
model = BaselineCNN()
|
| 48 |
-
|
| 49 |
# Load checkpoint
|
| 50 |
checkpoint = torch.load(self.model_path, map_location=self.device)
|
| 51 |
-
|
| 52 |
# Handle different checkpoint formats
|
| 53 |
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 54 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 55 |
else:
|
| 56 |
model.load_state_dict(checkpoint)
|
| 57 |
-
|
| 58 |
return model.to(self.device)
|
| 59 |
-
|
| 60 |
def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
| 61 |
"""
|
| 62 |
Preprocess image for model input.
|
| 63 |
-
|
| 64 |
Handles:
|
| 65 |
- RGB to grayscale conversion
|
| 66 |
- Resizing to 28x28
|
| 67 |
- Normalization
|
| 68 |
- Inversion if needed (white digit on black background)
|
| 69 |
-
|
| 70 |
Args:
|
| 71 |
image: PIL Image or numpy array
|
| 72 |
-
|
| 73 |
Returns:
|
| 74 |
Preprocessed tensor of shape (1, 1, 28, 28)
|
| 75 |
"""
|
| 76 |
# Convert numpy array to PIL Image if needed
|
| 77 |
if isinstance(image, np.ndarray):
|
| 78 |
image = Image.fromarray(image)
|
| 79 |
-
|
| 80 |
# Convert to grayscale if RGB
|
| 81 |
if image.mode != 'L':
|
| 82 |
image = image.convert('L')
|
| 83 |
-
|
| 84 |
# Resize to 28x28 if needed
|
| 85 |
if image.size != (28, 28):
|
| 86 |
image = image.resize((28, 28), Image.Resampling.LANCZOS)
|
| 87 |
-
|
| 88 |
# Convert to numpy array
|
| 89 |
img_array = np.array(image).astype(np.float32)
|
| 90 |
-
|
| 91 |
# Normalize to [0, 1]
|
| 92 |
img_array = img_array / 255.0
|
| 93 |
-
|
| 94 |
# Check if inversion is needed (MNIST is white digit on black background)
|
| 95 |
# If most pixels are bright, it's likely a black digit on white background
|
| 96 |
if img_array.mean() > 0.5:
|
| 97 |
img_array = 1.0 - img_array
|
| 98 |
-
|
| 99 |
# Apply normalization (same as training)
|
| 100 |
img_array = (img_array - self.mean) / self.std
|
| 101 |
-
|
| 102 |
# Convert to tensor and add batch and channel dimensions
|
| 103 |
img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
|
| 104 |
-
|
| 105 |
return img_tensor.to(self.device)
|
| 106 |
-
|
| 107 |
def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
|
| 108 |
"""
|
| 109 |
Predict digit from image.
|
| 110 |
-
|
| 111 |
Args:
|
| 112 |
image: PIL Image or numpy array containing digit
|
| 113 |
-
|
| 114 |
Returns:
|
| 115 |
Dictionary with:
|
| 116 |
- digit: Predicted digit (0-9)
|
|
@@ -118,25 +118,25 @@ class DigitClassifier:
|
|
| 118 |
- probabilities: List of probabilities for each digit
|
| 119 |
"""
|
| 120 |
img_tensor = self.preprocess(image)
|
| 121 |
-
|
| 122 |
with torch.no_grad():
|
| 123 |
outputs = self.model(img_tensor)
|
| 124 |
probabilities = torch.softmax(outputs, dim=1)[0]
|
| 125 |
confidence, predicted = torch.max(probabilities, dim=0)
|
| 126 |
-
|
| 127 |
return {
|
| 128 |
'digit': int(predicted.item()),
|
| 129 |
'confidence': float(confidence.item()),
|
| 130 |
'probabilities': probabilities.cpu().numpy().tolist()
|
| 131 |
}
|
| 132 |
-
|
| 133 |
def predict_batch(self, images: list) -> list:
|
| 134 |
"""
|
| 135 |
Predict digits for a batch of images.
|
| 136 |
-
|
| 137 |
Args:
|
| 138 |
images: List of PIL Images or numpy arrays
|
| 139 |
-
|
| 140 |
Returns:
|
| 141 |
List of prediction dictionaries
|
| 142 |
"""
|
|
@@ -147,23 +147,23 @@ def test_inference():
|
|
| 147 |
"""Test inference module with sample images."""
|
| 148 |
import sys
|
| 149 |
from pathlib import Path
|
| 150 |
-
|
| 151 |
# Add project root to path
|
| 152 |
project_root = Path(__file__).parent.parent
|
| 153 |
sys.path.insert(0, str(project_root))
|
| 154 |
-
|
| 155 |
from scripts.data_loader import MnistDataloader
|
| 156 |
-
|
| 157 |
print("Testing Inference Module")
|
| 158 |
print("=" * 50)
|
| 159 |
-
|
| 160 |
# Check if model exists
|
| 161 |
model_path = project_root / 'models' / 'best_model.pt'
|
| 162 |
if not model_path.exists():
|
| 163 |
print(f"Error: Model not found at {model_path}")
|
| 164 |
print("Please train a model first.")
|
| 165 |
return
|
| 166 |
-
|
| 167 |
# Load MNIST test data
|
| 168 |
data_path = project_root / 'data' / 'raw'
|
| 169 |
loader = MnistDataloader(
|
|
@@ -173,66 +173,70 @@ def test_inference():
|
|
| 173 |
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
|
| 174 |
)
|
| 175 |
_, (x_test, y_test) = loader.load_data()
|
| 176 |
-
|
| 177 |
# Initialize classifier
|
| 178 |
print(f"\n1. Loading model from: {model_path}")
|
| 179 |
classifier = DigitClassifier(str(model_path))
|
| 180 |
print(f" Device: {classifier.device}")
|
| 181 |
-
|
| 182 |
# Test on a few images
|
| 183 |
print("\n2. Testing predictions on 10 random test images:")
|
| 184 |
print("-" * 50)
|
| 185 |
-
|
| 186 |
indices = np.random.choice(len(x_test), 10, replace=False)
|
| 187 |
correct = 0
|
| 188 |
-
|
| 189 |
for i, idx in enumerate(indices, 1):
|
| 190 |
image = x_test[idx]
|
| 191 |
true_label = y_test[idx]
|
| 192 |
-
|
| 193 |
# Convert list to numpy array if needed
|
| 194 |
if isinstance(image, list):
|
| 195 |
image = np.array(image)
|
| 196 |
-
|
| 197 |
# Convert to PIL Image
|
| 198 |
img = Image.fromarray(image.astype(np.uint8), mode='L')
|
| 199 |
-
|
| 200 |
# Predict
|
| 201 |
result = classifier.predict(img)
|
| 202 |
-
|
| 203 |
is_correct = result['digit'] == true_label
|
| 204 |
correct += is_correct
|
| 205 |
-
|
| 206 |
print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
|
| 207 |
f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
|
| 208 |
-
|
| 209 |
accuracy = correct / len(indices) * 100
|
| 210 |
print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
|
| 211 |
-
|
| 212 |
# Test edge cases
|
| 213 |
print("\n3. Testing edge cases:")
|
| 214 |
print("-" * 50)
|
| 215 |
-
|
| 216 |
# Blank image
|
| 217 |
blank = np.zeros((28, 28), dtype=np.uint8)
|
| 218 |
blank_img = Image.fromarray(blank, mode='L')
|
| 219 |
result = classifier.predict(blank_img)
|
| 220 |
print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 221 |
-
|
| 222 |
# All white image
|
| 223 |
white = np.ones((28, 28), dtype=np.uint8) * 255
|
| 224 |
white_img = Image.fromarray(white, mode='L')
|
| 225 |
result = classifier.predict(white_img)
|
| 226 |
print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 227 |
-
|
| 228 |
# Different size image
|
| 229 |
test_img = x_test[0]
|
| 230 |
if isinstance(test_img, list):
|
| 231 |
test_img = np.array(test_img)
|
| 232 |
-
large = Image.fromarray(test_img.astype(np.uint8), mode='L')
|
|
|
|
| 233 |
result = classifier.predict(large)
|
| 234 |
-
print(
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
print("\n✓ Inference module test complete!")
|
| 237 |
|
| 238 |
|
|
|
|
| 14 |
|
| 15 |
class DigitClassifier:
|
| 16 |
"""Production inference wrapper for MNIST digit classifier."""
|
| 17 |
+
|
| 18 |
def __init__(self, model_path: str, device: str = None):
|
| 19 |
"""
|
| 20 |
Initialize the digit classifier.
|
| 21 |
+
|
| 22 |
Args:
|
| 23 |
model_path: Path to model checkpoint (.pt file)
|
| 24 |
+
device: Device to run inference on ('cuda' or 'cpu').
|
| 25 |
If None, auto-detects CUDA availability.
|
| 26 |
"""
|
| 27 |
if device is None:
|
| 28 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 29 |
else:
|
| 30 |
self.device = device
|
| 31 |
+
|
| 32 |
self.model_path = Path(model_path)
|
| 33 |
if not self.model_path.exists():
|
| 34 |
raise FileNotFoundError(f"Model not found at {model_path}")
|
| 35 |
+
|
| 36 |
self.model = self._load_model()
|
| 37 |
self.model.eval()
|
| 38 |
+
|
| 39 |
# Normalization values (same as training)
|
| 40 |
self.mean = 0.1307
|
| 41 |
self.std = 0.3081
|
| 42 |
+
|
| 43 |
def _load_model(self) -> torch.nn.Module:
|
| 44 |
"""Load model from checkpoint."""
|
| 45 |
from scripts.models import BaselineCNN
|
| 46 |
+
|
| 47 |
model = BaselineCNN()
|
| 48 |
+
|
| 49 |
# Load checkpoint
|
| 50 |
checkpoint = torch.load(self.model_path, map_location=self.device)
|
| 51 |
+
|
| 52 |
# Handle different checkpoint formats
|
| 53 |
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 54 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 55 |
else:
|
| 56 |
model.load_state_dict(checkpoint)
|
| 57 |
+
|
| 58 |
return model.to(self.device)
|
| 59 |
+
|
| 60 |
def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
|
| 61 |
"""
|
| 62 |
Preprocess image for model input.
|
| 63 |
+
|
| 64 |
Handles:
|
| 65 |
- RGB to grayscale conversion
|
| 66 |
- Resizing to 28x28
|
| 67 |
- Normalization
|
| 68 |
- Inversion if needed (white digit on black background)
|
| 69 |
+
|
| 70 |
Args:
|
| 71 |
image: PIL Image or numpy array
|
| 72 |
+
|
| 73 |
Returns:
|
| 74 |
Preprocessed tensor of shape (1, 1, 28, 28)
|
| 75 |
"""
|
| 76 |
# Convert numpy array to PIL Image if needed
|
| 77 |
if isinstance(image, np.ndarray):
|
| 78 |
image = Image.fromarray(image)
|
| 79 |
+
|
| 80 |
# Convert to grayscale if RGB
|
| 81 |
if image.mode != 'L':
|
| 82 |
image = image.convert('L')
|
| 83 |
+
|
| 84 |
# Resize to 28x28 if needed
|
| 85 |
if image.size != (28, 28):
|
| 86 |
image = image.resize((28, 28), Image.Resampling.LANCZOS)
|
| 87 |
+
|
| 88 |
# Convert to numpy array
|
| 89 |
img_array = np.array(image).astype(np.float32)
|
| 90 |
+
|
| 91 |
# Normalize to [0, 1]
|
| 92 |
img_array = img_array / 255.0
|
| 93 |
+
|
| 94 |
# Check if inversion is needed (MNIST is white digit on black background)
|
| 95 |
# If most pixels are bright, it's likely a black digit on white background
|
| 96 |
if img_array.mean() > 0.5:
|
| 97 |
img_array = 1.0 - img_array
|
| 98 |
+
|
| 99 |
# Apply normalization (same as training)
|
| 100 |
img_array = (img_array - self.mean) / self.std
|
| 101 |
+
|
| 102 |
# Convert to tensor and add batch and channel dimensions
|
| 103 |
img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
|
| 104 |
+
|
| 105 |
return img_tensor.to(self.device)
|
| 106 |
+
|
| 107 |
def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
|
| 108 |
"""
|
| 109 |
Predict digit from image.
|
| 110 |
+
|
| 111 |
Args:
|
| 112 |
image: PIL Image or numpy array containing digit
|
| 113 |
+
|
| 114 |
Returns:
|
| 115 |
Dictionary with:
|
| 116 |
- digit: Predicted digit (0-9)
|
|
|
|
| 118 |
- probabilities: List of probabilities for each digit
|
| 119 |
"""
|
| 120 |
img_tensor = self.preprocess(image)
|
| 121 |
+
|
| 122 |
with torch.no_grad():
|
| 123 |
outputs = self.model(img_tensor)
|
| 124 |
probabilities = torch.softmax(outputs, dim=1)[0]
|
| 125 |
confidence, predicted = torch.max(probabilities, dim=0)
|
| 126 |
+
|
| 127 |
return {
|
| 128 |
'digit': int(predicted.item()),
|
| 129 |
'confidence': float(confidence.item()),
|
| 130 |
'probabilities': probabilities.cpu().numpy().tolist()
|
| 131 |
}
|
| 132 |
+
|
| 133 |
def predict_batch(self, images: list) -> list:
|
| 134 |
"""
|
| 135 |
Predict digits for a batch of images.
|
| 136 |
+
|
| 137 |
Args:
|
| 138 |
images: List of PIL Images or numpy arrays
|
| 139 |
+
|
| 140 |
Returns:
|
| 141 |
List of prediction dictionaries
|
| 142 |
"""
|
|
|
|
| 147 |
"""Test inference module with sample images."""
|
| 148 |
import sys
|
| 149 |
from pathlib import Path
|
| 150 |
+
|
| 151 |
# Add project root to path
|
| 152 |
project_root = Path(__file__).parent.parent
|
| 153 |
sys.path.insert(0, str(project_root))
|
| 154 |
+
|
| 155 |
from scripts.data_loader import MnistDataloader
|
| 156 |
+
|
| 157 |
print("Testing Inference Module")
|
| 158 |
print("=" * 50)
|
| 159 |
+
|
| 160 |
# Check if model exists
|
| 161 |
model_path = project_root / 'models' / 'best_model.pt'
|
| 162 |
if not model_path.exists():
|
| 163 |
print(f"Error: Model not found at {model_path}")
|
| 164 |
print("Please train a model first.")
|
| 165 |
return
|
| 166 |
+
|
| 167 |
# Load MNIST test data
|
| 168 |
data_path = project_root / 'data' / 'raw'
|
| 169 |
loader = MnistDataloader(
|
|
|
|
| 173 |
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
|
| 174 |
)
|
| 175 |
_, (x_test, y_test) = loader.load_data()
|
| 176 |
+
|
| 177 |
# Initialize classifier
|
| 178 |
print(f"\n1. Loading model from: {model_path}")
|
| 179 |
classifier = DigitClassifier(str(model_path))
|
| 180 |
print(f" Device: {classifier.device}")
|
| 181 |
+
|
| 182 |
# Test on a few images
|
| 183 |
print("\n2. Testing predictions on 10 random test images:")
|
| 184 |
print("-" * 50)
|
| 185 |
+
|
| 186 |
indices = np.random.choice(len(x_test), 10, replace=False)
|
| 187 |
correct = 0
|
| 188 |
+
|
| 189 |
for i, idx in enumerate(indices, 1):
|
| 190 |
image = x_test[idx]
|
| 191 |
true_label = y_test[idx]
|
| 192 |
+
|
| 193 |
# Convert list to numpy array if needed
|
| 194 |
if isinstance(image, list):
|
| 195 |
image = np.array(image)
|
| 196 |
+
|
| 197 |
# Convert to PIL Image
|
| 198 |
img = Image.fromarray(image.astype(np.uint8), mode='L')
|
| 199 |
+
|
| 200 |
# Predict
|
| 201 |
result = classifier.predict(img)
|
| 202 |
+
|
| 203 |
is_correct = result['digit'] == true_label
|
| 204 |
correct += is_correct
|
| 205 |
+
|
| 206 |
print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
|
| 207 |
f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
|
| 208 |
+
|
| 209 |
accuracy = correct / len(indices) * 100
|
| 210 |
print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
|
| 211 |
+
|
| 212 |
# Test edge cases
|
| 213 |
print("\n3. Testing edge cases:")
|
| 214 |
print("-" * 50)
|
| 215 |
+
|
| 216 |
# Blank image
|
| 217 |
blank = np.zeros((28, 28), dtype=np.uint8)
|
| 218 |
blank_img = Image.fromarray(blank, mode='L')
|
| 219 |
result = classifier.predict(blank_img)
|
| 220 |
print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 221 |
+
|
| 222 |
# All white image
|
| 223 |
white = np.ones((28, 28), dtype=np.uint8) * 255
|
| 224 |
white_img = Image.fromarray(white, mode='L')
|
| 225 |
result = classifier.predict(white_img)
|
| 226 |
print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
|
| 227 |
+
|
| 228 |
# Different size image
|
| 229 |
test_img = x_test[0]
|
| 230 |
if isinstance(test_img, list):
|
| 231 |
test_img = np.array(test_img)
|
| 232 |
+
large = Image.fromarray(test_img.astype(np.uint8), mode='L')
|
| 233 |
+
large = large.resize((56, 56))
|
| 234 |
result = classifier.predict(large)
|
| 235 |
+
print(
|
| 236 |
+
f" Resized image (56x56): "
|
| 237 |
+
f"Pred={result['digit']}, Conf={result['confidence']:.4f}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
print("\n✓ Inference module test complete!")
|
| 241 |
|
| 242 |
|
scripts/mlflow_setup.py
CHANGED
|
@@ -24,20 +24,20 @@ def setup_mlflow(
|
|
| 24 |
) -> str:
|
| 25 |
"""
|
| 26 |
Setup MLflow tracking with best practices.
|
| 27 |
-
|
| 28 |
Args:
|
| 29 |
experiment_name: Name of the experiment
|
| 30 |
tracking_uri: MLflow tracking URI (default: local ./mlruns)
|
| 31 |
-
|
| 32 |
Returns:
|
| 33 |
experiment_id: MLflow experiment ID
|
| 34 |
"""
|
| 35 |
# Set tracking URI
|
| 36 |
if tracking_uri is None:
|
| 37 |
tracking_uri = MLFLOW_TRACKING_URI
|
| 38 |
-
|
| 39 |
mlflow.set_tracking_uri(tracking_uri)
|
| 40 |
-
|
| 41 |
# Create or get experiment
|
| 42 |
try:
|
| 43 |
experiment = mlflow.get_experiment_by_name(experiment_name)
|
|
@@ -55,34 +55,36 @@ def setup_mlflow(
|
|
| 55 |
except Exception as e:
|
| 56 |
print(f"Warning: Could not create experiment: {e}")
|
| 57 |
experiment_id = "0" # Default experiment
|
| 58 |
-
|
| 59 |
mlflow.set_experiment(experiment_name)
|
| 60 |
-
|
| 61 |
print(f"MLflow tracking URI: {tracking_uri}")
|
| 62 |
print(f"Experiment: {experiment_name} (ID: {experiment_id})")
|
| 63 |
-
|
| 64 |
return experiment_id
|
| 65 |
|
| 66 |
|
| 67 |
def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
|
| 68 |
"""
|
| 69 |
Log model parameters to MLflow.
|
| 70 |
-
|
| 71 |
Args:
|
| 72 |
model: PyTorch model
|
| 73 |
prefix: Prefix for parameter names
|
| 74 |
-
|
| 75 |
Returns:
|
| 76 |
Dictionary of logged parameters
|
| 77 |
"""
|
| 78 |
from scripts.models import count_parameters
|
| 79 |
-
|
| 80 |
params = {
|
| 81 |
f"{prefix}_name": model.__class__.__name__,
|
| 82 |
f"{prefix}_total_params": count_parameters(model),
|
| 83 |
-
f"{prefix}_trainable_params": sum(
|
|
|
|
|
|
|
| 84 |
}
|
| 85 |
-
|
| 86 |
mlflow.log_params(params)
|
| 87 |
return params
|
| 88 |
|
|
@@ -90,7 +92,7 @@ def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
|
|
| 90 |
def log_training_config(config: Dict[str, Any]) -> None:
|
| 91 |
"""
|
| 92 |
Log training configuration to MLflow.
|
| 93 |
-
|
| 94 |
Args:
|
| 95 |
config: Dictionary of training hyperparameters
|
| 96 |
"""
|
|
@@ -102,7 +104,7 @@ def log_training_config(config: Dict[str, Any]) -> None:
|
|
| 102 |
flat_config[f"{key}_{subkey}"] = subvalue
|
| 103 |
else:
|
| 104 |
flat_config[key] = value
|
| 105 |
-
|
| 106 |
mlflow.log_params(flat_config)
|
| 107 |
|
| 108 |
|
|
@@ -115,7 +117,7 @@ def log_data_info(
|
|
| 115 |
) -> None:
|
| 116 |
"""
|
| 117 |
Log dataset information to MLflow.
|
| 118 |
-
|
| 119 |
Args:
|
| 120 |
train_size: Number of training samples
|
| 121 |
val_size: Number of validation samples
|
|
@@ -135,13 +137,13 @@ def log_data_info(
|
|
| 135 |
def log_system_info() -> Dict[str, Any]:
|
| 136 |
"""
|
| 137 |
Log system information to MLflow.
|
| 138 |
-
|
| 139 |
Returns:
|
| 140 |
Dictionary of system information
|
| 141 |
"""
|
| 142 |
import torch
|
| 143 |
import platform
|
| 144 |
-
|
| 145 |
system_info = {
|
| 146 |
"system_platform": platform.system(),
|
| 147 |
"system_python_version": platform.python_version(),
|
|
@@ -152,11 +154,11 @@ def log_system_info() -> Dict[str, Any]:
|
|
| 152 |
),
|
| 153 |
"system_device": "cuda" if torch.cuda.is_available() else "cpu"
|
| 154 |
}
|
| 155 |
-
|
| 156 |
if torch.cuda.is_available():
|
| 157 |
system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
|
| 158 |
system_info["system_gpu_count"] = torch.cuda.device_count()
|
| 159 |
-
|
| 160 |
mlflow.log_params(system_info)
|
| 161 |
return system_info
|
| 162 |
|
|
@@ -164,7 +166,7 @@ def log_system_info() -> Dict[str, Any]:
|
|
| 164 |
def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
|
| 165 |
"""
|
| 166 |
Log metrics for a specific epoch.
|
| 167 |
-
|
| 168 |
Args:
|
| 169 |
metrics: Dictionary of metric names and values
|
| 170 |
step: Epoch number
|
|
@@ -175,7 +177,7 @@ def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
|
|
| 175 |
def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
|
| 176 |
"""
|
| 177 |
Log a file or directory as an artifact.
|
| 178 |
-
|
| 179 |
Args:
|
| 180 |
path: Path to file or directory
|
| 181 |
artifact_path: Optional artifact path in MLflow
|
|
@@ -194,14 +196,14 @@ def log_model_to_registry(
|
|
| 194 |
) -> None:
|
| 195 |
"""
|
| 196 |
Log model to MLflow with model registry integration.
|
| 197 |
-
|
| 198 |
Args:
|
| 199 |
model: PyTorch model
|
| 200 |
model_name: Name for the model artifact
|
| 201 |
artifact_path: Artifact path in MLflow
|
| 202 |
registered_model_name: Name for model registry (optional)
|
| 203 |
"""
|
| 204 |
-
|
| 205 |
# Log model
|
| 206 |
mlflow.pytorch.log_model(
|
| 207 |
pytorch_model=model,
|
|
@@ -216,11 +218,11 @@ def get_or_create_run(
|
|
| 216 |
) -> mlflow.ActiveRun:
|
| 217 |
"""
|
| 218 |
Get existing run or create a new one.
|
| 219 |
-
|
| 220 |
Args:
|
| 221 |
run_name: Name for the run
|
| 222 |
tags: Tags for the run
|
| 223 |
-
|
| 224 |
Returns:
|
| 225 |
MLflow active run context
|
| 226 |
"""
|
|
@@ -236,10 +238,10 @@ def test_mlflow_setup():
|
|
| 236 |
"""Test MLflow setup and basic logging."""
|
| 237 |
print("Testing MLflow Setup")
|
| 238 |
print("=" * 50)
|
| 239 |
-
|
| 240 |
# Setup MLflow
|
| 241 |
setup_mlflow("test-experiment")
|
| 242 |
-
|
| 243 |
# Test logging
|
| 244 |
with mlflow.start_run(run_name="test-run"):
|
| 245 |
# Log parameters
|
|
@@ -248,7 +250,7 @@ def test_mlflow_setup():
|
|
| 248 |
"batch_size": 64,
|
| 249 |
"epochs": 10
|
| 250 |
})
|
| 251 |
-
|
| 252 |
# Log metrics
|
| 253 |
for epoch in range(3):
|
| 254 |
mlflow.log_metrics({
|
|
@@ -257,13 +259,13 @@ def test_mlflow_setup():
|
|
| 257 |
"train_accuracy": 0.8 + epoch * 0.05,
|
| 258 |
"val_accuracy": 0.75 + epoch * 0.05
|
| 259 |
}, step=epoch)
|
| 260 |
-
|
| 261 |
# Log system info
|
| 262 |
system_info = log_system_info()
|
| 263 |
print("\nSystem Info:")
|
| 264 |
for key, value in system_info.items():
|
| 265 |
print(f" {key}: {value}")
|
| 266 |
-
|
| 267 |
print("\n✓ MLflow test complete!")
|
| 268 |
print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
|
| 269 |
|
|
|
|
| 24 |
) -> str:
|
| 25 |
"""
|
| 26 |
Setup MLflow tracking with best practices.
|
| 27 |
+
|
| 28 |
Args:
|
| 29 |
experiment_name: Name of the experiment
|
| 30 |
tracking_uri: MLflow tracking URI (default: local ./mlruns)
|
| 31 |
+
|
| 32 |
Returns:
|
| 33 |
experiment_id: MLflow experiment ID
|
| 34 |
"""
|
| 35 |
# Set tracking URI
|
| 36 |
if tracking_uri is None:
|
| 37 |
tracking_uri = MLFLOW_TRACKING_URI
|
| 38 |
+
|
| 39 |
mlflow.set_tracking_uri(tracking_uri)
|
| 40 |
+
|
| 41 |
# Create or get experiment
|
| 42 |
try:
|
| 43 |
experiment = mlflow.get_experiment_by_name(experiment_name)
|
|
|
|
| 55 |
except Exception as e:
|
| 56 |
print(f"Warning: Could not create experiment: {e}")
|
| 57 |
experiment_id = "0" # Default experiment
|
| 58 |
+
|
| 59 |
mlflow.set_experiment(experiment_name)
|
| 60 |
+
|
| 61 |
print(f"MLflow tracking URI: {tracking_uri}")
|
| 62 |
print(f"Experiment: {experiment_name} (ID: {experiment_id})")
|
| 63 |
+
|
| 64 |
return experiment_id
|
| 65 |
|
| 66 |
|
| 67 |
def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
|
| 68 |
"""
|
| 69 |
Log model parameters to MLflow.
|
| 70 |
+
|
| 71 |
Args:
|
| 72 |
model: PyTorch model
|
| 73 |
prefix: Prefix for parameter names
|
| 74 |
+
|
| 75 |
Returns:
|
| 76 |
Dictionary of logged parameters
|
| 77 |
"""
|
| 78 |
from scripts.models import count_parameters
|
| 79 |
+
|
| 80 |
params = {
|
| 81 |
f"{prefix}_name": model.__class__.__name__,
|
| 82 |
f"{prefix}_total_params": count_parameters(model),
|
| 83 |
+
f"{prefix}_trainable_params": sum(
|
| 84 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
| 85 |
+
)
|
| 86 |
}
|
| 87 |
+
|
| 88 |
mlflow.log_params(params)
|
| 89 |
return params
|
| 90 |
|
|
|
|
| 92 |
def log_training_config(config: Dict[str, Any]) -> None:
|
| 93 |
"""
|
| 94 |
Log training configuration to MLflow.
|
| 95 |
+
|
| 96 |
Args:
|
| 97 |
config: Dictionary of training hyperparameters
|
| 98 |
"""
|
|
|
|
| 104 |
flat_config[f"{key}_{subkey}"] = subvalue
|
| 105 |
else:
|
| 106 |
flat_config[key] = value
|
| 107 |
+
|
| 108 |
mlflow.log_params(flat_config)
|
| 109 |
|
| 110 |
|
|
|
|
| 117 |
) -> None:
|
| 118 |
"""
|
| 119 |
Log dataset information to MLflow.
|
| 120 |
+
|
| 121 |
Args:
|
| 122 |
train_size: Number of training samples
|
| 123 |
val_size: Number of validation samples
|
|
|
|
| 137 |
def log_system_info() -> Dict[str, Any]:
|
| 138 |
"""
|
| 139 |
Log system information to MLflow.
|
| 140 |
+
|
| 141 |
Returns:
|
| 142 |
Dictionary of system information
|
| 143 |
"""
|
| 144 |
import torch
|
| 145 |
import platform
|
| 146 |
+
|
| 147 |
system_info = {
|
| 148 |
"system_platform": platform.system(),
|
| 149 |
"system_python_version": platform.python_version(),
|
|
|
|
| 154 |
),
|
| 155 |
"system_device": "cuda" if torch.cuda.is_available() else "cpu"
|
| 156 |
}
|
| 157 |
+
|
| 158 |
if torch.cuda.is_available():
|
| 159 |
system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
|
| 160 |
system_info["system_gpu_count"] = torch.cuda.device_count()
|
| 161 |
+
|
| 162 |
mlflow.log_params(system_info)
|
| 163 |
return system_info
|
| 164 |
|
|
|
|
| 166 |
def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
|
| 167 |
"""
|
| 168 |
Log metrics for a specific epoch.
|
| 169 |
+
|
| 170 |
Args:
|
| 171 |
metrics: Dictionary of metric names and values
|
| 172 |
step: Epoch number
|
|
|
|
| 177 |
def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
|
| 178 |
"""
|
| 179 |
Log a file or directory as an artifact.
|
| 180 |
+
|
| 181 |
Args:
|
| 182 |
path: Path to file or directory
|
| 183 |
artifact_path: Optional artifact path in MLflow
|
|
|
|
| 196 |
) -> None:
|
| 197 |
"""
|
| 198 |
Log model to MLflow with model registry integration.
|
| 199 |
+
|
| 200 |
Args:
|
| 201 |
model: PyTorch model
|
| 202 |
model_name: Name for the model artifact
|
| 203 |
artifact_path: Artifact path in MLflow
|
| 204 |
registered_model_name: Name for model registry (optional)
|
| 205 |
"""
|
| 206 |
+
|
| 207 |
# Log model
|
| 208 |
mlflow.pytorch.log_model(
|
| 209 |
pytorch_model=model,
|
|
|
|
| 218 |
) -> mlflow.ActiveRun:
|
| 219 |
"""
|
| 220 |
Get existing run or create a new one.
|
| 221 |
+
|
| 222 |
Args:
|
| 223 |
run_name: Name for the run
|
| 224 |
tags: Tags for the run
|
| 225 |
+
|
| 226 |
Returns:
|
| 227 |
MLflow active run context
|
| 228 |
"""
|
|
|
|
| 238 |
"""Test MLflow setup and basic logging."""
|
| 239 |
print("Testing MLflow Setup")
|
| 240 |
print("=" * 50)
|
| 241 |
+
|
| 242 |
# Setup MLflow
|
| 243 |
setup_mlflow("test-experiment")
|
| 244 |
+
|
| 245 |
# Test logging
|
| 246 |
with mlflow.start_run(run_name="test-run"):
|
| 247 |
# Log parameters
|
|
|
|
| 250 |
"batch_size": 64,
|
| 251 |
"epochs": 10
|
| 252 |
})
|
| 253 |
+
|
| 254 |
# Log metrics
|
| 255 |
for epoch in range(3):
|
| 256 |
mlflow.log_metrics({
|
|
|
|
| 259 |
"train_accuracy": 0.8 + epoch * 0.05,
|
| 260 |
"val_accuracy": 0.75 + epoch * 0.05
|
| 261 |
}, step=epoch)
|
| 262 |
+
|
| 263 |
# Log system info
|
| 264 |
system_info = log_system_info()
|
| 265 |
print("\nSystem Info:")
|
| 266 |
for key, value in system_info.items():
|
| 267 |
print(f" {key}: {value}")
|
| 268 |
+
|
| 269 |
print("\n✓ MLflow test complete!")
|
| 270 |
print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
|
| 271 |
|
scripts/models.py
CHANGED
|
@@ -8,7 +8,7 @@ This module provides CNN models for digit recognition:
|
|
| 8 |
|
| 9 |
Usage:
|
| 10 |
from scripts.models import BaselineCNN
|
| 11 |
-
|
| 12 |
model = BaselineCNN()
|
| 13 |
output = model(images) # (batch, 10) logits
|
| 14 |
"""
|
|
@@ -22,7 +22,7 @@ from typing import Tuple
|
|
| 22 |
class BaselineCNN(nn.Module):
|
| 23 |
"""
|
| 24 |
Baseline CNN for MNIST classification.
|
| 25 |
-
|
| 26 |
Architecture:
|
| 27 |
Input: (batch, 1, 28, 28)
|
| 28 |
Conv1: 1 -> 32 channels, 3x3 kernel, padding=1
|
|
@@ -32,28 +32,28 @@ class BaselineCNN(nn.Module):
|
|
| 32 |
Flatten -> (batch, 3136)
|
| 33 |
FC1: 3136 -> 128, ReLU, Dropout(0.5)
|
| 34 |
FC2: 128 -> 10 (output logits)
|
| 35 |
-
|
| 36 |
Design Rationale:
|
| 37 |
- 2 conv layers: Balance between simplicity and capacity
|
| 38 |
- 32->64 filters: Standard progression, proven effective
|
| 39 |
- Dropout 0.5: Prevent overfitting on small dataset
|
| 40 |
- No batch norm: Keep baseline simple
|
| 41 |
-
|
| 42 |
Expected Performance:
|
| 43 |
- Parameters: ~110k
|
| 44 |
- Test accuracy: 98-99%
|
| 45 |
- Training time: ~5-10 min on GPU
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
def __init__(self, dropout_rate: float = 0.5):
|
| 49 |
"""
|
| 50 |
Initialize baseline CNN.
|
| 51 |
-
|
| 52 |
Args:
|
| 53 |
dropout_rate: Dropout probability (default 0.5)
|
| 54 |
"""
|
| 55 |
super(BaselineCNN, self).__init__()
|
| 56 |
-
|
| 57 |
# Convolutional layers
|
| 58 |
self.conv1 = nn.Conv2d(
|
| 59 |
in_channels=1,
|
|
@@ -67,25 +67,25 @@ class BaselineCNN(nn.Module):
|
|
| 67 |
kernel_size=3,
|
| 68 |
padding=1
|
| 69 |
)
|
| 70 |
-
|
| 71 |
# Pooling layer (shared)
|
| 72 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 73 |
-
|
| 74 |
# Fully connected layers
|
| 75 |
# After two pooling layers: 28->14->7, so 64*7*7 = 3136
|
| 76 |
self.fc1 = nn.Linear(64 * 7 * 7, 128)
|
| 77 |
self.fc2 = nn.Linear(128, 10)
|
| 78 |
-
|
| 79 |
# Dropout for regularization
|
| 80 |
self.dropout = nn.Dropout(p=dropout_rate)
|
| 81 |
-
|
| 82 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
"""
|
| 84 |
Forward pass.
|
| 85 |
-
|
| 86 |
Args:
|
| 87 |
x: Input tensor of shape (batch, 1, 28, 28)
|
| 88 |
-
|
| 89 |
Returns:
|
| 90 |
Output logits of shape (batch, 10)
|
| 91 |
"""
|
|
@@ -93,28 +93,28 @@ class BaselineCNN(nn.Module):
|
|
| 93 |
x = self.conv1(x) # (batch, 32, 28, 28)
|
| 94 |
x = F.relu(x)
|
| 95 |
x = self.pool(x) # (batch, 32, 14, 14)
|
| 96 |
-
|
| 97 |
# Conv block 2: Conv -> ReLU -> Pool
|
| 98 |
x = self.conv2(x) # (batch, 64, 14, 14)
|
| 99 |
x = F.relu(x)
|
| 100 |
x = self.pool(x) # (batch, 64, 7, 7)
|
| 101 |
-
|
| 102 |
# Flatten
|
| 103 |
x = x.view(-1, 64 * 7 * 7) # (batch, 3136)
|
| 104 |
-
|
| 105 |
# Fully connected layers
|
| 106 |
x = self.fc1(x) # (batch, 128)
|
| 107 |
x = F.relu(x)
|
| 108 |
x = self.dropout(x)
|
| 109 |
x = self.fc2(x) # (batch, 10)
|
| 110 |
-
|
| 111 |
return x
|
| 112 |
|
| 113 |
|
| 114 |
class ImprovedCNN(nn.Module):
|
| 115 |
"""
|
| 116 |
Enhanced CNN with batch normalization and deeper architecture.
|
| 117 |
-
|
| 118 |
Architecture:
|
| 119 |
Conv1: 1 -> 32, BatchNorm, ReLU, MaxPool
|
| 120 |
Conv2: 32 -> 64, BatchNorm, ReLU, MaxPool
|
|
@@ -122,49 +122,49 @@ class ImprovedCNN(nn.Module):
|
|
| 122 |
Flatten
|
| 123 |
FC1: 128*3*3 -> 256, BatchNorm, ReLU, Dropout(0.5)
|
| 124 |
FC2: 256 -> 10
|
| 125 |
-
|
| 126 |
Expected Performance:
|
| 127 |
- Parameters: ~200k
|
| 128 |
- Test accuracy: 99%+
|
| 129 |
- Converges faster than baseline
|
| 130 |
"""
|
| 131 |
-
|
| 132 |
def __init__(self, dropout_rate: float = 0.5):
|
| 133 |
"""
|
| 134 |
Initialize improved CNN.
|
| 135 |
-
|
| 136 |
Args:
|
| 137 |
dropout_rate: Dropout probability (default 0.5)
|
| 138 |
"""
|
| 139 |
super(ImprovedCNN, self).__init__()
|
| 140 |
-
|
| 141 |
# Convolutional layers with batch normalization
|
| 142 |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
| 143 |
self.bn1 = nn.BatchNorm2d(32)
|
| 144 |
-
|
| 145 |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
| 146 |
self.bn2 = nn.BatchNorm2d(64)
|
| 147 |
-
|
| 148 |
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
| 149 |
self.bn3 = nn.BatchNorm2d(128)
|
| 150 |
-
|
| 151 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 152 |
-
|
| 153 |
# Fully connected layers
|
| 154 |
# After three pooling layers: 28->14->7->3, so 128*3*3 = 1152
|
| 155 |
self.fc1 = nn.Linear(128 * 3 * 3, 256)
|
| 156 |
self.bn_fc = nn.BatchNorm1d(256)
|
| 157 |
self.fc2 = nn.Linear(256, 10)
|
| 158 |
-
|
| 159 |
self.dropout = nn.Dropout(p=dropout_rate)
|
| 160 |
-
|
| 161 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 162 |
"""
|
| 163 |
Forward pass.
|
| 164 |
-
|
| 165 |
Args:
|
| 166 |
x: Input tensor of shape (batch, 1, 28, 28)
|
| 167 |
-
|
| 168 |
Returns:
|
| 169 |
Output logits of shape (batch, 10)
|
| 170 |
"""
|
|
@@ -173,39 +173,39 @@ class ImprovedCNN(nn.Module):
|
|
| 173 |
x = self.bn1(x)
|
| 174 |
x = F.relu(x)
|
| 175 |
x = self.pool(x) # (batch, 32, 14, 14)
|
| 176 |
-
|
| 177 |
# Conv block 2
|
| 178 |
x = self.conv2(x)
|
| 179 |
x = self.bn2(x)
|
| 180 |
x = F.relu(x)
|
| 181 |
x = self.pool(x) # (batch, 64, 7, 7)
|
| 182 |
-
|
| 183 |
# Conv block 3
|
| 184 |
x = self.conv3(x)
|
| 185 |
x = self.bn3(x)
|
| 186 |
x = F.relu(x)
|
| 187 |
x = self.pool(x) # (batch, 128, 3, 3)
|
| 188 |
-
|
| 189 |
# Flatten
|
| 190 |
x = x.view(-1, 128 * 3 * 3)
|
| 191 |
-
|
| 192 |
# Fully connected layers
|
| 193 |
x = self.fc1(x)
|
| 194 |
x = self.bn_fc(x)
|
| 195 |
x = F.relu(x)
|
| 196 |
x = self.dropout(x)
|
| 197 |
x = self.fc2(x)
|
| 198 |
-
|
| 199 |
return x
|
| 200 |
|
| 201 |
|
| 202 |
def count_parameters(model: nn.Module) -> Tuple[int, int]:
|
| 203 |
"""
|
| 204 |
Count total and trainable parameters in model.
|
| 205 |
-
|
| 206 |
Args:
|
| 207 |
model: PyTorch model
|
| 208 |
-
|
| 209 |
Returns:
|
| 210 |
Tuple of (total_params, trainable_params)
|
| 211 |
"""
|
|
@@ -214,19 +214,21 @@ def count_parameters(model: nn.Module) -> Tuple[int, int]:
|
|
| 214 |
return total_params, trainable_params
|
| 215 |
|
| 216 |
|
| 217 |
-
def get_model_summary(
|
|
|
|
|
|
|
| 218 |
"""
|
| 219 |
Generate model architecture summary.
|
| 220 |
-
|
| 221 |
Args:
|
| 222 |
model: PyTorch model
|
| 223 |
input_size: Input tensor size (batch, channels, height, width)
|
| 224 |
-
|
| 225 |
Returns:
|
| 226 |
Formatted string with model summary
|
| 227 |
"""
|
| 228 |
total_params, trainable_params = count_parameters(model)
|
| 229 |
-
|
| 230 |
summary = []
|
| 231 |
summary.append("=" * 60)
|
| 232 |
summary.append(f"Model: {model.__class__.__name__}")
|
|
@@ -234,47 +236,49 @@ def get_model_summary(model: nn.Module, input_size: Tuple[int, ...] = (1, 1, 28,
|
|
| 234 |
summary.append(f"Input size: {input_size}")
|
| 235 |
summary.append(f"Total parameters: {total_params:,}")
|
| 236 |
summary.append(f"Trainable parameters: {trainable_params:,}")
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
summary.append("=" * 60)
|
| 239 |
-
|
| 240 |
return "\n".join(summary)
|
| 241 |
|
| 242 |
|
| 243 |
def test_model(model: nn.Module, device: str = 'cpu') -> bool:
|
| 244 |
"""
|
| 245 |
Test model with dummy input.
|
| 246 |
-
|
| 247 |
Args:
|
| 248 |
model: PyTorch model
|
| 249 |
device: Device to run on ('cpu' or 'cuda')
|
| 250 |
-
|
| 251 |
Returns:
|
| 252 |
True if test passes, False otherwise
|
| 253 |
"""
|
| 254 |
try:
|
| 255 |
model = model.to(device)
|
| 256 |
model.eval()
|
| 257 |
-
|
| 258 |
# Create dummy input
|
| 259 |
dummy_input = torch.randn(4, 1, 28, 28).to(device)
|
| 260 |
-
|
| 261 |
# Forward pass
|
| 262 |
with torch.no_grad():
|
| 263 |
output = model(dummy_input)
|
| 264 |
-
|
| 265 |
# Check output shape
|
| 266 |
assert output.shape == (4, 10), f"Expected shape (4, 10), got {output.shape}"
|
| 267 |
-
|
| 268 |
# Check output is finite
|
| 269 |
assert torch.isfinite(output).all(), "Output contains NaN or Inf"
|
| 270 |
-
|
| 271 |
print("✓ Model test passed")
|
| 272 |
print(f" Input shape: {dummy_input.shape}")
|
| 273 |
print(f" Output shape: {output.shape}")
|
| 274 |
print(f" Output range: [{output.min():.4f}, {output.max():.4f}]")
|
| 275 |
-
|
| 276 |
return True
|
| 277 |
-
|
| 278 |
except Exception as e:
|
| 279 |
print(f"✗ Model test failed: {e}")
|
| 280 |
return False
|
|
@@ -284,23 +288,23 @@ if __name__ == "__main__":
|
|
| 284 |
"""Test model instantiation and forward pass."""
|
| 285 |
print("Testing BaselineCNN:")
|
| 286 |
print()
|
| 287 |
-
|
| 288 |
# Create model
|
| 289 |
model = BaselineCNN()
|
| 290 |
print(get_model_summary(model))
|
| 291 |
print()
|
| 292 |
-
|
| 293 |
# Test forward pass
|
| 294 |
test_model(model)
|
| 295 |
print()
|
| 296 |
-
|
| 297 |
# Test improved model
|
| 298 |
print("=" * 60)
|
| 299 |
print("Testing ImprovedCNN:")
|
| 300 |
print()
|
| 301 |
-
|
| 302 |
model_improved = ImprovedCNN()
|
| 303 |
print(get_model_summary(model_improved))
|
| 304 |
print()
|
| 305 |
-
|
| 306 |
test_model(model_improved)
|
|
|
|
| 8 |
|
| 9 |
Usage:
|
| 10 |
from scripts.models import BaselineCNN
|
| 11 |
+
|
| 12 |
model = BaselineCNN()
|
| 13 |
output = model(images) # (batch, 10) logits
|
| 14 |
"""
|
|
|
|
| 22 |
class BaselineCNN(nn.Module):
|
| 23 |
"""
|
| 24 |
Baseline CNN for MNIST classification.
|
| 25 |
+
|
| 26 |
Architecture:
|
| 27 |
Input: (batch, 1, 28, 28)
|
| 28 |
Conv1: 1 -> 32 channels, 3x3 kernel, padding=1
|
|
|
|
| 32 |
Flatten -> (batch, 3136)
|
| 33 |
FC1: 3136 -> 128, ReLU, Dropout(0.5)
|
| 34 |
FC2: 128 -> 10 (output logits)
|
| 35 |
+
|
| 36 |
Design Rationale:
|
| 37 |
- 2 conv layers: Balance between simplicity and capacity
|
| 38 |
- 32->64 filters: Standard progression, proven effective
|
| 39 |
- Dropout 0.5: Prevent overfitting on small dataset
|
| 40 |
- No batch norm: Keep baseline simple
|
| 41 |
+
|
| 42 |
Expected Performance:
|
| 43 |
- Parameters: ~110k
|
| 44 |
- Test accuracy: 98-99%
|
| 45 |
- Training time: ~5-10 min on GPU
|
| 46 |
"""
|
| 47 |
+
|
| 48 |
def __init__(self, dropout_rate: float = 0.5):
|
| 49 |
"""
|
| 50 |
Initialize baseline CNN.
|
| 51 |
+
|
| 52 |
Args:
|
| 53 |
dropout_rate: Dropout probability (default 0.5)
|
| 54 |
"""
|
| 55 |
super(BaselineCNN, self).__init__()
|
| 56 |
+
|
| 57 |
# Convolutional layers
|
| 58 |
self.conv1 = nn.Conv2d(
|
| 59 |
in_channels=1,
|
|
|
|
| 67 |
kernel_size=3,
|
| 68 |
padding=1
|
| 69 |
)
|
| 70 |
+
|
| 71 |
# Pooling layer (shared)
|
| 72 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 73 |
+
|
| 74 |
# Fully connected layers
|
| 75 |
# After two pooling layers: 28->14->7, so 64*7*7 = 3136
|
| 76 |
self.fc1 = nn.Linear(64 * 7 * 7, 128)
|
| 77 |
self.fc2 = nn.Linear(128, 10)
|
| 78 |
+
|
| 79 |
# Dropout for regularization
|
| 80 |
self.dropout = nn.Dropout(p=dropout_rate)
|
| 81 |
+
|
| 82 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
"""
|
| 84 |
Forward pass.
|
| 85 |
+
|
| 86 |
Args:
|
| 87 |
x: Input tensor of shape (batch, 1, 28, 28)
|
| 88 |
+
|
| 89 |
Returns:
|
| 90 |
Output logits of shape (batch, 10)
|
| 91 |
"""
|
|
|
|
| 93 |
x = self.conv1(x) # (batch, 32, 28, 28)
|
| 94 |
x = F.relu(x)
|
| 95 |
x = self.pool(x) # (batch, 32, 14, 14)
|
| 96 |
+
|
| 97 |
# Conv block 2: Conv -> ReLU -> Pool
|
| 98 |
x = self.conv2(x) # (batch, 64, 14, 14)
|
| 99 |
x = F.relu(x)
|
| 100 |
x = self.pool(x) # (batch, 64, 7, 7)
|
| 101 |
+
|
| 102 |
# Flatten
|
| 103 |
x = x.view(-1, 64 * 7 * 7) # (batch, 3136)
|
| 104 |
+
|
| 105 |
# Fully connected layers
|
| 106 |
x = self.fc1(x) # (batch, 128)
|
| 107 |
x = F.relu(x)
|
| 108 |
x = self.dropout(x)
|
| 109 |
x = self.fc2(x) # (batch, 10)
|
| 110 |
+
|
| 111 |
return x
|
| 112 |
|
| 113 |
|
| 114 |
class ImprovedCNN(nn.Module):
|
| 115 |
"""
|
| 116 |
Enhanced CNN with batch normalization and deeper architecture.
|
| 117 |
+
|
| 118 |
Architecture:
|
| 119 |
Conv1: 1 -> 32, BatchNorm, ReLU, MaxPool
|
| 120 |
Conv2: 32 -> 64, BatchNorm, ReLU, MaxPool
|
|
|
|
| 122 |
Flatten
|
| 123 |
FC1: 128*3*3 -> 256, BatchNorm, ReLU, Dropout(0.5)
|
| 124 |
FC2: 256 -> 10
|
| 125 |
+
|
| 126 |
Expected Performance:
|
| 127 |
- Parameters: ~200k
|
| 128 |
- Test accuracy: 99%+
|
| 129 |
- Converges faster than baseline
|
| 130 |
"""
|
| 131 |
+
|
| 132 |
def __init__(self, dropout_rate: float = 0.5):
|
| 133 |
"""
|
| 134 |
Initialize improved CNN.
|
| 135 |
+
|
| 136 |
Args:
|
| 137 |
dropout_rate: Dropout probability (default 0.5)
|
| 138 |
"""
|
| 139 |
super(ImprovedCNN, self).__init__()
|
| 140 |
+
|
| 141 |
# Convolutional layers with batch normalization
|
| 142 |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
| 143 |
self.bn1 = nn.BatchNorm2d(32)
|
| 144 |
+
|
| 145 |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
| 146 |
self.bn2 = nn.BatchNorm2d(64)
|
| 147 |
+
|
| 148 |
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
| 149 |
self.bn3 = nn.BatchNorm2d(128)
|
| 150 |
+
|
| 151 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 152 |
+
|
| 153 |
# Fully connected layers
|
| 154 |
# After three pooling layers: 28->14->7->3, so 128*3*3 = 1152
|
| 155 |
self.fc1 = nn.Linear(128 * 3 * 3, 256)
|
| 156 |
self.bn_fc = nn.BatchNorm1d(256)
|
| 157 |
self.fc2 = nn.Linear(256, 10)
|
| 158 |
+
|
| 159 |
self.dropout = nn.Dropout(p=dropout_rate)
|
| 160 |
+
|
| 161 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 162 |
"""
|
| 163 |
Forward pass.
|
| 164 |
+
|
| 165 |
Args:
|
| 166 |
x: Input tensor of shape (batch, 1, 28, 28)
|
| 167 |
+
|
| 168 |
Returns:
|
| 169 |
Output logits of shape (batch, 10)
|
| 170 |
"""
|
|
|
|
| 173 |
x = self.bn1(x)
|
| 174 |
x = F.relu(x)
|
| 175 |
x = self.pool(x) # (batch, 32, 14, 14)
|
| 176 |
+
|
| 177 |
# Conv block 2
|
| 178 |
x = self.conv2(x)
|
| 179 |
x = self.bn2(x)
|
| 180 |
x = F.relu(x)
|
| 181 |
x = self.pool(x) # (batch, 64, 7, 7)
|
| 182 |
+
|
| 183 |
# Conv block 3
|
| 184 |
x = self.conv3(x)
|
| 185 |
x = self.bn3(x)
|
| 186 |
x = F.relu(x)
|
| 187 |
x = self.pool(x) # (batch, 128, 3, 3)
|
| 188 |
+
|
| 189 |
# Flatten
|
| 190 |
x = x.view(-1, 128 * 3 * 3)
|
| 191 |
+
|
| 192 |
# Fully connected layers
|
| 193 |
x = self.fc1(x)
|
| 194 |
x = self.bn_fc(x)
|
| 195 |
x = F.relu(x)
|
| 196 |
x = self.dropout(x)
|
| 197 |
x = self.fc2(x)
|
| 198 |
+
|
| 199 |
return x
|
| 200 |
|
| 201 |
|
| 202 |
def count_parameters(model: nn.Module) -> Tuple[int, int]:
|
| 203 |
"""
|
| 204 |
Count total and trainable parameters in model.
|
| 205 |
+
|
| 206 |
Args:
|
| 207 |
model: PyTorch model
|
| 208 |
+
|
| 209 |
Returns:
|
| 210 |
Tuple of (total_params, trainable_params)
|
| 211 |
"""
|
|
|
|
| 214 |
return total_params, trainable_params
|
| 215 |
|
| 216 |
|
| 217 |
+
def get_model_summary(
|
| 218 |
+
model: nn.Module, input_size: Tuple[int, ...] = (1, 1, 28, 28)
|
| 219 |
+
) -> str:
|
| 220 |
"""
|
| 221 |
Generate model architecture summary.
|
| 222 |
+
|
| 223 |
Args:
|
| 224 |
model: PyTorch model
|
| 225 |
input_size: Input tensor size (batch, channels, height, width)
|
| 226 |
+
|
| 227 |
Returns:
|
| 228 |
Formatted string with model summary
|
| 229 |
"""
|
| 230 |
total_params, trainable_params = count_parameters(model)
|
| 231 |
+
|
| 232 |
summary = []
|
| 233 |
summary.append("=" * 60)
|
| 234 |
summary.append(f"Model: {model.__class__.__name__}")
|
|
|
|
| 236 |
summary.append(f"Input size: {input_size}")
|
| 237 |
summary.append(f"Total parameters: {total_params:,}")
|
| 238 |
summary.append(f"Trainable parameters: {trainable_params:,}")
|
| 239 |
+
# Assuming float32
|
| 240 |
+
model_size_mb = total_params * 4 / (1024**2)
|
| 241 |
+
summary.append(f"Model size (MB): {model_size_mb:.2f}")
|
| 242 |
summary.append("=" * 60)
|
| 243 |
+
|
| 244 |
return "\n".join(summary)
|
| 245 |
|
| 246 |
|
| 247 |
def test_model(model: nn.Module, device: str = 'cpu') -> bool:
|
| 248 |
"""
|
| 249 |
Test model with dummy input.
|
| 250 |
+
|
| 251 |
Args:
|
| 252 |
model: PyTorch model
|
| 253 |
device: Device to run on ('cpu' or 'cuda')
|
| 254 |
+
|
| 255 |
Returns:
|
| 256 |
True if test passes, False otherwise
|
| 257 |
"""
|
| 258 |
try:
|
| 259 |
model = model.to(device)
|
| 260 |
model.eval()
|
| 261 |
+
|
| 262 |
# Create dummy input
|
| 263 |
dummy_input = torch.randn(4, 1, 28, 28).to(device)
|
| 264 |
+
|
| 265 |
# Forward pass
|
| 266 |
with torch.no_grad():
|
| 267 |
output = model(dummy_input)
|
| 268 |
+
|
| 269 |
# Check output shape
|
| 270 |
assert output.shape == (4, 10), f"Expected shape (4, 10), got {output.shape}"
|
| 271 |
+
|
| 272 |
# Check output is finite
|
| 273 |
assert torch.isfinite(output).all(), "Output contains NaN or Inf"
|
| 274 |
+
|
| 275 |
print("✓ Model test passed")
|
| 276 |
print(f" Input shape: {dummy_input.shape}")
|
| 277 |
print(f" Output shape: {output.shape}")
|
| 278 |
print(f" Output range: [{output.min():.4f}, {output.max():.4f}]")
|
| 279 |
+
|
| 280 |
return True
|
| 281 |
+
|
| 282 |
except Exception as e:
|
| 283 |
print(f"✗ Model test failed: {e}")
|
| 284 |
return False
|
|
|
|
| 288 |
"""Test model instantiation and forward pass."""
|
| 289 |
print("Testing BaselineCNN:")
|
| 290 |
print()
|
| 291 |
+
|
| 292 |
# Create model
|
| 293 |
model = BaselineCNN()
|
| 294 |
print(get_model_summary(model))
|
| 295 |
print()
|
| 296 |
+
|
| 297 |
# Test forward pass
|
| 298 |
test_model(model)
|
| 299 |
print()
|
| 300 |
+
|
| 301 |
# Test improved model
|
| 302 |
print("=" * 60)
|
| 303 |
print("Testing ImprovedCNN:")
|
| 304 |
print()
|
| 305 |
+
|
| 306 |
model_improved = ImprovedCNN()
|
| 307 |
print(get_model_summary(model_improved))
|
| 308 |
print()
|
| 309 |
+
|
| 310 |
test_model(model_improved)
|
scripts/preprocessing.py
CHANGED
|
@@ -9,7 +9,7 @@ This module provides PyTorch Dataset and DataLoader setup for MNIST:
|
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
from scripts.preprocessing import MnistDataset, create_dataloaders
|
| 12 |
-
|
| 13 |
train_dataset = MnistDataset(x_train, y_train, transform=None)
|
| 14 |
train_loader, val_loader = create_dataloaders(
|
| 15 |
train_dataset, val_dataset, batch_size=64
|
|
@@ -26,10 +26,10 @@ from torch.utils.data import Dataset, DataLoader
|
|
| 26 |
class MnistDataset(Dataset):
|
| 27 |
"""
|
| 28 |
PyTorch Dataset for MNIST images.
|
| 29 |
-
|
| 30 |
Handles normalization and conversion to tensors suitable for CNN training.
|
| 31 |
"""
|
| 32 |
-
|
| 33 |
def __init__(
|
| 34 |
self,
|
| 35 |
images: List[NDArray[np.uint8]],
|
|
@@ -38,7 +38,7 @@ class MnistDataset(Dataset):
|
|
| 38 |
):
|
| 39 |
"""
|
| 40 |
Initialize MNIST dataset.
|
| 41 |
-
|
| 42 |
Args:
|
| 43 |
images: List of 28x28 numpy arrays with pixel values [0, 255]
|
| 44 |
labels: List of integer labels (0-9)
|
|
@@ -47,22 +47,22 @@ class MnistDataset(Dataset):
|
|
| 47 |
self.images = images
|
| 48 |
self.labels = labels
|
| 49 |
self.transform = transform
|
| 50 |
-
|
| 51 |
# Validate inputs
|
| 52 |
assert len(images) == len(labels), \
|
| 53 |
f"Mismatch: {len(images)} images but {len(labels)} labels"
|
| 54 |
-
|
| 55 |
def __len__(self) -> int:
|
| 56 |
"""Return number of samples in dataset."""
|
| 57 |
return len(self.images)
|
| 58 |
-
|
| 59 |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
"""
|
| 61 |
Get a single sample.
|
| 62 |
-
|
| 63 |
Args:
|
| 64 |
idx: Index of sample to retrieve
|
| 65 |
-
|
| 66 |
Returns:
|
| 67 |
Tuple of (image_tensor, label_tensor)
|
| 68 |
- image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1]
|
|
@@ -71,18 +71,18 @@ class MnistDataset(Dataset):
|
|
| 71 |
# Get image and label
|
| 72 |
image = np.array(self.images[idx])
|
| 73 |
label = self.labels[idx]
|
| 74 |
-
|
| 75 |
# Normalize to [0, 1]
|
| 76 |
image = image.astype(np.float32) / 255.0
|
| 77 |
-
|
| 78 |
# Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28)
|
| 79 |
image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
|
| 80 |
label = torch.tensor(label, dtype=torch.long)
|
| 81 |
-
|
| 82 |
# Apply transforms if provided (e.g., augmentation)
|
| 83 |
if self.transform:
|
| 84 |
image = self.transform(image)
|
| 85 |
-
|
| 86 |
return image, label
|
| 87 |
|
| 88 |
|
|
@@ -95,14 +95,14 @@ def create_dataloaders(
|
|
| 95 |
) -> Tuple[DataLoader, DataLoader]:
|
| 96 |
"""
|
| 97 |
Create DataLoader instances for training and validation.
|
| 98 |
-
|
| 99 |
Args:
|
| 100 |
train_dataset: Training dataset
|
| 101 |
val_dataset: Validation dataset
|
| 102 |
batch_size: Number of samples per batch
|
| 103 |
num_workers: Number of worker processes for data loading
|
| 104 |
shuffle_train: Whether to shuffle training data
|
| 105 |
-
|
| 106 |
Returns:
|
| 107 |
Tuple of (train_loader, val_loader)
|
| 108 |
"""
|
|
@@ -113,7 +113,7 @@ def create_dataloaders(
|
|
| 113 |
num_workers=num_workers,
|
| 114 |
pin_memory=True # Faster GPU transfer
|
| 115 |
)
|
| 116 |
-
|
| 117 |
val_loader = DataLoader(
|
| 118 |
val_dataset,
|
| 119 |
batch_size=batch_size,
|
|
@@ -121,7 +121,7 @@ def create_dataloaders(
|
|
| 121 |
num_workers=num_workers,
|
| 122 |
pin_memory=True
|
| 123 |
)
|
| 124 |
-
|
| 125 |
return train_loader, val_loader
|
| 126 |
|
| 127 |
|
|
@@ -132,12 +132,12 @@ def create_test_dataloader(
|
|
| 132 |
) -> DataLoader:
|
| 133 |
"""
|
| 134 |
Create DataLoader for test set.
|
| 135 |
-
|
| 136 |
Args:
|
| 137 |
test_dataset: Test dataset
|
| 138 |
batch_size: Number of samples per batch
|
| 139 |
num_workers: Number of worker processes for data loading
|
| 140 |
-
|
| 141 |
Returns:
|
| 142 |
Test DataLoader
|
| 143 |
"""
|
|
@@ -148,7 +148,7 @@ def create_test_dataloader(
|
|
| 148 |
num_workers=num_workers,
|
| 149 |
pin_memory=True
|
| 150 |
)
|
| 151 |
-
|
| 152 |
return test_loader
|
| 153 |
|
| 154 |
|
|
@@ -163,70 +163,70 @@ def split_train_val(
|
|
| 163 |
]:
|
| 164 |
"""
|
| 165 |
Split training data into train and validation sets.
|
| 166 |
-
|
| 167 |
Uses stratified sampling to maintain class balance.
|
| 168 |
-
|
| 169 |
Args:
|
| 170 |
images: List of training images
|
| 171 |
labels: List of training labels
|
| 172 |
val_split: Fraction of data to use for validation (0.15 = 15%)
|
| 173 |
random_seed: Random seed for reproducibility
|
| 174 |
-
|
| 175 |
Returns:
|
| 176 |
Tuple of ((train_images, train_labels), (val_images, val_labels))
|
| 177 |
"""
|
| 178 |
from collections import defaultdict
|
| 179 |
-
|
| 180 |
# Group indices by class for stratified split
|
| 181 |
class_indices = defaultdict(list)
|
| 182 |
for idx, label in enumerate(labels):
|
| 183 |
class_indices[label].append(idx)
|
| 184 |
-
|
| 185 |
# Set random seed
|
| 186 |
np.random.seed(random_seed)
|
| 187 |
-
|
| 188 |
train_indices = []
|
| 189 |
val_indices = []
|
| 190 |
-
|
| 191 |
# Split each class separately
|
| 192 |
for class_label, indices in class_indices.items():
|
| 193 |
indices = np.array(indices)
|
| 194 |
np.random.shuffle(indices)
|
| 195 |
-
|
| 196 |
split_point = int(len(indices) * (1 - val_split))
|
| 197 |
train_indices.extend(indices[:split_point])
|
| 198 |
val_indices.extend(indices[split_point:])
|
| 199 |
-
|
| 200 |
# Shuffle combined indices
|
| 201 |
np.random.shuffle(train_indices)
|
| 202 |
np.random.shuffle(val_indices)
|
| 203 |
-
|
| 204 |
# Extract images and labels
|
| 205 |
train_images = [images[i] for i in train_indices]
|
| 206 |
train_labels = [labels[i] for i in train_indices]
|
| 207 |
val_images = [images[i] for i in val_indices]
|
| 208 |
val_labels = [labels[i] for i in val_indices]
|
| 209 |
-
|
| 210 |
return (train_images, train_labels), (val_images, val_labels)
|
| 211 |
|
| 212 |
|
| 213 |
def get_dataset_statistics(dataset: MnistDataset) -> dict:
|
| 214 |
"""
|
| 215 |
Compute statistics for a dataset (useful for debugging).
|
| 216 |
-
|
| 217 |
Args:
|
| 218 |
dataset: MnistDataset instance
|
| 219 |
-
|
| 220 |
Returns:
|
| 221 |
Dictionary with statistics
|
| 222 |
"""
|
| 223 |
# Sample first image to check preprocessing
|
| 224 |
sample_img, sample_label = dataset[0]
|
| 225 |
-
|
| 226 |
# Count labels
|
| 227 |
from collections import Counter
|
| 228 |
label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))])
|
| 229 |
-
|
| 230 |
return {
|
| 231 |
'num_samples': len(dataset),
|
| 232 |
'sample_image_shape': tuple(sample_img.shape),
|
|
|
|
| 9 |
|
| 10 |
Usage:
|
| 11 |
from scripts.preprocessing import MnistDataset, create_dataloaders
|
| 12 |
+
|
| 13 |
train_dataset = MnistDataset(x_train, y_train, transform=None)
|
| 14 |
train_loader, val_loader = create_dataloaders(
|
| 15 |
train_dataset, val_dataset, batch_size=64
|
|
|
|
| 26 |
class MnistDataset(Dataset):
|
| 27 |
"""
|
| 28 |
PyTorch Dataset for MNIST images.
|
| 29 |
+
|
| 30 |
Handles normalization and conversion to tensors suitable for CNN training.
|
| 31 |
"""
|
| 32 |
+
|
| 33 |
def __init__(
|
| 34 |
self,
|
| 35 |
images: List[NDArray[np.uint8]],
|
|
|
|
| 38 |
):
|
| 39 |
"""
|
| 40 |
Initialize MNIST dataset.
|
| 41 |
+
|
| 42 |
Args:
|
| 43 |
images: List of 28x28 numpy arrays with pixel values [0, 255]
|
| 44 |
labels: List of integer labels (0-9)
|
|
|
|
| 47 |
self.images = images
|
| 48 |
self.labels = labels
|
| 49 |
self.transform = transform
|
| 50 |
+
|
| 51 |
# Validate inputs
|
| 52 |
assert len(images) == len(labels), \
|
| 53 |
f"Mismatch: {len(images)} images but {len(labels)} labels"
|
| 54 |
+
|
| 55 |
def __len__(self) -> int:
|
| 56 |
"""Return number of samples in dataset."""
|
| 57 |
return len(self.images)
|
| 58 |
+
|
| 59 |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
"""
|
| 61 |
Get a single sample.
|
| 62 |
+
|
| 63 |
Args:
|
| 64 |
idx: Index of sample to retrieve
|
| 65 |
+
|
| 66 |
Returns:
|
| 67 |
Tuple of (image_tensor, label_tensor)
|
| 68 |
- image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1]
|
|
|
|
| 71 |
# Get image and label
|
| 72 |
image = np.array(self.images[idx])
|
| 73 |
label = self.labels[idx]
|
| 74 |
+
|
| 75 |
# Normalize to [0, 1]
|
| 76 |
image = image.astype(np.float32) / 255.0
|
| 77 |
+
|
| 78 |
# Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28)
|
| 79 |
image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
|
| 80 |
label = torch.tensor(label, dtype=torch.long)
|
| 81 |
+
|
| 82 |
# Apply transforms if provided (e.g., augmentation)
|
| 83 |
if self.transform:
|
| 84 |
image = self.transform(image)
|
| 85 |
+
|
| 86 |
return image, label
|
| 87 |
|
| 88 |
|
|
|
|
| 95 |
) -> Tuple[DataLoader, DataLoader]:
|
| 96 |
"""
|
| 97 |
Create DataLoader instances for training and validation.
|
| 98 |
+
|
| 99 |
Args:
|
| 100 |
train_dataset: Training dataset
|
| 101 |
val_dataset: Validation dataset
|
| 102 |
batch_size: Number of samples per batch
|
| 103 |
num_workers: Number of worker processes for data loading
|
| 104 |
shuffle_train: Whether to shuffle training data
|
| 105 |
+
|
| 106 |
Returns:
|
| 107 |
Tuple of (train_loader, val_loader)
|
| 108 |
"""
|
|
|
|
| 113 |
num_workers=num_workers,
|
| 114 |
pin_memory=True # Faster GPU transfer
|
| 115 |
)
|
| 116 |
+
|
| 117 |
val_loader = DataLoader(
|
| 118 |
val_dataset,
|
| 119 |
batch_size=batch_size,
|
|
|
|
| 121 |
num_workers=num_workers,
|
| 122 |
pin_memory=True
|
| 123 |
)
|
| 124 |
+
|
| 125 |
return train_loader, val_loader
|
| 126 |
|
| 127 |
|
|
|
|
| 132 |
) -> DataLoader:
|
| 133 |
"""
|
| 134 |
Create DataLoader for test set.
|
| 135 |
+
|
| 136 |
Args:
|
| 137 |
test_dataset: Test dataset
|
| 138 |
batch_size: Number of samples per batch
|
| 139 |
num_workers: Number of worker processes for data loading
|
| 140 |
+
|
| 141 |
Returns:
|
| 142 |
Test DataLoader
|
| 143 |
"""
|
|
|
|
| 148 |
num_workers=num_workers,
|
| 149 |
pin_memory=True
|
| 150 |
)
|
| 151 |
+
|
| 152 |
return test_loader
|
| 153 |
|
| 154 |
|
|
|
|
| 163 |
]:
|
| 164 |
"""
|
| 165 |
Split training data into train and validation sets.
|
| 166 |
+
|
| 167 |
Uses stratified sampling to maintain class balance.
|
| 168 |
+
|
| 169 |
Args:
|
| 170 |
images: List of training images
|
| 171 |
labels: List of training labels
|
| 172 |
val_split: Fraction of data to use for validation (0.15 = 15%)
|
| 173 |
random_seed: Random seed for reproducibility
|
| 174 |
+
|
| 175 |
Returns:
|
| 176 |
Tuple of ((train_images, train_labels), (val_images, val_labels))
|
| 177 |
"""
|
| 178 |
from collections import defaultdict
|
| 179 |
+
|
| 180 |
# Group indices by class for stratified split
|
| 181 |
class_indices = defaultdict(list)
|
| 182 |
for idx, label in enumerate(labels):
|
| 183 |
class_indices[label].append(idx)
|
| 184 |
+
|
| 185 |
# Set random seed
|
| 186 |
np.random.seed(random_seed)
|
| 187 |
+
|
| 188 |
train_indices = []
|
| 189 |
val_indices = []
|
| 190 |
+
|
| 191 |
# Split each class separately
|
| 192 |
for class_label, indices in class_indices.items():
|
| 193 |
indices = np.array(indices)
|
| 194 |
np.random.shuffle(indices)
|
| 195 |
+
|
| 196 |
split_point = int(len(indices) * (1 - val_split))
|
| 197 |
train_indices.extend(indices[:split_point])
|
| 198 |
val_indices.extend(indices[split_point:])
|
| 199 |
+
|
| 200 |
# Shuffle combined indices
|
| 201 |
np.random.shuffle(train_indices)
|
| 202 |
np.random.shuffle(val_indices)
|
| 203 |
+
|
| 204 |
# Extract images and labels
|
| 205 |
train_images = [images[i] for i in train_indices]
|
| 206 |
train_labels = [labels[i] for i in train_indices]
|
| 207 |
val_images = [images[i] for i in val_indices]
|
| 208 |
val_labels = [labels[i] for i in val_indices]
|
| 209 |
+
|
| 210 |
return (train_images, train_labels), (val_images, val_labels)
|
| 211 |
|
| 212 |
|
| 213 |
def get_dataset_statistics(dataset: MnistDataset) -> dict:
|
| 214 |
"""
|
| 215 |
Compute statistics for a dataset (useful for debugging).
|
| 216 |
+
|
| 217 |
Args:
|
| 218 |
dataset: MnistDataset instance
|
| 219 |
+
|
| 220 |
Returns:
|
| 221 |
Dictionary with statistics
|
| 222 |
"""
|
| 223 |
# Sample first image to check preprocessing
|
| 224 |
sample_img, sample_label = dataset[0]
|
| 225 |
+
|
| 226 |
# Count labels
|
| 227 |
from collections import Counter
|
| 228 |
label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))])
|
| 229 |
+
|
| 230 |
return {
|
| 231 |
'num_samples': len(dataset),
|
| 232 |
'sample_image_shape': tuple(sample_img.shape),
|
scripts/test_data_loader.py
CHANGED
|
@@ -19,24 +19,24 @@ def test_data_loader():
|
|
| 19 |
"""Test MNIST data loader with actual files."""
|
| 20 |
print("Testing MNIST Data Loader...")
|
| 21 |
print("-" * 50)
|
| 22 |
-
|
| 23 |
# Note: Files need to be uncompressed first
|
| 24 |
# If .gz files exist, uncompress with: gunzip data/raw/*.gz
|
| 25 |
-
|
| 26 |
base_path = Path(__file__).parent.parent / "data" / "raw"
|
| 27 |
-
|
| 28 |
# Try to find uncompressed files
|
| 29 |
train_images = base_path / "train-images.idx3-ubyte"
|
| 30 |
train_labels = base_path / "train-labels.idx1-ubyte"
|
| 31 |
test_images = base_path / "t10k-images.idx3-ubyte"
|
| 32 |
test_labels = base_path / "t10k-labels.idx1-ubyte"
|
| 33 |
-
|
| 34 |
# Check if files exist
|
| 35 |
missing_files = []
|
| 36 |
for filepath in [train_images, train_labels, test_images, test_labels]:
|
| 37 |
if not filepath.exists():
|
| 38 |
missing_files.append(str(filepath))
|
| 39 |
-
|
| 40 |
if missing_files:
|
| 41 |
print("⚠️ Missing uncompressed data files:")
|
| 42 |
for f in missing_files:
|
|
@@ -44,7 +44,7 @@ def test_data_loader():
|
|
| 44 |
print("\nTo uncompress .gz files, run:")
|
| 45 |
print(" cd data/raw && gunzip *.gz")
|
| 46 |
return False
|
| 47 |
-
|
| 48 |
try:
|
| 49 |
# Initialize loader
|
| 50 |
loader = MnistDataloader(
|
|
@@ -54,37 +54,37 @@ def test_data_loader():
|
|
| 54 |
str(test_labels)
|
| 55 |
)
|
| 56 |
print("✓ Loader initialized successfully")
|
| 57 |
-
|
| 58 |
# Load data
|
| 59 |
print("\nLoading MNIST dataset...")
|
| 60 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 61 |
-
|
| 62 |
# Verify shapes
|
| 63 |
print(f"\n✓ Training set: {len(x_train):,} images, {len(y_train):,} labels")
|
| 64 |
print(f"✓ Test set: {len(x_test):,} images, {len(y_test):,} labels")
|
| 65 |
-
|
| 66 |
# Convert first image to numpy array to check
|
| 67 |
import numpy as np
|
| 68 |
first_img = np.array(x_train[0])
|
| 69 |
print(f"\n✓ Image shape: {first_img.shape}")
|
| 70 |
print(f"✓ Image dtype: {first_img.dtype}")
|
| 71 |
print(f"✓ Label type: {type(y_train[0])}")
|
| 72 |
-
|
| 73 |
# Verify label range
|
| 74 |
unique_labels = set(y_train + y_test)
|
| 75 |
print(f"\n✓ Unique labels: {sorted(unique_labels)}")
|
| 76 |
-
|
| 77 |
# Verify pixel value range (convert to numpy for analysis)
|
| 78 |
sample_images = [np.array(img) for img in x_train[:100]]
|
| 79 |
max_val = max(img.max() for img in sample_images)
|
| 80 |
min_val = min(img.min() for img in sample_images)
|
| 81 |
print(f"✓ Pixel value range (sample): [{min_val}, {max_val}]")
|
| 82 |
-
|
| 83 |
print("\n" + "=" * 50)
|
| 84 |
print("✅ All tests passed!")
|
| 85 |
print("=" * 50)
|
| 86 |
return True
|
| 87 |
-
|
| 88 |
except Exception as e:
|
| 89 |
print(f"\n❌ Error: {e}")
|
| 90 |
import traceback
|
|
|
|
| 19 |
"""Test MNIST data loader with actual files."""
|
| 20 |
print("Testing MNIST Data Loader...")
|
| 21 |
print("-" * 50)
|
| 22 |
+
|
| 23 |
# Note: Files need to be uncompressed first
|
| 24 |
# If .gz files exist, uncompress with: gunzip data/raw/*.gz
|
| 25 |
+
|
| 26 |
base_path = Path(__file__).parent.parent / "data" / "raw"
|
| 27 |
+
|
| 28 |
# Try to find uncompressed files
|
| 29 |
train_images = base_path / "train-images.idx3-ubyte"
|
| 30 |
train_labels = base_path / "train-labels.idx1-ubyte"
|
| 31 |
test_images = base_path / "t10k-images.idx3-ubyte"
|
| 32 |
test_labels = base_path / "t10k-labels.idx1-ubyte"
|
| 33 |
+
|
| 34 |
# Check if files exist
|
| 35 |
missing_files = []
|
| 36 |
for filepath in [train_images, train_labels, test_images, test_labels]:
|
| 37 |
if not filepath.exists():
|
| 38 |
missing_files.append(str(filepath))
|
| 39 |
+
|
| 40 |
if missing_files:
|
| 41 |
print("⚠️ Missing uncompressed data files:")
|
| 42 |
for f in missing_files:
|
|
|
|
| 44 |
print("\nTo uncompress .gz files, run:")
|
| 45 |
print(" cd data/raw && gunzip *.gz")
|
| 46 |
return False
|
| 47 |
+
|
| 48 |
try:
|
| 49 |
# Initialize loader
|
| 50 |
loader = MnistDataloader(
|
|
|
|
| 54 |
str(test_labels)
|
| 55 |
)
|
| 56 |
print("✓ Loader initialized successfully")
|
| 57 |
+
|
| 58 |
# Load data
|
| 59 |
print("\nLoading MNIST dataset...")
|
| 60 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 61 |
+
|
| 62 |
# Verify shapes
|
| 63 |
print(f"\n✓ Training set: {len(x_train):,} images, {len(y_train):,} labels")
|
| 64 |
print(f"✓ Test set: {len(x_test):,} images, {len(y_test):,} labels")
|
| 65 |
+
|
| 66 |
# Convert first image to numpy array to check
|
| 67 |
import numpy as np
|
| 68 |
first_img = np.array(x_train[0])
|
| 69 |
print(f"\n✓ Image shape: {first_img.shape}")
|
| 70 |
print(f"✓ Image dtype: {first_img.dtype}")
|
| 71 |
print(f"✓ Label type: {type(y_train[0])}")
|
| 72 |
+
|
| 73 |
# Verify label range
|
| 74 |
unique_labels = set(y_train + y_test)
|
| 75 |
print(f"\n✓ Unique labels: {sorted(unique_labels)}")
|
| 76 |
+
|
| 77 |
# Verify pixel value range (convert to numpy for analysis)
|
| 78 |
sample_images = [np.array(img) for img in x_train[:100]]
|
| 79 |
max_val = max(img.max() for img in sample_images)
|
| 80 |
min_val = min(img.min() for img in sample_images)
|
| 81 |
print(f"✓ Pixel value range (sample): [{min_val}, {max_val}]")
|
| 82 |
+
|
| 83 |
print("\n" + "=" * 50)
|
| 84 |
print("✅ All tests passed!")
|
| 85 |
print("=" * 50)
|
| 86 |
return True
|
| 87 |
+
|
| 88 |
except Exception as e:
|
| 89 |
print(f"\n❌ Error: {e}")
|
| 90 |
import traceback
|
scripts/test_data_quality.py
CHANGED
|
@@ -21,32 +21,32 @@ def main():
|
|
| 21 |
"""Run quality checks and save report."""
|
| 22 |
print("Loading MNIST dataset...")
|
| 23 |
data_path = project_root / "data" / "raw"
|
| 24 |
-
|
| 25 |
loader = MnistDataloader(
|
| 26 |
str(data_path / "train-images.idx3-ubyte"),
|
| 27 |
str(data_path / "train-labels.idx1-ubyte"),
|
| 28 |
str(data_path / "t10k-images.idx3-ubyte"),
|
| 29 |
str(data_path / "t10k-labels.idx1-ubyte")
|
| 30 |
)
|
| 31 |
-
|
| 32 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 33 |
print("✓ Dataset loaded\n")
|
| 34 |
-
|
| 35 |
# Generate quality report
|
| 36 |
print("Running quality checks...")
|
| 37 |
report = generate_quality_report((x_train, y_train), (x_test, y_test))
|
| 38 |
print("✓ Quality checks complete\n")
|
| 39 |
-
|
| 40 |
# Print summary
|
| 41 |
print_quality_summary(report)
|
| 42 |
-
|
| 43 |
# Save report as JSON
|
| 44 |
output_path = project_root / "data" / "quality_report.json"
|
| 45 |
with open(output_path, 'w') as f:
|
| 46 |
json.dump(report, f, indent=2)
|
| 47 |
-
|
| 48 |
print(f"✓ Quality report saved to: {output_path}")
|
| 49 |
-
|
| 50 |
return 0 if report['summary']['all_checks_pass'] else 1
|
| 51 |
|
| 52 |
|
|
|
|
| 21 |
"""Run quality checks and save report."""
|
| 22 |
print("Loading MNIST dataset...")
|
| 23 |
data_path = project_root / "data" / "raw"
|
| 24 |
+
|
| 25 |
loader = MnistDataloader(
|
| 26 |
str(data_path / "train-images.idx3-ubyte"),
|
| 27 |
str(data_path / "train-labels.idx1-ubyte"),
|
| 28 |
str(data_path / "t10k-images.idx3-ubyte"),
|
| 29 |
str(data_path / "t10k-labels.idx1-ubyte")
|
| 30 |
)
|
| 31 |
+
|
| 32 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 33 |
print("✓ Dataset loaded\n")
|
| 34 |
+
|
| 35 |
# Generate quality report
|
| 36 |
print("Running quality checks...")
|
| 37 |
report = generate_quality_report((x_train, y_train), (x_test, y_test))
|
| 38 |
print("✓ Quality checks complete\n")
|
| 39 |
+
|
| 40 |
# Print summary
|
| 41 |
print_quality_summary(report)
|
| 42 |
+
|
| 43 |
# Save report as JSON
|
| 44 |
output_path = project_root / "data" / "quality_report.json"
|
| 45 |
with open(output_path, 'w') as f:
|
| 46 |
json.dump(report, f, indent=2)
|
| 47 |
+
|
| 48 |
print(f"✓ Quality report saved to: {output_path}")
|
| 49 |
+
|
| 50 |
return 0 if report['summary']['all_checks_pass'] else 1
|
| 51 |
|
| 52 |
|
scripts/test_preprocessing.py
CHANGED
|
@@ -32,7 +32,7 @@ def test_dataset():
|
|
| 32 |
print("=" * 60)
|
| 33 |
print("TEST 1: MnistDataset Initialization and Indexing")
|
| 34 |
print("=" * 60)
|
| 35 |
-
|
| 36 |
# Load data
|
| 37 |
data_path = project_root / "data" / "raw"
|
| 38 |
loader = MnistDataloader(
|
|
@@ -42,11 +42,11 @@ def test_dataset():
|
|
| 42 |
str(data_path / "t10k-labels.idx1-ubyte")
|
| 43 |
)
|
| 44 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 45 |
-
|
| 46 |
# Create dataset (small subset for testing)
|
| 47 |
dataset = MnistDataset(x_train[:1000], y_train[:1000])
|
| 48 |
print(f"✓ Dataset created with {len(dataset)} samples")
|
| 49 |
-
|
| 50 |
# Test __getitem__
|
| 51 |
image, label = dataset[0]
|
| 52 |
print("✓ Retrieved sample 0")
|
|
@@ -54,7 +54,7 @@ def test_dataset():
|
|
| 54 |
print(f" Image dtype: {image.dtype}")
|
| 55 |
print(f" Image range: [{image.min():.4f}, {image.max():.4f}]")
|
| 56 |
print(f" Label: {label.item()} (dtype: {label.dtype})")
|
| 57 |
-
|
| 58 |
# Verify normalization
|
| 59 |
assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}"
|
| 60 |
assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}"
|
|
@@ -63,7 +63,7 @@ def test_dataset():
|
|
| 63 |
assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}"
|
| 64 |
print("✓ All assertions passed")
|
| 65 |
print()
|
| 66 |
-
|
| 67 |
return dataset
|
| 68 |
|
| 69 |
|
|
@@ -72,10 +72,10 @@ def test_dataloader(dataset):
|
|
| 72 |
print("=" * 60)
|
| 73 |
print("TEST 2: DataLoader Batching")
|
| 74 |
print("=" * 60)
|
| 75 |
-
|
| 76 |
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
|
| 77 |
print("✓ DataLoader created (batch_size=32)")
|
| 78 |
-
|
| 79 |
# Get first batch
|
| 80 |
images, labels = next(iter(loader))
|
| 81 |
print("✓ Retrieved first batch")
|
|
@@ -83,7 +83,7 @@ def test_dataloader(dataset):
|
|
| 83 |
print(f" Batch labels shape: {labels.shape}")
|
| 84 |
print(f" Images dtype: {images.dtype}")
|
| 85 |
print(f" Labels dtype: {labels.dtype}")
|
| 86 |
-
|
| 87 |
# Verify batch dimensions
|
| 88 |
assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}"
|
| 89 |
assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}"
|
|
@@ -98,7 +98,7 @@ def test_train_val_split():
|
|
| 98 |
print("=" * 60)
|
| 99 |
print("TEST 3: Train/Validation Split")
|
| 100 |
print("=" * 60)
|
| 101 |
-
|
| 102 |
# Load data
|
| 103 |
data_path = project_root / "data" / "raw"
|
| 104 |
loader = MnistDataloader(
|
|
@@ -108,36 +108,38 @@ def test_train_val_split():
|
|
| 108 |
str(data_path / "t10k-labels.idx1-ubyte")
|
| 109 |
)
|
| 110 |
(x_train, y_train), _ = loader.load_data()
|
| 111 |
-
|
| 112 |
# Split
|
| 113 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 114 |
x_train, y_train, val_split=0.15, random_seed=42
|
| 115 |
)
|
| 116 |
-
|
| 117 |
print("✓ Split completed")
|
| 118 |
print(f" Original training: {len(x_train):,} samples")
|
| 119 |
-
|
| 120 |
-
print(f"
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
# Verify split ratio
|
| 123 |
expected_val_size = int(len(x_train) * 0.15)
|
| 124 |
assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect"
|
| 125 |
assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split"
|
| 126 |
print("✓ Split ratio correct")
|
| 127 |
-
|
| 128 |
# Check stratification (class balance)
|
| 129 |
from collections import Counter
|
| 130 |
train_counts = Counter(y_train_split)
|
| 131 |
val_counts = Counter(y_val)
|
| 132 |
-
|
| 133 |
print("\n Class distribution in training set:")
|
| 134 |
for digit in range(10):
|
| 135 |
print(f" Digit {digit}: {train_counts[digit]:>5,} samples")
|
| 136 |
-
|
| 137 |
print("\n Class distribution in validation set:")
|
| 138 |
for digit in range(10):
|
| 139 |
print(f" Digit {digit}: {val_counts[digit]:>4,} samples")
|
| 140 |
-
|
| 141 |
# Verify each class is present in both sets
|
| 142 |
assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train"
|
| 143 |
assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation"
|
|
@@ -150,7 +152,7 @@ def test_full_pipeline():
|
|
| 150 |
print("=" * 60)
|
| 151 |
print("TEST 4: Full Pipeline")
|
| 152 |
print("=" * 60)
|
| 153 |
-
|
| 154 |
# Load data
|
| 155 |
data_path = project_root / "data" / "raw"
|
| 156 |
loader = MnistDataloader(
|
|
@@ -161,19 +163,19 @@ def test_full_pipeline():
|
|
| 161 |
)
|
| 162 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 163 |
print("✓ Data loaded")
|
| 164 |
-
|
| 165 |
# Split train/val
|
| 166 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 167 |
x_train, y_train, val_split=0.15
|
| 168 |
)
|
| 169 |
print("✓ Train/val split completed")
|
| 170 |
-
|
| 171 |
# Create datasets
|
| 172 |
train_dataset = MnistDataset(x_train_split, y_train_split)
|
| 173 |
val_dataset = MnistDataset(x_val, y_val)
|
| 174 |
test_dataset = MnistDataset(x_test, y_test)
|
| 175 |
print("✓ Datasets created")
|
| 176 |
-
|
| 177 |
# Get statistics
|
| 178 |
train_stats = get_dataset_statistics(train_dataset)
|
| 179 |
print("\n Training dataset statistics:")
|
|
@@ -182,7 +184,7 @@ def test_full_pipeline():
|
|
| 182 |
print(f" Image dtype: {train_stats['sample_image_dtype']}")
|
| 183 |
print(f" Image range: {train_stats['sample_image_range']}")
|
| 184 |
print(f" Label dtype: {train_stats['sample_label_dtype']}")
|
| 185 |
-
|
| 186 |
# Create dataloaders
|
| 187 |
train_loader, val_loader = create_dataloaders(
|
| 188 |
train_dataset, val_dataset, batch_size=64, num_workers=0
|
|
@@ -192,12 +194,12 @@ def test_full_pipeline():
|
|
| 192 |
print(f" Training batches: {len(train_loader)}")
|
| 193 |
print(f" Validation batches: {len(val_loader)}")
|
| 194 |
print(f" Test batches: {len(test_loader)}")
|
| 195 |
-
|
| 196 |
# Test iteration
|
| 197 |
train_batch = next(iter(train_loader))
|
| 198 |
val_batch = next(iter(val_loader))
|
| 199 |
test_batch = next(iter(test_loader))
|
| 200 |
-
|
| 201 |
print("\n✓ Successfully iterated through all loaders")
|
| 202 |
print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}")
|
| 203 |
print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}")
|
|
@@ -209,19 +211,19 @@ def main():
|
|
| 209 |
"""Run all tests."""
|
| 210 |
print("\nTesting MNIST Preprocessing Pipeline")
|
| 211 |
print()
|
| 212 |
-
|
| 213 |
try:
|
| 214 |
dataset = test_dataset()
|
| 215 |
test_dataloader(dataset)
|
| 216 |
test_train_val_split()
|
| 217 |
test_full_pipeline()
|
| 218 |
-
|
| 219 |
print("=" * 60)
|
| 220 |
print("✅ ALL TESTS PASSED")
|
| 221 |
print("=" * 60)
|
| 222 |
print("\nPreprocessing pipeline is ready for model training!")
|
| 223 |
return 0
|
| 224 |
-
|
| 225 |
except Exception as e:
|
| 226 |
print(f"\n❌ TEST FAILED: {e}")
|
| 227 |
import traceback
|
|
|
|
| 32 |
print("=" * 60)
|
| 33 |
print("TEST 1: MnistDataset Initialization and Indexing")
|
| 34 |
print("=" * 60)
|
| 35 |
+
|
| 36 |
# Load data
|
| 37 |
data_path = project_root / "data" / "raw"
|
| 38 |
loader = MnistDataloader(
|
|
|
|
| 42 |
str(data_path / "t10k-labels.idx1-ubyte")
|
| 43 |
)
|
| 44 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 45 |
+
|
| 46 |
# Create dataset (small subset for testing)
|
| 47 |
dataset = MnistDataset(x_train[:1000], y_train[:1000])
|
| 48 |
print(f"✓ Dataset created with {len(dataset)} samples")
|
| 49 |
+
|
| 50 |
# Test __getitem__
|
| 51 |
image, label = dataset[0]
|
| 52 |
print("✓ Retrieved sample 0")
|
|
|
|
| 54 |
print(f" Image dtype: {image.dtype}")
|
| 55 |
print(f" Image range: [{image.min():.4f}, {image.max():.4f}]")
|
| 56 |
print(f" Label: {label.item()} (dtype: {label.dtype})")
|
| 57 |
+
|
| 58 |
# Verify normalization
|
| 59 |
assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}"
|
| 60 |
assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}"
|
|
|
|
| 63 |
assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}"
|
| 64 |
print("✓ All assertions passed")
|
| 65 |
print()
|
| 66 |
+
|
| 67 |
return dataset
|
| 68 |
|
| 69 |
|
|
|
|
| 72 |
print("=" * 60)
|
| 73 |
print("TEST 2: DataLoader Batching")
|
| 74 |
print("=" * 60)
|
| 75 |
+
|
| 76 |
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
|
| 77 |
print("✓ DataLoader created (batch_size=32)")
|
| 78 |
+
|
| 79 |
# Get first batch
|
| 80 |
images, labels = next(iter(loader))
|
| 81 |
print("✓ Retrieved first batch")
|
|
|
|
| 83 |
print(f" Batch labels shape: {labels.shape}")
|
| 84 |
print(f" Images dtype: {images.dtype}")
|
| 85 |
print(f" Labels dtype: {labels.dtype}")
|
| 86 |
+
|
| 87 |
# Verify batch dimensions
|
| 88 |
assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}"
|
| 89 |
assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}"
|
|
|
|
| 98 |
print("=" * 60)
|
| 99 |
print("TEST 3: Train/Validation Split")
|
| 100 |
print("=" * 60)
|
| 101 |
+
|
| 102 |
# Load data
|
| 103 |
data_path = project_root / "data" / "raw"
|
| 104 |
loader = MnistDataloader(
|
|
|
|
| 108 |
str(data_path / "t10k-labels.idx1-ubyte")
|
| 109 |
)
|
| 110 |
(x_train, y_train), _ = loader.load_data()
|
| 111 |
+
|
| 112 |
# Split
|
| 113 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 114 |
x_train, y_train, val_split=0.15, random_seed=42
|
| 115 |
)
|
| 116 |
+
|
| 117 |
print("✓ Split completed")
|
| 118 |
print(f" Original training: {len(x_train):,} samples")
|
| 119 |
+
train_pct = len(x_train_split) / len(x_train) * 100
|
| 120 |
+
print(f" New training: {len(x_train_split):,} samples ({train_pct:.1f}%)")
|
| 121 |
+
val_pct = len(x_val) / len(x_train) * 100
|
| 122 |
+
print(f" Validation: {len(x_val):,} samples ({val_pct:.1f}%)")
|
| 123 |
+
|
| 124 |
# Verify split ratio
|
| 125 |
expected_val_size = int(len(x_train) * 0.15)
|
| 126 |
assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect"
|
| 127 |
assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split"
|
| 128 |
print("✓ Split ratio correct")
|
| 129 |
+
|
| 130 |
# Check stratification (class balance)
|
| 131 |
from collections import Counter
|
| 132 |
train_counts = Counter(y_train_split)
|
| 133 |
val_counts = Counter(y_val)
|
| 134 |
+
|
| 135 |
print("\n Class distribution in training set:")
|
| 136 |
for digit in range(10):
|
| 137 |
print(f" Digit {digit}: {train_counts[digit]:>5,} samples")
|
| 138 |
+
|
| 139 |
print("\n Class distribution in validation set:")
|
| 140 |
for digit in range(10):
|
| 141 |
print(f" Digit {digit}: {val_counts[digit]:>4,} samples")
|
| 142 |
+
|
| 143 |
# Verify each class is present in both sets
|
| 144 |
assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train"
|
| 145 |
assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation"
|
|
|
|
| 152 |
print("=" * 60)
|
| 153 |
print("TEST 4: Full Pipeline")
|
| 154 |
print("=" * 60)
|
| 155 |
+
|
| 156 |
# Load data
|
| 157 |
data_path = project_root / "data" / "raw"
|
| 158 |
loader = MnistDataloader(
|
|
|
|
| 163 |
)
|
| 164 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 165 |
print("✓ Data loaded")
|
| 166 |
+
|
| 167 |
# Split train/val
|
| 168 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 169 |
x_train, y_train, val_split=0.15
|
| 170 |
)
|
| 171 |
print("✓ Train/val split completed")
|
| 172 |
+
|
| 173 |
# Create datasets
|
| 174 |
train_dataset = MnistDataset(x_train_split, y_train_split)
|
| 175 |
val_dataset = MnistDataset(x_val, y_val)
|
| 176 |
test_dataset = MnistDataset(x_test, y_test)
|
| 177 |
print("✓ Datasets created")
|
| 178 |
+
|
| 179 |
# Get statistics
|
| 180 |
train_stats = get_dataset_statistics(train_dataset)
|
| 181 |
print("\n Training dataset statistics:")
|
|
|
|
| 184 |
print(f" Image dtype: {train_stats['sample_image_dtype']}")
|
| 185 |
print(f" Image range: {train_stats['sample_image_range']}")
|
| 186 |
print(f" Label dtype: {train_stats['sample_label_dtype']}")
|
| 187 |
+
|
| 188 |
# Create dataloaders
|
| 189 |
train_loader, val_loader = create_dataloaders(
|
| 190 |
train_dataset, val_dataset, batch_size=64, num_workers=0
|
|
|
|
| 194 |
print(f" Training batches: {len(train_loader)}")
|
| 195 |
print(f" Validation batches: {len(val_loader)}")
|
| 196 |
print(f" Test batches: {len(test_loader)}")
|
| 197 |
+
|
| 198 |
# Test iteration
|
| 199 |
train_batch = next(iter(train_loader))
|
| 200 |
val_batch = next(iter(val_loader))
|
| 201 |
test_batch = next(iter(test_loader))
|
| 202 |
+
|
| 203 |
print("\n✓ Successfully iterated through all loaders")
|
| 204 |
print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}")
|
| 205 |
print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}")
|
|
|
|
| 211 |
"""Run all tests."""
|
| 212 |
print("\nTesting MNIST Preprocessing Pipeline")
|
| 213 |
print()
|
| 214 |
+
|
| 215 |
try:
|
| 216 |
dataset = test_dataset()
|
| 217 |
test_dataloader(dataset)
|
| 218 |
test_train_val_split()
|
| 219 |
test_full_pipeline()
|
| 220 |
+
|
| 221 |
print("=" * 60)
|
| 222 |
print("✅ ALL TESTS PASSED")
|
| 223 |
print("=" * 60)
|
| 224 |
print("\nPreprocessing pipeline is ready for model training!")
|
| 225 |
return 0
|
| 226 |
+
|
| 227 |
except Exception as e:
|
| 228 |
print(f"\n❌ TEST FAILED: {e}")
|
| 229 |
import traceback
|
scripts/test_train.py
CHANGED
|
@@ -28,7 +28,7 @@ def main():
|
|
| 28 |
print("Testing Training Pipeline")
|
| 29 |
print("=" * 60)
|
| 30 |
print()
|
| 31 |
-
|
| 32 |
# Load data (small subset for quick test)
|
| 33 |
print("1. Loading MNIST data...")
|
| 34 |
data_path = project_root / "data" / "raw"
|
|
@@ -41,7 +41,7 @@ def main():
|
|
| 41 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 42 |
print(f"✓ Loaded {len(x_train):,} training samples")
|
| 43 |
print()
|
| 44 |
-
|
| 45 |
# Use small subset for quick test (1000 samples)
|
| 46 |
print("2. Creating train/val split...")
|
| 47 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
|
@@ -50,13 +50,14 @@ def main():
|
|
| 50 |
print(f"✓ Train: {len(x_train_split)} samples")
|
| 51 |
print(f"✓ Val: {len(x_val)} samples")
|
| 52 |
print()
|
| 53 |
-
|
| 54 |
# Create datasets and loaders
|
| 55 |
print("3. Creating datasets and loaders...")
|
| 56 |
train_dataset = MnistDataset(x_train_split, y_train_split, transform=None)
|
| 57 |
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
train_loader, val_loader = create_dataloaders(
|
| 61 |
train_dataset, val_dataset, batch_size=32, num_workers=0
|
| 62 |
)
|
|
@@ -67,7 +68,7 @@ def main():
|
|
| 67 |
print(f"✓ Val batches: {len(val_loader)}")
|
| 68 |
print(f"✓ Test batches: {len(test_loader)}")
|
| 69 |
print()
|
| 70 |
-
|
| 71 |
# Create model
|
| 72 |
print("4. Creating model...")
|
| 73 |
model = BaselineCNN()
|
|
@@ -75,7 +76,7 @@ def main():
|
|
| 75 |
print(f"✓ Model: {model.__class__.__name__}")
|
| 76 |
print(f"✓ Device: {device}")
|
| 77 |
print()
|
| 78 |
-
|
| 79 |
# Train model (short run for testing)
|
| 80 |
print("5. Training model (3 epochs for testing)...")
|
| 81 |
print("-" * 60)
|
|
@@ -93,30 +94,30 @@ def main():
|
|
| 93 |
)
|
| 94 |
print("-" * 60)
|
| 95 |
print()
|
| 96 |
-
|
| 97 |
# Check checkpoints exist
|
| 98 |
print("6. Verifying checkpoints...")
|
| 99 |
best_model_path = project_root / "models" / "best_model.pt"
|
| 100 |
last_model_path = project_root / "models" / "last_model.pt"
|
| 101 |
-
|
| 102 |
assert best_model_path.exists(), "Best model checkpoint not found"
|
| 103 |
assert last_model_path.exists(), "Last model checkpoint not found"
|
| 104 |
print("✓ Best model saved")
|
| 105 |
print("✓ Last model saved")
|
| 106 |
print()
|
| 107 |
-
|
| 108 |
# Save history
|
| 109 |
print("7. Saving training history...")
|
| 110 |
history_path = project_root / "experiments" / "test_training_history.json"
|
| 111 |
save_training_history(history, str(history_path))
|
| 112 |
print()
|
| 113 |
-
|
| 114 |
# Evaluate on test set
|
| 115 |
print("8. Evaluating on test set...")
|
| 116 |
results = evaluate_model(model, test_loader, device=device)
|
| 117 |
print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
|
| 118 |
print()
|
| 119 |
-
|
| 120 |
# Print per-class metrics
|
| 121 |
print("Per-class metrics:")
|
| 122 |
report = results['classification_report']
|
|
@@ -128,7 +129,7 @@ def main():
|
|
| 128 |
f"Recall={metrics['recall']:.3f}, "
|
| 129 |
f"F1={metrics['f1-score']:.3f}")
|
| 130 |
print()
|
| 131 |
-
|
| 132 |
# Summary
|
| 133 |
print("=" * 60)
|
| 134 |
print("✅ ALL TESTS PASSED")
|
|
@@ -138,7 +139,7 @@ def main():
|
|
| 138 |
print(f"Test accuracy: {results['accuracy']:.2f}%")
|
| 139 |
print("\nNote: These are quick test results with limited data.")
|
| 140 |
print("For full training, use complete dataset and more epochs.")
|
| 141 |
-
|
| 142 |
return 0
|
| 143 |
|
| 144 |
|
|
|
|
| 28 |
print("Testing Training Pipeline")
|
| 29 |
print("=" * 60)
|
| 30 |
print()
|
| 31 |
+
|
| 32 |
# Load data (small subset for quick test)
|
| 33 |
print("1. Loading MNIST data...")
|
| 34 |
data_path = project_root / "data" / "raw"
|
|
|
|
| 41 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 42 |
print(f"✓ Loaded {len(x_train):,} training samples")
|
| 43 |
print()
|
| 44 |
+
|
| 45 |
# Use small subset for quick test (1000 samples)
|
| 46 |
print("2. Creating train/val split...")
|
| 47 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
|
|
|
| 50 |
print(f"✓ Train: {len(x_train_split)} samples")
|
| 51 |
print(f"✓ Val: {len(x_val)} samples")
|
| 52 |
print()
|
| 53 |
+
|
| 54 |
# Create datasets and loaders
|
| 55 |
print("3. Creating datasets and loaders...")
|
| 56 |
train_dataset = MnistDataset(x_train_split, y_train_split, transform=None)
|
| 57 |
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 58 |
+
# Small test set for quick validation
|
| 59 |
+
test_dataset = MnistDataset(x_test[:200], y_test[:200], transform=None)
|
| 60 |
+
|
| 61 |
train_loader, val_loader = create_dataloaders(
|
| 62 |
train_dataset, val_dataset, batch_size=32, num_workers=0
|
| 63 |
)
|
|
|
|
| 68 |
print(f"✓ Val batches: {len(val_loader)}")
|
| 69 |
print(f"✓ Test batches: {len(test_loader)}")
|
| 70 |
print()
|
| 71 |
+
|
| 72 |
# Create model
|
| 73 |
print("4. Creating model...")
|
| 74 |
model = BaselineCNN()
|
|
|
|
| 76 |
print(f"✓ Model: {model.__class__.__name__}")
|
| 77 |
print(f"✓ Device: {device}")
|
| 78 |
print()
|
| 79 |
+
|
| 80 |
# Train model (short run for testing)
|
| 81 |
print("5. Training model (3 epochs for testing)...")
|
| 82 |
print("-" * 60)
|
|
|
|
| 94 |
)
|
| 95 |
print("-" * 60)
|
| 96 |
print()
|
| 97 |
+
|
| 98 |
# Check checkpoints exist
|
| 99 |
print("6. Verifying checkpoints...")
|
| 100 |
best_model_path = project_root / "models" / "best_model.pt"
|
| 101 |
last_model_path = project_root / "models" / "last_model.pt"
|
| 102 |
+
|
| 103 |
assert best_model_path.exists(), "Best model checkpoint not found"
|
| 104 |
assert last_model_path.exists(), "Last model checkpoint not found"
|
| 105 |
print("✓ Best model saved")
|
| 106 |
print("✓ Last model saved")
|
| 107 |
print()
|
| 108 |
+
|
| 109 |
# Save history
|
| 110 |
print("7. Saving training history...")
|
| 111 |
history_path = project_root / "experiments" / "test_training_history.json"
|
| 112 |
save_training_history(history, str(history_path))
|
| 113 |
print()
|
| 114 |
+
|
| 115 |
# Evaluate on test set
|
| 116 |
print("8. Evaluating on test set...")
|
| 117 |
results = evaluate_model(model, test_loader, device=device)
|
| 118 |
print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
|
| 119 |
print()
|
| 120 |
+
|
| 121 |
# Print per-class metrics
|
| 122 |
print("Per-class metrics:")
|
| 123 |
report = results['classification_report']
|
|
|
|
| 129 |
f"Recall={metrics['recall']:.3f}, "
|
| 130 |
f"F1={metrics['f1-score']:.3f}")
|
| 131 |
print()
|
| 132 |
+
|
| 133 |
# Summary
|
| 134 |
print("=" * 60)
|
| 135 |
print("✅ ALL TESTS PASSED")
|
|
|
|
| 139 |
print(f"Test accuracy: {results['accuracy']:.2f}%")
|
| 140 |
print("\nNote: These are quick test results with limited data.")
|
| 141 |
print("For full training, use complete dataset and more epochs.")
|
| 142 |
+
|
| 143 |
return 0
|
| 144 |
|
| 145 |
|
scripts/train.py
CHANGED
|
@@ -12,7 +12,7 @@ Supports MLflow experiment tracking for reproducibility.
|
|
| 12 |
Usage:
|
| 13 |
from scripts.train import train_model
|
| 14 |
from scripts.models import BaselineCNN
|
| 15 |
-
|
| 16 |
model = BaselineCNN()
|
| 17 |
history = train_model(
|
| 18 |
model, train_loader, val_loader,
|
|
@@ -39,14 +39,14 @@ def train_epoch(
|
|
| 39 |
) -> Dict[str, float]:
|
| 40 |
"""
|
| 41 |
Train model for one epoch.
|
| 42 |
-
|
| 43 |
Args:
|
| 44 |
model: PyTorch model
|
| 45 |
train_loader: Training data loader
|
| 46 |
criterion: Loss function
|
| 47 |
optimizer: Optimizer
|
| 48 |
device: Device to train on ('cpu' or 'cuda')
|
| 49 |
-
|
| 50 |
Returns:
|
| 51 |
Dictionary with 'loss' and 'accuracy' metrics
|
| 52 |
"""
|
|
@@ -54,25 +54,25 @@ def train_epoch(
|
|
| 54 |
total_loss = 0.0
|
| 55 |
correct = 0
|
| 56 |
total = 0
|
| 57 |
-
|
| 58 |
for images, labels in train_loader:
|
| 59 |
images, labels = images.to(device), labels.to(device)
|
| 60 |
-
|
| 61 |
# Forward pass
|
| 62 |
optimizer.zero_grad()
|
| 63 |
outputs = model(images)
|
| 64 |
loss = criterion(outputs, labels)
|
| 65 |
-
|
| 66 |
# Backward pass
|
| 67 |
loss.backward()
|
| 68 |
optimizer.step()
|
| 69 |
-
|
| 70 |
# Track metrics
|
| 71 |
total_loss += loss.item()
|
| 72 |
_, predicted = outputs.max(1)
|
| 73 |
correct += predicted.eq(labels).sum().item()
|
| 74 |
total += labels.size(0)
|
| 75 |
-
|
| 76 |
return {
|
| 77 |
'loss': total_loss / len(train_loader),
|
| 78 |
'accuracy': 100.0 * correct / total
|
|
@@ -87,13 +87,13 @@ def validate(
|
|
| 87 |
) -> Dict[str, float]:
|
| 88 |
"""
|
| 89 |
Evaluate model on validation/test set.
|
| 90 |
-
|
| 91 |
Args:
|
| 92 |
model: PyTorch model
|
| 93 |
val_loader: Validation data loader
|
| 94 |
criterion: Loss function
|
| 95 |
device: Device to evaluate on
|
| 96 |
-
|
| 97 |
Returns:
|
| 98 |
Dictionary with 'loss' and 'accuracy' metrics
|
| 99 |
"""
|
|
@@ -101,21 +101,21 @@ def validate(
|
|
| 101 |
total_loss = 0.0
|
| 102 |
correct = 0
|
| 103 |
total = 0
|
| 104 |
-
|
| 105 |
with torch.no_grad():
|
| 106 |
for images, labels in val_loader:
|
| 107 |
images, labels = images.to(device), labels.to(device)
|
| 108 |
-
|
| 109 |
# Forward pass
|
| 110 |
outputs = model(images)
|
| 111 |
loss = criterion(outputs, labels)
|
| 112 |
-
|
| 113 |
# Track metrics
|
| 114 |
total_loss += loss.item()
|
| 115 |
_, predicted = outputs.max(1)
|
| 116 |
correct += predicted.eq(labels).sum().item()
|
| 117 |
total += labels.size(0)
|
| 118 |
-
|
| 119 |
return {
|
| 120 |
'loss': total_loss / len(val_loader),
|
| 121 |
'accuracy': 100.0 * correct / total
|
|
@@ -136,7 +136,7 @@ def train_model(
|
|
| 136 |
) -> Dict[str, List[float]]:
|
| 137 |
"""
|
| 138 |
Train model with early stopping and checkpointing.
|
| 139 |
-
|
| 140 |
Args:
|
| 141 |
model: PyTorch model
|
| 142 |
train_loader: Training data loader
|
|
@@ -148,7 +148,7 @@ def train_model(
|
|
| 148 |
device: Device to train on (auto-detect if None)
|
| 149 |
use_scheduler: Whether to use learning rate scheduler
|
| 150 |
verbose: Print training progress
|
| 151 |
-
|
| 152 |
Returns:
|
| 153 |
Dictionary with training history (losses and accuracies)
|
| 154 |
"""
|
|
@@ -156,30 +156,30 @@ def train_model(
|
|
| 156 |
if device is None:
|
| 157 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 158 |
model = model.to(device)
|
| 159 |
-
|
| 160 |
if verbose:
|
| 161 |
print(f"Training on device: {device}")
|
| 162 |
print(f"Model: {model.__class__.__name__}")
|
| 163 |
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 164 |
print()
|
| 165 |
-
|
| 166 |
# Setup training components
|
| 167 |
criterion = nn.CrossEntropyLoss()
|
| 168 |
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 169 |
-
|
| 170 |
# Learning rate scheduler
|
| 171 |
scheduler = None
|
| 172 |
if use_scheduler:
|
| 173 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 174 |
optimizer, mode='min', patience=3, factor=0.5, verbose=verbose
|
| 175 |
)
|
| 176 |
-
|
| 177 |
# Setup checkpointing
|
| 178 |
checkpoint_path = Path(checkpoint_dir)
|
| 179 |
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
| 180 |
best_model_path = checkpoint_path / 'best_model.pt'
|
| 181 |
last_model_path = checkpoint_path / 'last_model.pt'
|
| 182 |
-
|
| 183 |
# Training history
|
| 184 |
history = {
|
| 185 |
'train_loss': [],
|
|
@@ -188,26 +188,26 @@ def train_model(
|
|
| 188 |
'val_accuracy': [],
|
| 189 |
'learning_rate': []
|
| 190 |
}
|
| 191 |
-
|
| 192 |
# Early stopping setup
|
| 193 |
best_val_loss = float('inf')
|
| 194 |
epochs_without_improvement = 0
|
| 195 |
-
|
| 196 |
# Training loop
|
| 197 |
for epoch in range(num_epochs):
|
| 198 |
# Train
|
| 199 |
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
|
| 200 |
-
|
| 201 |
# Validate
|
| 202 |
val_metrics = validate(model, val_loader, criterion, device)
|
| 203 |
-
|
| 204 |
# Update history
|
| 205 |
history['train_loss'].append(train_metrics['loss'])
|
| 206 |
history['train_accuracy'].append(train_metrics['accuracy'])
|
| 207 |
history['val_loss'].append(val_metrics['loss'])
|
| 208 |
history['val_accuracy'].append(val_metrics['accuracy'])
|
| 209 |
history['learning_rate'].append(optimizer.param_groups[0]['lr'])
|
| 210 |
-
|
| 211 |
# Print progress
|
| 212 |
if verbose:
|
| 213 |
print(f"Epoch {epoch+1}/{num_epochs}")
|
|
@@ -217,11 +217,11 @@ def train_model(
|
|
| 217 |
f"Val Acc: {val_metrics['accuracy']:.2f}%")
|
| 218 |
print(f" LR: {optimizer.param_groups[0]['lr']:.6f}")
|
| 219 |
print()
|
| 220 |
-
|
| 221 |
# Learning rate scheduling
|
| 222 |
if scheduler is not None:
|
| 223 |
scheduler.step(val_metrics['loss'])
|
| 224 |
-
|
| 225 |
# Save best model
|
| 226 |
if val_metrics['loss'] < best_val_loss:
|
| 227 |
best_val_loss = val_metrics['loss']
|
|
@@ -238,14 +238,14 @@ def train_model(
|
|
| 238 |
print()
|
| 239 |
else:
|
| 240 |
epochs_without_improvement += 1
|
| 241 |
-
|
| 242 |
# Early stopping
|
| 243 |
if epochs_without_improvement >= patience:
|
| 244 |
if verbose:
|
| 245 |
print(f"Early stopping triggered after {epoch+1} epochs")
|
| 246 |
print(f"Best validation loss: {best_val_loss:.4f}")
|
| 247 |
break
|
| 248 |
-
|
| 249 |
# Save last model
|
| 250 |
torch.save({
|
| 251 |
'epoch': epoch,
|
|
@@ -254,12 +254,12 @@ def train_model(
|
|
| 254 |
'val_loss': val_metrics['loss'],
|
| 255 |
'val_accuracy': val_metrics['accuracy']
|
| 256 |
}, last_model_path)
|
| 257 |
-
|
| 258 |
if verbose:
|
| 259 |
print("Training complete!")
|
| 260 |
print(f"Best validation loss: {best_val_loss:.4f}")
|
| 261 |
print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%")
|
| 262 |
-
|
| 263 |
return history
|
| 264 |
|
| 265 |
|
|
@@ -271,13 +271,13 @@ def evaluate_model(
|
|
| 271 |
) -> Dict:
|
| 272 |
"""
|
| 273 |
Comprehensive model evaluation with per-class metrics.
|
| 274 |
-
|
| 275 |
Args:
|
| 276 |
model: Trained PyTorch model
|
| 277 |
test_loader: Test data loader
|
| 278 |
device: Device to evaluate on
|
| 279 |
class_names: List of class names (default: digits 0-9)
|
| 280 |
-
|
| 281 |
Returns:
|
| 282 |
Dictionary with metrics, predictions, and confusion matrix
|
| 283 |
"""
|
|
@@ -285,42 +285,42 @@ def evaluate_model(
|
|
| 285 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 286 |
model = model.to(device)
|
| 287 |
model.eval()
|
| 288 |
-
|
| 289 |
if class_names is None:
|
| 290 |
class_names = [str(i) for i in range(10)]
|
| 291 |
-
|
| 292 |
all_preds = []
|
| 293 |
all_labels = []
|
| 294 |
all_probs = []
|
| 295 |
-
|
| 296 |
with torch.no_grad():
|
| 297 |
for images, labels in test_loader:
|
| 298 |
images = images.to(device)
|
| 299 |
outputs = model(images)
|
| 300 |
probs = torch.softmax(outputs, dim=1)
|
| 301 |
_, predicted = outputs.max(1)
|
| 302 |
-
|
| 303 |
all_preds.extend(predicted.cpu().numpy())
|
| 304 |
all_labels.extend(labels.numpy())
|
| 305 |
all_probs.extend(probs.cpu().numpy())
|
| 306 |
-
|
| 307 |
all_preds = np.array(all_preds)
|
| 308 |
all_labels = np.array(all_labels)
|
| 309 |
all_probs = np.array(all_probs)
|
| 310 |
-
|
| 311 |
# Overall metrics
|
| 312 |
accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels)
|
| 313 |
-
|
| 314 |
# Classification report
|
| 315 |
report = classification_report(
|
| 316 |
all_labels, all_preds,
|
| 317 |
target_names=class_names,
|
| 318 |
output_dict=True
|
| 319 |
)
|
| 320 |
-
|
| 321 |
# Confusion matrix
|
| 322 |
conf_matrix = confusion_matrix(all_labels, all_preds)
|
| 323 |
-
|
| 324 |
return {
|
| 325 |
'accuracy': accuracy,
|
| 326 |
'classification_report': report,
|
|
@@ -334,7 +334,7 @@ def evaluate_model(
|
|
| 334 |
def save_training_history(history: Dict, filepath: str) -> None:
|
| 335 |
"""
|
| 336 |
Save training history to JSON file.
|
| 337 |
-
|
| 338 |
Args:
|
| 339 |
history: Training history dictionary
|
| 340 |
filepath: Path to save JSON file
|
|
@@ -348,11 +348,11 @@ def save_training_history(history: Dict, filepath: str) -> None:
|
|
| 348 |
def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]:
|
| 349 |
"""
|
| 350 |
Load model from checkpoint.
|
| 351 |
-
|
| 352 |
Args:
|
| 353 |
checkpoint_path: Path to checkpoint file
|
| 354 |
model: Model instance (for loading state dict)
|
| 355 |
-
|
| 356 |
Returns:
|
| 357 |
Tuple of (loaded_model, checkpoint_dict)
|
| 358 |
"""
|
|
|
|
| 12 |
Usage:
|
| 13 |
from scripts.train import train_model
|
| 14 |
from scripts.models import BaselineCNN
|
| 15 |
+
|
| 16 |
model = BaselineCNN()
|
| 17 |
history = train_model(
|
| 18 |
model, train_loader, val_loader,
|
|
|
|
| 39 |
) -> Dict[str, float]:
|
| 40 |
"""
|
| 41 |
Train model for one epoch.
|
| 42 |
+
|
| 43 |
Args:
|
| 44 |
model: PyTorch model
|
| 45 |
train_loader: Training data loader
|
| 46 |
criterion: Loss function
|
| 47 |
optimizer: Optimizer
|
| 48 |
device: Device to train on ('cpu' or 'cuda')
|
| 49 |
+
|
| 50 |
Returns:
|
| 51 |
Dictionary with 'loss' and 'accuracy' metrics
|
| 52 |
"""
|
|
|
|
| 54 |
total_loss = 0.0
|
| 55 |
correct = 0
|
| 56 |
total = 0
|
| 57 |
+
|
| 58 |
for images, labels in train_loader:
|
| 59 |
images, labels = images.to(device), labels.to(device)
|
| 60 |
+
|
| 61 |
# Forward pass
|
| 62 |
optimizer.zero_grad()
|
| 63 |
outputs = model(images)
|
| 64 |
loss = criterion(outputs, labels)
|
| 65 |
+
|
| 66 |
# Backward pass
|
| 67 |
loss.backward()
|
| 68 |
optimizer.step()
|
| 69 |
+
|
| 70 |
# Track metrics
|
| 71 |
total_loss += loss.item()
|
| 72 |
_, predicted = outputs.max(1)
|
| 73 |
correct += predicted.eq(labels).sum().item()
|
| 74 |
total += labels.size(0)
|
| 75 |
+
|
| 76 |
return {
|
| 77 |
'loss': total_loss / len(train_loader),
|
| 78 |
'accuracy': 100.0 * correct / total
|
|
|
|
| 87 |
) -> Dict[str, float]:
|
| 88 |
"""
|
| 89 |
Evaluate model on validation/test set.
|
| 90 |
+
|
| 91 |
Args:
|
| 92 |
model: PyTorch model
|
| 93 |
val_loader: Validation data loader
|
| 94 |
criterion: Loss function
|
| 95 |
device: Device to evaluate on
|
| 96 |
+
|
| 97 |
Returns:
|
| 98 |
Dictionary with 'loss' and 'accuracy' metrics
|
| 99 |
"""
|
|
|
|
| 101 |
total_loss = 0.0
|
| 102 |
correct = 0
|
| 103 |
total = 0
|
| 104 |
+
|
| 105 |
with torch.no_grad():
|
| 106 |
for images, labels in val_loader:
|
| 107 |
images, labels = images.to(device), labels.to(device)
|
| 108 |
+
|
| 109 |
# Forward pass
|
| 110 |
outputs = model(images)
|
| 111 |
loss = criterion(outputs, labels)
|
| 112 |
+
|
| 113 |
# Track metrics
|
| 114 |
total_loss += loss.item()
|
| 115 |
_, predicted = outputs.max(1)
|
| 116 |
correct += predicted.eq(labels).sum().item()
|
| 117 |
total += labels.size(0)
|
| 118 |
+
|
| 119 |
return {
|
| 120 |
'loss': total_loss / len(val_loader),
|
| 121 |
'accuracy': 100.0 * correct / total
|
|
|
|
| 136 |
) -> Dict[str, List[float]]:
|
| 137 |
"""
|
| 138 |
Train model with early stopping and checkpointing.
|
| 139 |
+
|
| 140 |
Args:
|
| 141 |
model: PyTorch model
|
| 142 |
train_loader: Training data loader
|
|
|
|
| 148 |
device: Device to train on (auto-detect if None)
|
| 149 |
use_scheduler: Whether to use learning rate scheduler
|
| 150 |
verbose: Print training progress
|
| 151 |
+
|
| 152 |
Returns:
|
| 153 |
Dictionary with training history (losses and accuracies)
|
| 154 |
"""
|
|
|
|
| 156 |
if device is None:
|
| 157 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 158 |
model = model.to(device)
|
| 159 |
+
|
| 160 |
if verbose:
|
| 161 |
print(f"Training on device: {device}")
|
| 162 |
print(f"Model: {model.__class__.__name__}")
|
| 163 |
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 164 |
print()
|
| 165 |
+
|
| 166 |
# Setup training components
|
| 167 |
criterion = nn.CrossEntropyLoss()
|
| 168 |
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 169 |
+
|
| 170 |
# Learning rate scheduler
|
| 171 |
scheduler = None
|
| 172 |
if use_scheduler:
|
| 173 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 174 |
optimizer, mode='min', patience=3, factor=0.5, verbose=verbose
|
| 175 |
)
|
| 176 |
+
|
| 177 |
# Setup checkpointing
|
| 178 |
checkpoint_path = Path(checkpoint_dir)
|
| 179 |
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
| 180 |
best_model_path = checkpoint_path / 'best_model.pt'
|
| 181 |
last_model_path = checkpoint_path / 'last_model.pt'
|
| 182 |
+
|
| 183 |
# Training history
|
| 184 |
history = {
|
| 185 |
'train_loss': [],
|
|
|
|
| 188 |
'val_accuracy': [],
|
| 189 |
'learning_rate': []
|
| 190 |
}
|
| 191 |
+
|
| 192 |
# Early stopping setup
|
| 193 |
best_val_loss = float('inf')
|
| 194 |
epochs_without_improvement = 0
|
| 195 |
+
|
| 196 |
# Training loop
|
| 197 |
for epoch in range(num_epochs):
|
| 198 |
# Train
|
| 199 |
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
|
| 200 |
+
|
| 201 |
# Validate
|
| 202 |
val_metrics = validate(model, val_loader, criterion, device)
|
| 203 |
+
|
| 204 |
# Update history
|
| 205 |
history['train_loss'].append(train_metrics['loss'])
|
| 206 |
history['train_accuracy'].append(train_metrics['accuracy'])
|
| 207 |
history['val_loss'].append(val_metrics['loss'])
|
| 208 |
history['val_accuracy'].append(val_metrics['accuracy'])
|
| 209 |
history['learning_rate'].append(optimizer.param_groups[0]['lr'])
|
| 210 |
+
|
| 211 |
# Print progress
|
| 212 |
if verbose:
|
| 213 |
print(f"Epoch {epoch+1}/{num_epochs}")
|
|
|
|
| 217 |
f"Val Acc: {val_metrics['accuracy']:.2f}%")
|
| 218 |
print(f" LR: {optimizer.param_groups[0]['lr']:.6f}")
|
| 219 |
print()
|
| 220 |
+
|
| 221 |
# Learning rate scheduling
|
| 222 |
if scheduler is not None:
|
| 223 |
scheduler.step(val_metrics['loss'])
|
| 224 |
+
|
| 225 |
# Save best model
|
| 226 |
if val_metrics['loss'] < best_val_loss:
|
| 227 |
best_val_loss = val_metrics['loss']
|
|
|
|
| 238 |
print()
|
| 239 |
else:
|
| 240 |
epochs_without_improvement += 1
|
| 241 |
+
|
| 242 |
# Early stopping
|
| 243 |
if epochs_without_improvement >= patience:
|
| 244 |
if verbose:
|
| 245 |
print(f"Early stopping triggered after {epoch+1} epochs")
|
| 246 |
print(f"Best validation loss: {best_val_loss:.4f}")
|
| 247 |
break
|
| 248 |
+
|
| 249 |
# Save last model
|
| 250 |
torch.save({
|
| 251 |
'epoch': epoch,
|
|
|
|
| 254 |
'val_loss': val_metrics['loss'],
|
| 255 |
'val_accuracy': val_metrics['accuracy']
|
| 256 |
}, last_model_path)
|
| 257 |
+
|
| 258 |
if verbose:
|
| 259 |
print("Training complete!")
|
| 260 |
print(f"Best validation loss: {best_val_loss:.4f}")
|
| 261 |
print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%")
|
| 262 |
+
|
| 263 |
return history
|
| 264 |
|
| 265 |
|
|
|
|
| 271 |
) -> Dict:
|
| 272 |
"""
|
| 273 |
Comprehensive model evaluation with per-class metrics.
|
| 274 |
+
|
| 275 |
Args:
|
| 276 |
model: Trained PyTorch model
|
| 277 |
test_loader: Test data loader
|
| 278 |
device: Device to evaluate on
|
| 279 |
class_names: List of class names (default: digits 0-9)
|
| 280 |
+
|
| 281 |
Returns:
|
| 282 |
Dictionary with metrics, predictions, and confusion matrix
|
| 283 |
"""
|
|
|
|
| 285 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 286 |
model = model.to(device)
|
| 287 |
model.eval()
|
| 288 |
+
|
| 289 |
if class_names is None:
|
| 290 |
class_names = [str(i) for i in range(10)]
|
| 291 |
+
|
| 292 |
all_preds = []
|
| 293 |
all_labels = []
|
| 294 |
all_probs = []
|
| 295 |
+
|
| 296 |
with torch.no_grad():
|
| 297 |
for images, labels in test_loader:
|
| 298 |
images = images.to(device)
|
| 299 |
outputs = model(images)
|
| 300 |
probs = torch.softmax(outputs, dim=1)
|
| 301 |
_, predicted = outputs.max(1)
|
| 302 |
+
|
| 303 |
all_preds.extend(predicted.cpu().numpy())
|
| 304 |
all_labels.extend(labels.numpy())
|
| 305 |
all_probs.extend(probs.cpu().numpy())
|
| 306 |
+
|
| 307 |
all_preds = np.array(all_preds)
|
| 308 |
all_labels = np.array(all_labels)
|
| 309 |
all_probs = np.array(all_probs)
|
| 310 |
+
|
| 311 |
# Overall metrics
|
| 312 |
accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels)
|
| 313 |
+
|
| 314 |
# Classification report
|
| 315 |
report = classification_report(
|
| 316 |
all_labels, all_preds,
|
| 317 |
target_names=class_names,
|
| 318 |
output_dict=True
|
| 319 |
)
|
| 320 |
+
|
| 321 |
# Confusion matrix
|
| 322 |
conf_matrix = confusion_matrix(all_labels, all_preds)
|
| 323 |
+
|
| 324 |
return {
|
| 325 |
'accuracy': accuracy,
|
| 326 |
'classification_report': report,
|
|
|
|
| 334 |
def save_training_history(history: Dict, filepath: str) -> None:
|
| 335 |
"""
|
| 336 |
Save training history to JSON file.
|
| 337 |
+
|
| 338 |
Args:
|
| 339 |
history: Training history dictionary
|
| 340 |
filepath: Path to save JSON file
|
|
|
|
| 348 |
def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]:
|
| 349 |
"""
|
| 350 |
Load model from checkpoint.
|
| 351 |
+
|
| 352 |
Args:
|
| 353 |
checkpoint_path: Path to checkpoint file
|
| 354 |
model: Model instance (for loading state dict)
|
| 355 |
+
|
| 356 |
Returns:
|
| 357 |
Tuple of (loaded_model, checkpoint_dict)
|
| 358 |
"""
|
scripts/train_baseline.py
CHANGED
|
@@ -34,9 +34,9 @@ from scripts.train import train_model, evaluate_model, save_training_history
|
|
| 34 |
def plot_training_history(history: dict, save_path: str):
|
| 35 |
"""Plot and save training history curves."""
|
| 36 |
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 37 |
-
|
| 38 |
epochs = range(1, len(history['train_loss']) + 1)
|
| 39 |
-
|
| 40 |
# Loss curves
|
| 41 |
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
|
| 42 |
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss')
|
|
@@ -45,7 +45,7 @@ def plot_training_history(history: dict, save_path: str):
|
|
| 45 |
axes[0, 0].set_title('Training and Validation Loss')
|
| 46 |
axes[0, 0].legend()
|
| 47 |
axes[0, 0].grid(True, alpha=0.3)
|
| 48 |
-
|
| 49 |
# Accuracy curves
|
| 50 |
axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc')
|
| 51 |
axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc')
|
|
@@ -54,7 +54,7 @@ def plot_training_history(history: dict, save_path: str):
|
|
| 54 |
axes[0, 1].set_title('Training and Validation Accuracy')
|
| 55 |
axes[0, 1].legend()
|
| 56 |
axes[0, 1].grid(True, alpha=0.3)
|
| 57 |
-
|
| 58 |
# Learning rate
|
| 59 |
axes[1, 0].plot(epochs, history['learning_rate'], 'g-')
|
| 60 |
axes[1, 0].set_xlabel('Epoch')
|
|
@@ -62,7 +62,7 @@ def plot_training_history(history: dict, save_path: str):
|
|
| 62 |
axes[1, 0].set_title('Learning Rate Schedule')
|
| 63 |
axes[1, 0].set_yscale('log')
|
| 64 |
axes[1, 0].grid(True, alpha=0.3)
|
| 65 |
-
|
| 66 |
# Loss difference (overfitting indicator)
|
| 67 |
loss_diff = np.array(history['val_loss']) - np.array(history['train_loss'])
|
| 68 |
axes[1, 1].plot(epochs, loss_diff, 'm-')
|
|
@@ -71,7 +71,7 @@ def plot_training_history(history: dict, save_path: str):
|
|
| 71 |
axes[1, 1].set_ylabel('Val Loss - Train Loss')
|
| 72 |
axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)')
|
| 73 |
axes[1, 1].grid(True, alpha=0.3)
|
| 74 |
-
|
| 75 |
plt.tight_layout()
|
| 76 |
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 77 |
print(f"Training curves saved to {save_path}")
|
|
@@ -81,10 +81,10 @@ def plot_training_history(history: dict, save_path: str):
|
|
| 81 |
def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
|
| 82 |
"""Plot and save confusion matrix."""
|
| 83 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 84 |
-
|
| 85 |
im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
|
| 86 |
ax.figure.colorbar(im, ax=ax)
|
| 87 |
-
|
| 88 |
# Labels
|
| 89 |
classes = list(range(10))
|
| 90 |
ax.set(xticks=np.arange(conf_matrix.shape[1]),
|
|
@@ -93,10 +93,10 @@ def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
|
|
| 93 |
title='Confusion Matrix - MNIST Digit Classification',
|
| 94 |
ylabel='True Label',
|
| 95 |
xlabel='Predicted Label')
|
| 96 |
-
|
| 97 |
# Rotate the tick labels
|
| 98 |
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
| 99 |
-
|
| 100 |
# Add text annotations
|
| 101 |
thresh = conf_matrix.max() / 2.
|
| 102 |
for i in range(conf_matrix.shape[0]):
|
|
@@ -104,7 +104,7 @@ def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
|
|
| 104 |
ax.text(j, i, format(conf_matrix[i, j], 'd'),
|
| 105 |
ha="center", va="center",
|
| 106 |
color="white" if conf_matrix[i, j] > thresh else "black")
|
| 107 |
-
|
| 108 |
plt.tight_layout()
|
| 109 |
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 110 |
print(f"Confusion matrix saved to {save_path}")
|
|
@@ -112,14 +112,30 @@ def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
|
|
| 112 |
|
| 113 |
|
| 114 |
def main():
|
| 115 |
-
parser = argparse.ArgumentParser(
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
parser.add_argument(
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
args = parser.parse_args()
|
| 122 |
-
|
| 123 |
print("=" * 60)
|
| 124 |
print("MNIST CNN Training - Baseline Model")
|
| 125 |
print("=" * 60)
|
|
@@ -130,7 +146,7 @@ def main():
|
|
| 130 |
print(f" Augmentation: {'Yes' if args.augment else 'No'}")
|
| 131 |
print(f" Early Stopping Patience: {args.patience}")
|
| 132 |
print()
|
| 133 |
-
|
| 134 |
# 1. Load data
|
| 135 |
print("1. Loading MNIST dataset...")
|
| 136 |
data_path = project_root / "data" / "raw"
|
|
@@ -144,7 +160,7 @@ def main():
|
|
| 144 |
print(f"✓ Loaded {len(x_train):,} training samples")
|
| 145 |
print(f"✓ Loaded {len(x_test):,} test samples")
|
| 146 |
print()
|
| 147 |
-
|
| 148 |
# 2. Train/val split
|
| 149 |
print("2. Creating train/validation split...")
|
| 150 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
|
@@ -154,15 +170,15 @@ def main():
|
|
| 154 |
print(f"✓ Validation: {len(x_val):,} samples")
|
| 155 |
print(f"✓ Test: {len(x_test):,} samples")
|
| 156 |
print()
|
| 157 |
-
|
| 158 |
# 3. Create datasets with optional augmentation
|
| 159 |
print("3. Creating datasets...")
|
| 160 |
augmentation = get_train_augmentation() if args.augment else None
|
| 161 |
-
|
| 162 |
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
|
| 163 |
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 164 |
test_dataset = MnistDataset(x_test, y_test, transform=None)
|
| 165 |
-
|
| 166 |
train_loader, val_loader = create_dataloaders(
|
| 167 |
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
|
| 168 |
)
|
|
@@ -173,13 +189,13 @@ def main():
|
|
| 173 |
print(f"✓ Val batches: {len(val_loader)}")
|
| 174 |
print(f"✓ Test batches: {len(test_loader)}")
|
| 175 |
print()
|
| 176 |
-
|
| 177 |
# 4. Create model
|
| 178 |
print("4. Creating model...")
|
| 179 |
model = BaselineCNN()
|
| 180 |
print(get_model_summary(model))
|
| 181 |
print()
|
| 182 |
-
|
| 183 |
# 5. Train model
|
| 184 |
print("5. Training model...")
|
| 185 |
print("-" * 60)
|
|
@@ -197,54 +213,69 @@ def main():
|
|
| 197 |
)
|
| 198 |
print("-" * 60)
|
| 199 |
print()
|
| 200 |
-
|
| 201 |
# 6. Load best model and evaluate
|
| 202 |
print("6. Evaluating best model on test set...")
|
| 203 |
checkpoint = torch.load('models/best_model.pt', map_location='cpu')
|
| 204 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 205 |
-
|
| 206 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 207 |
results = evaluate_model(model, test_loader, device=device)
|
| 208 |
-
|
| 209 |
print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
|
| 210 |
print()
|
| 211 |
-
|
| 212 |
# 7. Print detailed metrics
|
| 213 |
print("7. Per-class metrics:")
|
| 214 |
print("-" * 60)
|
| 215 |
report = results['classification_report']
|
| 216 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 217 |
print("-" * 60)
|
| 218 |
for digit in range(10):
|
| 219 |
if str(digit) in report:
|
| 220 |
metrics = report[str(digit)]
|
| 221 |
-
print(
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
| 224 |
print("-" * 60)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
print()
|
| 230 |
-
|
| 231 |
# 8. Save results
|
| 232 |
print("8. Saving results...")
|
| 233 |
results_dir = project_root / "results"
|
| 234 |
results_dir.mkdir(exist_ok=True)
|
| 235 |
-
|
| 236 |
# Save history
|
| 237 |
history_path = results_dir / "baseline_training_history.json"
|
| 238 |
save_training_history(history, str(history_path))
|
| 239 |
-
|
| 240 |
# Plot training curves
|
| 241 |
curves_path = results_dir / "baseline_training_curves.png"
|
| 242 |
plot_training_history(history, str(curves_path))
|
| 243 |
-
|
| 244 |
# Plot confusion matrix
|
| 245 |
conf_matrix_path = results_dir / "baseline_confusion_matrix.png"
|
| 246 |
plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path))
|
| 247 |
-
|
| 248 |
# Save evaluation metrics
|
| 249 |
metrics_path = results_dir / "baseline_metrics.json"
|
| 250 |
# Convert numpy arrays to lists for JSON serialization
|
|
@@ -269,7 +300,7 @@ def main():
|
|
| 269 |
json.dump(metrics_data, f, indent=2)
|
| 270 |
print(f"Evaluation metrics saved to {metrics_path}")
|
| 271 |
print()
|
| 272 |
-
|
| 273 |
# 9. Summary
|
| 274 |
print("=" * 60)
|
| 275 |
print("✅ TRAINING COMPLETE")
|
|
@@ -285,7 +316,7 @@ def main():
|
|
| 285 |
print(f" - Training curves: {curves_path}")
|
| 286 |
print(f" - Confusion matrix: {conf_matrix_path}")
|
| 287 |
print(f" - Metrics: {metrics_path}")
|
| 288 |
-
|
| 289 |
return 0
|
| 290 |
|
| 291 |
|
|
|
|
| 34 |
def plot_training_history(history: dict, save_path: str):
|
| 35 |
"""Plot and save training history curves."""
|
| 36 |
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 37 |
+
|
| 38 |
epochs = range(1, len(history['train_loss']) + 1)
|
| 39 |
+
|
| 40 |
# Loss curves
|
| 41 |
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
|
| 42 |
axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss')
|
|
|
|
| 45 |
axes[0, 0].set_title('Training and Validation Loss')
|
| 46 |
axes[0, 0].legend()
|
| 47 |
axes[0, 0].grid(True, alpha=0.3)
|
| 48 |
+
|
| 49 |
# Accuracy curves
|
| 50 |
axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc')
|
| 51 |
axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc')
|
|
|
|
| 54 |
axes[0, 1].set_title('Training and Validation Accuracy')
|
| 55 |
axes[0, 1].legend()
|
| 56 |
axes[0, 1].grid(True, alpha=0.3)
|
| 57 |
+
|
| 58 |
# Learning rate
|
| 59 |
axes[1, 0].plot(epochs, history['learning_rate'], 'g-')
|
| 60 |
axes[1, 0].set_xlabel('Epoch')
|
|
|
|
| 62 |
axes[1, 0].set_title('Learning Rate Schedule')
|
| 63 |
axes[1, 0].set_yscale('log')
|
| 64 |
axes[1, 0].grid(True, alpha=0.3)
|
| 65 |
+
|
| 66 |
# Loss difference (overfitting indicator)
|
| 67 |
loss_diff = np.array(history['val_loss']) - np.array(history['train_loss'])
|
| 68 |
axes[1, 1].plot(epochs, loss_diff, 'm-')
|
|
|
|
| 71 |
axes[1, 1].set_ylabel('Val Loss - Train Loss')
|
| 72 |
axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)')
|
| 73 |
axes[1, 1].grid(True, alpha=0.3)
|
| 74 |
+
|
| 75 |
plt.tight_layout()
|
| 76 |
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 77 |
print(f"Training curves saved to {save_path}")
|
|
|
|
| 81 |
def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
|
| 82 |
"""Plot and save confusion matrix."""
|
| 83 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 84 |
+
|
| 85 |
im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
|
| 86 |
ax.figure.colorbar(im, ax=ax)
|
| 87 |
+
|
| 88 |
# Labels
|
| 89 |
classes = list(range(10))
|
| 90 |
ax.set(xticks=np.arange(conf_matrix.shape[1]),
|
|
|
|
| 93 |
title='Confusion Matrix - MNIST Digit Classification',
|
| 94 |
ylabel='True Label',
|
| 95 |
xlabel='Predicted Label')
|
| 96 |
+
|
| 97 |
# Rotate the tick labels
|
| 98 |
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
| 99 |
+
|
| 100 |
# Add text annotations
|
| 101 |
thresh = conf_matrix.max() / 2.
|
| 102 |
for i in range(conf_matrix.shape[0]):
|
|
|
|
| 104 |
ax.text(j, i, format(conf_matrix[i, j], 'd'),
|
| 105 |
ha="center", va="center",
|
| 106 |
color="white" if conf_matrix[i, j] > thresh else "black")
|
| 107 |
+
|
| 108 |
plt.tight_layout()
|
| 109 |
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 110 |
print(f"Confusion matrix saved to {save_path}")
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def main():
|
| 115 |
+
parser = argparse.ArgumentParser(
|
| 116 |
+
description='Train baseline CNN on MNIST'
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
'--augment', action='store_true', help='Use data augmentation'
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
'--epochs', type=int, default=20,
|
| 123 |
+
help='Number of epochs (default: 20)'
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
'--lr', type=float, default=0.001,
|
| 127 |
+
help='Learning rate (default: 0.001)'
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
'--batch-size', type=int, default=64,
|
| 131 |
+
help='Batch size (default: 64)'
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
'--patience', type=int, default=5,
|
| 135 |
+
help='Early stopping patience (default: 5)'
|
| 136 |
+
)
|
| 137 |
args = parser.parse_args()
|
| 138 |
+
|
| 139 |
print("=" * 60)
|
| 140 |
print("MNIST CNN Training - Baseline Model")
|
| 141 |
print("=" * 60)
|
|
|
|
| 146 |
print(f" Augmentation: {'Yes' if args.augment else 'No'}")
|
| 147 |
print(f" Early Stopping Patience: {args.patience}")
|
| 148 |
print()
|
| 149 |
+
|
| 150 |
# 1. Load data
|
| 151 |
print("1. Loading MNIST dataset...")
|
| 152 |
data_path = project_root / "data" / "raw"
|
|
|
|
| 160 |
print(f"✓ Loaded {len(x_train):,} training samples")
|
| 161 |
print(f"✓ Loaded {len(x_test):,} test samples")
|
| 162 |
print()
|
| 163 |
+
|
| 164 |
# 2. Train/val split
|
| 165 |
print("2. Creating train/validation split...")
|
| 166 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
|
|
|
| 170 |
print(f"✓ Validation: {len(x_val):,} samples")
|
| 171 |
print(f"✓ Test: {len(x_test):,} samples")
|
| 172 |
print()
|
| 173 |
+
|
| 174 |
# 3. Create datasets with optional augmentation
|
| 175 |
print("3. Creating datasets...")
|
| 176 |
augmentation = get_train_augmentation() if args.augment else None
|
| 177 |
+
|
| 178 |
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
|
| 179 |
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 180 |
test_dataset = MnistDataset(x_test, y_test, transform=None)
|
| 181 |
+
|
| 182 |
train_loader, val_loader = create_dataloaders(
|
| 183 |
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
|
| 184 |
)
|
|
|
|
| 189 |
print(f"✓ Val batches: {len(val_loader)}")
|
| 190 |
print(f"✓ Test batches: {len(test_loader)}")
|
| 191 |
print()
|
| 192 |
+
|
| 193 |
# 4. Create model
|
| 194 |
print("4. Creating model...")
|
| 195 |
model = BaselineCNN()
|
| 196 |
print(get_model_summary(model))
|
| 197 |
print()
|
| 198 |
+
|
| 199 |
# 5. Train model
|
| 200 |
print("5. Training model...")
|
| 201 |
print("-" * 60)
|
|
|
|
| 213 |
)
|
| 214 |
print("-" * 60)
|
| 215 |
print()
|
| 216 |
+
|
| 217 |
# 6. Load best model and evaluate
|
| 218 |
print("6. Evaluating best model on test set...")
|
| 219 |
checkpoint = torch.load('models/best_model.pt', map_location='cpu')
|
| 220 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 221 |
+
|
| 222 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 223 |
results = evaluate_model(model, test_loader, device=device)
|
| 224 |
+
|
| 225 |
print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
|
| 226 |
print()
|
| 227 |
+
|
| 228 |
# 7. Print detailed metrics
|
| 229 |
print("7. Per-class metrics:")
|
| 230 |
print("-" * 60)
|
| 231 |
report = results['classification_report']
|
| 232 |
+
print(
|
| 233 |
+
f"{'Digit':<8} {'Precision':<12} {'Recall':<12} "
|
| 234 |
+
f"{'F1-Score':<12} {'Support':<10}"
|
| 235 |
+
)
|
| 236 |
print("-" * 60)
|
| 237 |
for digit in range(10):
|
| 238 |
if str(digit) in report:
|
| 239 |
metrics = report[str(digit)]
|
| 240 |
+
print(
|
| 241 |
+
f"{digit:<8} {metrics['precision']:<12.3f} "
|
| 242 |
+
f"{metrics['recall']:<12.3f} "
|
| 243 |
+
f"{metrics['f1-score']:<12.3f} {metrics['support']:<10}"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
print("-" * 60)
|
| 247 |
+
acc_line = (
|
| 248 |
+
f"{'Accuracy':<8} {' ':<12} {' ':<12} "
|
| 249 |
+
f"{report['accuracy']:<12.3f} "
|
| 250 |
+
f"{report['macro avg']['support']:<10}"
|
| 251 |
+
)
|
| 252 |
+
print(acc_line)
|
| 253 |
+
macro_line = (
|
| 254 |
+
f"{'Macro Avg':<8} {report['macro avg']['precision']:<12.3f} "
|
| 255 |
+
f"{report['macro avg']['recall']:<12.3f} "
|
| 256 |
+
f"{report['macro avg']['f1-score']:<12.3f} "
|
| 257 |
+
f"{report['macro avg']['support']:<10}"
|
| 258 |
+
)
|
| 259 |
+
print(macro_line)
|
| 260 |
print()
|
| 261 |
+
|
| 262 |
# 8. Save results
|
| 263 |
print("8. Saving results...")
|
| 264 |
results_dir = project_root / "results"
|
| 265 |
results_dir.mkdir(exist_ok=True)
|
| 266 |
+
|
| 267 |
# Save history
|
| 268 |
history_path = results_dir / "baseline_training_history.json"
|
| 269 |
save_training_history(history, str(history_path))
|
| 270 |
+
|
| 271 |
# Plot training curves
|
| 272 |
curves_path = results_dir / "baseline_training_curves.png"
|
| 273 |
plot_training_history(history, str(curves_path))
|
| 274 |
+
|
| 275 |
# Plot confusion matrix
|
| 276 |
conf_matrix_path = results_dir / "baseline_confusion_matrix.png"
|
| 277 |
plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path))
|
| 278 |
+
|
| 279 |
# Save evaluation metrics
|
| 280 |
metrics_path = results_dir / "baseline_metrics.json"
|
| 281 |
# Convert numpy arrays to lists for JSON serialization
|
|
|
|
| 300 |
json.dump(metrics_data, f, indent=2)
|
| 301 |
print(f"Evaluation metrics saved to {metrics_path}")
|
| 302 |
print()
|
| 303 |
+
|
| 304 |
# 9. Summary
|
| 305 |
print("=" * 60)
|
| 306 |
print("✅ TRAINING COMPLETE")
|
|
|
|
| 316 |
print(f" - Training curves: {curves_path}")
|
| 317 |
print(f" - Confusion matrix: {conf_matrix_path}")
|
| 318 |
print(f" - Metrics: {metrics_path}")
|
| 319 |
+
|
| 320 |
return 0
|
| 321 |
|
| 322 |
|
scripts/train_with_mlflow.py
CHANGED
|
@@ -50,7 +50,7 @@ def train_with_mlflow(
|
|
| 50 |
) -> dict:
|
| 51 |
"""
|
| 52 |
Train model with full MLflow tracking.
|
| 53 |
-
|
| 54 |
Args:
|
| 55 |
model: PyTorch model to train
|
| 56 |
train_loader: Training data loader
|
|
@@ -58,23 +58,23 @@ def train_with_mlflow(
|
|
| 58 |
test_loader: Test data loader
|
| 59 |
config: Training configuration dictionary
|
| 60 |
run_name: Optional name for MLflow run
|
| 61 |
-
|
| 62 |
Returns:
|
| 63 |
Training history dictionary
|
| 64 |
"""
|
| 65 |
device = config['device']
|
| 66 |
num_epochs = config['num_epochs']
|
| 67 |
learning_rate = config['learning_rate']
|
| 68 |
-
|
| 69 |
# Setup MLflow
|
| 70 |
setup_mlflow("mnist-digit-classification")
|
| 71 |
-
|
| 72 |
# Start MLflow run
|
| 73 |
with mlflow.start_run(run_name=run_name):
|
| 74 |
print("\n" + "="*70)
|
| 75 |
print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
|
| 76 |
print("="*70 + "\n")
|
| 77 |
-
|
| 78 |
# Log all configuration
|
| 79 |
print("Logging configuration to MLflow...")
|
| 80 |
log_training_config(config)
|
|
@@ -87,7 +87,7 @@ def train_with_mlflow(
|
|
| 87 |
augmentation=config.get('augmentation', False)
|
| 88 |
)
|
| 89 |
log_system_info()
|
| 90 |
-
|
| 91 |
# Log model architecture as text
|
| 92 |
total_params, trainable_params = count_parameters(model)
|
| 93 |
model_summary = f"""
|
|
@@ -100,14 +100,14 @@ Architecture:
|
|
| 100 |
{str(model)}
|
| 101 |
"""
|
| 102 |
mlflow.log_text(model_summary, "model_architecture.txt")
|
| 103 |
-
|
| 104 |
# Setup training
|
| 105 |
criterion = nn.CrossEntropyLoss()
|
| 106 |
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 107 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 108 |
optimizer, mode='min', patience=3, factor=0.5, verbose=True
|
| 109 |
)
|
| 110 |
-
|
| 111 |
# Training history
|
| 112 |
history = {
|
| 113 |
'train_loss': [],
|
|
@@ -116,37 +116,39 @@ Architecture:
|
|
| 116 |
'val_accuracy': [],
|
| 117 |
'learning_rate': []
|
| 118 |
}
|
| 119 |
-
|
| 120 |
best_val_loss = float('inf')
|
| 121 |
patience = 5
|
| 122 |
patience_counter = 0
|
| 123 |
-
|
| 124 |
print(f"\nStarting training for {num_epochs} epochs...")
|
| 125 |
print(f"Device: {device}")
|
| 126 |
total_p, _ = count_parameters(model)
|
| 127 |
print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
|
| 128 |
print("-" * 70)
|
| 129 |
-
|
| 130 |
for epoch in range(num_epochs):
|
| 131 |
# Train
|
| 132 |
-
train_metrics = train_epoch(
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
# Validate
|
| 135 |
val_metrics = validate(model, val_loader, criterion, device)
|
| 136 |
-
|
| 137 |
# Get current learning rate
|
| 138 |
current_lr = optimizer.param_groups[0]['lr']
|
| 139 |
-
|
| 140 |
# Update scheduler
|
| 141 |
scheduler.step(val_metrics['loss'])
|
| 142 |
-
|
| 143 |
# Save history
|
| 144 |
history['train_loss'].append(train_metrics['loss'])
|
| 145 |
history['train_accuracy'].append(train_metrics['accuracy'])
|
| 146 |
history['val_loss'].append(val_metrics['loss'])
|
| 147 |
history['val_accuracy'].append(val_metrics['accuracy'])
|
| 148 |
history['learning_rate'].append(current_lr)
|
| 149 |
-
|
| 150 |
# Log metrics to MLflow
|
| 151 |
mlflow_metrics = {
|
| 152 |
'train_loss': train_metrics['loss'],
|
|
@@ -157,19 +159,23 @@ Architecture:
|
|
| 157 |
'epoch': epoch + 1
|
| 158 |
}
|
| 159 |
log_metrics_epoch(mlflow_metrics, step=epoch)
|
| 160 |
-
|
| 161 |
# Print progress
|
| 162 |
-
print(
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# Save best model
|
| 168 |
if val_metrics['loss'] < best_val_loss:
|
| 169 |
best_val_loss = val_metrics['loss']
|
| 170 |
best_epoch = epoch + 1
|
| 171 |
patience_counter = 0
|
| 172 |
-
|
| 173 |
# Save checkpoint
|
| 174 |
checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
|
| 175 |
torch.save({
|
|
@@ -180,9 +186,9 @@ Architecture:
|
|
| 180 |
'val_loss': val_metrics['loss'],
|
| 181 |
'val_accuracy': val_metrics['accuracy'],
|
| 182 |
}, checkpoint_path)
|
| 183 |
-
|
| 184 |
print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
|
| 185 |
-
|
| 186 |
# Log model to MLflow
|
| 187 |
mlflow.pytorch.log_model(
|
| 188 |
model,
|
|
@@ -191,18 +197,18 @@ Architecture:
|
|
| 191 |
)
|
| 192 |
else:
|
| 193 |
patience_counter += 1
|
| 194 |
-
|
| 195 |
# Early stopping
|
| 196 |
if patience_counter >= patience:
|
| 197 |
print(f"\nEarly stopping triggered after {epoch+1} epochs")
|
| 198 |
mlflow.log_param("early_stopped", True)
|
| 199 |
mlflow.log_param("early_stop_epoch", epoch + 1)
|
| 200 |
break
|
| 201 |
-
|
| 202 |
print("-" * 70)
|
| 203 |
print("\nTraining complete!")
|
| 204 |
print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
|
| 205 |
-
|
| 206 |
# Log best metrics
|
| 207 |
mlflow.log_metrics({
|
| 208 |
'best_epoch': best_epoch,
|
|
@@ -210,24 +216,24 @@ Architecture:
|
|
| 210 |
'final_train_loss': history['train_loss'][-1],
|
| 211 |
'final_val_loss': history['val_loss'][-1]
|
| 212 |
})
|
| 213 |
-
|
| 214 |
# Evaluate on test set
|
| 215 |
print("\nEvaluating on test set...")
|
| 216 |
test_metrics = evaluate_model(model, test_loader, device)
|
| 217 |
-
|
| 218 |
test_accuracy = test_metrics['accuracy']
|
| 219 |
test_report = test_metrics['classification_report']
|
| 220 |
-
|
| 221 |
# Extract macro average metrics
|
| 222 |
test_precision = test_report['macro avg']['precision']
|
| 223 |
test_recall = test_report['macro avg']['recall']
|
| 224 |
test_f1_score = test_report['macro avg']['f1-score']
|
| 225 |
-
|
| 226 |
print(f"Test Accuracy: {test_accuracy:.2f}%")
|
| 227 |
print(f"Test Precision: {test_precision:.4f}")
|
| 228 |
print(f"Test Recall: {test_recall:.4f}")
|
| 229 |
print(f"Test F1-Score: {test_f1_score:.4f}")
|
| 230 |
-
|
| 231 |
# Log test metrics to MLflow
|
| 232 |
mlflow.log_metrics({
|
| 233 |
'test_accuracy': test_accuracy,
|
|
@@ -235,16 +241,16 @@ Architecture:
|
|
| 235 |
'test_recall': test_recall,
|
| 236 |
'test_f1_score': test_f1_score
|
| 237 |
})
|
| 238 |
-
|
| 239 |
# Save and log artifacts
|
| 240 |
print("\nSaving artifacts...")
|
| 241 |
-
|
| 242 |
# Save history
|
| 243 |
history_path = project_root / 'results' / 'mlflow_training_history.json'
|
| 244 |
history_path.parent.mkdir(exist_ok=True)
|
| 245 |
save_training_history(history, history_path)
|
| 246 |
log_artifact_path(str(history_path))
|
| 247 |
-
|
| 248 |
# Save test metrics
|
| 249 |
metrics_to_save = {
|
| 250 |
'test_accuracy': test_accuracy,
|
|
@@ -258,43 +264,63 @@ Architecture:
|
|
| 258 |
with open(metrics_path, 'w') as f:
|
| 259 |
json.dump(metrics_to_save, f, indent=2)
|
| 260 |
log_artifact_path(str(metrics_path))
|
| 261 |
-
|
| 262 |
# Save model checkpoint
|
| 263 |
log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
|
| 264 |
-
|
| 265 |
# Log confusion matrix as JSON
|
| 266 |
conf_matrix_dict = {
|
| 267 |
f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
|
| 268 |
for i in range(len(test_metrics['confusion_matrix']))
|
| 269 |
}
|
| 270 |
mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
|
| 271 |
-
|
| 272 |
# Log classification report
|
| 273 |
mlflow.log_dict(test_report, "classification_report.json")
|
| 274 |
-
|
| 275 |
print("\n✓ All artifacts logged to MLflow")
|
| 276 |
print("View results: mlflow ui --backend-store-uri file:./mlruns")
|
| 277 |
-
|
| 278 |
return history
|
| 279 |
|
| 280 |
|
| 281 |
def main():
|
| 282 |
-
parser = argparse.ArgumentParser(
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
parser.add_argument(
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
args = parser.parse_args()
|
| 291 |
-
|
| 292 |
# Set random seeds
|
| 293 |
torch.manual_seed(args.seed)
|
| 294 |
np.random.seed(args.seed)
|
| 295 |
if torch.cuda.is_available():
|
| 296 |
torch.cuda.manual_seed(args.seed)
|
| 297 |
-
|
| 298 |
# Configuration
|
| 299 |
config = {
|
| 300 |
'num_epochs': args.epochs,
|
|
@@ -307,10 +333,10 @@ def main():
|
|
| 307 |
'scheduler': 'ReduceLROnPlateau',
|
| 308 |
'early_stopping_patience': 5
|
| 309 |
}
|
| 310 |
-
|
| 311 |
print("Training Configuration:")
|
| 312 |
print(json.dumps(config, indent=2))
|
| 313 |
-
|
| 314 |
# Load MNIST data
|
| 315 |
print("\nLoading MNIST data...")
|
| 316 |
data_path = project_root / 'data' / 'raw'
|
|
@@ -321,18 +347,18 @@ def main():
|
|
| 321 |
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
|
| 322 |
)
|
| 323 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 324 |
-
|
| 325 |
# Split train/val
|
| 326 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 327 |
x_train, y_train, val_split=0.15, random_seed=args.seed
|
| 328 |
)
|
| 329 |
-
|
| 330 |
# Create datasets with optional augmentation
|
| 331 |
augmentation = get_train_augmentation() if args.augment else None
|
| 332 |
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
|
| 333 |
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 334 |
test_dataset = MnistDataset(x_test, y_test, transform=None)
|
| 335 |
-
|
| 336 |
# Create data loaders
|
| 337 |
train_loader, val_loader = create_dataloaders(
|
| 338 |
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
|
|
@@ -340,20 +366,20 @@ def main():
|
|
| 340 |
test_loader = torch.utils.data.DataLoader(
|
| 341 |
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
|
| 342 |
)
|
| 343 |
-
|
| 344 |
print(f"Train: {len(train_loader.dataset)} samples")
|
| 345 |
print(f"Val: {len(val_loader.dataset)} samples")
|
| 346 |
print(f"Test: {len(test_loader.dataset)} samples")
|
| 347 |
-
|
| 348 |
# Create model
|
| 349 |
model = BaselineCNN().to(config['device'])
|
| 350 |
-
|
| 351 |
# Train with MLflow
|
| 352 |
train_with_mlflow(
|
| 353 |
model, train_loader, val_loader, test_loader,
|
| 354 |
config, run_name=args.run_name
|
| 355 |
)
|
| 356 |
-
|
| 357 |
print("\n" + "="*70)
|
| 358 |
print("Training complete! View MLflow dashboard:")
|
| 359 |
print(" ./scripts/launch_mlflow_ui.sh")
|
|
|
|
| 50 |
) -> dict:
|
| 51 |
"""
|
| 52 |
Train model with full MLflow tracking.
|
| 53 |
+
|
| 54 |
Args:
|
| 55 |
model: PyTorch model to train
|
| 56 |
train_loader: Training data loader
|
|
|
|
| 58 |
test_loader: Test data loader
|
| 59 |
config: Training configuration dictionary
|
| 60 |
run_name: Optional name for MLflow run
|
| 61 |
+
|
| 62 |
Returns:
|
| 63 |
Training history dictionary
|
| 64 |
"""
|
| 65 |
device = config['device']
|
| 66 |
num_epochs = config['num_epochs']
|
| 67 |
learning_rate = config['learning_rate']
|
| 68 |
+
|
| 69 |
# Setup MLflow
|
| 70 |
setup_mlflow("mnist-digit-classification")
|
| 71 |
+
|
| 72 |
# Start MLflow run
|
| 73 |
with mlflow.start_run(run_name=run_name):
|
| 74 |
print("\n" + "="*70)
|
| 75 |
print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
|
| 76 |
print("="*70 + "\n")
|
| 77 |
+
|
| 78 |
# Log all configuration
|
| 79 |
print("Logging configuration to MLflow...")
|
| 80 |
log_training_config(config)
|
|
|
|
| 87 |
augmentation=config.get('augmentation', False)
|
| 88 |
)
|
| 89 |
log_system_info()
|
| 90 |
+
|
| 91 |
# Log model architecture as text
|
| 92 |
total_params, trainable_params = count_parameters(model)
|
| 93 |
model_summary = f"""
|
|
|
|
| 100 |
{str(model)}
|
| 101 |
"""
|
| 102 |
mlflow.log_text(model_summary, "model_architecture.txt")
|
| 103 |
+
|
| 104 |
# Setup training
|
| 105 |
criterion = nn.CrossEntropyLoss()
|
| 106 |
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 107 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 108 |
optimizer, mode='min', patience=3, factor=0.5, verbose=True
|
| 109 |
)
|
| 110 |
+
|
| 111 |
# Training history
|
| 112 |
history = {
|
| 113 |
'train_loss': [],
|
|
|
|
| 116 |
'val_accuracy': [],
|
| 117 |
'learning_rate': []
|
| 118 |
}
|
| 119 |
+
|
| 120 |
best_val_loss = float('inf')
|
| 121 |
patience = 5
|
| 122 |
patience_counter = 0
|
| 123 |
+
|
| 124 |
print(f"\nStarting training for {num_epochs} epochs...")
|
| 125 |
print(f"Device: {device}")
|
| 126 |
total_p, _ = count_parameters(model)
|
| 127 |
print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
|
| 128 |
print("-" * 70)
|
| 129 |
+
|
| 130 |
for epoch in range(num_epochs):
|
| 131 |
# Train
|
| 132 |
+
train_metrics = train_epoch(
|
| 133 |
+
model, train_loader, criterion, optimizer, device
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
# Validate
|
| 137 |
val_metrics = validate(model, val_loader, criterion, device)
|
| 138 |
+
|
| 139 |
# Get current learning rate
|
| 140 |
current_lr = optimizer.param_groups[0]['lr']
|
| 141 |
+
|
| 142 |
# Update scheduler
|
| 143 |
scheduler.step(val_metrics['loss'])
|
| 144 |
+
|
| 145 |
# Save history
|
| 146 |
history['train_loss'].append(train_metrics['loss'])
|
| 147 |
history['train_accuracy'].append(train_metrics['accuracy'])
|
| 148 |
history['val_loss'].append(val_metrics['loss'])
|
| 149 |
history['val_accuracy'].append(val_metrics['accuracy'])
|
| 150 |
history['learning_rate'].append(current_lr)
|
| 151 |
+
|
| 152 |
# Log metrics to MLflow
|
| 153 |
mlflow_metrics = {
|
| 154 |
'train_loss': train_metrics['loss'],
|
|
|
|
| 159 |
'epoch': epoch + 1
|
| 160 |
}
|
| 161 |
log_metrics_epoch(mlflow_metrics, step=epoch)
|
| 162 |
+
|
| 163 |
# Print progress
|
| 164 |
+
print(
|
| 165 |
+
f"Epoch {epoch+1}/{num_epochs} | "
|
| 166 |
+
f"Train Loss: {train_metrics['loss']:.4f} "
|
| 167 |
+
f"({train_metrics['accuracy']:.2f}%) | "
|
| 168 |
+
f"Val Loss: {val_metrics['loss']:.4f} "
|
| 169 |
+
f"({val_metrics['accuracy']:.2f}%) | "
|
| 170 |
+
f"LR: {current_lr:.6f}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
# Save best model
|
| 174 |
if val_metrics['loss'] < best_val_loss:
|
| 175 |
best_val_loss = val_metrics['loss']
|
| 176 |
best_epoch = epoch + 1
|
| 177 |
patience_counter = 0
|
| 178 |
+
|
| 179 |
# Save checkpoint
|
| 180 |
checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
|
| 181 |
torch.save({
|
|
|
|
| 186 |
'val_loss': val_metrics['loss'],
|
| 187 |
'val_accuracy': val_metrics['accuracy'],
|
| 188 |
}, checkpoint_path)
|
| 189 |
+
|
| 190 |
print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
|
| 191 |
+
|
| 192 |
# Log model to MLflow
|
| 193 |
mlflow.pytorch.log_model(
|
| 194 |
model,
|
|
|
|
| 197 |
)
|
| 198 |
else:
|
| 199 |
patience_counter += 1
|
| 200 |
+
|
| 201 |
# Early stopping
|
| 202 |
if patience_counter >= patience:
|
| 203 |
print(f"\nEarly stopping triggered after {epoch+1} epochs")
|
| 204 |
mlflow.log_param("early_stopped", True)
|
| 205 |
mlflow.log_param("early_stop_epoch", epoch + 1)
|
| 206 |
break
|
| 207 |
+
|
| 208 |
print("-" * 70)
|
| 209 |
print("\nTraining complete!")
|
| 210 |
print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
|
| 211 |
+
|
| 212 |
# Log best metrics
|
| 213 |
mlflow.log_metrics({
|
| 214 |
'best_epoch': best_epoch,
|
|
|
|
| 216 |
'final_train_loss': history['train_loss'][-1],
|
| 217 |
'final_val_loss': history['val_loss'][-1]
|
| 218 |
})
|
| 219 |
+
|
| 220 |
# Evaluate on test set
|
| 221 |
print("\nEvaluating on test set...")
|
| 222 |
test_metrics = evaluate_model(model, test_loader, device)
|
| 223 |
+
|
| 224 |
test_accuracy = test_metrics['accuracy']
|
| 225 |
test_report = test_metrics['classification_report']
|
| 226 |
+
|
| 227 |
# Extract macro average metrics
|
| 228 |
test_precision = test_report['macro avg']['precision']
|
| 229 |
test_recall = test_report['macro avg']['recall']
|
| 230 |
test_f1_score = test_report['macro avg']['f1-score']
|
| 231 |
+
|
| 232 |
print(f"Test Accuracy: {test_accuracy:.2f}%")
|
| 233 |
print(f"Test Precision: {test_precision:.4f}")
|
| 234 |
print(f"Test Recall: {test_recall:.4f}")
|
| 235 |
print(f"Test F1-Score: {test_f1_score:.4f}")
|
| 236 |
+
|
| 237 |
# Log test metrics to MLflow
|
| 238 |
mlflow.log_metrics({
|
| 239 |
'test_accuracy': test_accuracy,
|
|
|
|
| 241 |
'test_recall': test_recall,
|
| 242 |
'test_f1_score': test_f1_score
|
| 243 |
})
|
| 244 |
+
|
| 245 |
# Save and log artifacts
|
| 246 |
print("\nSaving artifacts...")
|
| 247 |
+
|
| 248 |
# Save history
|
| 249 |
history_path = project_root / 'results' / 'mlflow_training_history.json'
|
| 250 |
history_path.parent.mkdir(exist_ok=True)
|
| 251 |
save_training_history(history, history_path)
|
| 252 |
log_artifact_path(str(history_path))
|
| 253 |
+
|
| 254 |
# Save test metrics
|
| 255 |
metrics_to_save = {
|
| 256 |
'test_accuracy': test_accuracy,
|
|
|
|
| 264 |
with open(metrics_path, 'w') as f:
|
| 265 |
json.dump(metrics_to_save, f, indent=2)
|
| 266 |
log_artifact_path(str(metrics_path))
|
| 267 |
+
|
| 268 |
# Save model checkpoint
|
| 269 |
log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
|
| 270 |
+
|
| 271 |
# Log confusion matrix as JSON
|
| 272 |
conf_matrix_dict = {
|
| 273 |
f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
|
| 274 |
for i in range(len(test_metrics['confusion_matrix']))
|
| 275 |
}
|
| 276 |
mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
|
| 277 |
+
|
| 278 |
# Log classification report
|
| 279 |
mlflow.log_dict(test_report, "classification_report.json")
|
| 280 |
+
|
| 281 |
print("\n✓ All artifacts logged to MLflow")
|
| 282 |
print("View results: mlflow ui --backend-store-uri file:./mlruns")
|
| 283 |
+
|
| 284 |
return history
|
| 285 |
|
| 286 |
|
| 287 |
def main():
|
| 288 |
+
parser = argparse.ArgumentParser(
|
| 289 |
+
description='Train MNIST CNN with MLflow tracking'
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
'--epochs', type=int, default=20,
|
| 293 |
+
help='Number of epochs (default: 20)'
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
'--lr', type=float, default=0.001,
|
| 297 |
+
help='Learning rate (default: 0.001)'
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
'--batch-size', type=int, default=64,
|
| 301 |
+
help='Batch size (default: 64)'
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
'--augment', action='store_true',
|
| 305 |
+
help='Use data augmentation'
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument(
|
| 308 |
+
'--run-name', type=str, default=None,
|
| 309 |
+
help='MLflow run name'
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
'--seed', type=int, default=42,
|
| 313 |
+
help='Random seed (default: 42)'
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
args = parser.parse_args()
|
| 317 |
+
|
| 318 |
# Set random seeds
|
| 319 |
torch.manual_seed(args.seed)
|
| 320 |
np.random.seed(args.seed)
|
| 321 |
if torch.cuda.is_available():
|
| 322 |
torch.cuda.manual_seed(args.seed)
|
| 323 |
+
|
| 324 |
# Configuration
|
| 325 |
config = {
|
| 326 |
'num_epochs': args.epochs,
|
|
|
|
| 333 |
'scheduler': 'ReduceLROnPlateau',
|
| 334 |
'early_stopping_patience': 5
|
| 335 |
}
|
| 336 |
+
|
| 337 |
print("Training Configuration:")
|
| 338 |
print(json.dumps(config, indent=2))
|
| 339 |
+
|
| 340 |
# Load MNIST data
|
| 341 |
print("\nLoading MNIST data...")
|
| 342 |
data_path = project_root / 'data' / 'raw'
|
|
|
|
| 347 |
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
|
| 348 |
)
|
| 349 |
(x_train, y_train), (x_test, y_test) = loader.load_data()
|
| 350 |
+
|
| 351 |
# Split train/val
|
| 352 |
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
|
| 353 |
x_train, y_train, val_split=0.15, random_seed=args.seed
|
| 354 |
)
|
| 355 |
+
|
| 356 |
# Create datasets with optional augmentation
|
| 357 |
augmentation = get_train_augmentation() if args.augment else None
|
| 358 |
train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
|
| 359 |
val_dataset = MnistDataset(x_val, y_val, transform=None)
|
| 360 |
test_dataset = MnistDataset(x_test, y_test, transform=None)
|
| 361 |
+
|
| 362 |
# Create data loaders
|
| 363 |
train_loader, val_loader = create_dataloaders(
|
| 364 |
train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
|
|
|
|
| 366 |
test_loader = torch.utils.data.DataLoader(
|
| 367 |
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
|
| 368 |
)
|
| 369 |
+
|
| 370 |
print(f"Train: {len(train_loader.dataset)} samples")
|
| 371 |
print(f"Val: {len(val_loader.dataset)} samples")
|
| 372 |
print(f"Test: {len(test_loader.dataset)} samples")
|
| 373 |
+
|
| 374 |
# Create model
|
| 375 |
model = BaselineCNN().to(config['device'])
|
| 376 |
+
|
| 377 |
# Train with MLflow
|
| 378 |
train_with_mlflow(
|
| 379 |
model, train_loader, val_loader, test_loader,
|
| 380 |
config, run_name=args.run_name
|
| 381 |
)
|
| 382 |
+
|
| 383 |
print("\n" + "="*70)
|
| 384 |
print("Training complete! View MLflow dashboard:")
|
| 385 |
print(" ./scripts/launch_mlflow_ui.sh")
|