faizan commited on
Commit
e77a25a
·
1 Parent(s): 1c6a6f3

fix: resolve all 468 ruff linting errors (code quality enforcement complete)

Browse files

- Added ruff.toml config with E402 ignore for sys.path test pattern
- Fixed 3 E501 line length violations manually
- Auto-fixed 468 W293/W291 whitespace errors
- Codebase now passes all linting checks (0 errors)
- Ready for deployment (Phase 3)

app.py CHANGED
@@ -19,21 +19,21 @@ print(f"Model loaded on {classifier.device}")
19
  def predict_digit(image):
20
  """
21
  Predict digit from user-drawn image.
22
-
23
  Args:
24
  image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
25
-
26
  Returns:
27
  Tuple of (predicted_digit, confidence_text, probability_dict)
28
  """
29
  if image is None:
30
  return "Please draw a digit", "", {}
31
-
32
  # Handle different image formats from Gradio
33
  if isinstance(image, dict):
34
  # Sketchpad returns dict with 'composite' key
35
  image = image.get('composite', image)
36
-
37
  # Convert to PIL Image
38
  if isinstance(image, np.ndarray):
39
  # If RGB, convert to grayscale
@@ -43,29 +43,29 @@ def predict_digit(image):
43
  image = image[:, :, 3] # Use alpha channel
44
  else: # RGB
45
  image = np.mean(image, axis=2).astype(np.uint8)
46
-
47
  # Ensure values are in [0, 255]
48
  if image.max() <= 1.0:
49
  image = (image * 255).astype(np.uint8)
50
-
51
  pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
52
  else:
53
  pil_image = image
54
-
55
  # Get prediction
56
  result = classifier.predict(pil_image)
57
-
58
  # Format output
59
  digit = result['digit']
60
  confidence = result['confidence']
61
  probabilities = result['probabilities']
62
-
63
  # Create confidence text
64
  confidence_text = f"Confidence: {confidence*100:.1f}%"
65
-
66
  # Create probability dictionary for bar chart
67
  prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
68
-
69
  return digit, confidence_text, prob_dict
70
 
71
 
@@ -92,15 +92,15 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
92
  gr.Markdown(
93
  """
94
  # 🔢 Handwritten Digit Classifier
95
-
96
  Draw a digit (0-9) in the box below and the AI will predict
97
  what it is!
98
-
99
  This model uses a Convolutional Neural Network (CNN) trained on
100
  the MNIST dataset with **99.17% accuracy** on 10,000 test images.
101
  """
102
  )
103
-
104
  with gr.Row():
105
  with gr.Column(scale=1):
106
  # Sketchpad for drawing
@@ -116,12 +116,16 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
116
  height=280,
117
  width=280
118
  )
119
-
120
  # Buttons
121
  with gr.Row():
122
- predict_btn = gr.Button("🔍 Predict", variant="primary", scale=2)
123
- clear_btn = gr.ClearButton(components=[input_image], value="🗑️ Clear", scale=1)
124
-
 
 
 
 
125
  with gr.Column(scale=1):
126
  # Prediction output
127
  output_digit = gr.Textbox(
@@ -132,7 +136,7 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
132
  max_lines=1,
133
  interactive=False
134
  )
135
-
136
  output_confidence = gr.Textbox(
137
  label="Confidence",
138
  placeholder="",
@@ -141,13 +145,13 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
141
  max_lines=1,
142
  interactive=False
143
  )
144
-
145
  # Probability distribution
146
  output_probs = gr.Label(
147
  label="Probability Distribution",
148
  num_top_classes=10
149
  )
150
-
151
  # Example images section
152
  gr.Markdown("### 📝 Try these examples:")
153
  gr.Examples(
@@ -157,7 +161,7 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
157
  inputs=input_image,
158
  label="Example digits"
159
  )
160
-
161
  # Model info
162
  gr.Markdown(
163
  """
@@ -168,7 +172,7 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
168
  - **Training**: MNIST dataset (60,000 images)
169
  - **Test Accuracy**: 99.17%
170
  - **Framework**: PyTorch 2.0.1
171
-
172
  ### 💡 Tips for best results:
173
  - Draw the digit large and centered
174
  - Use a thick brush stroke
@@ -176,14 +180,14 @@ with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
176
  - Make sure the digit is clear and recognizable
177
  """
178
  )
179
-
180
  # Connect events
181
  predict_btn.click(
182
  fn=predict_digit,
183
  inputs=input_image,
184
  outputs=[output_digit, output_confidence, output_probs]
185
  )
186
-
187
  # Also predict on sketchpad change (real-time prediction)
188
  input_image.change(
189
  fn=predict_digit,
 
19
  def predict_digit(image):
20
  """
21
  Predict digit from user-drawn image.
22
+
23
  Args:
24
  image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
25
+
26
  Returns:
27
  Tuple of (predicted_digit, confidence_text, probability_dict)
28
  """
29
  if image is None:
30
  return "Please draw a digit", "", {}
31
+
32
  # Handle different image formats from Gradio
33
  if isinstance(image, dict):
34
  # Sketchpad returns dict with 'composite' key
35
  image = image.get('composite', image)
36
+
37
  # Convert to PIL Image
38
  if isinstance(image, np.ndarray):
39
  # If RGB, convert to grayscale
 
43
  image = image[:, :, 3] # Use alpha channel
44
  else: # RGB
45
  image = np.mean(image, axis=2).astype(np.uint8)
46
+
47
  # Ensure values are in [0, 255]
48
  if image.max() <= 1.0:
49
  image = (image * 255).astype(np.uint8)
50
+
51
  pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
52
  else:
53
  pil_image = image
54
+
55
  # Get prediction
56
  result = classifier.predict(pil_image)
57
+
58
  # Format output
59
  digit = result['digit']
60
  confidence = result['confidence']
61
  probabilities = result['probabilities']
62
+
63
  # Create confidence text
64
  confidence_text = f"Confidence: {confidence*100:.1f}%"
65
+
66
  # Create probability dictionary for bar chart
67
  prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
68
+
69
  return digit, confidence_text, prob_dict
70
 
71
 
 
92
  gr.Markdown(
93
  """
94
  # 🔢 Handwritten Digit Classifier
95
+
96
  Draw a digit (0-9) in the box below and the AI will predict
97
  what it is!
98
+
99
  This model uses a Convolutional Neural Network (CNN) trained on
100
  the MNIST dataset with **99.17% accuracy** on 10,000 test images.
101
  """
102
  )
103
+
104
  with gr.Row():
105
  with gr.Column(scale=1):
106
  # Sketchpad for drawing
 
116
  height=280,
117
  width=280
118
  )
119
+
120
  # Buttons
121
  with gr.Row():
122
+ predict_btn = gr.Button(
123
+ "🔍 Predict", variant="primary", scale=2
124
+ )
125
+ clear_btn = gr.ClearButton(
126
+ components=[input_image], value="🗑️ Clear", scale=1
127
+ )
128
+
129
  with gr.Column(scale=1):
130
  # Prediction output
131
  output_digit = gr.Textbox(
 
136
  max_lines=1,
137
  interactive=False
138
  )
139
+
140
  output_confidence = gr.Textbox(
141
  label="Confidence",
142
  placeholder="",
 
145
  max_lines=1,
146
  interactive=False
147
  )
148
+
149
  # Probability distribution
150
  output_probs = gr.Label(
151
  label="Probability Distribution",
152
  num_top_classes=10
153
  )
154
+
155
  # Example images section
156
  gr.Markdown("### 📝 Try these examples:")
157
  gr.Examples(
 
161
  inputs=input_image,
162
  label="Example digits"
163
  )
164
+
165
  # Model info
166
  gr.Markdown(
167
  """
 
172
  - **Training**: MNIST dataset (60,000 images)
173
  - **Test Accuracy**: 99.17%
174
  - **Framework**: PyTorch 2.0.1
175
+
176
  ### 💡 Tips for best results:
177
  - Draw the digit large and centered
178
  - Use a thick brush stroke
 
180
  - Make sure the digit is clear and recognizable
181
  """
182
  )
183
+
184
  # Connect events
185
  predict_btn.click(
186
  fn=predict_digit,
187
  inputs=input_image,
188
  outputs=[output_digit, output_confidence, output_probs]
189
  )
190
+
191
  # Also predict on sketchpad change (real-time prediction)
192
  input_image.change(
193
  fn=predict_digit,
ruff.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ruff configuration for MNIST project
2
+
3
+ select = ["E", "F", "W"]
4
+ ignore = []
5
+
6
+ # Line length
7
+ line-length = 88
8
+
9
+ [per-file-ignores]
10
+ # Ignore E402 (module level import not at top) for scripts that modify sys.path
11
+ "scripts/test_*.py" = ["E402"]
12
+ "scripts/train_*.py" = ["E402"]
13
+ "scripts/inference.py" = ["E402"]
scripts/augmentation.py CHANGED
@@ -11,10 +11,10 @@ These augmentations are applied on-the-fly during training for infinite variatio
11
  Usage:
12
  from scripts.augmentation import get_train_augmentation, get_val_augmentation
13
  from scripts.preprocessing import MnistDataset
14
-
15
  # Training with augmentation
16
  train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation())
17
-
18
  # Validation/test without augmentation
19
  val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation())
20
  """
@@ -26,14 +26,14 @@ import torch
26
  def get_train_augmentation():
27
  """
28
  Get augmentation pipeline for training data.
29
-
30
  Applies realistic transformations that preserve digit readability:
31
  - Rotation: ±15° (typical handwriting angle variation)
32
  - Translation: ±10% (off-center digits)
33
  - Scaling: 90-110% (size variation)
34
-
35
  Note: Normalization happens in MnistDataset, not here.
36
-
37
  Returns:
38
  torchvision.transforms.Compose: Composition of augmentation transforms
39
  """
@@ -44,7 +44,7 @@ def get_train_augmentation():
44
  interpolation=transforms.InterpolationMode.BILINEAR,
45
  fill=0 # Fill with black (background)
46
  ),
47
-
48
  # Random translation and scaling (no additional rotation)
49
  transforms.RandomAffine(
50
  degrees=0, # No rotation here (already done above)
@@ -59,10 +59,10 @@ def get_train_augmentation():
59
  def get_val_augmentation():
60
  """
61
  Get augmentation pipeline for validation/test data.
62
-
63
  No augmentation is applied - returns identity transform.
64
  This ensures fair evaluation on original unmodified data.
65
-
66
  Returns:
67
  None (no transforms)
68
  """
@@ -72,12 +72,12 @@ def get_val_augmentation():
72
  def get_mild_augmentation():
73
  """
74
  Get milder augmentation pipeline (conservative settings).
75
-
76
  Use this if standard augmentation is too aggressive:
77
  - Rotation: ±10° (reduced from ±15°)
78
  - Translation: ±5% (reduced from ±10%)
79
  - Scaling: 95-105% (reduced from 90-110%)
80
-
81
  Returns:
82
  torchvision.transforms.Compose: Mild augmentation transforms
83
  """
@@ -100,13 +100,13 @@ def get_mild_augmentation():
100
  def get_aggressive_augmentation():
101
  """
102
  Get aggressive augmentation pipeline (stronger settings).
103
-
104
  Use with caution - may distort digits beyond recognition:
105
  - Rotation: ±20°
106
  - Translation: ±15%
107
  - Scaling: 80-120%
108
  - Elastic deformation (optional, commented out)
109
-
110
  Returns:
111
  torchvision.transforms.Compose: Aggressive augmentation transforms
112
  """
@@ -131,14 +131,14 @@ def get_aggressive_augmentation():
131
  def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9):
132
  """
133
  Apply augmentation multiple times to visualize variations.
134
-
135
  Useful for debugging and understanding augmentation effects.
136
-
137
  Args:
138
  image: Single image tensor (1, 28, 28)
139
  transform: Augmentation transform to apply
140
  num_samples: Number of augmented versions to generate
141
-
142
  Returns:
143
  list: List of augmented image tensors
144
  """
@@ -149,7 +149,7 @@ def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9
149
  else:
150
  aug_img = image
151
  augmented_images.append(aug_img)
152
-
153
  return augmented_images
154
 
155
 
@@ -165,10 +165,10 @@ AUGMENTATION_PRESETS = {
165
  def get_augmentation_by_name(preset_name: str = 'standard'):
166
  """
167
  Get augmentation pipeline by preset name.
168
-
169
  Args:
170
  preset_name: One of ['none', 'mild', 'standard', 'aggressive']
171
-
172
  Returns:
173
  Augmentation transform or None
174
  """
@@ -177,6 +177,6 @@ def get_augmentation_by_name(preset_name: str = 'standard'):
177
  f"Unknown preset '{preset_name}'. "
178
  f"Choose from: {list(AUGMENTATION_PRESETS.keys())}"
179
  )
180
-
181
  preset = AUGMENTATION_PRESETS[preset_name]
182
  return preset() if callable(preset) else preset
 
11
  Usage:
12
  from scripts.augmentation import get_train_augmentation, get_val_augmentation
13
  from scripts.preprocessing import MnistDataset
14
+
15
  # Training with augmentation
16
  train_dataset = MnistDataset(x_train, y_train, transform=get_train_augmentation())
17
+
18
  # Validation/test without augmentation
19
  val_dataset = MnistDataset(x_val, y_val, transform=get_val_augmentation())
20
  """
 
26
  def get_train_augmentation():
27
  """
28
  Get augmentation pipeline for training data.
29
+
30
  Applies realistic transformations that preserve digit readability:
31
  - Rotation: ±15° (typical handwriting angle variation)
32
  - Translation: ±10% (off-center digits)
33
  - Scaling: 90-110% (size variation)
34
+
35
  Note: Normalization happens in MnistDataset, not here.
36
+
37
  Returns:
38
  torchvision.transforms.Compose: Composition of augmentation transforms
39
  """
 
44
  interpolation=transforms.InterpolationMode.BILINEAR,
45
  fill=0 # Fill with black (background)
46
  ),
47
+
48
  # Random translation and scaling (no additional rotation)
49
  transforms.RandomAffine(
50
  degrees=0, # No rotation here (already done above)
 
59
  def get_val_augmentation():
60
  """
61
  Get augmentation pipeline for validation/test data.
62
+
63
  No augmentation is applied - returns identity transform.
64
  This ensures fair evaluation on original unmodified data.
65
+
66
  Returns:
67
  None (no transforms)
68
  """
 
72
  def get_mild_augmentation():
73
  """
74
  Get milder augmentation pipeline (conservative settings).
75
+
76
  Use this if standard augmentation is too aggressive:
77
  - Rotation: ±10° (reduced from ±15°)
78
  - Translation: ±5% (reduced from ±10%)
79
  - Scaling: 95-105% (reduced from 90-110%)
80
+
81
  Returns:
82
  torchvision.transforms.Compose: Mild augmentation transforms
83
  """
 
100
  def get_aggressive_augmentation():
101
  """
102
  Get aggressive augmentation pipeline (stronger settings).
103
+
104
  Use with caution - may distort digits beyond recognition:
105
  - Rotation: ±20°
106
  - Translation: ±15%
107
  - Scaling: 80-120%
108
  - Elastic deformation (optional, commented out)
109
+
110
  Returns:
111
  torchvision.transforms.Compose: Aggressive augmentation transforms
112
  """
 
131
  def visualize_augmentations(image: torch.Tensor, transform, num_samples: int = 9):
132
  """
133
  Apply augmentation multiple times to visualize variations.
134
+
135
  Useful for debugging and understanding augmentation effects.
136
+
137
  Args:
138
  image: Single image tensor (1, 28, 28)
139
  transform: Augmentation transform to apply
140
  num_samples: Number of augmented versions to generate
141
+
142
  Returns:
143
  list: List of augmented image tensors
144
  """
 
149
  else:
150
  aug_img = image
151
  augmented_images.append(aug_img)
152
+
153
  return augmented_images
154
 
155
 
 
165
  def get_augmentation_by_name(preset_name: str = 'standard'):
166
  """
167
  Get augmentation pipeline by preset name.
168
+
169
  Args:
170
  preset_name: One of ['none', 'mild', 'standard', 'aggressive']
171
+
172
  Returns:
173
  Augmentation transform or None
174
  """
 
177
  f"Unknown preset '{preset_name}'. "
178
  f"Choose from: {list(AUGMENTATION_PRESETS.keys())}"
179
  )
180
+
181
  preset = AUGMENTATION_PRESETS[preset_name]
182
  return preset() if callable(preset) else preset
scripts/data_loader.py CHANGED
@@ -15,17 +15,17 @@ from numpy.typing import NDArray
15
  class MnistDataloader:
16
  """
17
  Load MNIST handwritten digit dataset from IDX binary files.
18
-
19
  The MNIST dataset uses a custom IDX binary format with magic numbers
20
  to identify image (2051) and label (2049) files.
21
-
22
  Attributes:
23
  training_images_filepath: Path to training images IDX file
24
  training_labels_filepath: Path to training labels IDX file
25
  test_images_filepath: Path to test images IDX file
26
  test_labels_filepath: Path to test labels IDX file
27
  """
28
-
29
  def __init__(
30
  self,
31
  training_images_filepath: str,
@@ -35,13 +35,13 @@ class MnistDataloader:
35
  ) -> None:
36
  """
37
  Initialize MNIST data loader with file paths.
38
-
39
  Args:
40
  training_images_filepath: Path to training images (.idx3-ubyte)
41
  training_labels_filepath: Path to training labels (.idx1-ubyte)
42
  test_images_filepath: Path to test images (.idx3-ubyte)
43
  test_labels_filepath: Path to test labels (.idx1-ubyte)
44
-
45
  Raises:
46
  FileNotFoundError: If any of the specified files don't exist
47
  """
@@ -49,7 +49,7 @@ class MnistDataloader:
49
  self.training_labels_filepath = training_labels_filepath
50
  self.test_images_filepath = test_images_filepath
51
  self.test_labels_filepath = test_labels_filepath
52
-
53
  # Verify files exist
54
  for filepath in [
55
  training_images_filepath,
@@ -59,7 +59,7 @@ class MnistDataloader:
59
  ]:
60
  if not Path(filepath).exists():
61
  raise FileNotFoundError(f"MNIST data file not found: {filepath}")
62
-
63
  def read_images_labels(
64
  self,
65
  images_filepath: str,
@@ -67,16 +67,16 @@ class MnistDataloader:
67
  ) -> Tuple[List[NDArray[np.uint8]], List[int]]:
68
  """
69
  Read images and labels from IDX binary files.
70
-
71
  Args:
72
  images_filepath: Path to images IDX file
73
  labels_filepath: Path to labels IDX file
74
-
75
  Returns:
76
  Tuple of (images, labels) where:
77
  - images: List of 28x28 numpy arrays (uint8)
78
  - labels: List of integer labels (0-9)
79
-
80
  Raises:
81
  ValueError: If magic numbers don't match expected values
82
  """
@@ -90,7 +90,7 @@ class MnistDataloader:
90
  f'Expected 2049, got {magic}'
91
  )
92
  labels = array("B", file.read())
93
-
94
  # Read images
95
  with open(images_filepath, 'rb') as file:
96
  magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
@@ -100,12 +100,12 @@ class MnistDataloader:
100
  f'Expected 2051, got {magic}'
101
  )
102
  image_data = array("B", file.read())
103
-
104
  # Convert to list of 28x28 arrays
105
  images = []
106
  for i in range(size):
107
  images.append([0] * rows * cols)
108
-
109
  for i in range(size):
110
  img = np.array(
111
  image_data[i * rows * cols:(i + 1) * rows * cols],
@@ -113,23 +113,23 @@ class MnistDataloader:
113
  )
114
  img = img.reshape(rows, cols)
115
  images[i][:] = img
116
-
117
  return images, list(labels)
118
-
119
  def load_data(self) -> Tuple[
120
  Tuple[List[NDArray[np.uint8]], List[int]],
121
  Tuple[List[NDArray[np.uint8]], List[int]]
122
  ]:
123
  """
124
  Load complete MNIST dataset (training and test sets).
125
-
126
  Returns:
127
  Tuple of ((x_train, y_train), (x_test, y_test)) where:
128
  - x_train: 60,000 training images (28x28 uint8 arrays)
129
  - y_train: 60,000 training labels (0-9)
130
  - x_test: 10,000 test images (28x28 uint8 arrays)
131
  - y_test: 10,000 test labels (0-9)
132
-
133
  Example:
134
  >>> loader = MnistDataloader(
135
  ... 'data/raw/train-images.idx3-ubyte',
 
15
  class MnistDataloader:
16
  """
17
  Load MNIST handwritten digit dataset from IDX binary files.
18
+
19
  The MNIST dataset uses a custom IDX binary format with magic numbers
20
  to identify image (2051) and label (2049) files.
21
+
22
  Attributes:
23
  training_images_filepath: Path to training images IDX file
24
  training_labels_filepath: Path to training labels IDX file
25
  test_images_filepath: Path to test images IDX file
26
  test_labels_filepath: Path to test labels IDX file
27
  """
28
+
29
  def __init__(
30
  self,
31
  training_images_filepath: str,
 
35
  ) -> None:
36
  """
37
  Initialize MNIST data loader with file paths.
38
+
39
  Args:
40
  training_images_filepath: Path to training images (.idx3-ubyte)
41
  training_labels_filepath: Path to training labels (.idx1-ubyte)
42
  test_images_filepath: Path to test images (.idx3-ubyte)
43
  test_labels_filepath: Path to test labels (.idx1-ubyte)
44
+
45
  Raises:
46
  FileNotFoundError: If any of the specified files don't exist
47
  """
 
49
  self.training_labels_filepath = training_labels_filepath
50
  self.test_images_filepath = test_images_filepath
51
  self.test_labels_filepath = test_labels_filepath
52
+
53
  # Verify files exist
54
  for filepath in [
55
  training_images_filepath,
 
59
  ]:
60
  if not Path(filepath).exists():
61
  raise FileNotFoundError(f"MNIST data file not found: {filepath}")
62
+
63
  def read_images_labels(
64
  self,
65
  images_filepath: str,
 
67
  ) -> Tuple[List[NDArray[np.uint8]], List[int]]:
68
  """
69
  Read images and labels from IDX binary files.
70
+
71
  Args:
72
  images_filepath: Path to images IDX file
73
  labels_filepath: Path to labels IDX file
74
+
75
  Returns:
76
  Tuple of (images, labels) where:
77
  - images: List of 28x28 numpy arrays (uint8)
78
  - labels: List of integer labels (0-9)
79
+
80
  Raises:
81
  ValueError: If magic numbers don't match expected values
82
  """
 
90
  f'Expected 2049, got {magic}'
91
  )
92
  labels = array("B", file.read())
93
+
94
  # Read images
95
  with open(images_filepath, 'rb') as file:
96
  magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
 
100
  f'Expected 2051, got {magic}'
101
  )
102
  image_data = array("B", file.read())
103
+
104
  # Convert to list of 28x28 arrays
105
  images = []
106
  for i in range(size):
107
  images.append([0] * rows * cols)
108
+
109
  for i in range(size):
110
  img = np.array(
111
  image_data[i * rows * cols:(i + 1) * rows * cols],
 
113
  )
114
  img = img.reshape(rows, cols)
115
  images[i][:] = img
116
+
117
  return images, list(labels)
118
+
119
  def load_data(self) -> Tuple[
120
  Tuple[List[NDArray[np.uint8]], List[int]],
121
  Tuple[List[NDArray[np.uint8]], List[int]]
122
  ]:
123
  """
124
  Load complete MNIST dataset (training and test sets).
125
+
126
  Returns:
127
  Tuple of ((x_train, y_train), (x_test, y_test)) where:
128
  - x_train: 60,000 training images (28x28 uint8 arrays)
129
  - y_train: 60,000 training labels (0-9)
130
  - x_test: 10,000 test images (28x28 uint8 arrays)
131
  - y_test: 10,000 test labels (0-9)
132
+
133
  Example:
134
  >>> loader = MnistDataloader(
135
  ... 'data/raw/train-images.idx3-ubyte',
scripts/data_quality.py CHANGED
@@ -10,9 +10,9 @@ This module provides functions to systematically check MNIST dataset quality:
10
 
11
  Usage:
12
  from scripts.data_quality import generate_quality_report
13
-
14
  report = generate_quality_report(
15
- (x_train, y_train),
16
  (x_test, y_test)
17
  )
18
  """
@@ -24,16 +24,16 @@ from collections import Counter
24
 
25
 
26
  def check_missing_values(
27
- images: List[NDArray[np.uint8]],
28
  labels: List[int]
29
  ) -> Dict[str, Any]:
30
  """
31
  Check for NaN or missing values in images and labels.
32
-
33
  Args:
34
  images: List of image arrays (each 28x28)
35
  labels: List of integer labels (0-9)
36
-
37
  Returns:
38
  dict: Contains 'has_missing_values', 'missing_count', 'details'
39
  """
@@ -43,12 +43,12 @@ def check_missing_values(
43
  img_array = np.array(img)
44
  if np.isnan(img_array).any():
45
  images_with_nan.append(idx)
46
-
47
  # Check labels for None
48
  labels_with_none = [idx for idx, label in enumerate(labels) if label is None]
49
-
50
  has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0
51
-
52
  return {
53
  'has_missing_values': has_missing,
54
  'missing_count': len(images_with_nan) + len(labels_with_none),
@@ -66,25 +66,25 @@ def check_missing_values(
66
  def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
67
  """
68
  Identify pixels outside valid range [0, 255] for uint8 images.
69
-
70
  Args:
71
  images: List of image arrays (each 28x28)
72
-
73
  Returns:
74
  dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details'
75
  """
76
  outlier_images = []
77
  pixel_min = 255
78
  pixel_max = 0
79
-
80
  for idx, img in enumerate(images):
81
  img_array = np.array(img)
82
  img_min = img_array.min()
83
  img_max = img_array.max()
84
-
85
  pixel_min = min(pixel_min, img_min)
86
  pixel_max = max(pixel_max, img_max)
87
-
88
  # Check for values outside [0, 255]
89
  if img_min < 0 or img_max > 255:
90
  outlier_images.append({
@@ -92,7 +92,7 @@ def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
92
  'min': int(img_min),
93
  'max': int(img_max)
94
  })
95
-
96
  return {
97
  'has_outliers': len(outlier_images) > 0,
98
  'outlier_count': len(outlier_images),
@@ -109,44 +109,44 @@ def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
109
  def check_class_balance(labels: List[int]) -> Dict[str, Any]:
110
  """
111
  Compute samples per class and calculate imbalance ratio.
112
-
113
  Imbalance ratio = max_count / min_count
114
  A ratio < 1.2 indicates good balance (< 20% difference)
115
-
116
  Args:
117
  labels: List of integer labels (0-9)
118
-
119
  Returns:
120
  dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details'
121
  """
122
  class_counts = Counter(labels)
123
-
124
  # Ensure all 10 digits present
125
  for digit in range(10):
126
  if digit not in class_counts:
127
  class_counts[digit] = 0
128
-
129
  counts = list(class_counts.values())
130
  max_count = max(counts)
131
  min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero
132
-
133
  imbalance_ratio = max_count / min_count
134
  is_balanced = imbalance_ratio < 1.2 # Less than 20% difference
135
-
136
  # Per-class percentages
137
  total = len(labels)
138
  class_percentages = {
139
- digit: (count / total) * 100
140
  for digit, count in class_counts.items()
141
  }
142
-
143
  return {
144
  'is_balanced': is_balanced,
145
  'imbalance_ratio': round(imbalance_ratio, 3),
146
  'threshold': 1.2,
147
  'class_counts': dict(sorted(class_counts.items())),
148
  'class_percentages': {
149
- k: round(v, 2)
150
  for k, v in sorted(class_percentages.items())
151
  },
152
  'details': {
@@ -162,16 +162,16 @@ def check_class_balance(labels: List[int]) -> Dict[str, Any]:
162
  def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
163
  """
164
  Verify all images are 28x28 as expected for MNIST.
165
-
166
  Args:
167
  images: List of image arrays
168
-
169
  Returns:
170
  dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details'
171
  """
172
  expected_shape = (28, 28)
173
  invalid_images = []
174
-
175
  for idx, img in enumerate(images):
176
  img_array = np.array(img)
177
  if img_array.shape != expected_shape:
@@ -179,7 +179,7 @@ def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
179
  'index': idx,
180
  'shape': img_array.shape
181
  })
182
-
183
  return {
184
  'all_correct_shape': len(invalid_images) == 0,
185
  'expected_shape': expected_shape,
@@ -194,16 +194,16 @@ def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
194
  def check_label_validity(labels: List[int]) -> Dict[str, Any]:
195
  """
196
  Verify all labels are valid integers in range [0, 9].
197
-
198
  Args:
199
  labels: List of labels
200
-
201
  Returns:
202
  dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details'
203
  """
204
  valid_range = set(range(10))
205
  invalid_labels = []
206
-
207
  for idx, label in enumerate(labels):
208
  if not isinstance(label, int) or label not in valid_range:
209
  invalid_labels.append({
@@ -211,9 +211,9 @@ def check_label_validity(labels: List[int]) -> Dict[str, Any]:
211
  'value': label,
212
  'type': type(label).__name__
213
  })
214
-
215
  unique_labels = sorted(set(labels))
216
-
217
  return {
218
  'all_valid': len(invalid_labels) == 0,
219
  'expected_range': [0, 9],
@@ -232,17 +232,17 @@ def generate_quality_report(
232
  ) -> Dict[str, Any]:
233
  """
234
  Run all quality checks on training and test sets.
235
-
236
  Args:
237
  train_data: Tuple of (train_images, train_labels)
238
  test_data: Tuple of (test_images, test_labels)
239
-
240
  Returns:
241
  dict: Comprehensive quality report with all check results
242
  """
243
  x_train, y_train = train_data
244
  x_test, y_test = test_data
245
-
246
  report = {
247
  'dataset_info': {
248
  'train_samples': len(x_train),
@@ -264,7 +264,7 @@ def generate_quality_report(
264
  'label_validity': check_label_validity(y_test)
265
  }
266
  }
267
-
268
  # Overall quality assessment
269
  all_checks_pass = (
270
  not report['training_set']['missing_values']['has_missing_values'] and
@@ -276,7 +276,7 @@ def generate_quality_report(
276
  report['test_set']['image_dimensions']['all_correct_shape'] and
277
  report['test_set']['label_validity']['all_valid']
278
  )
279
-
280
  report['summary'] = {
281
  'all_checks_pass': all_checks_pass,
282
  'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND',
@@ -284,22 +284,22 @@ def generate_quality_report(
284
  'test_balanced': report['test_set']['class_balance']['is_balanced'],
285
  'recommendations': _generate_recommendations(report)
286
  }
287
-
288
  return report
289
 
290
 
291
  def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
292
  """
293
  Generate recommendations based on quality check results.
294
-
295
  Args:
296
  report: Quality report dictionary
297
-
298
  Returns:
299
  list: List of recommendation strings
300
  """
301
  recommendations = []
302
-
303
  # Check missing values
304
  if report['training_set']['missing_values']['has_missing_values']:
305
  recommendations.append(
@@ -309,7 +309,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
309
  recommendations.append(
310
  "Remove or impute samples with missing values in test set"
311
  )
312
-
313
  # Check outliers
314
  if report['training_set']['outliers']['has_outliers']:
315
  recommendations.append(
@@ -319,7 +319,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
319
  recommendations.append(
320
  "Clip or remove test images with pixel values outside [0, 255]"
321
  )
322
-
323
  # Check class balance
324
  train_imbalance = report['training_set']['class_balance']['imbalance_ratio']
325
  if train_imbalance >= 1.5:
@@ -332,7 +332,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
332
  f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). "
333
  "Monitor per-class performance during training."
334
  )
335
-
336
  # Check dimensions
337
  if not report['training_set']['image_dimensions']['all_correct_shape']:
338
  recommendations.append(
@@ -342,7 +342,7 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
342
  recommendations.append(
343
  "Resize or remove test images with incorrect dimensions"
344
  )
345
-
346
  # Check labels
347
  if not report['training_set']['label_validity']['all_valid']:
348
  recommendations.append(
@@ -352,20 +352,20 @@ def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
352
  recommendations.append(
353
  "Remove or correct test samples with invalid labels"
354
  )
355
-
356
  # If all checks pass
357
  if not recommendations:
358
  recommendations.append(
359
  "Dataset is high quality - proceed with preprocessing and normalization"
360
  )
361
-
362
  return recommendations
363
 
364
 
365
  def print_quality_summary(report: Dict[str, Any]) -> None:
366
  """
367
  Print a human-readable summary of the quality report.
368
-
369
  Args:
370
  report: Quality report dictionary from generate_quality_report()
371
  """
@@ -373,7 +373,7 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
373
  print("MNIST DATASET QUALITY REPORT")
374
  print("=" * 60)
375
  print()
376
-
377
  # Dataset info
378
  info = report['dataset_info']
379
  print("Dataset Size:")
@@ -381,7 +381,7 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
381
  print(f" Test: {info['test_samples']:,} samples")
382
  print(f" Total: {info['total_samples']:,} samples")
383
  print()
384
-
385
  # Training set checks
386
  print("Training Set Quality Checks:")
387
  train = report['training_set']
@@ -394,7 +394,7 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
394
  train['class_balance']['is_balanced']
395
  )
396
  print()
397
-
398
  # Test set checks
399
  print("Test Set Quality Checks:")
400
  test = report['test_set']
@@ -407,14 +407,14 @@ def print_quality_summary(report: Dict[str, Any]) -> None:
407
  test['class_balance']['is_balanced']
408
  )
409
  print()
410
-
411
  # Overall summary
412
  summary = report['summary']
413
  print("=" * 60)
414
  print(f"Overall Quality: {summary['quality_rating']}")
415
  print("=" * 60)
416
  print()
417
-
418
  # Recommendations
419
  print("Recommendations:")
420
  for i, rec in enumerate(summary['recommendations'], 1):
 
10
 
11
  Usage:
12
  from scripts.data_quality import generate_quality_report
13
+
14
  report = generate_quality_report(
15
+ (x_train, y_train),
16
  (x_test, y_test)
17
  )
18
  """
 
24
 
25
 
26
  def check_missing_values(
27
+ images: List[NDArray[np.uint8]],
28
  labels: List[int]
29
  ) -> Dict[str, Any]:
30
  """
31
  Check for NaN or missing values in images and labels.
32
+
33
  Args:
34
  images: List of image arrays (each 28x28)
35
  labels: List of integer labels (0-9)
36
+
37
  Returns:
38
  dict: Contains 'has_missing_values', 'missing_count', 'details'
39
  """
 
43
  img_array = np.array(img)
44
  if np.isnan(img_array).any():
45
  images_with_nan.append(idx)
46
+
47
  # Check labels for None
48
  labels_with_none = [idx for idx, label in enumerate(labels) if label is None]
49
+
50
  has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0
51
+
52
  return {
53
  'has_missing_values': has_missing,
54
  'missing_count': len(images_with_nan) + len(labels_with_none),
 
66
  def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
67
  """
68
  Identify pixels outside valid range [0, 255] for uint8 images.
69
+
70
  Args:
71
  images: List of image arrays (each 28x28)
72
+
73
  Returns:
74
  dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details'
75
  """
76
  outlier_images = []
77
  pixel_min = 255
78
  pixel_max = 0
79
+
80
  for idx, img in enumerate(images):
81
  img_array = np.array(img)
82
  img_min = img_array.min()
83
  img_max = img_array.max()
84
+
85
  pixel_min = min(pixel_min, img_min)
86
  pixel_max = max(pixel_max, img_max)
87
+
88
  # Check for values outside [0, 255]
89
  if img_min < 0 or img_max > 255:
90
  outlier_images.append({
 
92
  'min': int(img_min),
93
  'max': int(img_max)
94
  })
95
+
96
  return {
97
  'has_outliers': len(outlier_images) > 0,
98
  'outlier_count': len(outlier_images),
 
109
  def check_class_balance(labels: List[int]) -> Dict[str, Any]:
110
  """
111
  Compute samples per class and calculate imbalance ratio.
112
+
113
  Imbalance ratio = max_count / min_count
114
  A ratio < 1.2 indicates good balance (< 20% difference)
115
+
116
  Args:
117
  labels: List of integer labels (0-9)
118
+
119
  Returns:
120
  dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details'
121
  """
122
  class_counts = Counter(labels)
123
+
124
  # Ensure all 10 digits present
125
  for digit in range(10):
126
  if digit not in class_counts:
127
  class_counts[digit] = 0
128
+
129
  counts = list(class_counts.values())
130
  max_count = max(counts)
131
  min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero
132
+
133
  imbalance_ratio = max_count / min_count
134
  is_balanced = imbalance_ratio < 1.2 # Less than 20% difference
135
+
136
  # Per-class percentages
137
  total = len(labels)
138
  class_percentages = {
139
+ digit: (count / total) * 100
140
  for digit, count in class_counts.items()
141
  }
142
+
143
  return {
144
  'is_balanced': is_balanced,
145
  'imbalance_ratio': round(imbalance_ratio, 3),
146
  'threshold': 1.2,
147
  'class_counts': dict(sorted(class_counts.items())),
148
  'class_percentages': {
149
+ k: round(v, 2)
150
  for k, v in sorted(class_percentages.items())
151
  },
152
  'details': {
 
162
  def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
163
  """
164
  Verify all images are 28x28 as expected for MNIST.
165
+
166
  Args:
167
  images: List of image arrays
168
+
169
  Returns:
170
  dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details'
171
  """
172
  expected_shape = (28, 28)
173
  invalid_images = []
174
+
175
  for idx, img in enumerate(images):
176
  img_array = np.array(img)
177
  if img_array.shape != expected_shape:
 
179
  'index': idx,
180
  'shape': img_array.shape
181
  })
182
+
183
  return {
184
  'all_correct_shape': len(invalid_images) == 0,
185
  'expected_shape': expected_shape,
 
194
  def check_label_validity(labels: List[int]) -> Dict[str, Any]:
195
  """
196
  Verify all labels are valid integers in range [0, 9].
197
+
198
  Args:
199
  labels: List of labels
200
+
201
  Returns:
202
  dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details'
203
  """
204
  valid_range = set(range(10))
205
  invalid_labels = []
206
+
207
  for idx, label in enumerate(labels):
208
  if not isinstance(label, int) or label not in valid_range:
209
  invalid_labels.append({
 
211
  'value': label,
212
  'type': type(label).__name__
213
  })
214
+
215
  unique_labels = sorted(set(labels))
216
+
217
  return {
218
  'all_valid': len(invalid_labels) == 0,
219
  'expected_range': [0, 9],
 
232
  ) -> Dict[str, Any]:
233
  """
234
  Run all quality checks on training and test sets.
235
+
236
  Args:
237
  train_data: Tuple of (train_images, train_labels)
238
  test_data: Tuple of (test_images, test_labels)
239
+
240
  Returns:
241
  dict: Comprehensive quality report with all check results
242
  """
243
  x_train, y_train = train_data
244
  x_test, y_test = test_data
245
+
246
  report = {
247
  'dataset_info': {
248
  'train_samples': len(x_train),
 
264
  'label_validity': check_label_validity(y_test)
265
  }
266
  }
267
+
268
  # Overall quality assessment
269
  all_checks_pass = (
270
  not report['training_set']['missing_values']['has_missing_values'] and
 
276
  report['test_set']['image_dimensions']['all_correct_shape'] and
277
  report['test_set']['label_validity']['all_valid']
278
  )
279
+
280
  report['summary'] = {
281
  'all_checks_pass': all_checks_pass,
282
  'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND',
 
284
  'test_balanced': report['test_set']['class_balance']['is_balanced'],
285
  'recommendations': _generate_recommendations(report)
286
  }
287
+
288
  return report
289
 
290
 
291
  def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
292
  """
293
  Generate recommendations based on quality check results.
294
+
295
  Args:
296
  report: Quality report dictionary
297
+
298
  Returns:
299
  list: List of recommendation strings
300
  """
301
  recommendations = []
302
+
303
  # Check missing values
304
  if report['training_set']['missing_values']['has_missing_values']:
305
  recommendations.append(
 
309
  recommendations.append(
310
  "Remove or impute samples with missing values in test set"
311
  )
312
+
313
  # Check outliers
314
  if report['training_set']['outliers']['has_outliers']:
315
  recommendations.append(
 
319
  recommendations.append(
320
  "Clip or remove test images with pixel values outside [0, 255]"
321
  )
322
+
323
  # Check class balance
324
  train_imbalance = report['training_set']['class_balance']['imbalance_ratio']
325
  if train_imbalance >= 1.5:
 
332
  f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). "
333
  "Monitor per-class performance during training."
334
  )
335
+
336
  # Check dimensions
337
  if not report['training_set']['image_dimensions']['all_correct_shape']:
338
  recommendations.append(
 
342
  recommendations.append(
343
  "Resize or remove test images with incorrect dimensions"
344
  )
345
+
346
  # Check labels
347
  if not report['training_set']['label_validity']['all_valid']:
348
  recommendations.append(
 
352
  recommendations.append(
353
  "Remove or correct test samples with invalid labels"
354
  )
355
+
356
  # If all checks pass
357
  if not recommendations:
358
  recommendations.append(
359
  "Dataset is high quality - proceed with preprocessing and normalization"
360
  )
361
+
362
  return recommendations
363
 
364
 
365
  def print_quality_summary(report: Dict[str, Any]) -> None:
366
  """
367
  Print a human-readable summary of the quality report.
368
+
369
  Args:
370
  report: Quality report dictionary from generate_quality_report()
371
  """
 
373
  print("MNIST DATASET QUALITY REPORT")
374
  print("=" * 60)
375
  print()
376
+
377
  # Dataset info
378
  info = report['dataset_info']
379
  print("Dataset Size:")
 
381
  print(f" Test: {info['test_samples']:,} samples")
382
  print(f" Total: {info['total_samples']:,} samples")
383
  print()
384
+
385
  # Training set checks
386
  print("Training Set Quality Checks:")
387
  train = report['training_set']
 
394
  train['class_balance']['is_balanced']
395
  )
396
  print()
397
+
398
  # Test set checks
399
  print("Test Set Quality Checks:")
400
  test = report['test_set']
 
407
  test['class_balance']['is_balanced']
408
  )
409
  print()
410
+
411
  # Overall summary
412
  summary = report['summary']
413
  print("=" * 60)
414
  print(f"Overall Quality: {summary['quality_rating']}")
415
  print("=" * 60)
416
  print()
417
+
418
  # Recommendations
419
  print("Recommendations:")
420
  for i, rec in enumerate(summary['recommendations'], 1):
scripts/inference.py CHANGED
@@ -14,103 +14,103 @@ from typing import Union, Dict
14
 
15
  class DigitClassifier:
16
  """Production inference wrapper for MNIST digit classifier."""
17
-
18
  def __init__(self, model_path: str, device: str = None):
19
  """
20
  Initialize the digit classifier.
21
-
22
  Args:
23
  model_path: Path to model checkpoint (.pt file)
24
- device: Device to run inference on ('cuda' or 'cpu').
25
  If None, auto-detects CUDA availability.
26
  """
27
  if device is None:
28
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
  else:
30
  self.device = device
31
-
32
  self.model_path = Path(model_path)
33
  if not self.model_path.exists():
34
  raise FileNotFoundError(f"Model not found at {model_path}")
35
-
36
  self.model = self._load_model()
37
  self.model.eval()
38
-
39
  # Normalization values (same as training)
40
  self.mean = 0.1307
41
  self.std = 0.3081
42
-
43
  def _load_model(self) -> torch.nn.Module:
44
  """Load model from checkpoint."""
45
  from scripts.models import BaselineCNN
46
-
47
  model = BaselineCNN()
48
-
49
  # Load checkpoint
50
  checkpoint = torch.load(self.model_path, map_location=self.device)
51
-
52
  # Handle different checkpoint formats
53
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
54
  model.load_state_dict(checkpoint['model_state_dict'])
55
  else:
56
  model.load_state_dict(checkpoint)
57
-
58
  return model.to(self.device)
59
-
60
  def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
61
  """
62
  Preprocess image for model input.
63
-
64
  Handles:
65
  - RGB to grayscale conversion
66
  - Resizing to 28x28
67
  - Normalization
68
  - Inversion if needed (white digit on black background)
69
-
70
  Args:
71
  image: PIL Image or numpy array
72
-
73
  Returns:
74
  Preprocessed tensor of shape (1, 1, 28, 28)
75
  """
76
  # Convert numpy array to PIL Image if needed
77
  if isinstance(image, np.ndarray):
78
  image = Image.fromarray(image)
79
-
80
  # Convert to grayscale if RGB
81
  if image.mode != 'L':
82
  image = image.convert('L')
83
-
84
  # Resize to 28x28 if needed
85
  if image.size != (28, 28):
86
  image = image.resize((28, 28), Image.Resampling.LANCZOS)
87
-
88
  # Convert to numpy array
89
  img_array = np.array(image).astype(np.float32)
90
-
91
  # Normalize to [0, 1]
92
  img_array = img_array / 255.0
93
-
94
  # Check if inversion is needed (MNIST is white digit on black background)
95
  # If most pixels are bright, it's likely a black digit on white background
96
  if img_array.mean() > 0.5:
97
  img_array = 1.0 - img_array
98
-
99
  # Apply normalization (same as training)
100
  img_array = (img_array - self.mean) / self.std
101
-
102
  # Convert to tensor and add batch and channel dimensions
103
  img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
104
-
105
  return img_tensor.to(self.device)
106
-
107
  def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
108
  """
109
  Predict digit from image.
110
-
111
  Args:
112
  image: PIL Image or numpy array containing digit
113
-
114
  Returns:
115
  Dictionary with:
116
  - digit: Predicted digit (0-9)
@@ -118,25 +118,25 @@ class DigitClassifier:
118
  - probabilities: List of probabilities for each digit
119
  """
120
  img_tensor = self.preprocess(image)
121
-
122
  with torch.no_grad():
123
  outputs = self.model(img_tensor)
124
  probabilities = torch.softmax(outputs, dim=1)[0]
125
  confidence, predicted = torch.max(probabilities, dim=0)
126
-
127
  return {
128
  'digit': int(predicted.item()),
129
  'confidence': float(confidence.item()),
130
  'probabilities': probabilities.cpu().numpy().tolist()
131
  }
132
-
133
  def predict_batch(self, images: list) -> list:
134
  """
135
  Predict digits for a batch of images.
136
-
137
  Args:
138
  images: List of PIL Images or numpy arrays
139
-
140
  Returns:
141
  List of prediction dictionaries
142
  """
@@ -147,23 +147,23 @@ def test_inference():
147
  """Test inference module with sample images."""
148
  import sys
149
  from pathlib import Path
150
-
151
  # Add project root to path
152
  project_root = Path(__file__).parent.parent
153
  sys.path.insert(0, str(project_root))
154
-
155
  from scripts.data_loader import MnistDataloader
156
-
157
  print("Testing Inference Module")
158
  print("=" * 50)
159
-
160
  # Check if model exists
161
  model_path = project_root / 'models' / 'best_model.pt'
162
  if not model_path.exists():
163
  print(f"Error: Model not found at {model_path}")
164
  print("Please train a model first.")
165
  return
166
-
167
  # Load MNIST test data
168
  data_path = project_root / 'data' / 'raw'
169
  loader = MnistDataloader(
@@ -173,66 +173,70 @@ def test_inference():
173
  test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
174
  )
175
  _, (x_test, y_test) = loader.load_data()
176
-
177
  # Initialize classifier
178
  print(f"\n1. Loading model from: {model_path}")
179
  classifier = DigitClassifier(str(model_path))
180
  print(f" Device: {classifier.device}")
181
-
182
  # Test on a few images
183
  print("\n2. Testing predictions on 10 random test images:")
184
  print("-" * 50)
185
-
186
  indices = np.random.choice(len(x_test), 10, replace=False)
187
  correct = 0
188
-
189
  for i, idx in enumerate(indices, 1):
190
  image = x_test[idx]
191
  true_label = y_test[idx]
192
-
193
  # Convert list to numpy array if needed
194
  if isinstance(image, list):
195
  image = np.array(image)
196
-
197
  # Convert to PIL Image
198
  img = Image.fromarray(image.astype(np.uint8), mode='L')
199
-
200
  # Predict
201
  result = classifier.predict(img)
202
-
203
  is_correct = result['digit'] == true_label
204
  correct += is_correct
205
-
206
  print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
207
  f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
208
-
209
  accuracy = correct / len(indices) * 100
210
  print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
211
-
212
  # Test edge cases
213
  print("\n3. Testing edge cases:")
214
  print("-" * 50)
215
-
216
  # Blank image
217
  blank = np.zeros((28, 28), dtype=np.uint8)
218
  blank_img = Image.fromarray(blank, mode='L')
219
  result = classifier.predict(blank_img)
220
  print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
221
-
222
  # All white image
223
  white = np.ones((28, 28), dtype=np.uint8) * 255
224
  white_img = Image.fromarray(white, mode='L')
225
  result = classifier.predict(white_img)
226
  print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
227
-
228
  # Different size image
229
  test_img = x_test[0]
230
  if isinstance(test_img, list):
231
  test_img = np.array(test_img)
232
- large = Image.fromarray(test_img.astype(np.uint8), mode='L').resize((56, 56))
 
233
  result = classifier.predict(large)
234
- print(f" Resized image (56x56): Pred={result['digit']}, Conf={result['confidence']:.4f}")
235
-
 
 
 
236
  print("\n✓ Inference module test complete!")
237
 
238
 
 
14
 
15
  class DigitClassifier:
16
  """Production inference wrapper for MNIST digit classifier."""
17
+
18
  def __init__(self, model_path: str, device: str = None):
19
  """
20
  Initialize the digit classifier.
21
+
22
  Args:
23
  model_path: Path to model checkpoint (.pt file)
24
+ device: Device to run inference on ('cuda' or 'cpu').
25
  If None, auto-detects CUDA availability.
26
  """
27
  if device is None:
28
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
  else:
30
  self.device = device
31
+
32
  self.model_path = Path(model_path)
33
  if not self.model_path.exists():
34
  raise FileNotFoundError(f"Model not found at {model_path}")
35
+
36
  self.model = self._load_model()
37
  self.model.eval()
38
+
39
  # Normalization values (same as training)
40
  self.mean = 0.1307
41
  self.std = 0.3081
42
+
43
  def _load_model(self) -> torch.nn.Module:
44
  """Load model from checkpoint."""
45
  from scripts.models import BaselineCNN
46
+
47
  model = BaselineCNN()
48
+
49
  # Load checkpoint
50
  checkpoint = torch.load(self.model_path, map_location=self.device)
51
+
52
  # Handle different checkpoint formats
53
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
54
  model.load_state_dict(checkpoint['model_state_dict'])
55
  else:
56
  model.load_state_dict(checkpoint)
57
+
58
  return model.to(self.device)
59
+
60
  def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
61
  """
62
  Preprocess image for model input.
63
+
64
  Handles:
65
  - RGB to grayscale conversion
66
  - Resizing to 28x28
67
  - Normalization
68
  - Inversion if needed (white digit on black background)
69
+
70
  Args:
71
  image: PIL Image or numpy array
72
+
73
  Returns:
74
  Preprocessed tensor of shape (1, 1, 28, 28)
75
  """
76
  # Convert numpy array to PIL Image if needed
77
  if isinstance(image, np.ndarray):
78
  image = Image.fromarray(image)
79
+
80
  # Convert to grayscale if RGB
81
  if image.mode != 'L':
82
  image = image.convert('L')
83
+
84
  # Resize to 28x28 if needed
85
  if image.size != (28, 28):
86
  image = image.resize((28, 28), Image.Resampling.LANCZOS)
87
+
88
  # Convert to numpy array
89
  img_array = np.array(image).astype(np.float32)
90
+
91
  # Normalize to [0, 1]
92
  img_array = img_array / 255.0
93
+
94
  # Check if inversion is needed (MNIST is white digit on black background)
95
  # If most pixels are bright, it's likely a black digit on white background
96
  if img_array.mean() > 0.5:
97
  img_array = 1.0 - img_array
98
+
99
  # Apply normalization (same as training)
100
  img_array = (img_array - self.mean) / self.std
101
+
102
  # Convert to tensor and add batch and channel dimensions
103
  img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
104
+
105
  return img_tensor.to(self.device)
106
+
107
  def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
108
  """
109
  Predict digit from image.
110
+
111
  Args:
112
  image: PIL Image or numpy array containing digit
113
+
114
  Returns:
115
  Dictionary with:
116
  - digit: Predicted digit (0-9)
 
118
  - probabilities: List of probabilities for each digit
119
  """
120
  img_tensor = self.preprocess(image)
121
+
122
  with torch.no_grad():
123
  outputs = self.model(img_tensor)
124
  probabilities = torch.softmax(outputs, dim=1)[0]
125
  confidence, predicted = torch.max(probabilities, dim=0)
126
+
127
  return {
128
  'digit': int(predicted.item()),
129
  'confidence': float(confidence.item()),
130
  'probabilities': probabilities.cpu().numpy().tolist()
131
  }
132
+
133
  def predict_batch(self, images: list) -> list:
134
  """
135
  Predict digits for a batch of images.
136
+
137
  Args:
138
  images: List of PIL Images or numpy arrays
139
+
140
  Returns:
141
  List of prediction dictionaries
142
  """
 
147
  """Test inference module with sample images."""
148
  import sys
149
  from pathlib import Path
150
+
151
  # Add project root to path
152
  project_root = Path(__file__).parent.parent
153
  sys.path.insert(0, str(project_root))
154
+
155
  from scripts.data_loader import MnistDataloader
156
+
157
  print("Testing Inference Module")
158
  print("=" * 50)
159
+
160
  # Check if model exists
161
  model_path = project_root / 'models' / 'best_model.pt'
162
  if not model_path.exists():
163
  print(f"Error: Model not found at {model_path}")
164
  print("Please train a model first.")
165
  return
166
+
167
  # Load MNIST test data
168
  data_path = project_root / 'data' / 'raw'
169
  loader = MnistDataloader(
 
173
  test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
174
  )
175
  _, (x_test, y_test) = loader.load_data()
176
+
177
  # Initialize classifier
178
  print(f"\n1. Loading model from: {model_path}")
179
  classifier = DigitClassifier(str(model_path))
180
  print(f" Device: {classifier.device}")
181
+
182
  # Test on a few images
183
  print("\n2. Testing predictions on 10 random test images:")
184
  print("-" * 50)
185
+
186
  indices = np.random.choice(len(x_test), 10, replace=False)
187
  correct = 0
188
+
189
  for i, idx in enumerate(indices, 1):
190
  image = x_test[idx]
191
  true_label = y_test[idx]
192
+
193
  # Convert list to numpy array if needed
194
  if isinstance(image, list):
195
  image = np.array(image)
196
+
197
  # Convert to PIL Image
198
  img = Image.fromarray(image.astype(np.uint8), mode='L')
199
+
200
  # Predict
201
  result = classifier.predict(img)
202
+
203
  is_correct = result['digit'] == true_label
204
  correct += is_correct
205
+
206
  print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
207
  f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
208
+
209
  accuracy = correct / len(indices) * 100
210
  print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
211
+
212
  # Test edge cases
213
  print("\n3. Testing edge cases:")
214
  print("-" * 50)
215
+
216
  # Blank image
217
  blank = np.zeros((28, 28), dtype=np.uint8)
218
  blank_img = Image.fromarray(blank, mode='L')
219
  result = classifier.predict(blank_img)
220
  print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
221
+
222
  # All white image
223
  white = np.ones((28, 28), dtype=np.uint8) * 255
224
  white_img = Image.fromarray(white, mode='L')
225
  result = classifier.predict(white_img)
226
  print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
227
+
228
  # Different size image
229
  test_img = x_test[0]
230
  if isinstance(test_img, list):
231
  test_img = np.array(test_img)
232
+ large = Image.fromarray(test_img.astype(np.uint8), mode='L')
233
+ large = large.resize((56, 56))
234
  result = classifier.predict(large)
235
+ print(
236
+ f" Resized image (56x56): "
237
+ f"Pred={result['digit']}, Conf={result['confidence']:.4f}"
238
+ )
239
+
240
  print("\n✓ Inference module test complete!")
241
 
242
 
scripts/mlflow_setup.py CHANGED
@@ -24,20 +24,20 @@ def setup_mlflow(
24
  ) -> str:
25
  """
26
  Setup MLflow tracking with best practices.
27
-
28
  Args:
29
  experiment_name: Name of the experiment
30
  tracking_uri: MLflow tracking URI (default: local ./mlruns)
31
-
32
  Returns:
33
  experiment_id: MLflow experiment ID
34
  """
35
  # Set tracking URI
36
  if tracking_uri is None:
37
  tracking_uri = MLFLOW_TRACKING_URI
38
-
39
  mlflow.set_tracking_uri(tracking_uri)
40
-
41
  # Create or get experiment
42
  try:
43
  experiment = mlflow.get_experiment_by_name(experiment_name)
@@ -55,34 +55,36 @@ def setup_mlflow(
55
  except Exception as e:
56
  print(f"Warning: Could not create experiment: {e}")
57
  experiment_id = "0" # Default experiment
58
-
59
  mlflow.set_experiment(experiment_name)
60
-
61
  print(f"MLflow tracking URI: {tracking_uri}")
62
  print(f"Experiment: {experiment_name} (ID: {experiment_id})")
63
-
64
  return experiment_id
65
 
66
 
67
  def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
68
  """
69
  Log model parameters to MLflow.
70
-
71
  Args:
72
  model: PyTorch model
73
  prefix: Prefix for parameter names
74
-
75
  Returns:
76
  Dictionary of logged parameters
77
  """
78
  from scripts.models import count_parameters
79
-
80
  params = {
81
  f"{prefix}_name": model.__class__.__name__,
82
  f"{prefix}_total_params": count_parameters(model),
83
- f"{prefix}_trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad)
 
 
84
  }
85
-
86
  mlflow.log_params(params)
87
  return params
88
 
@@ -90,7 +92,7 @@ def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
90
  def log_training_config(config: Dict[str, Any]) -> None:
91
  """
92
  Log training configuration to MLflow.
93
-
94
  Args:
95
  config: Dictionary of training hyperparameters
96
  """
@@ -102,7 +104,7 @@ def log_training_config(config: Dict[str, Any]) -> None:
102
  flat_config[f"{key}_{subkey}"] = subvalue
103
  else:
104
  flat_config[key] = value
105
-
106
  mlflow.log_params(flat_config)
107
 
108
 
@@ -115,7 +117,7 @@ def log_data_info(
115
  ) -> None:
116
  """
117
  Log dataset information to MLflow.
118
-
119
  Args:
120
  train_size: Number of training samples
121
  val_size: Number of validation samples
@@ -135,13 +137,13 @@ def log_data_info(
135
  def log_system_info() -> Dict[str, Any]:
136
  """
137
  Log system information to MLflow.
138
-
139
  Returns:
140
  Dictionary of system information
141
  """
142
  import torch
143
  import platform
144
-
145
  system_info = {
146
  "system_platform": platform.system(),
147
  "system_python_version": platform.python_version(),
@@ -152,11 +154,11 @@ def log_system_info() -> Dict[str, Any]:
152
  ),
153
  "system_device": "cuda" if torch.cuda.is_available() else "cpu"
154
  }
155
-
156
  if torch.cuda.is_available():
157
  system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
158
  system_info["system_gpu_count"] = torch.cuda.device_count()
159
-
160
  mlflow.log_params(system_info)
161
  return system_info
162
 
@@ -164,7 +166,7 @@ def log_system_info() -> Dict[str, Any]:
164
  def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
165
  """
166
  Log metrics for a specific epoch.
167
-
168
  Args:
169
  metrics: Dictionary of metric names and values
170
  step: Epoch number
@@ -175,7 +177,7 @@ def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
175
  def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
176
  """
177
  Log a file or directory as an artifact.
178
-
179
  Args:
180
  path: Path to file or directory
181
  artifact_path: Optional artifact path in MLflow
@@ -194,14 +196,14 @@ def log_model_to_registry(
194
  ) -> None:
195
  """
196
  Log model to MLflow with model registry integration.
197
-
198
  Args:
199
  model: PyTorch model
200
  model_name: Name for the model artifact
201
  artifact_path: Artifact path in MLflow
202
  registered_model_name: Name for model registry (optional)
203
  """
204
-
205
  # Log model
206
  mlflow.pytorch.log_model(
207
  pytorch_model=model,
@@ -216,11 +218,11 @@ def get_or_create_run(
216
  ) -> mlflow.ActiveRun:
217
  """
218
  Get existing run or create a new one.
219
-
220
  Args:
221
  run_name: Name for the run
222
  tags: Tags for the run
223
-
224
  Returns:
225
  MLflow active run context
226
  """
@@ -236,10 +238,10 @@ def test_mlflow_setup():
236
  """Test MLflow setup and basic logging."""
237
  print("Testing MLflow Setup")
238
  print("=" * 50)
239
-
240
  # Setup MLflow
241
  setup_mlflow("test-experiment")
242
-
243
  # Test logging
244
  with mlflow.start_run(run_name="test-run"):
245
  # Log parameters
@@ -248,7 +250,7 @@ def test_mlflow_setup():
248
  "batch_size": 64,
249
  "epochs": 10
250
  })
251
-
252
  # Log metrics
253
  for epoch in range(3):
254
  mlflow.log_metrics({
@@ -257,13 +259,13 @@ def test_mlflow_setup():
257
  "train_accuracy": 0.8 + epoch * 0.05,
258
  "val_accuracy": 0.75 + epoch * 0.05
259
  }, step=epoch)
260
-
261
  # Log system info
262
  system_info = log_system_info()
263
  print("\nSystem Info:")
264
  for key, value in system_info.items():
265
  print(f" {key}: {value}")
266
-
267
  print("\n✓ MLflow test complete!")
268
  print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
269
 
 
24
  ) -> str:
25
  """
26
  Setup MLflow tracking with best practices.
27
+
28
  Args:
29
  experiment_name: Name of the experiment
30
  tracking_uri: MLflow tracking URI (default: local ./mlruns)
31
+
32
  Returns:
33
  experiment_id: MLflow experiment ID
34
  """
35
  # Set tracking URI
36
  if tracking_uri is None:
37
  tracking_uri = MLFLOW_TRACKING_URI
38
+
39
  mlflow.set_tracking_uri(tracking_uri)
40
+
41
  # Create or get experiment
42
  try:
43
  experiment = mlflow.get_experiment_by_name(experiment_name)
 
55
  except Exception as e:
56
  print(f"Warning: Could not create experiment: {e}")
57
  experiment_id = "0" # Default experiment
58
+
59
  mlflow.set_experiment(experiment_name)
60
+
61
  print(f"MLflow tracking URI: {tracking_uri}")
62
  print(f"Experiment: {experiment_name} (ID: {experiment_id})")
63
+
64
  return experiment_id
65
 
66
 
67
  def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
68
  """
69
  Log model parameters to MLflow.
70
+
71
  Args:
72
  model: PyTorch model
73
  prefix: Prefix for parameter names
74
+
75
  Returns:
76
  Dictionary of logged parameters
77
  """
78
  from scripts.models import count_parameters
79
+
80
  params = {
81
  f"{prefix}_name": model.__class__.__name__,
82
  f"{prefix}_total_params": count_parameters(model),
83
+ f"{prefix}_trainable_params": sum(
84
+ p.numel() for p in model.parameters() if p.requires_grad
85
+ )
86
  }
87
+
88
  mlflow.log_params(params)
89
  return params
90
 
 
92
  def log_training_config(config: Dict[str, Any]) -> None:
93
  """
94
  Log training configuration to MLflow.
95
+
96
  Args:
97
  config: Dictionary of training hyperparameters
98
  """
 
104
  flat_config[f"{key}_{subkey}"] = subvalue
105
  else:
106
  flat_config[key] = value
107
+
108
  mlflow.log_params(flat_config)
109
 
110
 
 
117
  ) -> None:
118
  """
119
  Log dataset information to MLflow.
120
+
121
  Args:
122
  train_size: Number of training samples
123
  val_size: Number of validation samples
 
137
  def log_system_info() -> Dict[str, Any]:
138
  """
139
  Log system information to MLflow.
140
+
141
  Returns:
142
  Dictionary of system information
143
  """
144
  import torch
145
  import platform
146
+
147
  system_info = {
148
  "system_platform": platform.system(),
149
  "system_python_version": platform.python_version(),
 
154
  ),
155
  "system_device": "cuda" if torch.cuda.is_available() else "cpu"
156
  }
157
+
158
  if torch.cuda.is_available():
159
  system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
160
  system_info["system_gpu_count"] = torch.cuda.device_count()
161
+
162
  mlflow.log_params(system_info)
163
  return system_info
164
 
 
166
  def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
167
  """
168
  Log metrics for a specific epoch.
169
+
170
  Args:
171
  metrics: Dictionary of metric names and values
172
  step: Epoch number
 
177
  def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
178
  """
179
  Log a file or directory as an artifact.
180
+
181
  Args:
182
  path: Path to file or directory
183
  artifact_path: Optional artifact path in MLflow
 
196
  ) -> None:
197
  """
198
  Log model to MLflow with model registry integration.
199
+
200
  Args:
201
  model: PyTorch model
202
  model_name: Name for the model artifact
203
  artifact_path: Artifact path in MLflow
204
  registered_model_name: Name for model registry (optional)
205
  """
206
+
207
  # Log model
208
  mlflow.pytorch.log_model(
209
  pytorch_model=model,
 
218
  ) -> mlflow.ActiveRun:
219
  """
220
  Get existing run or create a new one.
221
+
222
  Args:
223
  run_name: Name for the run
224
  tags: Tags for the run
225
+
226
  Returns:
227
  MLflow active run context
228
  """
 
238
  """Test MLflow setup and basic logging."""
239
  print("Testing MLflow Setup")
240
  print("=" * 50)
241
+
242
  # Setup MLflow
243
  setup_mlflow("test-experiment")
244
+
245
  # Test logging
246
  with mlflow.start_run(run_name="test-run"):
247
  # Log parameters
 
250
  "batch_size": 64,
251
  "epochs": 10
252
  })
253
+
254
  # Log metrics
255
  for epoch in range(3):
256
  mlflow.log_metrics({
 
259
  "train_accuracy": 0.8 + epoch * 0.05,
260
  "val_accuracy": 0.75 + epoch * 0.05
261
  }, step=epoch)
262
+
263
  # Log system info
264
  system_info = log_system_info()
265
  print("\nSystem Info:")
266
  for key, value in system_info.items():
267
  print(f" {key}: {value}")
268
+
269
  print("\n✓ MLflow test complete!")
270
  print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
271
 
scripts/models.py CHANGED
@@ -8,7 +8,7 @@ This module provides CNN models for digit recognition:
8
 
9
  Usage:
10
  from scripts.models import BaselineCNN
11
-
12
  model = BaselineCNN()
13
  output = model(images) # (batch, 10) logits
14
  """
@@ -22,7 +22,7 @@ from typing import Tuple
22
  class BaselineCNN(nn.Module):
23
  """
24
  Baseline CNN for MNIST classification.
25
-
26
  Architecture:
27
  Input: (batch, 1, 28, 28)
28
  Conv1: 1 -> 32 channels, 3x3 kernel, padding=1
@@ -32,28 +32,28 @@ class BaselineCNN(nn.Module):
32
  Flatten -> (batch, 3136)
33
  FC1: 3136 -> 128, ReLU, Dropout(0.5)
34
  FC2: 128 -> 10 (output logits)
35
-
36
  Design Rationale:
37
  - 2 conv layers: Balance between simplicity and capacity
38
  - 32->64 filters: Standard progression, proven effective
39
  - Dropout 0.5: Prevent overfitting on small dataset
40
  - No batch norm: Keep baseline simple
41
-
42
  Expected Performance:
43
  - Parameters: ~110k
44
  - Test accuracy: 98-99%
45
  - Training time: ~5-10 min on GPU
46
  """
47
-
48
  def __init__(self, dropout_rate: float = 0.5):
49
  """
50
  Initialize baseline CNN.
51
-
52
  Args:
53
  dropout_rate: Dropout probability (default 0.5)
54
  """
55
  super(BaselineCNN, self).__init__()
56
-
57
  # Convolutional layers
58
  self.conv1 = nn.Conv2d(
59
  in_channels=1,
@@ -67,25 +67,25 @@ class BaselineCNN(nn.Module):
67
  kernel_size=3,
68
  padding=1
69
  )
70
-
71
  # Pooling layer (shared)
72
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
73
-
74
  # Fully connected layers
75
  # After two pooling layers: 28->14->7, so 64*7*7 = 3136
76
  self.fc1 = nn.Linear(64 * 7 * 7, 128)
77
  self.fc2 = nn.Linear(128, 10)
78
-
79
  # Dropout for regularization
80
  self.dropout = nn.Dropout(p=dropout_rate)
81
-
82
  def forward(self, x: torch.Tensor) -> torch.Tensor:
83
  """
84
  Forward pass.
85
-
86
  Args:
87
  x: Input tensor of shape (batch, 1, 28, 28)
88
-
89
  Returns:
90
  Output logits of shape (batch, 10)
91
  """
@@ -93,28 +93,28 @@ class BaselineCNN(nn.Module):
93
  x = self.conv1(x) # (batch, 32, 28, 28)
94
  x = F.relu(x)
95
  x = self.pool(x) # (batch, 32, 14, 14)
96
-
97
  # Conv block 2: Conv -> ReLU -> Pool
98
  x = self.conv2(x) # (batch, 64, 14, 14)
99
  x = F.relu(x)
100
  x = self.pool(x) # (batch, 64, 7, 7)
101
-
102
  # Flatten
103
  x = x.view(-1, 64 * 7 * 7) # (batch, 3136)
104
-
105
  # Fully connected layers
106
  x = self.fc1(x) # (batch, 128)
107
  x = F.relu(x)
108
  x = self.dropout(x)
109
  x = self.fc2(x) # (batch, 10)
110
-
111
  return x
112
 
113
 
114
  class ImprovedCNN(nn.Module):
115
  """
116
  Enhanced CNN with batch normalization and deeper architecture.
117
-
118
  Architecture:
119
  Conv1: 1 -> 32, BatchNorm, ReLU, MaxPool
120
  Conv2: 32 -> 64, BatchNorm, ReLU, MaxPool
@@ -122,49 +122,49 @@ class ImprovedCNN(nn.Module):
122
  Flatten
123
  FC1: 128*3*3 -> 256, BatchNorm, ReLU, Dropout(0.5)
124
  FC2: 256 -> 10
125
-
126
  Expected Performance:
127
  - Parameters: ~200k
128
  - Test accuracy: 99%+
129
  - Converges faster than baseline
130
  """
131
-
132
  def __init__(self, dropout_rate: float = 0.5):
133
  """
134
  Initialize improved CNN.
135
-
136
  Args:
137
  dropout_rate: Dropout probability (default 0.5)
138
  """
139
  super(ImprovedCNN, self).__init__()
140
-
141
  # Convolutional layers with batch normalization
142
  self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
143
  self.bn1 = nn.BatchNorm2d(32)
144
-
145
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
146
  self.bn2 = nn.BatchNorm2d(64)
147
-
148
  self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
149
  self.bn3 = nn.BatchNorm2d(128)
150
-
151
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
152
-
153
  # Fully connected layers
154
  # After three pooling layers: 28->14->7->3, so 128*3*3 = 1152
155
  self.fc1 = nn.Linear(128 * 3 * 3, 256)
156
  self.bn_fc = nn.BatchNorm1d(256)
157
  self.fc2 = nn.Linear(256, 10)
158
-
159
  self.dropout = nn.Dropout(p=dropout_rate)
160
-
161
  def forward(self, x: torch.Tensor) -> torch.Tensor:
162
  """
163
  Forward pass.
164
-
165
  Args:
166
  x: Input tensor of shape (batch, 1, 28, 28)
167
-
168
  Returns:
169
  Output logits of shape (batch, 10)
170
  """
@@ -173,39 +173,39 @@ class ImprovedCNN(nn.Module):
173
  x = self.bn1(x)
174
  x = F.relu(x)
175
  x = self.pool(x) # (batch, 32, 14, 14)
176
-
177
  # Conv block 2
178
  x = self.conv2(x)
179
  x = self.bn2(x)
180
  x = F.relu(x)
181
  x = self.pool(x) # (batch, 64, 7, 7)
182
-
183
  # Conv block 3
184
  x = self.conv3(x)
185
  x = self.bn3(x)
186
  x = F.relu(x)
187
  x = self.pool(x) # (batch, 128, 3, 3)
188
-
189
  # Flatten
190
  x = x.view(-1, 128 * 3 * 3)
191
-
192
  # Fully connected layers
193
  x = self.fc1(x)
194
  x = self.bn_fc(x)
195
  x = F.relu(x)
196
  x = self.dropout(x)
197
  x = self.fc2(x)
198
-
199
  return x
200
 
201
 
202
  def count_parameters(model: nn.Module) -> Tuple[int, int]:
203
  """
204
  Count total and trainable parameters in model.
205
-
206
  Args:
207
  model: PyTorch model
208
-
209
  Returns:
210
  Tuple of (total_params, trainable_params)
211
  """
@@ -214,19 +214,21 @@ def count_parameters(model: nn.Module) -> Tuple[int, int]:
214
  return total_params, trainable_params
215
 
216
 
217
- def get_model_summary(model: nn.Module, input_size: Tuple[int, ...] = (1, 1, 28, 28)) -> str:
 
 
218
  """
219
  Generate model architecture summary.
220
-
221
  Args:
222
  model: PyTorch model
223
  input_size: Input tensor size (batch, channels, height, width)
224
-
225
  Returns:
226
  Formatted string with model summary
227
  """
228
  total_params, trainable_params = count_parameters(model)
229
-
230
  summary = []
231
  summary.append("=" * 60)
232
  summary.append(f"Model: {model.__class__.__name__}")
@@ -234,47 +236,49 @@ def get_model_summary(model: nn.Module, input_size: Tuple[int, ...] = (1, 1, 28,
234
  summary.append(f"Input size: {input_size}")
235
  summary.append(f"Total parameters: {total_params:,}")
236
  summary.append(f"Trainable parameters: {trainable_params:,}")
237
- summary.append(f"Model size (MB): {total_params * 4 / (1024**2):.2f}") # Assuming float32
 
 
238
  summary.append("=" * 60)
239
-
240
  return "\n".join(summary)
241
 
242
 
243
  def test_model(model: nn.Module, device: str = 'cpu') -> bool:
244
  """
245
  Test model with dummy input.
246
-
247
  Args:
248
  model: PyTorch model
249
  device: Device to run on ('cpu' or 'cuda')
250
-
251
  Returns:
252
  True if test passes, False otherwise
253
  """
254
  try:
255
  model = model.to(device)
256
  model.eval()
257
-
258
  # Create dummy input
259
  dummy_input = torch.randn(4, 1, 28, 28).to(device)
260
-
261
  # Forward pass
262
  with torch.no_grad():
263
  output = model(dummy_input)
264
-
265
  # Check output shape
266
  assert output.shape == (4, 10), f"Expected shape (4, 10), got {output.shape}"
267
-
268
  # Check output is finite
269
  assert torch.isfinite(output).all(), "Output contains NaN or Inf"
270
-
271
  print("✓ Model test passed")
272
  print(f" Input shape: {dummy_input.shape}")
273
  print(f" Output shape: {output.shape}")
274
  print(f" Output range: [{output.min():.4f}, {output.max():.4f}]")
275
-
276
  return True
277
-
278
  except Exception as e:
279
  print(f"✗ Model test failed: {e}")
280
  return False
@@ -284,23 +288,23 @@ if __name__ == "__main__":
284
  """Test model instantiation and forward pass."""
285
  print("Testing BaselineCNN:")
286
  print()
287
-
288
  # Create model
289
  model = BaselineCNN()
290
  print(get_model_summary(model))
291
  print()
292
-
293
  # Test forward pass
294
  test_model(model)
295
  print()
296
-
297
  # Test improved model
298
  print("=" * 60)
299
  print("Testing ImprovedCNN:")
300
  print()
301
-
302
  model_improved = ImprovedCNN()
303
  print(get_model_summary(model_improved))
304
  print()
305
-
306
  test_model(model_improved)
 
8
 
9
  Usage:
10
  from scripts.models import BaselineCNN
11
+
12
  model = BaselineCNN()
13
  output = model(images) # (batch, 10) logits
14
  """
 
22
  class BaselineCNN(nn.Module):
23
  """
24
  Baseline CNN for MNIST classification.
25
+
26
  Architecture:
27
  Input: (batch, 1, 28, 28)
28
  Conv1: 1 -> 32 channels, 3x3 kernel, padding=1
 
32
  Flatten -> (batch, 3136)
33
  FC1: 3136 -> 128, ReLU, Dropout(0.5)
34
  FC2: 128 -> 10 (output logits)
35
+
36
  Design Rationale:
37
  - 2 conv layers: Balance between simplicity and capacity
38
  - 32->64 filters: Standard progression, proven effective
39
  - Dropout 0.5: Prevent overfitting on small dataset
40
  - No batch norm: Keep baseline simple
41
+
42
  Expected Performance:
43
  - Parameters: ~110k
44
  - Test accuracy: 98-99%
45
  - Training time: ~5-10 min on GPU
46
  """
47
+
48
  def __init__(self, dropout_rate: float = 0.5):
49
  """
50
  Initialize baseline CNN.
51
+
52
  Args:
53
  dropout_rate: Dropout probability (default 0.5)
54
  """
55
  super(BaselineCNN, self).__init__()
56
+
57
  # Convolutional layers
58
  self.conv1 = nn.Conv2d(
59
  in_channels=1,
 
67
  kernel_size=3,
68
  padding=1
69
  )
70
+
71
  # Pooling layer (shared)
72
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
73
+
74
  # Fully connected layers
75
  # After two pooling layers: 28->14->7, so 64*7*7 = 3136
76
  self.fc1 = nn.Linear(64 * 7 * 7, 128)
77
  self.fc2 = nn.Linear(128, 10)
78
+
79
  # Dropout for regularization
80
  self.dropout = nn.Dropout(p=dropout_rate)
81
+
82
  def forward(self, x: torch.Tensor) -> torch.Tensor:
83
  """
84
  Forward pass.
85
+
86
  Args:
87
  x: Input tensor of shape (batch, 1, 28, 28)
88
+
89
  Returns:
90
  Output logits of shape (batch, 10)
91
  """
 
93
  x = self.conv1(x) # (batch, 32, 28, 28)
94
  x = F.relu(x)
95
  x = self.pool(x) # (batch, 32, 14, 14)
96
+
97
  # Conv block 2: Conv -> ReLU -> Pool
98
  x = self.conv2(x) # (batch, 64, 14, 14)
99
  x = F.relu(x)
100
  x = self.pool(x) # (batch, 64, 7, 7)
101
+
102
  # Flatten
103
  x = x.view(-1, 64 * 7 * 7) # (batch, 3136)
104
+
105
  # Fully connected layers
106
  x = self.fc1(x) # (batch, 128)
107
  x = F.relu(x)
108
  x = self.dropout(x)
109
  x = self.fc2(x) # (batch, 10)
110
+
111
  return x
112
 
113
 
114
  class ImprovedCNN(nn.Module):
115
  """
116
  Enhanced CNN with batch normalization and deeper architecture.
117
+
118
  Architecture:
119
  Conv1: 1 -> 32, BatchNorm, ReLU, MaxPool
120
  Conv2: 32 -> 64, BatchNorm, ReLU, MaxPool
 
122
  Flatten
123
  FC1: 128*3*3 -> 256, BatchNorm, ReLU, Dropout(0.5)
124
  FC2: 256 -> 10
125
+
126
  Expected Performance:
127
  - Parameters: ~200k
128
  - Test accuracy: 99%+
129
  - Converges faster than baseline
130
  """
131
+
132
  def __init__(self, dropout_rate: float = 0.5):
133
  """
134
  Initialize improved CNN.
135
+
136
  Args:
137
  dropout_rate: Dropout probability (default 0.5)
138
  """
139
  super(ImprovedCNN, self).__init__()
140
+
141
  # Convolutional layers with batch normalization
142
  self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
143
  self.bn1 = nn.BatchNorm2d(32)
144
+
145
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
146
  self.bn2 = nn.BatchNorm2d(64)
147
+
148
  self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
149
  self.bn3 = nn.BatchNorm2d(128)
150
+
151
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
152
+
153
  # Fully connected layers
154
  # After three pooling layers: 28->14->7->3, so 128*3*3 = 1152
155
  self.fc1 = nn.Linear(128 * 3 * 3, 256)
156
  self.bn_fc = nn.BatchNorm1d(256)
157
  self.fc2 = nn.Linear(256, 10)
158
+
159
  self.dropout = nn.Dropout(p=dropout_rate)
160
+
161
  def forward(self, x: torch.Tensor) -> torch.Tensor:
162
  """
163
  Forward pass.
164
+
165
  Args:
166
  x: Input tensor of shape (batch, 1, 28, 28)
167
+
168
  Returns:
169
  Output logits of shape (batch, 10)
170
  """
 
173
  x = self.bn1(x)
174
  x = F.relu(x)
175
  x = self.pool(x) # (batch, 32, 14, 14)
176
+
177
  # Conv block 2
178
  x = self.conv2(x)
179
  x = self.bn2(x)
180
  x = F.relu(x)
181
  x = self.pool(x) # (batch, 64, 7, 7)
182
+
183
  # Conv block 3
184
  x = self.conv3(x)
185
  x = self.bn3(x)
186
  x = F.relu(x)
187
  x = self.pool(x) # (batch, 128, 3, 3)
188
+
189
  # Flatten
190
  x = x.view(-1, 128 * 3 * 3)
191
+
192
  # Fully connected layers
193
  x = self.fc1(x)
194
  x = self.bn_fc(x)
195
  x = F.relu(x)
196
  x = self.dropout(x)
197
  x = self.fc2(x)
198
+
199
  return x
200
 
201
 
202
  def count_parameters(model: nn.Module) -> Tuple[int, int]:
203
  """
204
  Count total and trainable parameters in model.
205
+
206
  Args:
207
  model: PyTorch model
208
+
209
  Returns:
210
  Tuple of (total_params, trainable_params)
211
  """
 
214
  return total_params, trainable_params
215
 
216
 
217
+ def get_model_summary(
218
+ model: nn.Module, input_size: Tuple[int, ...] = (1, 1, 28, 28)
219
+ ) -> str:
220
  """
221
  Generate model architecture summary.
222
+
223
  Args:
224
  model: PyTorch model
225
  input_size: Input tensor size (batch, channels, height, width)
226
+
227
  Returns:
228
  Formatted string with model summary
229
  """
230
  total_params, trainable_params = count_parameters(model)
231
+
232
  summary = []
233
  summary.append("=" * 60)
234
  summary.append(f"Model: {model.__class__.__name__}")
 
236
  summary.append(f"Input size: {input_size}")
237
  summary.append(f"Total parameters: {total_params:,}")
238
  summary.append(f"Trainable parameters: {trainable_params:,}")
239
+ # Assuming float32
240
+ model_size_mb = total_params * 4 / (1024**2)
241
+ summary.append(f"Model size (MB): {model_size_mb:.2f}")
242
  summary.append("=" * 60)
243
+
244
  return "\n".join(summary)
245
 
246
 
247
  def test_model(model: nn.Module, device: str = 'cpu') -> bool:
248
  """
249
  Test model with dummy input.
250
+
251
  Args:
252
  model: PyTorch model
253
  device: Device to run on ('cpu' or 'cuda')
254
+
255
  Returns:
256
  True if test passes, False otherwise
257
  """
258
  try:
259
  model = model.to(device)
260
  model.eval()
261
+
262
  # Create dummy input
263
  dummy_input = torch.randn(4, 1, 28, 28).to(device)
264
+
265
  # Forward pass
266
  with torch.no_grad():
267
  output = model(dummy_input)
268
+
269
  # Check output shape
270
  assert output.shape == (4, 10), f"Expected shape (4, 10), got {output.shape}"
271
+
272
  # Check output is finite
273
  assert torch.isfinite(output).all(), "Output contains NaN or Inf"
274
+
275
  print("✓ Model test passed")
276
  print(f" Input shape: {dummy_input.shape}")
277
  print(f" Output shape: {output.shape}")
278
  print(f" Output range: [{output.min():.4f}, {output.max():.4f}]")
279
+
280
  return True
281
+
282
  except Exception as e:
283
  print(f"✗ Model test failed: {e}")
284
  return False
 
288
  """Test model instantiation and forward pass."""
289
  print("Testing BaselineCNN:")
290
  print()
291
+
292
  # Create model
293
  model = BaselineCNN()
294
  print(get_model_summary(model))
295
  print()
296
+
297
  # Test forward pass
298
  test_model(model)
299
  print()
300
+
301
  # Test improved model
302
  print("=" * 60)
303
  print("Testing ImprovedCNN:")
304
  print()
305
+
306
  model_improved = ImprovedCNN()
307
  print(get_model_summary(model_improved))
308
  print()
309
+
310
  test_model(model_improved)
scripts/preprocessing.py CHANGED
@@ -9,7 +9,7 @@ This module provides PyTorch Dataset and DataLoader setup for MNIST:
9
 
10
  Usage:
11
  from scripts.preprocessing import MnistDataset, create_dataloaders
12
-
13
  train_dataset = MnistDataset(x_train, y_train, transform=None)
14
  train_loader, val_loader = create_dataloaders(
15
  train_dataset, val_dataset, batch_size=64
@@ -26,10 +26,10 @@ from torch.utils.data import Dataset, DataLoader
26
  class MnistDataset(Dataset):
27
  """
28
  PyTorch Dataset for MNIST images.
29
-
30
  Handles normalization and conversion to tensors suitable for CNN training.
31
  """
32
-
33
  def __init__(
34
  self,
35
  images: List[NDArray[np.uint8]],
@@ -38,7 +38,7 @@ class MnistDataset(Dataset):
38
  ):
39
  """
40
  Initialize MNIST dataset.
41
-
42
  Args:
43
  images: List of 28x28 numpy arrays with pixel values [0, 255]
44
  labels: List of integer labels (0-9)
@@ -47,22 +47,22 @@ class MnistDataset(Dataset):
47
  self.images = images
48
  self.labels = labels
49
  self.transform = transform
50
-
51
  # Validate inputs
52
  assert len(images) == len(labels), \
53
  f"Mismatch: {len(images)} images but {len(labels)} labels"
54
-
55
  def __len__(self) -> int:
56
  """Return number of samples in dataset."""
57
  return len(self.images)
58
-
59
  def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
60
  """
61
  Get a single sample.
62
-
63
  Args:
64
  idx: Index of sample to retrieve
65
-
66
  Returns:
67
  Tuple of (image_tensor, label_tensor)
68
  - image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1]
@@ -71,18 +71,18 @@ class MnistDataset(Dataset):
71
  # Get image and label
72
  image = np.array(self.images[idx])
73
  label = self.labels[idx]
74
-
75
  # Normalize to [0, 1]
76
  image = image.astype(np.float32) / 255.0
77
-
78
  # Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28)
79
  image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
80
  label = torch.tensor(label, dtype=torch.long)
81
-
82
  # Apply transforms if provided (e.g., augmentation)
83
  if self.transform:
84
  image = self.transform(image)
85
-
86
  return image, label
87
 
88
 
@@ -95,14 +95,14 @@ def create_dataloaders(
95
  ) -> Tuple[DataLoader, DataLoader]:
96
  """
97
  Create DataLoader instances for training and validation.
98
-
99
  Args:
100
  train_dataset: Training dataset
101
  val_dataset: Validation dataset
102
  batch_size: Number of samples per batch
103
  num_workers: Number of worker processes for data loading
104
  shuffle_train: Whether to shuffle training data
105
-
106
  Returns:
107
  Tuple of (train_loader, val_loader)
108
  """
@@ -113,7 +113,7 @@ def create_dataloaders(
113
  num_workers=num_workers,
114
  pin_memory=True # Faster GPU transfer
115
  )
116
-
117
  val_loader = DataLoader(
118
  val_dataset,
119
  batch_size=batch_size,
@@ -121,7 +121,7 @@ def create_dataloaders(
121
  num_workers=num_workers,
122
  pin_memory=True
123
  )
124
-
125
  return train_loader, val_loader
126
 
127
 
@@ -132,12 +132,12 @@ def create_test_dataloader(
132
  ) -> DataLoader:
133
  """
134
  Create DataLoader for test set.
135
-
136
  Args:
137
  test_dataset: Test dataset
138
  batch_size: Number of samples per batch
139
  num_workers: Number of worker processes for data loading
140
-
141
  Returns:
142
  Test DataLoader
143
  """
@@ -148,7 +148,7 @@ def create_test_dataloader(
148
  num_workers=num_workers,
149
  pin_memory=True
150
  )
151
-
152
  return test_loader
153
 
154
 
@@ -163,70 +163,70 @@ def split_train_val(
163
  ]:
164
  """
165
  Split training data into train and validation sets.
166
-
167
  Uses stratified sampling to maintain class balance.
168
-
169
  Args:
170
  images: List of training images
171
  labels: List of training labels
172
  val_split: Fraction of data to use for validation (0.15 = 15%)
173
  random_seed: Random seed for reproducibility
174
-
175
  Returns:
176
  Tuple of ((train_images, train_labels), (val_images, val_labels))
177
  """
178
  from collections import defaultdict
179
-
180
  # Group indices by class for stratified split
181
  class_indices = defaultdict(list)
182
  for idx, label in enumerate(labels):
183
  class_indices[label].append(idx)
184
-
185
  # Set random seed
186
  np.random.seed(random_seed)
187
-
188
  train_indices = []
189
  val_indices = []
190
-
191
  # Split each class separately
192
  for class_label, indices in class_indices.items():
193
  indices = np.array(indices)
194
  np.random.shuffle(indices)
195
-
196
  split_point = int(len(indices) * (1 - val_split))
197
  train_indices.extend(indices[:split_point])
198
  val_indices.extend(indices[split_point:])
199
-
200
  # Shuffle combined indices
201
  np.random.shuffle(train_indices)
202
  np.random.shuffle(val_indices)
203
-
204
  # Extract images and labels
205
  train_images = [images[i] for i in train_indices]
206
  train_labels = [labels[i] for i in train_indices]
207
  val_images = [images[i] for i in val_indices]
208
  val_labels = [labels[i] for i in val_indices]
209
-
210
  return (train_images, train_labels), (val_images, val_labels)
211
 
212
 
213
  def get_dataset_statistics(dataset: MnistDataset) -> dict:
214
  """
215
  Compute statistics for a dataset (useful for debugging).
216
-
217
  Args:
218
  dataset: MnistDataset instance
219
-
220
  Returns:
221
  Dictionary with statistics
222
  """
223
  # Sample first image to check preprocessing
224
  sample_img, sample_label = dataset[0]
225
-
226
  # Count labels
227
  from collections import Counter
228
  label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))])
229
-
230
  return {
231
  'num_samples': len(dataset),
232
  'sample_image_shape': tuple(sample_img.shape),
 
9
 
10
  Usage:
11
  from scripts.preprocessing import MnistDataset, create_dataloaders
12
+
13
  train_dataset = MnistDataset(x_train, y_train, transform=None)
14
  train_loader, val_loader = create_dataloaders(
15
  train_dataset, val_dataset, batch_size=64
 
26
  class MnistDataset(Dataset):
27
  """
28
  PyTorch Dataset for MNIST images.
29
+
30
  Handles normalization and conversion to tensors suitable for CNN training.
31
  """
32
+
33
  def __init__(
34
  self,
35
  images: List[NDArray[np.uint8]],
 
38
  ):
39
  """
40
  Initialize MNIST dataset.
41
+
42
  Args:
43
  images: List of 28x28 numpy arrays with pixel values [0, 255]
44
  labels: List of integer labels (0-9)
 
47
  self.images = images
48
  self.labels = labels
49
  self.transform = transform
50
+
51
  # Validate inputs
52
  assert len(images) == len(labels), \
53
  f"Mismatch: {len(images)} images but {len(labels)} labels"
54
+
55
  def __len__(self) -> int:
56
  """Return number of samples in dataset."""
57
  return len(self.images)
58
+
59
  def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
60
  """
61
  Get a single sample.
62
+
63
  Args:
64
  idx: Index of sample to retrieve
65
+
66
  Returns:
67
  Tuple of (image_tensor, label_tensor)
68
  - image_tensor: Shape (1, 28, 28), dtype float32, range [0, 1]
 
71
  # Get image and label
72
  image = np.array(self.images[idx])
73
  label = self.labels[idx]
74
+
75
  # Normalize to [0, 1]
76
  image = image.astype(np.float32) / 255.0
77
+
78
  # Convert to tensor and add channel dimension: (28, 28) -> (1, 28, 28)
79
  image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)
80
  label = torch.tensor(label, dtype=torch.long)
81
+
82
  # Apply transforms if provided (e.g., augmentation)
83
  if self.transform:
84
  image = self.transform(image)
85
+
86
  return image, label
87
 
88
 
 
95
  ) -> Tuple[DataLoader, DataLoader]:
96
  """
97
  Create DataLoader instances for training and validation.
98
+
99
  Args:
100
  train_dataset: Training dataset
101
  val_dataset: Validation dataset
102
  batch_size: Number of samples per batch
103
  num_workers: Number of worker processes for data loading
104
  shuffle_train: Whether to shuffle training data
105
+
106
  Returns:
107
  Tuple of (train_loader, val_loader)
108
  """
 
113
  num_workers=num_workers,
114
  pin_memory=True # Faster GPU transfer
115
  )
116
+
117
  val_loader = DataLoader(
118
  val_dataset,
119
  batch_size=batch_size,
 
121
  num_workers=num_workers,
122
  pin_memory=True
123
  )
124
+
125
  return train_loader, val_loader
126
 
127
 
 
132
  ) -> DataLoader:
133
  """
134
  Create DataLoader for test set.
135
+
136
  Args:
137
  test_dataset: Test dataset
138
  batch_size: Number of samples per batch
139
  num_workers: Number of worker processes for data loading
140
+
141
  Returns:
142
  Test DataLoader
143
  """
 
148
  num_workers=num_workers,
149
  pin_memory=True
150
  )
151
+
152
  return test_loader
153
 
154
 
 
163
  ]:
164
  """
165
  Split training data into train and validation sets.
166
+
167
  Uses stratified sampling to maintain class balance.
168
+
169
  Args:
170
  images: List of training images
171
  labels: List of training labels
172
  val_split: Fraction of data to use for validation (0.15 = 15%)
173
  random_seed: Random seed for reproducibility
174
+
175
  Returns:
176
  Tuple of ((train_images, train_labels), (val_images, val_labels))
177
  """
178
  from collections import defaultdict
179
+
180
  # Group indices by class for stratified split
181
  class_indices = defaultdict(list)
182
  for idx, label in enumerate(labels):
183
  class_indices[label].append(idx)
184
+
185
  # Set random seed
186
  np.random.seed(random_seed)
187
+
188
  train_indices = []
189
  val_indices = []
190
+
191
  # Split each class separately
192
  for class_label, indices in class_indices.items():
193
  indices = np.array(indices)
194
  np.random.shuffle(indices)
195
+
196
  split_point = int(len(indices) * (1 - val_split))
197
  train_indices.extend(indices[:split_point])
198
  val_indices.extend(indices[split_point:])
199
+
200
  # Shuffle combined indices
201
  np.random.shuffle(train_indices)
202
  np.random.shuffle(val_indices)
203
+
204
  # Extract images and labels
205
  train_images = [images[i] for i in train_indices]
206
  train_labels = [labels[i] for i in train_indices]
207
  val_images = [images[i] for i in val_indices]
208
  val_labels = [labels[i] for i in val_indices]
209
+
210
  return (train_images, train_labels), (val_images, val_labels)
211
 
212
 
213
  def get_dataset_statistics(dataset: MnistDataset) -> dict:
214
  """
215
  Compute statistics for a dataset (useful for debugging).
216
+
217
  Args:
218
  dataset: MnistDataset instance
219
+
220
  Returns:
221
  Dictionary with statistics
222
  """
223
  # Sample first image to check preprocessing
224
  sample_img, sample_label = dataset[0]
225
+
226
  # Count labels
227
  from collections import Counter
228
  label_counts = Counter([dataset[i][1].item() for i in range(len(dataset))])
229
+
230
  return {
231
  'num_samples': len(dataset),
232
  'sample_image_shape': tuple(sample_img.shape),
scripts/test_data_loader.py CHANGED
@@ -19,24 +19,24 @@ def test_data_loader():
19
  """Test MNIST data loader with actual files."""
20
  print("Testing MNIST Data Loader...")
21
  print("-" * 50)
22
-
23
  # Note: Files need to be uncompressed first
24
  # If .gz files exist, uncompress with: gunzip data/raw/*.gz
25
-
26
  base_path = Path(__file__).parent.parent / "data" / "raw"
27
-
28
  # Try to find uncompressed files
29
  train_images = base_path / "train-images.idx3-ubyte"
30
  train_labels = base_path / "train-labels.idx1-ubyte"
31
  test_images = base_path / "t10k-images.idx3-ubyte"
32
  test_labels = base_path / "t10k-labels.idx1-ubyte"
33
-
34
  # Check if files exist
35
  missing_files = []
36
  for filepath in [train_images, train_labels, test_images, test_labels]:
37
  if not filepath.exists():
38
  missing_files.append(str(filepath))
39
-
40
  if missing_files:
41
  print("⚠️ Missing uncompressed data files:")
42
  for f in missing_files:
@@ -44,7 +44,7 @@ def test_data_loader():
44
  print("\nTo uncompress .gz files, run:")
45
  print(" cd data/raw && gunzip *.gz")
46
  return False
47
-
48
  try:
49
  # Initialize loader
50
  loader = MnistDataloader(
@@ -54,37 +54,37 @@ def test_data_loader():
54
  str(test_labels)
55
  )
56
  print("✓ Loader initialized successfully")
57
-
58
  # Load data
59
  print("\nLoading MNIST dataset...")
60
  (x_train, y_train), (x_test, y_test) = loader.load_data()
61
-
62
  # Verify shapes
63
  print(f"\n✓ Training set: {len(x_train):,} images, {len(y_train):,} labels")
64
  print(f"✓ Test set: {len(x_test):,} images, {len(y_test):,} labels")
65
-
66
  # Convert first image to numpy array to check
67
  import numpy as np
68
  first_img = np.array(x_train[0])
69
  print(f"\n✓ Image shape: {first_img.shape}")
70
  print(f"✓ Image dtype: {first_img.dtype}")
71
  print(f"✓ Label type: {type(y_train[0])}")
72
-
73
  # Verify label range
74
  unique_labels = set(y_train + y_test)
75
  print(f"\n✓ Unique labels: {sorted(unique_labels)}")
76
-
77
  # Verify pixel value range (convert to numpy for analysis)
78
  sample_images = [np.array(img) for img in x_train[:100]]
79
  max_val = max(img.max() for img in sample_images)
80
  min_val = min(img.min() for img in sample_images)
81
  print(f"✓ Pixel value range (sample): [{min_val}, {max_val}]")
82
-
83
  print("\n" + "=" * 50)
84
  print("✅ All tests passed!")
85
  print("=" * 50)
86
  return True
87
-
88
  except Exception as e:
89
  print(f"\n❌ Error: {e}")
90
  import traceback
 
19
  """Test MNIST data loader with actual files."""
20
  print("Testing MNIST Data Loader...")
21
  print("-" * 50)
22
+
23
  # Note: Files need to be uncompressed first
24
  # If .gz files exist, uncompress with: gunzip data/raw/*.gz
25
+
26
  base_path = Path(__file__).parent.parent / "data" / "raw"
27
+
28
  # Try to find uncompressed files
29
  train_images = base_path / "train-images.idx3-ubyte"
30
  train_labels = base_path / "train-labels.idx1-ubyte"
31
  test_images = base_path / "t10k-images.idx3-ubyte"
32
  test_labels = base_path / "t10k-labels.idx1-ubyte"
33
+
34
  # Check if files exist
35
  missing_files = []
36
  for filepath in [train_images, train_labels, test_images, test_labels]:
37
  if not filepath.exists():
38
  missing_files.append(str(filepath))
39
+
40
  if missing_files:
41
  print("⚠️ Missing uncompressed data files:")
42
  for f in missing_files:
 
44
  print("\nTo uncompress .gz files, run:")
45
  print(" cd data/raw && gunzip *.gz")
46
  return False
47
+
48
  try:
49
  # Initialize loader
50
  loader = MnistDataloader(
 
54
  str(test_labels)
55
  )
56
  print("✓ Loader initialized successfully")
57
+
58
  # Load data
59
  print("\nLoading MNIST dataset...")
60
  (x_train, y_train), (x_test, y_test) = loader.load_data()
61
+
62
  # Verify shapes
63
  print(f"\n✓ Training set: {len(x_train):,} images, {len(y_train):,} labels")
64
  print(f"✓ Test set: {len(x_test):,} images, {len(y_test):,} labels")
65
+
66
  # Convert first image to numpy array to check
67
  import numpy as np
68
  first_img = np.array(x_train[0])
69
  print(f"\n✓ Image shape: {first_img.shape}")
70
  print(f"✓ Image dtype: {first_img.dtype}")
71
  print(f"✓ Label type: {type(y_train[0])}")
72
+
73
  # Verify label range
74
  unique_labels = set(y_train + y_test)
75
  print(f"\n✓ Unique labels: {sorted(unique_labels)}")
76
+
77
  # Verify pixel value range (convert to numpy for analysis)
78
  sample_images = [np.array(img) for img in x_train[:100]]
79
  max_val = max(img.max() for img in sample_images)
80
  min_val = min(img.min() for img in sample_images)
81
  print(f"✓ Pixel value range (sample): [{min_val}, {max_val}]")
82
+
83
  print("\n" + "=" * 50)
84
  print("✅ All tests passed!")
85
  print("=" * 50)
86
  return True
87
+
88
  except Exception as e:
89
  print(f"\n❌ Error: {e}")
90
  import traceback
scripts/test_data_quality.py CHANGED
@@ -21,32 +21,32 @@ def main():
21
  """Run quality checks and save report."""
22
  print("Loading MNIST dataset...")
23
  data_path = project_root / "data" / "raw"
24
-
25
  loader = MnistDataloader(
26
  str(data_path / "train-images.idx3-ubyte"),
27
  str(data_path / "train-labels.idx1-ubyte"),
28
  str(data_path / "t10k-images.idx3-ubyte"),
29
  str(data_path / "t10k-labels.idx1-ubyte")
30
  )
31
-
32
  (x_train, y_train), (x_test, y_test) = loader.load_data()
33
  print("✓ Dataset loaded\n")
34
-
35
  # Generate quality report
36
  print("Running quality checks...")
37
  report = generate_quality_report((x_train, y_train), (x_test, y_test))
38
  print("✓ Quality checks complete\n")
39
-
40
  # Print summary
41
  print_quality_summary(report)
42
-
43
  # Save report as JSON
44
  output_path = project_root / "data" / "quality_report.json"
45
  with open(output_path, 'w') as f:
46
  json.dump(report, f, indent=2)
47
-
48
  print(f"✓ Quality report saved to: {output_path}")
49
-
50
  return 0 if report['summary']['all_checks_pass'] else 1
51
 
52
 
 
21
  """Run quality checks and save report."""
22
  print("Loading MNIST dataset...")
23
  data_path = project_root / "data" / "raw"
24
+
25
  loader = MnistDataloader(
26
  str(data_path / "train-images.idx3-ubyte"),
27
  str(data_path / "train-labels.idx1-ubyte"),
28
  str(data_path / "t10k-images.idx3-ubyte"),
29
  str(data_path / "t10k-labels.idx1-ubyte")
30
  )
31
+
32
  (x_train, y_train), (x_test, y_test) = loader.load_data()
33
  print("✓ Dataset loaded\n")
34
+
35
  # Generate quality report
36
  print("Running quality checks...")
37
  report = generate_quality_report((x_train, y_train), (x_test, y_test))
38
  print("✓ Quality checks complete\n")
39
+
40
  # Print summary
41
  print_quality_summary(report)
42
+
43
  # Save report as JSON
44
  output_path = project_root / "data" / "quality_report.json"
45
  with open(output_path, 'w') as f:
46
  json.dump(report, f, indent=2)
47
+
48
  print(f"✓ Quality report saved to: {output_path}")
49
+
50
  return 0 if report['summary']['all_checks_pass'] else 1
51
 
52
 
scripts/test_preprocessing.py CHANGED
@@ -32,7 +32,7 @@ def test_dataset():
32
  print("=" * 60)
33
  print("TEST 1: MnistDataset Initialization and Indexing")
34
  print("=" * 60)
35
-
36
  # Load data
37
  data_path = project_root / "data" / "raw"
38
  loader = MnistDataloader(
@@ -42,11 +42,11 @@ def test_dataset():
42
  str(data_path / "t10k-labels.idx1-ubyte")
43
  )
44
  (x_train, y_train), (x_test, y_test) = loader.load_data()
45
-
46
  # Create dataset (small subset for testing)
47
  dataset = MnistDataset(x_train[:1000], y_train[:1000])
48
  print(f"✓ Dataset created with {len(dataset)} samples")
49
-
50
  # Test __getitem__
51
  image, label = dataset[0]
52
  print("✓ Retrieved sample 0")
@@ -54,7 +54,7 @@ def test_dataset():
54
  print(f" Image dtype: {image.dtype}")
55
  print(f" Image range: [{image.min():.4f}, {image.max():.4f}]")
56
  print(f" Label: {label.item()} (dtype: {label.dtype})")
57
-
58
  # Verify normalization
59
  assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}"
60
  assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}"
@@ -63,7 +63,7 @@ def test_dataset():
63
  assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}"
64
  print("✓ All assertions passed")
65
  print()
66
-
67
  return dataset
68
 
69
 
@@ -72,10 +72,10 @@ def test_dataloader(dataset):
72
  print("=" * 60)
73
  print("TEST 2: DataLoader Batching")
74
  print("=" * 60)
75
-
76
  loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
77
  print("✓ DataLoader created (batch_size=32)")
78
-
79
  # Get first batch
80
  images, labels = next(iter(loader))
81
  print("✓ Retrieved first batch")
@@ -83,7 +83,7 @@ def test_dataloader(dataset):
83
  print(f" Batch labels shape: {labels.shape}")
84
  print(f" Images dtype: {images.dtype}")
85
  print(f" Labels dtype: {labels.dtype}")
86
-
87
  # Verify batch dimensions
88
  assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}"
89
  assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}"
@@ -98,7 +98,7 @@ def test_train_val_split():
98
  print("=" * 60)
99
  print("TEST 3: Train/Validation Split")
100
  print("=" * 60)
101
-
102
  # Load data
103
  data_path = project_root / "data" / "raw"
104
  loader = MnistDataloader(
@@ -108,36 +108,38 @@ def test_train_val_split():
108
  str(data_path / "t10k-labels.idx1-ubyte")
109
  )
110
  (x_train, y_train), _ = loader.load_data()
111
-
112
  # Split
113
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
114
  x_train, y_train, val_split=0.15, random_seed=42
115
  )
116
-
117
  print("✓ Split completed")
118
  print(f" Original training: {len(x_train):,} samples")
119
- print(f" New training: {len(x_train_split):,} samples ({len(x_train_split)/len(x_train)*100:.1f}%)")
120
- print(f" Validation: {len(x_val):,} samples ({len(x_val)/len(x_train)*100:.1f}%)")
121
-
 
 
122
  # Verify split ratio
123
  expected_val_size = int(len(x_train) * 0.15)
124
  assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect"
125
  assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split"
126
  print("✓ Split ratio correct")
127
-
128
  # Check stratification (class balance)
129
  from collections import Counter
130
  train_counts = Counter(y_train_split)
131
  val_counts = Counter(y_val)
132
-
133
  print("\n Class distribution in training set:")
134
  for digit in range(10):
135
  print(f" Digit {digit}: {train_counts[digit]:>5,} samples")
136
-
137
  print("\n Class distribution in validation set:")
138
  for digit in range(10):
139
  print(f" Digit {digit}: {val_counts[digit]:>4,} samples")
140
-
141
  # Verify each class is present in both sets
142
  assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train"
143
  assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation"
@@ -150,7 +152,7 @@ def test_full_pipeline():
150
  print("=" * 60)
151
  print("TEST 4: Full Pipeline")
152
  print("=" * 60)
153
-
154
  # Load data
155
  data_path = project_root / "data" / "raw"
156
  loader = MnistDataloader(
@@ -161,19 +163,19 @@ def test_full_pipeline():
161
  )
162
  (x_train, y_train), (x_test, y_test) = loader.load_data()
163
  print("✓ Data loaded")
164
-
165
  # Split train/val
166
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
167
  x_train, y_train, val_split=0.15
168
  )
169
  print("✓ Train/val split completed")
170
-
171
  # Create datasets
172
  train_dataset = MnistDataset(x_train_split, y_train_split)
173
  val_dataset = MnistDataset(x_val, y_val)
174
  test_dataset = MnistDataset(x_test, y_test)
175
  print("✓ Datasets created")
176
-
177
  # Get statistics
178
  train_stats = get_dataset_statistics(train_dataset)
179
  print("\n Training dataset statistics:")
@@ -182,7 +184,7 @@ def test_full_pipeline():
182
  print(f" Image dtype: {train_stats['sample_image_dtype']}")
183
  print(f" Image range: {train_stats['sample_image_range']}")
184
  print(f" Label dtype: {train_stats['sample_label_dtype']}")
185
-
186
  # Create dataloaders
187
  train_loader, val_loader = create_dataloaders(
188
  train_dataset, val_dataset, batch_size=64, num_workers=0
@@ -192,12 +194,12 @@ def test_full_pipeline():
192
  print(f" Training batches: {len(train_loader)}")
193
  print(f" Validation batches: {len(val_loader)}")
194
  print(f" Test batches: {len(test_loader)}")
195
-
196
  # Test iteration
197
  train_batch = next(iter(train_loader))
198
  val_batch = next(iter(val_loader))
199
  test_batch = next(iter(test_loader))
200
-
201
  print("\n✓ Successfully iterated through all loaders")
202
  print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}")
203
  print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}")
@@ -209,19 +211,19 @@ def main():
209
  """Run all tests."""
210
  print("\nTesting MNIST Preprocessing Pipeline")
211
  print()
212
-
213
  try:
214
  dataset = test_dataset()
215
  test_dataloader(dataset)
216
  test_train_val_split()
217
  test_full_pipeline()
218
-
219
  print("=" * 60)
220
  print("✅ ALL TESTS PASSED")
221
  print("=" * 60)
222
  print("\nPreprocessing pipeline is ready for model training!")
223
  return 0
224
-
225
  except Exception as e:
226
  print(f"\n❌ TEST FAILED: {e}")
227
  import traceback
 
32
  print("=" * 60)
33
  print("TEST 1: MnistDataset Initialization and Indexing")
34
  print("=" * 60)
35
+
36
  # Load data
37
  data_path = project_root / "data" / "raw"
38
  loader = MnistDataloader(
 
42
  str(data_path / "t10k-labels.idx1-ubyte")
43
  )
44
  (x_train, y_train), (x_test, y_test) = loader.load_data()
45
+
46
  # Create dataset (small subset for testing)
47
  dataset = MnistDataset(x_train[:1000], y_train[:1000])
48
  print(f"✓ Dataset created with {len(dataset)} samples")
49
+
50
  # Test __getitem__
51
  image, label = dataset[0]
52
  print("✓ Retrieved sample 0")
 
54
  print(f" Image dtype: {image.dtype}")
55
  print(f" Image range: [{image.min():.4f}, {image.max():.4f}]")
56
  print(f" Label: {label.item()} (dtype: {label.dtype})")
57
+
58
  # Verify normalization
59
  assert image.shape == (1, 28, 28), f"Wrong shape: {image.shape}"
60
  assert image.dtype == torch.float32, f"Wrong dtype: {image.dtype}"
 
63
  assert label.dtype == torch.long, f"Label wrong dtype: {label.dtype}"
64
  print("✓ All assertions passed")
65
  print()
66
+
67
  return dataset
68
 
69
 
 
72
  print("=" * 60)
73
  print("TEST 2: DataLoader Batching")
74
  print("=" * 60)
75
+
76
  loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
77
  print("✓ DataLoader created (batch_size=32)")
78
+
79
  # Get first batch
80
  images, labels = next(iter(loader))
81
  print("✓ Retrieved first batch")
 
83
  print(f" Batch labels shape: {labels.shape}")
84
  print(f" Images dtype: {images.dtype}")
85
  print(f" Labels dtype: {labels.dtype}")
86
+
87
  # Verify batch dimensions
88
  assert images.shape == (32, 1, 28, 28), f"Wrong batch shape: {images.shape}"
89
  assert labels.shape == (32,), f"Wrong labels shape: {labels.shape}"
 
98
  print("=" * 60)
99
  print("TEST 3: Train/Validation Split")
100
  print("=" * 60)
101
+
102
  # Load data
103
  data_path = project_root / "data" / "raw"
104
  loader = MnistDataloader(
 
108
  str(data_path / "t10k-labels.idx1-ubyte")
109
  )
110
  (x_train, y_train), _ = loader.load_data()
111
+
112
  # Split
113
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
114
  x_train, y_train, val_split=0.15, random_seed=42
115
  )
116
+
117
  print("✓ Split completed")
118
  print(f" Original training: {len(x_train):,} samples")
119
+ train_pct = len(x_train_split) / len(x_train) * 100
120
+ print(f" New training: {len(x_train_split):,} samples ({train_pct:.1f}%)")
121
+ val_pct = len(x_val) / len(x_train) * 100
122
+ print(f" Validation: {len(x_val):,} samples ({val_pct:.1f}%)")
123
+
124
  # Verify split ratio
125
  expected_val_size = int(len(x_train) * 0.15)
126
  assert abs(len(x_val) - expected_val_size) < 100, "Split ratio incorrect"
127
  assert len(x_train_split) + len(x_val) == len(x_train), "Data loss during split"
128
  print("✓ Split ratio correct")
129
+
130
  # Check stratification (class balance)
131
  from collections import Counter
132
  train_counts = Counter(y_train_split)
133
  val_counts = Counter(y_val)
134
+
135
  print("\n Class distribution in training set:")
136
  for digit in range(10):
137
  print(f" Digit {digit}: {train_counts[digit]:>5,} samples")
138
+
139
  print("\n Class distribution in validation set:")
140
  for digit in range(10):
141
  print(f" Digit {digit}: {val_counts[digit]:>4,} samples")
142
+
143
  # Verify each class is present in both sets
144
  assert all(train_counts[i] > 0 for i in range(10)), "Missing class in train"
145
  assert all(val_counts[i] > 0 for i in range(10)), "Missing class in validation"
 
152
  print("=" * 60)
153
  print("TEST 4: Full Pipeline")
154
  print("=" * 60)
155
+
156
  # Load data
157
  data_path = project_root / "data" / "raw"
158
  loader = MnistDataloader(
 
163
  )
164
  (x_train, y_train), (x_test, y_test) = loader.load_data()
165
  print("✓ Data loaded")
166
+
167
  # Split train/val
168
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
169
  x_train, y_train, val_split=0.15
170
  )
171
  print("✓ Train/val split completed")
172
+
173
  # Create datasets
174
  train_dataset = MnistDataset(x_train_split, y_train_split)
175
  val_dataset = MnistDataset(x_val, y_val)
176
  test_dataset = MnistDataset(x_test, y_test)
177
  print("✓ Datasets created")
178
+
179
  # Get statistics
180
  train_stats = get_dataset_statistics(train_dataset)
181
  print("\n Training dataset statistics:")
 
184
  print(f" Image dtype: {train_stats['sample_image_dtype']}")
185
  print(f" Image range: {train_stats['sample_image_range']}")
186
  print(f" Label dtype: {train_stats['sample_label_dtype']}")
187
+
188
  # Create dataloaders
189
  train_loader, val_loader = create_dataloaders(
190
  train_dataset, val_dataset, batch_size=64, num_workers=0
 
194
  print(f" Training batches: {len(train_loader)}")
195
  print(f" Validation batches: {len(val_loader)}")
196
  print(f" Test batches: {len(test_loader)}")
197
+
198
  # Test iteration
199
  train_batch = next(iter(train_loader))
200
  val_batch = next(iter(val_loader))
201
  test_batch = next(iter(test_loader))
202
+
203
  print("\n✓ Successfully iterated through all loaders")
204
  print(f" Train batch shapes: {train_batch[0].shape}, {train_batch[1].shape}")
205
  print(f" Val batch shapes: {val_batch[0].shape}, {val_batch[1].shape}")
 
211
  """Run all tests."""
212
  print("\nTesting MNIST Preprocessing Pipeline")
213
  print()
214
+
215
  try:
216
  dataset = test_dataset()
217
  test_dataloader(dataset)
218
  test_train_val_split()
219
  test_full_pipeline()
220
+
221
  print("=" * 60)
222
  print("✅ ALL TESTS PASSED")
223
  print("=" * 60)
224
  print("\nPreprocessing pipeline is ready for model training!")
225
  return 0
226
+
227
  except Exception as e:
228
  print(f"\n❌ TEST FAILED: {e}")
229
  import traceback
scripts/test_train.py CHANGED
@@ -28,7 +28,7 @@ def main():
28
  print("Testing Training Pipeline")
29
  print("=" * 60)
30
  print()
31
-
32
  # Load data (small subset for quick test)
33
  print("1. Loading MNIST data...")
34
  data_path = project_root / "data" / "raw"
@@ -41,7 +41,7 @@ def main():
41
  (x_train, y_train), (x_test, y_test) = loader.load_data()
42
  print(f"✓ Loaded {len(x_train):,} training samples")
43
  print()
44
-
45
  # Use small subset for quick test (1000 samples)
46
  print("2. Creating train/val split...")
47
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
@@ -50,13 +50,14 @@ def main():
50
  print(f"✓ Train: {len(x_train_split)} samples")
51
  print(f"✓ Val: {len(x_val)} samples")
52
  print()
53
-
54
  # Create datasets and loaders
55
  print("3. Creating datasets and loaders...")
56
  train_dataset = MnistDataset(x_train_split, y_train_split, transform=None)
57
  val_dataset = MnistDataset(x_val, y_val, transform=None)
58
- test_dataset = MnistDataset(x_test[:200], y_test[:200], transform=None) # Small test set
59
-
 
60
  train_loader, val_loader = create_dataloaders(
61
  train_dataset, val_dataset, batch_size=32, num_workers=0
62
  )
@@ -67,7 +68,7 @@ def main():
67
  print(f"✓ Val batches: {len(val_loader)}")
68
  print(f"✓ Test batches: {len(test_loader)}")
69
  print()
70
-
71
  # Create model
72
  print("4. Creating model...")
73
  model = BaselineCNN()
@@ -75,7 +76,7 @@ def main():
75
  print(f"✓ Model: {model.__class__.__name__}")
76
  print(f"✓ Device: {device}")
77
  print()
78
-
79
  # Train model (short run for testing)
80
  print("5. Training model (3 epochs for testing)...")
81
  print("-" * 60)
@@ -93,30 +94,30 @@ def main():
93
  )
94
  print("-" * 60)
95
  print()
96
-
97
  # Check checkpoints exist
98
  print("6. Verifying checkpoints...")
99
  best_model_path = project_root / "models" / "best_model.pt"
100
  last_model_path = project_root / "models" / "last_model.pt"
101
-
102
  assert best_model_path.exists(), "Best model checkpoint not found"
103
  assert last_model_path.exists(), "Last model checkpoint not found"
104
  print("✓ Best model saved")
105
  print("✓ Last model saved")
106
  print()
107
-
108
  # Save history
109
  print("7. Saving training history...")
110
  history_path = project_root / "experiments" / "test_training_history.json"
111
  save_training_history(history, str(history_path))
112
  print()
113
-
114
  # Evaluate on test set
115
  print("8. Evaluating on test set...")
116
  results = evaluate_model(model, test_loader, device=device)
117
  print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
118
  print()
119
-
120
  # Print per-class metrics
121
  print("Per-class metrics:")
122
  report = results['classification_report']
@@ -128,7 +129,7 @@ def main():
128
  f"Recall={metrics['recall']:.3f}, "
129
  f"F1={metrics['f1-score']:.3f}")
130
  print()
131
-
132
  # Summary
133
  print("=" * 60)
134
  print("✅ ALL TESTS PASSED")
@@ -138,7 +139,7 @@ def main():
138
  print(f"Test accuracy: {results['accuracy']:.2f}%")
139
  print("\nNote: These are quick test results with limited data.")
140
  print("For full training, use complete dataset and more epochs.")
141
-
142
  return 0
143
 
144
 
 
28
  print("Testing Training Pipeline")
29
  print("=" * 60)
30
  print()
31
+
32
  # Load data (small subset for quick test)
33
  print("1. Loading MNIST data...")
34
  data_path = project_root / "data" / "raw"
 
41
  (x_train, y_train), (x_test, y_test) = loader.load_data()
42
  print(f"✓ Loaded {len(x_train):,} training samples")
43
  print()
44
+
45
  # Use small subset for quick test (1000 samples)
46
  print("2. Creating train/val split...")
47
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
 
50
  print(f"✓ Train: {len(x_train_split)} samples")
51
  print(f"✓ Val: {len(x_val)} samples")
52
  print()
53
+
54
  # Create datasets and loaders
55
  print("3. Creating datasets and loaders...")
56
  train_dataset = MnistDataset(x_train_split, y_train_split, transform=None)
57
  val_dataset = MnistDataset(x_val, y_val, transform=None)
58
+ # Small test set for quick validation
59
+ test_dataset = MnistDataset(x_test[:200], y_test[:200], transform=None)
60
+
61
  train_loader, val_loader = create_dataloaders(
62
  train_dataset, val_dataset, batch_size=32, num_workers=0
63
  )
 
68
  print(f"✓ Val batches: {len(val_loader)}")
69
  print(f"✓ Test batches: {len(test_loader)}")
70
  print()
71
+
72
  # Create model
73
  print("4. Creating model...")
74
  model = BaselineCNN()
 
76
  print(f"✓ Model: {model.__class__.__name__}")
77
  print(f"✓ Device: {device}")
78
  print()
79
+
80
  # Train model (short run for testing)
81
  print("5. Training model (3 epochs for testing)...")
82
  print("-" * 60)
 
94
  )
95
  print("-" * 60)
96
  print()
97
+
98
  # Check checkpoints exist
99
  print("6. Verifying checkpoints...")
100
  best_model_path = project_root / "models" / "best_model.pt"
101
  last_model_path = project_root / "models" / "last_model.pt"
102
+
103
  assert best_model_path.exists(), "Best model checkpoint not found"
104
  assert last_model_path.exists(), "Last model checkpoint not found"
105
  print("✓ Best model saved")
106
  print("✓ Last model saved")
107
  print()
108
+
109
  # Save history
110
  print("7. Saving training history...")
111
  history_path = project_root / "experiments" / "test_training_history.json"
112
  save_training_history(history, str(history_path))
113
  print()
114
+
115
  # Evaluate on test set
116
  print("8. Evaluating on test set...")
117
  results = evaluate_model(model, test_loader, device=device)
118
  print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
119
  print()
120
+
121
  # Print per-class metrics
122
  print("Per-class metrics:")
123
  report = results['classification_report']
 
129
  f"Recall={metrics['recall']:.3f}, "
130
  f"F1={metrics['f1-score']:.3f}")
131
  print()
132
+
133
  # Summary
134
  print("=" * 60)
135
  print("✅ ALL TESTS PASSED")
 
139
  print(f"Test accuracy: {results['accuracy']:.2f}%")
140
  print("\nNote: These are quick test results with limited data.")
141
  print("For full training, use complete dataset and more epochs.")
142
+
143
  return 0
144
 
145
 
scripts/train.py CHANGED
@@ -12,7 +12,7 @@ Supports MLflow experiment tracking for reproducibility.
12
  Usage:
13
  from scripts.train import train_model
14
  from scripts.models import BaselineCNN
15
-
16
  model = BaselineCNN()
17
  history = train_model(
18
  model, train_loader, val_loader,
@@ -39,14 +39,14 @@ def train_epoch(
39
  ) -> Dict[str, float]:
40
  """
41
  Train model for one epoch.
42
-
43
  Args:
44
  model: PyTorch model
45
  train_loader: Training data loader
46
  criterion: Loss function
47
  optimizer: Optimizer
48
  device: Device to train on ('cpu' or 'cuda')
49
-
50
  Returns:
51
  Dictionary with 'loss' and 'accuracy' metrics
52
  """
@@ -54,25 +54,25 @@ def train_epoch(
54
  total_loss = 0.0
55
  correct = 0
56
  total = 0
57
-
58
  for images, labels in train_loader:
59
  images, labels = images.to(device), labels.to(device)
60
-
61
  # Forward pass
62
  optimizer.zero_grad()
63
  outputs = model(images)
64
  loss = criterion(outputs, labels)
65
-
66
  # Backward pass
67
  loss.backward()
68
  optimizer.step()
69
-
70
  # Track metrics
71
  total_loss += loss.item()
72
  _, predicted = outputs.max(1)
73
  correct += predicted.eq(labels).sum().item()
74
  total += labels.size(0)
75
-
76
  return {
77
  'loss': total_loss / len(train_loader),
78
  'accuracy': 100.0 * correct / total
@@ -87,13 +87,13 @@ def validate(
87
  ) -> Dict[str, float]:
88
  """
89
  Evaluate model on validation/test set.
90
-
91
  Args:
92
  model: PyTorch model
93
  val_loader: Validation data loader
94
  criterion: Loss function
95
  device: Device to evaluate on
96
-
97
  Returns:
98
  Dictionary with 'loss' and 'accuracy' metrics
99
  """
@@ -101,21 +101,21 @@ def validate(
101
  total_loss = 0.0
102
  correct = 0
103
  total = 0
104
-
105
  with torch.no_grad():
106
  for images, labels in val_loader:
107
  images, labels = images.to(device), labels.to(device)
108
-
109
  # Forward pass
110
  outputs = model(images)
111
  loss = criterion(outputs, labels)
112
-
113
  # Track metrics
114
  total_loss += loss.item()
115
  _, predicted = outputs.max(1)
116
  correct += predicted.eq(labels).sum().item()
117
  total += labels.size(0)
118
-
119
  return {
120
  'loss': total_loss / len(val_loader),
121
  'accuracy': 100.0 * correct / total
@@ -136,7 +136,7 @@ def train_model(
136
  ) -> Dict[str, List[float]]:
137
  """
138
  Train model with early stopping and checkpointing.
139
-
140
  Args:
141
  model: PyTorch model
142
  train_loader: Training data loader
@@ -148,7 +148,7 @@ def train_model(
148
  device: Device to train on (auto-detect if None)
149
  use_scheduler: Whether to use learning rate scheduler
150
  verbose: Print training progress
151
-
152
  Returns:
153
  Dictionary with training history (losses and accuracies)
154
  """
@@ -156,30 +156,30 @@ def train_model(
156
  if device is None:
157
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
  model = model.to(device)
159
-
160
  if verbose:
161
  print(f"Training on device: {device}")
162
  print(f"Model: {model.__class__.__name__}")
163
  print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
164
  print()
165
-
166
  # Setup training components
167
  criterion = nn.CrossEntropyLoss()
168
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
169
-
170
  # Learning rate scheduler
171
  scheduler = None
172
  if use_scheduler:
173
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
174
  optimizer, mode='min', patience=3, factor=0.5, verbose=verbose
175
  )
176
-
177
  # Setup checkpointing
178
  checkpoint_path = Path(checkpoint_dir)
179
  checkpoint_path.mkdir(parents=True, exist_ok=True)
180
  best_model_path = checkpoint_path / 'best_model.pt'
181
  last_model_path = checkpoint_path / 'last_model.pt'
182
-
183
  # Training history
184
  history = {
185
  'train_loss': [],
@@ -188,26 +188,26 @@ def train_model(
188
  'val_accuracy': [],
189
  'learning_rate': []
190
  }
191
-
192
  # Early stopping setup
193
  best_val_loss = float('inf')
194
  epochs_without_improvement = 0
195
-
196
  # Training loop
197
  for epoch in range(num_epochs):
198
  # Train
199
  train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
200
-
201
  # Validate
202
  val_metrics = validate(model, val_loader, criterion, device)
203
-
204
  # Update history
205
  history['train_loss'].append(train_metrics['loss'])
206
  history['train_accuracy'].append(train_metrics['accuracy'])
207
  history['val_loss'].append(val_metrics['loss'])
208
  history['val_accuracy'].append(val_metrics['accuracy'])
209
  history['learning_rate'].append(optimizer.param_groups[0]['lr'])
210
-
211
  # Print progress
212
  if verbose:
213
  print(f"Epoch {epoch+1}/{num_epochs}")
@@ -217,11 +217,11 @@ def train_model(
217
  f"Val Acc: {val_metrics['accuracy']:.2f}%")
218
  print(f" LR: {optimizer.param_groups[0]['lr']:.6f}")
219
  print()
220
-
221
  # Learning rate scheduling
222
  if scheduler is not None:
223
  scheduler.step(val_metrics['loss'])
224
-
225
  # Save best model
226
  if val_metrics['loss'] < best_val_loss:
227
  best_val_loss = val_metrics['loss']
@@ -238,14 +238,14 @@ def train_model(
238
  print()
239
  else:
240
  epochs_without_improvement += 1
241
-
242
  # Early stopping
243
  if epochs_without_improvement >= patience:
244
  if verbose:
245
  print(f"Early stopping triggered after {epoch+1} epochs")
246
  print(f"Best validation loss: {best_val_loss:.4f}")
247
  break
248
-
249
  # Save last model
250
  torch.save({
251
  'epoch': epoch,
@@ -254,12 +254,12 @@ def train_model(
254
  'val_loss': val_metrics['loss'],
255
  'val_accuracy': val_metrics['accuracy']
256
  }, last_model_path)
257
-
258
  if verbose:
259
  print("Training complete!")
260
  print(f"Best validation loss: {best_val_loss:.4f}")
261
  print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%")
262
-
263
  return history
264
 
265
 
@@ -271,13 +271,13 @@ def evaluate_model(
271
  ) -> Dict:
272
  """
273
  Comprehensive model evaluation with per-class metrics.
274
-
275
  Args:
276
  model: Trained PyTorch model
277
  test_loader: Test data loader
278
  device: Device to evaluate on
279
  class_names: List of class names (default: digits 0-9)
280
-
281
  Returns:
282
  Dictionary with metrics, predictions, and confusion matrix
283
  """
@@ -285,42 +285,42 @@ def evaluate_model(
285
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
286
  model = model.to(device)
287
  model.eval()
288
-
289
  if class_names is None:
290
  class_names = [str(i) for i in range(10)]
291
-
292
  all_preds = []
293
  all_labels = []
294
  all_probs = []
295
-
296
  with torch.no_grad():
297
  for images, labels in test_loader:
298
  images = images.to(device)
299
  outputs = model(images)
300
  probs = torch.softmax(outputs, dim=1)
301
  _, predicted = outputs.max(1)
302
-
303
  all_preds.extend(predicted.cpu().numpy())
304
  all_labels.extend(labels.numpy())
305
  all_probs.extend(probs.cpu().numpy())
306
-
307
  all_preds = np.array(all_preds)
308
  all_labels = np.array(all_labels)
309
  all_probs = np.array(all_probs)
310
-
311
  # Overall metrics
312
  accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels)
313
-
314
  # Classification report
315
  report = classification_report(
316
  all_labels, all_preds,
317
  target_names=class_names,
318
  output_dict=True
319
  )
320
-
321
  # Confusion matrix
322
  conf_matrix = confusion_matrix(all_labels, all_preds)
323
-
324
  return {
325
  'accuracy': accuracy,
326
  'classification_report': report,
@@ -334,7 +334,7 @@ def evaluate_model(
334
  def save_training_history(history: Dict, filepath: str) -> None:
335
  """
336
  Save training history to JSON file.
337
-
338
  Args:
339
  history: Training history dictionary
340
  filepath: Path to save JSON file
@@ -348,11 +348,11 @@ def save_training_history(history: Dict, filepath: str) -> None:
348
  def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]:
349
  """
350
  Load model from checkpoint.
351
-
352
  Args:
353
  checkpoint_path: Path to checkpoint file
354
  model: Model instance (for loading state dict)
355
-
356
  Returns:
357
  Tuple of (loaded_model, checkpoint_dict)
358
  """
 
12
  Usage:
13
  from scripts.train import train_model
14
  from scripts.models import BaselineCNN
15
+
16
  model = BaselineCNN()
17
  history = train_model(
18
  model, train_loader, val_loader,
 
39
  ) -> Dict[str, float]:
40
  """
41
  Train model for one epoch.
42
+
43
  Args:
44
  model: PyTorch model
45
  train_loader: Training data loader
46
  criterion: Loss function
47
  optimizer: Optimizer
48
  device: Device to train on ('cpu' or 'cuda')
49
+
50
  Returns:
51
  Dictionary with 'loss' and 'accuracy' metrics
52
  """
 
54
  total_loss = 0.0
55
  correct = 0
56
  total = 0
57
+
58
  for images, labels in train_loader:
59
  images, labels = images.to(device), labels.to(device)
60
+
61
  # Forward pass
62
  optimizer.zero_grad()
63
  outputs = model(images)
64
  loss = criterion(outputs, labels)
65
+
66
  # Backward pass
67
  loss.backward()
68
  optimizer.step()
69
+
70
  # Track metrics
71
  total_loss += loss.item()
72
  _, predicted = outputs.max(1)
73
  correct += predicted.eq(labels).sum().item()
74
  total += labels.size(0)
75
+
76
  return {
77
  'loss': total_loss / len(train_loader),
78
  'accuracy': 100.0 * correct / total
 
87
  ) -> Dict[str, float]:
88
  """
89
  Evaluate model on validation/test set.
90
+
91
  Args:
92
  model: PyTorch model
93
  val_loader: Validation data loader
94
  criterion: Loss function
95
  device: Device to evaluate on
96
+
97
  Returns:
98
  Dictionary with 'loss' and 'accuracy' metrics
99
  """
 
101
  total_loss = 0.0
102
  correct = 0
103
  total = 0
104
+
105
  with torch.no_grad():
106
  for images, labels in val_loader:
107
  images, labels = images.to(device), labels.to(device)
108
+
109
  # Forward pass
110
  outputs = model(images)
111
  loss = criterion(outputs, labels)
112
+
113
  # Track metrics
114
  total_loss += loss.item()
115
  _, predicted = outputs.max(1)
116
  correct += predicted.eq(labels).sum().item()
117
  total += labels.size(0)
118
+
119
  return {
120
  'loss': total_loss / len(val_loader),
121
  'accuracy': 100.0 * correct / total
 
136
  ) -> Dict[str, List[float]]:
137
  """
138
  Train model with early stopping and checkpointing.
139
+
140
  Args:
141
  model: PyTorch model
142
  train_loader: Training data loader
 
148
  device: Device to train on (auto-detect if None)
149
  use_scheduler: Whether to use learning rate scheduler
150
  verbose: Print training progress
151
+
152
  Returns:
153
  Dictionary with training history (losses and accuracies)
154
  """
 
156
  if device is None:
157
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
  model = model.to(device)
159
+
160
  if verbose:
161
  print(f"Training on device: {device}")
162
  print(f"Model: {model.__class__.__name__}")
163
  print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
164
  print()
165
+
166
  # Setup training components
167
  criterion = nn.CrossEntropyLoss()
168
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
169
+
170
  # Learning rate scheduler
171
  scheduler = None
172
  if use_scheduler:
173
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
174
  optimizer, mode='min', patience=3, factor=0.5, verbose=verbose
175
  )
176
+
177
  # Setup checkpointing
178
  checkpoint_path = Path(checkpoint_dir)
179
  checkpoint_path.mkdir(parents=True, exist_ok=True)
180
  best_model_path = checkpoint_path / 'best_model.pt'
181
  last_model_path = checkpoint_path / 'last_model.pt'
182
+
183
  # Training history
184
  history = {
185
  'train_loss': [],
 
188
  'val_accuracy': [],
189
  'learning_rate': []
190
  }
191
+
192
  # Early stopping setup
193
  best_val_loss = float('inf')
194
  epochs_without_improvement = 0
195
+
196
  # Training loop
197
  for epoch in range(num_epochs):
198
  # Train
199
  train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
200
+
201
  # Validate
202
  val_metrics = validate(model, val_loader, criterion, device)
203
+
204
  # Update history
205
  history['train_loss'].append(train_metrics['loss'])
206
  history['train_accuracy'].append(train_metrics['accuracy'])
207
  history['val_loss'].append(val_metrics['loss'])
208
  history['val_accuracy'].append(val_metrics['accuracy'])
209
  history['learning_rate'].append(optimizer.param_groups[0]['lr'])
210
+
211
  # Print progress
212
  if verbose:
213
  print(f"Epoch {epoch+1}/{num_epochs}")
 
217
  f"Val Acc: {val_metrics['accuracy']:.2f}%")
218
  print(f" LR: {optimizer.param_groups[0]['lr']:.6f}")
219
  print()
220
+
221
  # Learning rate scheduling
222
  if scheduler is not None:
223
  scheduler.step(val_metrics['loss'])
224
+
225
  # Save best model
226
  if val_metrics['loss'] < best_val_loss:
227
  best_val_loss = val_metrics['loss']
 
238
  print()
239
  else:
240
  epochs_without_improvement += 1
241
+
242
  # Early stopping
243
  if epochs_without_improvement >= patience:
244
  if verbose:
245
  print(f"Early stopping triggered after {epoch+1} epochs")
246
  print(f"Best validation loss: {best_val_loss:.4f}")
247
  break
248
+
249
  # Save last model
250
  torch.save({
251
  'epoch': epoch,
 
254
  'val_loss': val_metrics['loss'],
255
  'val_accuracy': val_metrics['accuracy']
256
  }, last_model_path)
257
+
258
  if verbose:
259
  print("Training complete!")
260
  print(f"Best validation loss: {best_val_loss:.4f}")
261
  print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%")
262
+
263
  return history
264
 
265
 
 
271
  ) -> Dict:
272
  """
273
  Comprehensive model evaluation with per-class metrics.
274
+
275
  Args:
276
  model: Trained PyTorch model
277
  test_loader: Test data loader
278
  device: Device to evaluate on
279
  class_names: List of class names (default: digits 0-9)
280
+
281
  Returns:
282
  Dictionary with metrics, predictions, and confusion matrix
283
  """
 
285
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
286
  model = model.to(device)
287
  model.eval()
288
+
289
  if class_names is None:
290
  class_names = [str(i) for i in range(10)]
291
+
292
  all_preds = []
293
  all_labels = []
294
  all_probs = []
295
+
296
  with torch.no_grad():
297
  for images, labels in test_loader:
298
  images = images.to(device)
299
  outputs = model(images)
300
  probs = torch.softmax(outputs, dim=1)
301
  _, predicted = outputs.max(1)
302
+
303
  all_preds.extend(predicted.cpu().numpy())
304
  all_labels.extend(labels.numpy())
305
  all_probs.extend(probs.cpu().numpy())
306
+
307
  all_preds = np.array(all_preds)
308
  all_labels = np.array(all_labels)
309
  all_probs = np.array(all_probs)
310
+
311
  # Overall metrics
312
  accuracy = 100.0 * (all_preds == all_labels).sum() / len(all_labels)
313
+
314
  # Classification report
315
  report = classification_report(
316
  all_labels, all_preds,
317
  target_names=class_names,
318
  output_dict=True
319
  )
320
+
321
  # Confusion matrix
322
  conf_matrix = confusion_matrix(all_labels, all_preds)
323
+
324
  return {
325
  'accuracy': accuracy,
326
  'classification_report': report,
 
334
  def save_training_history(history: Dict, filepath: str) -> None:
335
  """
336
  Save training history to JSON file.
337
+
338
  Args:
339
  history: Training history dictionary
340
  filepath: Path to save JSON file
 
348
  def load_checkpoint(checkpoint_path: str, model: nn.Module) -> Tuple[nn.Module, Dict]:
349
  """
350
  Load model from checkpoint.
351
+
352
  Args:
353
  checkpoint_path: Path to checkpoint file
354
  model: Model instance (for loading state dict)
355
+
356
  Returns:
357
  Tuple of (loaded_model, checkpoint_dict)
358
  """
scripts/train_baseline.py CHANGED
@@ -34,9 +34,9 @@ from scripts.train import train_model, evaluate_model, save_training_history
34
  def plot_training_history(history: dict, save_path: str):
35
  """Plot and save training history curves."""
36
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
37
-
38
  epochs = range(1, len(history['train_loss']) + 1)
39
-
40
  # Loss curves
41
  axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
42
  axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss')
@@ -45,7 +45,7 @@ def plot_training_history(history: dict, save_path: str):
45
  axes[0, 0].set_title('Training and Validation Loss')
46
  axes[0, 0].legend()
47
  axes[0, 0].grid(True, alpha=0.3)
48
-
49
  # Accuracy curves
50
  axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc')
51
  axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc')
@@ -54,7 +54,7 @@ def plot_training_history(history: dict, save_path: str):
54
  axes[0, 1].set_title('Training and Validation Accuracy')
55
  axes[0, 1].legend()
56
  axes[0, 1].grid(True, alpha=0.3)
57
-
58
  # Learning rate
59
  axes[1, 0].plot(epochs, history['learning_rate'], 'g-')
60
  axes[1, 0].set_xlabel('Epoch')
@@ -62,7 +62,7 @@ def plot_training_history(history: dict, save_path: str):
62
  axes[1, 0].set_title('Learning Rate Schedule')
63
  axes[1, 0].set_yscale('log')
64
  axes[1, 0].grid(True, alpha=0.3)
65
-
66
  # Loss difference (overfitting indicator)
67
  loss_diff = np.array(history['val_loss']) - np.array(history['train_loss'])
68
  axes[1, 1].plot(epochs, loss_diff, 'm-')
@@ -71,7 +71,7 @@ def plot_training_history(history: dict, save_path: str):
71
  axes[1, 1].set_ylabel('Val Loss - Train Loss')
72
  axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)')
73
  axes[1, 1].grid(True, alpha=0.3)
74
-
75
  plt.tight_layout()
76
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
77
  print(f"Training curves saved to {save_path}")
@@ -81,10 +81,10 @@ def plot_training_history(history: dict, save_path: str):
81
  def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
82
  """Plot and save confusion matrix."""
83
  fig, ax = plt.subplots(figsize=(10, 8))
84
-
85
  im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
86
  ax.figure.colorbar(im, ax=ax)
87
-
88
  # Labels
89
  classes = list(range(10))
90
  ax.set(xticks=np.arange(conf_matrix.shape[1]),
@@ -93,10 +93,10 @@ def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
93
  title='Confusion Matrix - MNIST Digit Classification',
94
  ylabel='True Label',
95
  xlabel='Predicted Label')
96
-
97
  # Rotate the tick labels
98
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
99
-
100
  # Add text annotations
101
  thresh = conf_matrix.max() / 2.
102
  for i in range(conf_matrix.shape[0]):
@@ -104,7 +104,7 @@ def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
104
  ax.text(j, i, format(conf_matrix[i, j], 'd'),
105
  ha="center", va="center",
106
  color="white" if conf_matrix[i, j] > thresh else "black")
107
-
108
  plt.tight_layout()
109
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
110
  print(f"Confusion matrix saved to {save_path}")
@@ -112,14 +112,30 @@ def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
112
 
113
 
114
  def main():
115
- parser = argparse.ArgumentParser(description='Train baseline CNN on MNIST')
116
- parser.add_argument('--augment', action='store_true', help='Use data augmentation')
117
- parser.add_argument('--epochs', type=int, default=20, help='Number of epochs (default: 20)')
118
- parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
119
- parser.add_argument('--batch-size', type=int, default=64, help='Batch size (default: 64)')
120
- parser.add_argument('--patience', type=int, default=5, help='Early stopping patience (default: 5)')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  args = parser.parse_args()
122
-
123
  print("=" * 60)
124
  print("MNIST CNN Training - Baseline Model")
125
  print("=" * 60)
@@ -130,7 +146,7 @@ def main():
130
  print(f" Augmentation: {'Yes' if args.augment else 'No'}")
131
  print(f" Early Stopping Patience: {args.patience}")
132
  print()
133
-
134
  # 1. Load data
135
  print("1. Loading MNIST dataset...")
136
  data_path = project_root / "data" / "raw"
@@ -144,7 +160,7 @@ def main():
144
  print(f"✓ Loaded {len(x_train):,} training samples")
145
  print(f"✓ Loaded {len(x_test):,} test samples")
146
  print()
147
-
148
  # 2. Train/val split
149
  print("2. Creating train/validation split...")
150
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
@@ -154,15 +170,15 @@ def main():
154
  print(f"✓ Validation: {len(x_val):,} samples")
155
  print(f"✓ Test: {len(x_test):,} samples")
156
  print()
157
-
158
  # 3. Create datasets with optional augmentation
159
  print("3. Creating datasets...")
160
  augmentation = get_train_augmentation() if args.augment else None
161
-
162
  train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
163
  val_dataset = MnistDataset(x_val, y_val, transform=None)
164
  test_dataset = MnistDataset(x_test, y_test, transform=None)
165
-
166
  train_loader, val_loader = create_dataloaders(
167
  train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
168
  )
@@ -173,13 +189,13 @@ def main():
173
  print(f"✓ Val batches: {len(val_loader)}")
174
  print(f"✓ Test batches: {len(test_loader)}")
175
  print()
176
-
177
  # 4. Create model
178
  print("4. Creating model...")
179
  model = BaselineCNN()
180
  print(get_model_summary(model))
181
  print()
182
-
183
  # 5. Train model
184
  print("5. Training model...")
185
  print("-" * 60)
@@ -197,54 +213,69 @@ def main():
197
  )
198
  print("-" * 60)
199
  print()
200
-
201
  # 6. Load best model and evaluate
202
  print("6. Evaluating best model on test set...")
203
  checkpoint = torch.load('models/best_model.pt', map_location='cpu')
204
  model.load_state_dict(checkpoint['model_state_dict'])
205
-
206
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
207
  results = evaluate_model(model, test_loader, device=device)
208
-
209
  print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
210
  print()
211
-
212
  # 7. Print detailed metrics
213
  print("7. Per-class metrics:")
214
  print("-" * 60)
215
  report = results['classification_report']
216
- print(f"{'Digit':<8} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
 
 
 
217
  print("-" * 60)
218
  for digit in range(10):
219
  if str(digit) in report:
220
  metrics = report[str(digit)]
221
- print(f"{digit:<8} {metrics['precision']:<12.3f} {metrics['recall']:<12.3f} "
222
- f"{metrics['f1-score']:<12.3f} {metrics['support']:<10}")
223
-
 
 
 
224
  print("-" * 60)
225
- print(f"{'Accuracy':<8} {' ':<12} {' ':<12} {report['accuracy']:<12.3f} {report['macro avg']['support']:<10}")
226
- print(f"{'Macro Avg':<8} {report['macro avg']['precision']:<12.3f} "
227
- f"{report['macro avg']['recall']:<12.3f} {report['macro avg']['f1-score']:<12.3f} "
228
- f"{report['macro avg']['support']:<10}")
 
 
 
 
 
 
 
 
 
229
  print()
230
-
231
  # 8. Save results
232
  print("8. Saving results...")
233
  results_dir = project_root / "results"
234
  results_dir.mkdir(exist_ok=True)
235
-
236
  # Save history
237
  history_path = results_dir / "baseline_training_history.json"
238
  save_training_history(history, str(history_path))
239
-
240
  # Plot training curves
241
  curves_path = results_dir / "baseline_training_curves.png"
242
  plot_training_history(history, str(curves_path))
243
-
244
  # Plot confusion matrix
245
  conf_matrix_path = results_dir / "baseline_confusion_matrix.png"
246
  plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path))
247
-
248
  # Save evaluation metrics
249
  metrics_path = results_dir / "baseline_metrics.json"
250
  # Convert numpy arrays to lists for JSON serialization
@@ -269,7 +300,7 @@ def main():
269
  json.dump(metrics_data, f, indent=2)
270
  print(f"Evaluation metrics saved to {metrics_path}")
271
  print()
272
-
273
  # 9. Summary
274
  print("=" * 60)
275
  print("✅ TRAINING COMPLETE")
@@ -285,7 +316,7 @@ def main():
285
  print(f" - Training curves: {curves_path}")
286
  print(f" - Confusion matrix: {conf_matrix_path}")
287
  print(f" - Metrics: {metrics_path}")
288
-
289
  return 0
290
 
291
 
 
34
  def plot_training_history(history: dict, save_path: str):
35
  """Plot and save training history curves."""
36
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
37
+
38
  epochs = range(1, len(history['train_loss']) + 1)
39
+
40
  # Loss curves
41
  axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss')
42
  axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss')
 
45
  axes[0, 0].set_title('Training and Validation Loss')
46
  axes[0, 0].legend()
47
  axes[0, 0].grid(True, alpha=0.3)
48
+
49
  # Accuracy curves
50
  axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Acc')
51
  axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Acc')
 
54
  axes[0, 1].set_title('Training and Validation Accuracy')
55
  axes[0, 1].legend()
56
  axes[0, 1].grid(True, alpha=0.3)
57
+
58
  # Learning rate
59
  axes[1, 0].plot(epochs, history['learning_rate'], 'g-')
60
  axes[1, 0].set_xlabel('Epoch')
 
62
  axes[1, 0].set_title('Learning Rate Schedule')
63
  axes[1, 0].set_yscale('log')
64
  axes[1, 0].grid(True, alpha=0.3)
65
+
66
  # Loss difference (overfitting indicator)
67
  loss_diff = np.array(history['val_loss']) - np.array(history['train_loss'])
68
  axes[1, 1].plot(epochs, loss_diff, 'm-')
 
71
  axes[1, 1].set_ylabel('Val Loss - Train Loss')
72
  axes[1, 1].set_title('Overfitting Indicator (positive = overfitting)')
73
  axes[1, 1].grid(True, alpha=0.3)
74
+
75
  plt.tight_layout()
76
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
77
  print(f"Training curves saved to {save_path}")
 
81
  def plot_confusion_matrix(conf_matrix: np.ndarray, save_path: str):
82
  """Plot and save confusion matrix."""
83
  fig, ax = plt.subplots(figsize=(10, 8))
84
+
85
  im = ax.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
86
  ax.figure.colorbar(im, ax=ax)
87
+
88
  # Labels
89
  classes = list(range(10))
90
  ax.set(xticks=np.arange(conf_matrix.shape[1]),
 
93
  title='Confusion Matrix - MNIST Digit Classification',
94
  ylabel='True Label',
95
  xlabel='Predicted Label')
96
+
97
  # Rotate the tick labels
98
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
99
+
100
  # Add text annotations
101
  thresh = conf_matrix.max() / 2.
102
  for i in range(conf_matrix.shape[0]):
 
104
  ax.text(j, i, format(conf_matrix[i, j], 'd'),
105
  ha="center", va="center",
106
  color="white" if conf_matrix[i, j] > thresh else "black")
107
+
108
  plt.tight_layout()
109
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
110
  print(f"Confusion matrix saved to {save_path}")
 
112
 
113
 
114
  def main():
115
+ parser = argparse.ArgumentParser(
116
+ description='Train baseline CNN on MNIST'
117
+ )
118
+ parser.add_argument(
119
+ '--augment', action='store_true', help='Use data augmentation'
120
+ )
121
+ parser.add_argument(
122
+ '--epochs', type=int, default=20,
123
+ help='Number of epochs (default: 20)'
124
+ )
125
+ parser.add_argument(
126
+ '--lr', type=float, default=0.001,
127
+ help='Learning rate (default: 0.001)'
128
+ )
129
+ parser.add_argument(
130
+ '--batch-size', type=int, default=64,
131
+ help='Batch size (default: 64)'
132
+ )
133
+ parser.add_argument(
134
+ '--patience', type=int, default=5,
135
+ help='Early stopping patience (default: 5)'
136
+ )
137
  args = parser.parse_args()
138
+
139
  print("=" * 60)
140
  print("MNIST CNN Training - Baseline Model")
141
  print("=" * 60)
 
146
  print(f" Augmentation: {'Yes' if args.augment else 'No'}")
147
  print(f" Early Stopping Patience: {args.patience}")
148
  print()
149
+
150
  # 1. Load data
151
  print("1. Loading MNIST dataset...")
152
  data_path = project_root / "data" / "raw"
 
160
  print(f"✓ Loaded {len(x_train):,} training samples")
161
  print(f"✓ Loaded {len(x_test):,} test samples")
162
  print()
163
+
164
  # 2. Train/val split
165
  print("2. Creating train/validation split...")
166
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
 
170
  print(f"✓ Validation: {len(x_val):,} samples")
171
  print(f"✓ Test: {len(x_test):,} samples")
172
  print()
173
+
174
  # 3. Create datasets with optional augmentation
175
  print("3. Creating datasets...")
176
  augmentation = get_train_augmentation() if args.augment else None
177
+
178
  train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
179
  val_dataset = MnistDataset(x_val, y_val, transform=None)
180
  test_dataset = MnistDataset(x_test, y_test, transform=None)
181
+
182
  train_loader, val_loader = create_dataloaders(
183
  train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
184
  )
 
189
  print(f"✓ Val batches: {len(val_loader)}")
190
  print(f"✓ Test batches: {len(test_loader)}")
191
  print()
192
+
193
  # 4. Create model
194
  print("4. Creating model...")
195
  model = BaselineCNN()
196
  print(get_model_summary(model))
197
  print()
198
+
199
  # 5. Train model
200
  print("5. Training model...")
201
  print("-" * 60)
 
213
  )
214
  print("-" * 60)
215
  print()
216
+
217
  # 6. Load best model and evaluate
218
  print("6. Evaluating best model on test set...")
219
  checkpoint = torch.load('models/best_model.pt', map_location='cpu')
220
  model.load_state_dict(checkpoint['model_state_dict'])
221
+
222
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
223
  results = evaluate_model(model, test_loader, device=device)
224
+
225
  print(f"✓ Test Accuracy: {results['accuracy']:.2f}%")
226
  print()
227
+
228
  # 7. Print detailed metrics
229
  print("7. Per-class metrics:")
230
  print("-" * 60)
231
  report = results['classification_report']
232
+ print(
233
+ f"{'Digit':<8} {'Precision':<12} {'Recall':<12} "
234
+ f"{'F1-Score':<12} {'Support':<10}"
235
+ )
236
  print("-" * 60)
237
  for digit in range(10):
238
  if str(digit) in report:
239
  metrics = report[str(digit)]
240
+ print(
241
+ f"{digit:<8} {metrics['precision']:<12.3f} "
242
+ f"{metrics['recall']:<12.3f} "
243
+ f"{metrics['f1-score']:<12.3f} {metrics['support']:<10}"
244
+ )
245
+
246
  print("-" * 60)
247
+ acc_line = (
248
+ f"{'Accuracy':<8} {' ':<12} {' ':<12} "
249
+ f"{report['accuracy']:<12.3f} "
250
+ f"{report['macro avg']['support']:<10}"
251
+ )
252
+ print(acc_line)
253
+ macro_line = (
254
+ f"{'Macro Avg':<8} {report['macro avg']['precision']:<12.3f} "
255
+ f"{report['macro avg']['recall']:<12.3f} "
256
+ f"{report['macro avg']['f1-score']:<12.3f} "
257
+ f"{report['macro avg']['support']:<10}"
258
+ )
259
+ print(macro_line)
260
  print()
261
+
262
  # 8. Save results
263
  print("8. Saving results...")
264
  results_dir = project_root / "results"
265
  results_dir.mkdir(exist_ok=True)
266
+
267
  # Save history
268
  history_path = results_dir / "baseline_training_history.json"
269
  save_training_history(history, str(history_path))
270
+
271
  # Plot training curves
272
  curves_path = results_dir / "baseline_training_curves.png"
273
  plot_training_history(history, str(curves_path))
274
+
275
  # Plot confusion matrix
276
  conf_matrix_path = results_dir / "baseline_confusion_matrix.png"
277
  plot_confusion_matrix(results['confusion_matrix'], str(conf_matrix_path))
278
+
279
  # Save evaluation metrics
280
  metrics_path = results_dir / "baseline_metrics.json"
281
  # Convert numpy arrays to lists for JSON serialization
 
300
  json.dump(metrics_data, f, indent=2)
301
  print(f"Evaluation metrics saved to {metrics_path}")
302
  print()
303
+
304
  # 9. Summary
305
  print("=" * 60)
306
  print("✅ TRAINING COMPLETE")
 
316
  print(f" - Training curves: {curves_path}")
317
  print(f" - Confusion matrix: {conf_matrix_path}")
318
  print(f" - Metrics: {metrics_path}")
319
+
320
  return 0
321
 
322
 
scripts/train_with_mlflow.py CHANGED
@@ -50,7 +50,7 @@ def train_with_mlflow(
50
  ) -> dict:
51
  """
52
  Train model with full MLflow tracking.
53
-
54
  Args:
55
  model: PyTorch model to train
56
  train_loader: Training data loader
@@ -58,23 +58,23 @@ def train_with_mlflow(
58
  test_loader: Test data loader
59
  config: Training configuration dictionary
60
  run_name: Optional name for MLflow run
61
-
62
  Returns:
63
  Training history dictionary
64
  """
65
  device = config['device']
66
  num_epochs = config['num_epochs']
67
  learning_rate = config['learning_rate']
68
-
69
  # Setup MLflow
70
  setup_mlflow("mnist-digit-classification")
71
-
72
  # Start MLflow run
73
  with mlflow.start_run(run_name=run_name):
74
  print("\n" + "="*70)
75
  print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
76
  print("="*70 + "\n")
77
-
78
  # Log all configuration
79
  print("Logging configuration to MLflow...")
80
  log_training_config(config)
@@ -87,7 +87,7 @@ def train_with_mlflow(
87
  augmentation=config.get('augmentation', False)
88
  )
89
  log_system_info()
90
-
91
  # Log model architecture as text
92
  total_params, trainable_params = count_parameters(model)
93
  model_summary = f"""
@@ -100,14 +100,14 @@ Architecture:
100
  {str(model)}
101
  """
102
  mlflow.log_text(model_summary, "model_architecture.txt")
103
-
104
  # Setup training
105
  criterion = nn.CrossEntropyLoss()
106
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
107
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
108
  optimizer, mode='min', patience=3, factor=0.5, verbose=True
109
  )
110
-
111
  # Training history
112
  history = {
113
  'train_loss': [],
@@ -116,37 +116,39 @@ Architecture:
116
  'val_accuracy': [],
117
  'learning_rate': []
118
  }
119
-
120
  best_val_loss = float('inf')
121
  patience = 5
122
  patience_counter = 0
123
-
124
  print(f"\nStarting training for {num_epochs} epochs...")
125
  print(f"Device: {device}")
126
  total_p, _ = count_parameters(model)
127
  print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
128
  print("-" * 70)
129
-
130
  for epoch in range(num_epochs):
131
  # Train
132
- train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
133
-
 
 
134
  # Validate
135
  val_metrics = validate(model, val_loader, criterion, device)
136
-
137
  # Get current learning rate
138
  current_lr = optimizer.param_groups[0]['lr']
139
-
140
  # Update scheduler
141
  scheduler.step(val_metrics['loss'])
142
-
143
  # Save history
144
  history['train_loss'].append(train_metrics['loss'])
145
  history['train_accuracy'].append(train_metrics['accuracy'])
146
  history['val_loss'].append(val_metrics['loss'])
147
  history['val_accuracy'].append(val_metrics['accuracy'])
148
  history['learning_rate'].append(current_lr)
149
-
150
  # Log metrics to MLflow
151
  mlflow_metrics = {
152
  'train_loss': train_metrics['loss'],
@@ -157,19 +159,23 @@ Architecture:
157
  'epoch': epoch + 1
158
  }
159
  log_metrics_epoch(mlflow_metrics, step=epoch)
160
-
161
  # Print progress
162
- print(f"Epoch {epoch+1}/{num_epochs} | "
163
- f"Train Loss: {train_metrics['loss']:.4f} ({train_metrics['accuracy']:.2f}%) | "
164
- f"Val Loss: {val_metrics['loss']:.4f} ({val_metrics['accuracy']:.2f}%) | "
165
- f"LR: {current_lr:.6f}")
166
-
 
 
 
 
167
  # Save best model
168
  if val_metrics['loss'] < best_val_loss:
169
  best_val_loss = val_metrics['loss']
170
  best_epoch = epoch + 1
171
  patience_counter = 0
172
-
173
  # Save checkpoint
174
  checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
175
  torch.save({
@@ -180,9 +186,9 @@ Architecture:
180
  'val_loss': val_metrics['loss'],
181
  'val_accuracy': val_metrics['accuracy'],
182
  }, checkpoint_path)
183
-
184
  print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
185
-
186
  # Log model to MLflow
187
  mlflow.pytorch.log_model(
188
  model,
@@ -191,18 +197,18 @@ Architecture:
191
  )
192
  else:
193
  patience_counter += 1
194
-
195
  # Early stopping
196
  if patience_counter >= patience:
197
  print(f"\nEarly stopping triggered after {epoch+1} epochs")
198
  mlflow.log_param("early_stopped", True)
199
  mlflow.log_param("early_stop_epoch", epoch + 1)
200
  break
201
-
202
  print("-" * 70)
203
  print("\nTraining complete!")
204
  print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
205
-
206
  # Log best metrics
207
  mlflow.log_metrics({
208
  'best_epoch': best_epoch,
@@ -210,24 +216,24 @@ Architecture:
210
  'final_train_loss': history['train_loss'][-1],
211
  'final_val_loss': history['val_loss'][-1]
212
  })
213
-
214
  # Evaluate on test set
215
  print("\nEvaluating on test set...")
216
  test_metrics = evaluate_model(model, test_loader, device)
217
-
218
  test_accuracy = test_metrics['accuracy']
219
  test_report = test_metrics['classification_report']
220
-
221
  # Extract macro average metrics
222
  test_precision = test_report['macro avg']['precision']
223
  test_recall = test_report['macro avg']['recall']
224
  test_f1_score = test_report['macro avg']['f1-score']
225
-
226
  print(f"Test Accuracy: {test_accuracy:.2f}%")
227
  print(f"Test Precision: {test_precision:.4f}")
228
  print(f"Test Recall: {test_recall:.4f}")
229
  print(f"Test F1-Score: {test_f1_score:.4f}")
230
-
231
  # Log test metrics to MLflow
232
  mlflow.log_metrics({
233
  'test_accuracy': test_accuracy,
@@ -235,16 +241,16 @@ Architecture:
235
  'test_recall': test_recall,
236
  'test_f1_score': test_f1_score
237
  })
238
-
239
  # Save and log artifacts
240
  print("\nSaving artifacts...")
241
-
242
  # Save history
243
  history_path = project_root / 'results' / 'mlflow_training_history.json'
244
  history_path.parent.mkdir(exist_ok=True)
245
  save_training_history(history, history_path)
246
  log_artifact_path(str(history_path))
247
-
248
  # Save test metrics
249
  metrics_to_save = {
250
  'test_accuracy': test_accuracy,
@@ -258,43 +264,63 @@ Architecture:
258
  with open(metrics_path, 'w') as f:
259
  json.dump(metrics_to_save, f, indent=2)
260
  log_artifact_path(str(metrics_path))
261
-
262
  # Save model checkpoint
263
  log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
264
-
265
  # Log confusion matrix as JSON
266
  conf_matrix_dict = {
267
  f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
268
  for i in range(len(test_metrics['confusion_matrix']))
269
  }
270
  mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
271
-
272
  # Log classification report
273
  mlflow.log_dict(test_report, "classification_report.json")
274
-
275
  print("\n✓ All artifacts logged to MLflow")
276
  print("View results: mlflow ui --backend-store-uri file:./mlruns")
277
-
278
  return history
279
 
280
 
281
  def main():
282
- parser = argparse.ArgumentParser(description='Train MNIST CNN with MLflow tracking')
283
- parser.add_argument('--epochs', type=int, default=20, help='Number of epochs (default: 20)')
284
- parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
285
- parser.add_argument('--batch-size', type=int, default=64, help='Batch size (default: 64)')
286
- parser.add_argument('--augment', action='store_true', help='Use data augmentation')
287
- parser.add_argument('--run-name', type=str, default=None, help='MLflow run name')
288
- parser.add_argument('--seed', type=int, default=42, help='Random seed (default: 42)')
289
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  args = parser.parse_args()
291
-
292
  # Set random seeds
293
  torch.manual_seed(args.seed)
294
  np.random.seed(args.seed)
295
  if torch.cuda.is_available():
296
  torch.cuda.manual_seed(args.seed)
297
-
298
  # Configuration
299
  config = {
300
  'num_epochs': args.epochs,
@@ -307,10 +333,10 @@ def main():
307
  'scheduler': 'ReduceLROnPlateau',
308
  'early_stopping_patience': 5
309
  }
310
-
311
  print("Training Configuration:")
312
  print(json.dumps(config, indent=2))
313
-
314
  # Load MNIST data
315
  print("\nLoading MNIST data...")
316
  data_path = project_root / 'data' / 'raw'
@@ -321,18 +347,18 @@ def main():
321
  test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
322
  )
323
  (x_train, y_train), (x_test, y_test) = loader.load_data()
324
-
325
  # Split train/val
326
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
327
  x_train, y_train, val_split=0.15, random_seed=args.seed
328
  )
329
-
330
  # Create datasets with optional augmentation
331
  augmentation = get_train_augmentation() if args.augment else None
332
  train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
333
  val_dataset = MnistDataset(x_val, y_val, transform=None)
334
  test_dataset = MnistDataset(x_test, y_test, transform=None)
335
-
336
  # Create data loaders
337
  train_loader, val_loader = create_dataloaders(
338
  train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
@@ -340,20 +366,20 @@ def main():
340
  test_loader = torch.utils.data.DataLoader(
341
  test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
342
  )
343
-
344
  print(f"Train: {len(train_loader.dataset)} samples")
345
  print(f"Val: {len(val_loader.dataset)} samples")
346
  print(f"Test: {len(test_loader.dataset)} samples")
347
-
348
  # Create model
349
  model = BaselineCNN().to(config['device'])
350
-
351
  # Train with MLflow
352
  train_with_mlflow(
353
  model, train_loader, val_loader, test_loader,
354
  config, run_name=args.run_name
355
  )
356
-
357
  print("\n" + "="*70)
358
  print("Training complete! View MLflow dashboard:")
359
  print(" ./scripts/launch_mlflow_ui.sh")
 
50
  ) -> dict:
51
  """
52
  Train model with full MLflow tracking.
53
+
54
  Args:
55
  model: PyTorch model to train
56
  train_loader: Training data loader
 
58
  test_loader: Test data loader
59
  config: Training configuration dictionary
60
  run_name: Optional name for MLflow run
61
+
62
  Returns:
63
  Training history dictionary
64
  """
65
  device = config['device']
66
  num_epochs = config['num_epochs']
67
  learning_rate = config['learning_rate']
68
+
69
  # Setup MLflow
70
  setup_mlflow("mnist-digit-classification")
71
+
72
  # Start MLflow run
73
  with mlflow.start_run(run_name=run_name):
74
  print("\n" + "="*70)
75
  print(f"MLflow Run ID: {mlflow.active_run().info.run_id}")
76
  print("="*70 + "\n")
77
+
78
  # Log all configuration
79
  print("Logging configuration to MLflow...")
80
  log_training_config(config)
 
87
  augmentation=config.get('augmentation', False)
88
  )
89
  log_system_info()
90
+
91
  # Log model architecture as text
92
  total_params, trainable_params = count_parameters(model)
93
  model_summary = f"""
 
100
  {str(model)}
101
  """
102
  mlflow.log_text(model_summary, "model_architecture.txt")
103
+
104
  # Setup training
105
  criterion = nn.CrossEntropyLoss()
106
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
107
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
108
  optimizer, mode='min', patience=3, factor=0.5, verbose=True
109
  )
110
+
111
  # Training history
112
  history = {
113
  'train_loss': [],
 
116
  'val_accuracy': [],
117
  'learning_rate': []
118
  }
119
+
120
  best_val_loss = float('inf')
121
  patience = 5
122
  patience_counter = 0
123
+
124
  print(f"\nStarting training for {num_epochs} epochs...")
125
  print(f"Device: {device}")
126
  total_p, _ = count_parameters(model)
127
  print(f"Model: {model.__class__.__name__} ({total_p:,} parameters)")
128
  print("-" * 70)
129
+
130
  for epoch in range(num_epochs):
131
  # Train
132
+ train_metrics = train_epoch(
133
+ model, train_loader, criterion, optimizer, device
134
+ )
135
+
136
  # Validate
137
  val_metrics = validate(model, val_loader, criterion, device)
138
+
139
  # Get current learning rate
140
  current_lr = optimizer.param_groups[0]['lr']
141
+
142
  # Update scheduler
143
  scheduler.step(val_metrics['loss'])
144
+
145
  # Save history
146
  history['train_loss'].append(train_metrics['loss'])
147
  history['train_accuracy'].append(train_metrics['accuracy'])
148
  history['val_loss'].append(val_metrics['loss'])
149
  history['val_accuracy'].append(val_metrics['accuracy'])
150
  history['learning_rate'].append(current_lr)
151
+
152
  # Log metrics to MLflow
153
  mlflow_metrics = {
154
  'train_loss': train_metrics['loss'],
 
159
  'epoch': epoch + 1
160
  }
161
  log_metrics_epoch(mlflow_metrics, step=epoch)
162
+
163
  # Print progress
164
+ print(
165
+ f"Epoch {epoch+1}/{num_epochs} | "
166
+ f"Train Loss: {train_metrics['loss']:.4f} "
167
+ f"({train_metrics['accuracy']:.2f}%) | "
168
+ f"Val Loss: {val_metrics['loss']:.4f} "
169
+ f"({val_metrics['accuracy']:.2f}%) | "
170
+ f"LR: {current_lr:.6f}"
171
+ )
172
+
173
  # Save best model
174
  if val_metrics['loss'] < best_val_loss:
175
  best_val_loss = val_metrics['loss']
176
  best_epoch = epoch + 1
177
  patience_counter = 0
178
+
179
  # Save checkpoint
180
  checkpoint_path = project_root / 'models' / 'best_model_mlflow.pt'
181
  torch.save({
 
186
  'val_loss': val_metrics['loss'],
187
  'val_accuracy': val_metrics['accuracy'],
188
  }, checkpoint_path)
189
+
190
  print(f" → New best model! (Val Loss: {best_val_loss:.4f})")
191
+
192
  # Log model to MLflow
193
  mlflow.pytorch.log_model(
194
  model,
 
197
  )
198
  else:
199
  patience_counter += 1
200
+
201
  # Early stopping
202
  if patience_counter >= patience:
203
  print(f"\nEarly stopping triggered after {epoch+1} epochs")
204
  mlflow.log_param("early_stopped", True)
205
  mlflow.log_param("early_stop_epoch", epoch + 1)
206
  break
207
+
208
  print("-" * 70)
209
  print("\nTraining complete!")
210
  print(f"Best epoch: {best_epoch} (Val Loss: {best_val_loss:.4f})")
211
+
212
  # Log best metrics
213
  mlflow.log_metrics({
214
  'best_epoch': best_epoch,
 
216
  'final_train_loss': history['train_loss'][-1],
217
  'final_val_loss': history['val_loss'][-1]
218
  })
219
+
220
  # Evaluate on test set
221
  print("\nEvaluating on test set...")
222
  test_metrics = evaluate_model(model, test_loader, device)
223
+
224
  test_accuracy = test_metrics['accuracy']
225
  test_report = test_metrics['classification_report']
226
+
227
  # Extract macro average metrics
228
  test_precision = test_report['macro avg']['precision']
229
  test_recall = test_report['macro avg']['recall']
230
  test_f1_score = test_report['macro avg']['f1-score']
231
+
232
  print(f"Test Accuracy: {test_accuracy:.2f}%")
233
  print(f"Test Precision: {test_precision:.4f}")
234
  print(f"Test Recall: {test_recall:.4f}")
235
  print(f"Test F1-Score: {test_f1_score:.4f}")
236
+
237
  # Log test metrics to MLflow
238
  mlflow.log_metrics({
239
  'test_accuracy': test_accuracy,
 
241
  'test_recall': test_recall,
242
  'test_f1_score': test_f1_score
243
  })
244
+
245
  # Save and log artifacts
246
  print("\nSaving artifacts...")
247
+
248
  # Save history
249
  history_path = project_root / 'results' / 'mlflow_training_history.json'
250
  history_path.parent.mkdir(exist_ok=True)
251
  save_training_history(history, history_path)
252
  log_artifact_path(str(history_path))
253
+
254
  # Save test metrics
255
  metrics_to_save = {
256
  'test_accuracy': test_accuracy,
 
264
  with open(metrics_path, 'w') as f:
265
  json.dump(metrics_to_save, f, indent=2)
266
  log_artifact_path(str(metrics_path))
267
+
268
  # Save model checkpoint
269
  log_artifact_path(str(project_root / 'models' / 'best_model_mlflow.pt'))
270
+
271
  # Log confusion matrix as JSON
272
  conf_matrix_dict = {
273
  f"row_{i}": test_metrics['confusion_matrix'][i].tolist()
274
  for i in range(len(test_metrics['confusion_matrix']))
275
  }
276
  mlflow.log_dict(conf_matrix_dict, "confusion_matrix.json")
277
+
278
  # Log classification report
279
  mlflow.log_dict(test_report, "classification_report.json")
280
+
281
  print("\n✓ All artifacts logged to MLflow")
282
  print("View results: mlflow ui --backend-store-uri file:./mlruns")
283
+
284
  return history
285
 
286
 
287
  def main():
288
+ parser = argparse.ArgumentParser(
289
+ description='Train MNIST CNN with MLflow tracking'
290
+ )
291
+ parser.add_argument(
292
+ '--epochs', type=int, default=20,
293
+ help='Number of epochs (default: 20)'
294
+ )
295
+ parser.add_argument(
296
+ '--lr', type=float, default=0.001,
297
+ help='Learning rate (default: 0.001)'
298
+ )
299
+ parser.add_argument(
300
+ '--batch-size', type=int, default=64,
301
+ help='Batch size (default: 64)'
302
+ )
303
+ parser.add_argument(
304
+ '--augment', action='store_true',
305
+ help='Use data augmentation'
306
+ )
307
+ parser.add_argument(
308
+ '--run-name', type=str, default=None,
309
+ help='MLflow run name'
310
+ )
311
+ parser.add_argument(
312
+ '--seed', type=int, default=42,
313
+ help='Random seed (default: 42)'
314
+ )
315
+
316
  args = parser.parse_args()
317
+
318
  # Set random seeds
319
  torch.manual_seed(args.seed)
320
  np.random.seed(args.seed)
321
  if torch.cuda.is_available():
322
  torch.cuda.manual_seed(args.seed)
323
+
324
  # Configuration
325
  config = {
326
  'num_epochs': args.epochs,
 
333
  'scheduler': 'ReduceLROnPlateau',
334
  'early_stopping_patience': 5
335
  }
336
+
337
  print("Training Configuration:")
338
  print(json.dumps(config, indent=2))
339
+
340
  # Load MNIST data
341
  print("\nLoading MNIST data...")
342
  data_path = project_root / 'data' / 'raw'
 
347
  test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
348
  )
349
  (x_train, y_train), (x_test, y_test) = loader.load_data()
350
+
351
  # Split train/val
352
  (x_train_split, y_train_split), (x_val, y_val) = split_train_val(
353
  x_train, y_train, val_split=0.15, random_seed=args.seed
354
  )
355
+
356
  # Create datasets with optional augmentation
357
  augmentation = get_train_augmentation() if args.augment else None
358
  train_dataset = MnistDataset(x_train_split, y_train_split, transform=augmentation)
359
  val_dataset = MnistDataset(x_val, y_val, transform=None)
360
  test_dataset = MnistDataset(x_test, y_test, transform=None)
361
+
362
  # Create data loaders
363
  train_loader, val_loader = create_dataloaders(
364
  train_dataset, val_dataset, batch_size=args.batch_size, num_workers=2
 
366
  test_loader = torch.utils.data.DataLoader(
367
  test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
368
  )
369
+
370
  print(f"Train: {len(train_loader.dataset)} samples")
371
  print(f"Val: {len(val_loader.dataset)} samples")
372
  print(f"Test: {len(test_loader.dataset)} samples")
373
+
374
  # Create model
375
  model = BaselineCNN().to(config['device'])
376
+
377
  # Train with MLflow
378
  train_with_mlflow(
379
  model, train_loader, val_loader, test_loader,
380
  config, run_name=args.run_name
381
  )
382
+
383
  print("\n" + "="*70)
384
  print("Training complete! View MLflow dashboard:")
385
  print(" ./scripts/launch_mlflow_ui.sh")