Eli181927 commited on
Commit
6764326
verified
1 Parent(s): 8b835bf

Upload 2 files

Browse files
2.CNN/trained_model_mnist100.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3c08456f9f9d0b8ff934e8df4f78966a77c97f7a39b19a786df621dfae42347
3
+ size 3349284
2.CNN/training-100.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Section 1: Imports and network configurations
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+ import argparse
9
+ import csv
10
+ from pathlib import Path
11
+ from copy import deepcopy
12
+ from numpy.lib.stride_tricks import sliding_window_view
13
+
14
+
15
+ BASE_DIR = Path(__file__).resolve().parent
16
+ ARCHIVE_DIR = BASE_DIR / "archive"
17
+ DATASET_PATH = ARCHIVE_DIR / "mnist_compressed.npz"
18
+
19
+ np.random.seed(42)
20
+
21
+
22
+ # Network configuration
23
+ IMAGE_CHANNELS = 1
24
+ IMAGE_HEIGHT = 28
25
+ IMAGE_WIDTH = 56
26
+ INPUT_DIM = IMAGE_HEIGHT * IMAGE_WIDTH # flattened input for compatibility
27
+ CONV_FILTERS = (16, 32)
28
+ KERNEL_SIZE = 3
29
+ POOL_SIZE = 2
30
+ FC_HIDDEN_DIM = 256
31
+ OUTPUT_DIM = 100
32
+ EPOCHS = 20
33
+ BATCH_SIZE = 256
34
+ LEARNING_RATE = 1e-3
35
+ REG_LAMBDA = 1e-4
36
+ DROP_RATE_FC = 0.4
37
+ EARLY_STOP_PATIENCE = 5
38
+ EARLY_STOP_MIN_DELTA = 1e-3
39
+ MAX_SHIFT_PIXELS = 2
40
+ CONTRAST_JITTER_STD = 0.1
41
+ BETA1 = 0.9
42
+ BETA2 = 0.999
43
+ EPSILON = 1e-8
44
+ DEV_SIZE = 10_000 # held-out validation set size
45
+
46
+
47
+ def save_history_to_csv(history, filepath):
48
+ target_path = Path(filepath)
49
+ target_path.parent.mkdir(parents=True, exist_ok=True)
50
+ with target_path.open("w", newline="") as csvfile:
51
+ writer = csv.DictWriter(csvfile, fieldnames=("epoch", "loss", "train_acc", "dev_acc"))
52
+ writer.writeheader()
53
+ for row in history:
54
+ writer.writerow(row)
55
+
56
+
57
+ def save_sweep_summary(results, filepath, *, include_trial=False):
58
+ target_path = Path(filepath)
59
+ target_path.parent.mkdir(parents=True, exist_ok=True)
60
+ fieldnames = ["learning_rate", "reg_lambda", "dev_acc"]
61
+ if include_trial:
62
+ fieldnames.insert(0, "trial")
63
+ with target_path.open("w", newline="") as csvfile:
64
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
65
+ writer.writeheader()
66
+ for entry in results:
67
+ row = {
68
+ "learning_rate": float(entry["learning_rate"]),
69
+ "reg_lambda": float(entry["reg_lambda"]),
70
+ "dev_acc": float(entry["dev_acc"]),
71
+ }
72
+ if include_trial:
73
+ row["trial"] = int(entry["trial"])
74
+ writer.writerow(row)
75
+
76
+ """
77
+ Section 2: Loads the input data, transposes (so arrays are feature x samples) and normalises it (scales features to 0-1)
78
+ """
79
+ def load_data(path: Path, dev_size: int = DEV_SIZE):
80
+ """
81
+ Load the MNIST-100 dataset from the compressed archive and return
82
+ training / validation splits flattened to (features, samples).
83
+ """
84
+ path = Path(path)
85
+ if not path.exists():
86
+ raise FileNotFoundError(f"Dataset not found at '{path}'")
87
+
88
+ with np.load(path) as data:
89
+ train_images = data["train_images"].astype(np.float32)
90
+ train_labels = data["train_labels"].astype(np.int64)
91
+ test_images = data["test_images"].astype(np.float32)
92
+ test_labels = data["test_labels"].astype(np.int64)
93
+
94
+ # Flatten images to column-major format (features, samples)
95
+ X_full = train_images.reshape(train_images.shape[0], -1).T # (input_dim, m)
96
+ Y_full = train_labels
97
+
98
+ # Shuffle before splitting to validation
99
+ permutation = np.random.permutation(X_full.shape[1])
100
+ X_full = X_full[:, permutation]
101
+ Y_full = Y_full[permutation]
102
+
103
+ X_dev = X_full[:, :dev_size]
104
+ Y_dev = Y_full[:dev_size]
105
+ X_train = X_full[:, dev_size:]
106
+ Y_train = Y_full[dev_size:]
107
+
108
+ # Also flatten the test set for later reuse if needed.
109
+ X_test = test_images.reshape(test_images.shape[0], -1).T
110
+
111
+ return X_train, Y_train, X_dev, Y_dev, X_test, test_labels
112
+
113
+
114
+ """
115
+ Section 3: Normalises the features [(0, 255)] to [(0, 1)]
116
+ """
117
+ def normalize_features(X_train, X_dev):
118
+ """
119
+ Normalize features to zero mean and unit variance using the training set.
120
+ """
121
+ X_train /= 255.0
122
+ X_dev /= 255.0
123
+
124
+ mean = np.mean(X_train, axis=1, keepdims=True)
125
+ std = np.std(X_train, axis=1, keepdims=True) + 1e-8
126
+
127
+ X_train = (X_train - mean) / std
128
+ X_dev = (X_dev - mean) / std
129
+
130
+ return X_train, X_dev, mean, std
131
+
132
+
133
+ """
134
+ Section 4: Initialises the parameters (layers, weights and biases) and adam optimizer
135
+ """
136
+ def init_params():
137
+ params = {}
138
+ conv1_fan_in = IMAGE_CHANNELS * KERNEL_SIZE * KERNEL_SIZE
139
+ params["conv1_W"] = (
140
+ np.random.randn(CONV_FILTERS[0], IMAGE_CHANNELS, KERNEL_SIZE, KERNEL_SIZE) * np.sqrt(2.0 / conv1_fan_in)
141
+ ).astype(np.float32)
142
+ params["conv1_b"] = np.zeros((CONV_FILTERS[0], 1), dtype=np.float32)
143
+
144
+ conv2_fan_in = CONV_FILTERS[0] * KERNEL_SIZE * KERNEL_SIZE
145
+ params["conv2_W"] = (
146
+ np.random.randn(CONV_FILTERS[1], CONV_FILTERS[0], KERNEL_SIZE, KERNEL_SIZE) * np.sqrt(2.0 / conv2_fan_in)
147
+ ).astype(np.float32)
148
+ params["conv2_b"] = np.zeros((CONV_FILTERS[1], 1), dtype=np.float32)
149
+
150
+ height_after_pool1 = IMAGE_HEIGHT // POOL_SIZE
151
+ width_after_pool1 = IMAGE_WIDTH // POOL_SIZE
152
+ height_after_pool2 = height_after_pool1 // POOL_SIZE
153
+ width_after_pool2 = width_after_pool1 // POOL_SIZE
154
+ flattened_dim = CONV_FILTERS[1] * height_after_pool2 * width_after_pool2
155
+
156
+ params["fc1_W"] = (
157
+ np.random.randn(FC_HIDDEN_DIM, flattened_dim) * np.sqrt(2.0 / flattened_dim)
158
+ ).astype(np.float32)
159
+ params["fc1_b"] = np.zeros((FC_HIDDEN_DIM, 1), dtype=np.float32)
160
+
161
+ params["fc2_W"] = (
162
+ np.random.randn(OUTPUT_DIM, FC_HIDDEN_DIM) * np.sqrt(2.0 / FC_HIDDEN_DIM)
163
+ ).astype(np.float32)
164
+ params["fc2_b"] = np.zeros((OUTPUT_DIM, 1), dtype=np.float32)
165
+
166
+ return params
167
+
168
+
169
+ def init_adam(params):
170
+ v = {}
171
+ s = {}
172
+ for key, value in params.items():
173
+ v[key] = np.zeros_like(value)
174
+ s[key] = np.zeros_like(value)
175
+ return v, s
176
+
177
+
178
+ """
179
+ Section 5: ReLu activation function and backward ReLu function
180
+ """
181
+ def relu(Z):
182
+ return np.maximum(0.0, Z)
183
+
184
+
185
+ def relu_backward(Z):
186
+ return (Z > 0).astype(np.float32)
187
+
188
+
189
+ """
190
+ Section 6: Reshapes the flattened input to 4D tensors (batch, channels, height, width) for the convolutional layers
191
+ """
192
+ def reshape_flat_to_images(X: np.ndarray, *, batch_size: int | None = None):
193
+ """
194
+ Convert flattened columns (features, batch) into 4D tensors (batch, channels, height, width).
195
+ """
196
+ _, m = X.shape
197
+ if batch_size is not None and m != batch_size:
198
+ raise ValueError(f"Expected batch size {batch_size}, got {m}")
199
+ images = X.T.reshape(m, IMAGE_HEIGHT, IMAGE_WIDTH)
200
+ return images[:, None, :, :] # add channel dim
201
+
202
+
203
+ """
204
+ Section 7: Convolutional layer forward pass and backward pass
205
+ """
206
+
207
+ def im2col(X, kernel_h, kernel_w, stride, padding):
208
+ X_padded = np.pad(
209
+ X,
210
+ ((0, 0), (0, 0), (padding, padding), (padding, padding)),
211
+ mode="constant",
212
+ )
213
+ windows = sliding_window_view(X_padded, (kernel_h, kernel_w), axis=(2, 3))
214
+ # windows shape: (batch, channels, out_height, out_width, kernel_h, kernel_w)
215
+ batch_size, channels, out_height, out_width, _, _ = windows.shape
216
+ cols = windows.transpose(0, 2, 3, 1, 4, 5).reshape(batch_size * out_height * out_width, channels * kernel_h * kernel_w)
217
+ return X_padded, cols, out_height, out_width
218
+
219
+
220
+ def col2im(cols, X_shape, kernel_h, kernel_w, stride, padding, out_height, out_width):
221
+ batch_size, channels, height, width = X_shape
222
+ cols_reshaped = cols.reshape(batch_size, out_height, out_width, channels, kernel_h, kernel_w)
223
+ cols_reshaped = cols_reshaped.transpose(0, 3, 1, 2, 4, 5)
224
+ X_padded = np.zeros((batch_size, channels, height + 2 * padding, width + 2 * padding), dtype=np.float32)
225
+
226
+ for h_idx in range(out_height):
227
+ h_start = h_idx * stride
228
+ h_end = h_start + kernel_h
229
+ for w_idx in range(out_width):
230
+ w_start = w_idx * stride
231
+ w_end = w_start + kernel_w
232
+ X_padded[:, :, h_start:h_end, w_start:w_end] += cols_reshaped[:, :, h_idx, w_idx, :, :]
233
+
234
+ if padding > 0:
235
+ return X_padded[:, :, padding:-padding, padding:-padding]
236
+ return X_padded
237
+
238
+
239
+ def conv_forward(X, W, b, *, stride: int = 1, padding: int = 0):
240
+ batch_size, in_channels, height, width = X.shape
241
+ num_filters, _, kernel_h, kernel_w = W.shape
242
+
243
+ X_padded, cols, out_height, out_width = im2col(X, kernel_h, kernel_w, stride, padding)
244
+ W_col = W.reshape(num_filters, -1)
245
+ out_cols = cols @ W_col.T # (batch*out_height*out_width, num_filters)
246
+ out = out_cols.reshape(batch_size, out_height, out_width, num_filters).transpose(0, 3, 1, 2)
247
+ out = out.astype(np.float32, copy=False)
248
+ out += b.reshape(1, num_filters, 1, 1)
249
+
250
+ cache = {
251
+ "X": X,
252
+ "X_padded": X_padded,
253
+ "W": W,
254
+ "stride": stride,
255
+ "padding": padding,
256
+ "kernel_h": kernel_h,
257
+ "kernel_w": kernel_w,
258
+ "out_height": out_height,
259
+ "out_width": out_width,
260
+ "cols": cols,
261
+ "W_col": W_col,
262
+ "output_shape": out.shape,
263
+ }
264
+ return out, cache
265
+
266
+
267
+ def conv_backward(dout, cache):
268
+ X = cache["X"]
269
+ W = cache["W"]
270
+ stride = cache["stride"]
271
+ padding = cache["padding"]
272
+ kernel_h = cache["kernel_h"]
273
+ kernel_w = cache["kernel_w"]
274
+ out_height = cache["out_height"]
275
+ out_width = cache["out_width"]
276
+ cols = cache["cols"]
277
+ W_col = cache["W_col"]
278
+
279
+ batch_size, _, _, _ = X.shape
280
+ num_filters = W.shape[0]
281
+
282
+ dout_cols = dout.transpose(0, 2, 3, 1).reshape(batch_size * out_height * out_width, num_filters)
283
+ dW_col = dout_cols.T @ cols
284
+ dW = dW_col.reshape(W.shape)
285
+ db = np.sum(dout, axis=(0, 2, 3)).reshape(num_filters, 1)
286
+
287
+ dcols = dout_cols @ W_col
288
+ dX = col2im(dcols, X.shape, kernel_h, kernel_w, stride, padding, out_height, out_width)
289
+
290
+ return dX, dW, db
291
+
292
+
293
+
294
+ """
295
+ Section 8: Max pooling layer forward pass and backward pass
296
+ """
297
+ def maxpool_forward(X, *, pool_size: int = 2, stride: int = 2):
298
+ batch_size, channels, height, width = X.shape
299
+ out_height = (height - pool_size) // stride + 1
300
+ out_width = (width - pool_size) // stride + 1
301
+
302
+ out = np.zeros((batch_size, channels, out_height, out_width), dtype=np.float32)
303
+
304
+ for h_idx in range(out_height):
305
+ h_start = h_idx * stride
306
+ h_end = h_start + pool_size
307
+ for w_idx in range(out_width):
308
+ w_start = w_idx * stride
309
+ w_end = w_start + pool_size
310
+ window = X[:, :, h_start:h_end, w_start:w_end]
311
+ max_vals = np.max(window, axis=(2, 3))
312
+ out[:, :, h_idx, w_idx] = max_vals
313
+
314
+ cache = {
315
+ "X": X,
316
+ "pool_size": pool_size,
317
+ "stride": stride,
318
+ "output_shape": out.shape,
319
+ }
320
+ return out, cache
321
+
322
+
323
+ def maxpool_backward(dout, cache):
324
+ X = cache["X"]
325
+ pool_size = cache["pool_size"]
326
+ stride = cache["stride"]
327
+ batch_size, channels, out_height, out_width = dout.shape
328
+
329
+ dX = np.zeros_like(X)
330
+ for h_idx in range(out_height):
331
+ h_start = h_idx * stride
332
+ h_end = h_start + pool_size
333
+ for w_idx in range(out_width):
334
+ w_start = w_idx * stride
335
+ w_end = w_start + pool_size
336
+ window = X[:, :, h_start:h_end, w_start:w_end]
337
+ max_vals = np.max(window, axis=(2, 3), keepdims=True)
338
+ mask = (window == max_vals).astype(np.float32)
339
+ mask_sum = np.sum(mask, axis=(2, 3), keepdims=True)
340
+ mask /= np.maximum(mask_sum, 1.0)
341
+ grad_slice = dout[:, :, h_idx, w_idx][:, :, None, None]
342
+ dX[:, :, h_start:h_end, w_start:w_end] += mask * grad_slice
343
+ return dX
344
+
345
+
346
+ def softmax(Z):
347
+ Z_shift = Z - np.max(Z, axis=0, keepdims=True)
348
+ expZ = np.exp(Z_shift)
349
+ return expZ / np.sum(expZ, axis=0, keepdims=True)
350
+
351
+
352
+ def one_hot(Y, num_classes=OUTPUT_DIM):
353
+ one_hot_y = np.zeros((num_classes, Y.size), dtype=np.float32)
354
+ one_hot_y[Y, np.arange(Y.size)] = 1.0
355
+ return one_hot_y
356
+
357
+
358
+
359
+ """
360
+ Section 9: Forward propagation and comptutes for loss
361
+ """
362
+ def forward_prop(
363
+ X,
364
+ params,
365
+ *,
366
+ training: bool = False,
367
+ dropout_rate: float = DROP_RATE_FC,
368
+ ):
369
+ batch_size = X.shape[1]
370
+ images = reshape_flat_to_images(X, batch_size=batch_size)
371
+ padding = KERNEL_SIZE // 2
372
+
373
+ conv1_out, conv1_cache = conv_forward(images, params["conv1_W"], params["conv1_b"], stride=1, padding=padding)
374
+ relu1 = relu(conv1_out)
375
+ pool1_out, pool1_cache = maxpool_forward(relu1, pool_size=POOL_SIZE, stride=POOL_SIZE)
376
+
377
+ conv2_out, conv2_cache = conv_forward(pool1_out, params["conv2_W"], params["conv2_b"], stride=1, padding=padding)
378
+ relu2 = relu(conv2_out)
379
+ pool2_out, pool2_cache = maxpool_forward(relu2, pool_size=POOL_SIZE, stride=POOL_SIZE)
380
+
381
+ flattened = pool2_out.reshape(batch_size, -1).T # (features_flat, batch)
382
+
383
+ Z_fc1 = params["fc1_W"] @ flattened + params["fc1_b"]
384
+ A_fc1 = relu(Z_fc1)
385
+
386
+ dropout_mask = None
387
+ keep_prob = 1.0 - dropout_rate
388
+ if training and dropout_rate > 0.0:
389
+ dropout_mask = (np.random.rand(*A_fc1.shape) >= dropout_rate).astype(np.float32)
390
+ A_fc1 = (A_fc1 * dropout_mask) / keep_prob
391
+
392
+ Z_fc2 = params["fc2_W"] @ A_fc1 + params["fc2_b"]
393
+ probs = softmax(Z_fc2)
394
+
395
+ cache = {
396
+ "X": X,
397
+ "images": images,
398
+ "conv1_out": conv1_out,
399
+ "conv1_cache": conv1_cache,
400
+ "pool1_cache": pool1_cache,
401
+ "conv2_out": conv2_out,
402
+ "conv2_cache": conv2_cache,
403
+ "pool2_cache": pool2_cache,
404
+ "flattened": flattened,
405
+ "Z_fc1": Z_fc1,
406
+ "A_fc1": A_fc1,
407
+ "dropout_mask": dropout_mask,
408
+ "keep_prob": keep_prob,
409
+ "dropout_rate": dropout_rate,
410
+ "Z_fc2": Z_fc2,
411
+ "probs": probs,
412
+ }
413
+
414
+ return cache, probs
415
+
416
+
417
+ def compute_loss(probs, Y_batch, params, reg_lambda):
418
+ m = Y_batch.shape[1]
419
+ log_likelihood = -np.log(probs + 1e-9) * Y_batch
420
+ data_loss = np.sum(log_likelihood) / m
421
+
422
+ l2_penalty = 0.0
423
+ for key in ("conv1_W", "conv2_W", "fc1_W", "fc2_W"):
424
+ l2_penalty += np.sum(np.square(params[key]))
425
+ l2_loss = (reg_lambda / (2 * m)) * l2_penalty
426
+
427
+ return data_loss + l2_loss
428
+
429
+
430
+ """
431
+ Section 10: Back propagation for the CNN model
432
+ """
433
+ def back_prop(cache, Y_batch, params, reg_lambda, dropout_rate):
434
+ m = Y_batch.shape[1]
435
+ grads = {}
436
+
437
+ probs = cache["probs"]
438
+ A_fc1 = cache["A_fc1"]
439
+ Z_fc1 = cache["Z_fc1"]
440
+ flattened = cache["flattened"]
441
+ dropout_mask = cache["dropout_mask"]
442
+ keep_prob = cache["keep_prob"]
443
+
444
+ dZ_fc2 = probs - Y_batch
445
+ grads["fc2_W"] = (dZ_fc2 @ A_fc1.T) / m + (reg_lambda / m) * params["fc2_W"]
446
+ grads["fc2_b"] = np.sum(dZ_fc2, axis=1, keepdims=True) / m
447
+
448
+ dA_fc1 = params["fc2_W"].T @ dZ_fc2
449
+ if dropout_mask is not None:
450
+ dA_fc1 = (dA_fc1 * dropout_mask) / keep_prob
451
+ dZ_fc1 = dA_fc1 * relu_backward(Z_fc1)
452
+ grads["fc1_W"] = (dZ_fc1 @ flattened.T) / m + (reg_lambda / m) * params["fc1_W"]
453
+ grads["fc1_b"] = np.sum(dZ_fc1, axis=1, keepdims=True) / m
454
+
455
+ dFlatten = params["fc1_W"].T @ dZ_fc1 # (flatten_dim, batch)
456
+ pool2_shape = cache["pool2_cache"]["output_shape"]
457
+ dPool2 = dFlatten.T.reshape(pool2_shape)
458
+
459
+ dRelu2_input = maxpool_backward(dPool2, cache["pool2_cache"])
460
+ dConv2 = dRelu2_input * relu_backward(cache["conv2_out"])
461
+ dPool1_input, dConv2_W, dConv2_b = conv_backward(dConv2, cache["conv2_cache"])
462
+ grads["conv2_W"] = dConv2_W / m + (reg_lambda / m) * params["conv2_W"]
463
+ grads["conv2_b"] = dConv2_b / m
464
+
465
+ dRelu1_input = maxpool_backward(dPool1_input, cache["pool1_cache"])
466
+ dConv1 = dRelu1_input * relu_backward(cache["conv1_out"])
467
+ _, dConv1_W, dConv1_b = conv_backward(dConv1, cache["conv1_cache"])
468
+ grads["conv1_W"] = dConv1_W / m + (reg_lambda / m) * params["conv1_W"]
469
+ grads["conv1_b"] = dConv1_b / m
470
+
471
+ return grads
472
+
473
+
474
+ """
475
+ Section 11: Updates the parameters using the adam optimizer
476
+ """
477
+
478
+ def update_params_adam(params, grads, v, s, t, learning_rate):
479
+ updated_params = {}
480
+ for key in params:
481
+ v[key] = BETA1 * v[key] + (1 - BETA1) * grads[key]
482
+ s[key] = BETA2 * s[key] + (1 - BETA2) * (grads[key] ** 2)
483
+
484
+ v_corrected = v[key] / (1 - BETA1 ** t)
485
+ s_corrected = s[key] / (1 - BETA2 ** t)
486
+
487
+ updated_params[key] = params[key] - learning_rate * v_corrected / (np.sqrt(s_corrected) + EPSILON)
488
+
489
+ return updated_params, v, s
490
+
491
+
492
+ def get_predictions(probs):
493
+ return np.argmax(probs, axis=0)
494
+
495
+
496
+ def get_accuracy(probs, labels):
497
+ predictions = get_predictions(probs)
498
+ return np.mean(predictions == labels)
499
+
500
+
501
+ """
502
+ Section 12: Augments the batch with horizontal shifts and contrast/brightness jitter
503
+ """
504
+
505
+ def augment_batch(
506
+ X_batch,
507
+ *,
508
+ image_shape: tuple[int, int] = (28, 56),
509
+ max_shift: int = MAX_SHIFT_PIXELS,
510
+ contrast_jitter_std: float = CONTRAST_JITTER_STD,
511
+ ):
512
+ """
513
+ Apply lightweight augmentation: horizontal shifts and contrast/brightness jitter.
514
+ """
515
+ if max_shift <= 0 and contrast_jitter_std <= 0.0:
516
+ return X_batch
517
+
518
+ batch_size = X_batch.shape[1]
519
+ images = X_batch.T.reshape(batch_size, *image_shape)
520
+
521
+ if max_shift > 0:
522
+ shifts = np.random.randint(-max_shift, max_shift + 1, size=batch_size)
523
+ for idx, shift in enumerate(shifts):
524
+ if shift > 0:
525
+ shifted = np.roll(images[idx], shift, axis=1)
526
+ shifted[:, :shift] = 0.0
527
+ images[idx] = shifted
528
+ elif shift < 0:
529
+ shift = -shift
530
+ shifted = np.roll(images[idx], -shift, axis=1)
531
+ shifted[:, -shift:] = 0.0
532
+ images[idx] = shifted
533
+
534
+ if contrast_jitter_std > 0.0:
535
+ scale = 1.0 + np.random.normal(0.0, contrast_jitter_std, size=batch_size)
536
+ bias = np.random.normal(0.0, contrast_jitter_std, size=batch_size)
537
+ images *= scale[:, None, None]
538
+ images += bias[:, None, None]
539
+ np.clip(images, -3.0, 3.0, out=images)
540
+
541
+ return images.reshape(batch_size, -1).T
542
+
543
+
544
+ """
545
+ Section 13: Trains the model + evaluates the model
546
+ """
547
+ def train_model(
548
+ X_train,
549
+ Y_train,
550
+ X_dev,
551
+ Y_dev,
552
+ *,
553
+ epochs: int = EPOCHS,
554
+ batch_size: int = BATCH_SIZE,
555
+ learning_rate: float = LEARNING_RATE,
556
+ reg_lambda: float = REG_LAMBDA,
557
+ dropout_rate: float = DROP_RATE_FC,
558
+ early_stop_patience: int = EARLY_STOP_PATIENCE,
559
+ early_stop_min_delta: float = EARLY_STOP_MIN_DELTA,
560
+ use_augmentation: bool = True,
561
+ ):
562
+ params = init_params()
563
+ v, s = init_adam(params)
564
+ m_train = X_train.shape[1]
565
+ global_step = 0
566
+ best_dev_acc = -np.inf
567
+ best_params = deepcopy(params)
568
+ patience_counter = 0
569
+ history = []
570
+
571
+ for epoch in range(1, epochs + 1):
572
+ permutation = np.random.permutation(m_train)
573
+ X_shuffled = X_train[:, permutation]
574
+ Y_shuffled = Y_train[permutation]
575
+
576
+ epoch_loss = 0.0
577
+
578
+ for start in range(0, m_train, batch_size):
579
+ end = min(start + batch_size, m_train)
580
+ X_batch = X_shuffled[:, start:end]
581
+ Y_batch_indices = Y_shuffled[start:end]
582
+ Y_batch = one_hot(Y_batch_indices)
583
+
584
+ if use_augmentation:
585
+ X_batch = augment_batch(X_batch.copy())
586
+
587
+ cache, probs = forward_prop(
588
+ X_batch,
589
+ params,
590
+ training=True,
591
+ dropout_rate=dropout_rate,
592
+ )
593
+ loss = compute_loss(probs, Y_batch, params, reg_lambda)
594
+ grads = back_prop(cache, Y_batch, params, reg_lambda, dropout_rate)
595
+
596
+ global_step += 1
597
+ params, v, s = update_params_adam(params, grads, v, s, global_step, learning_rate)
598
+
599
+ epoch_loss += loss * (end - start)
600
+
601
+ epoch_loss /= m_train
602
+
603
+ _, train_probs = forward_prop(X_train, params, training=False, dropout_rate=dropout_rate)
604
+ train_accuracy = get_accuracy(train_probs, Y_train)
605
+
606
+ _, dev_probs = forward_prop(X_dev, params, training=False, dropout_rate=dropout_rate)
607
+ dev_accuracy = get_accuracy(dev_probs, Y_dev)
608
+
609
+ print(
610
+ f"Epoch {epoch:02d} - loss: {epoch_loss:.4f} "
611
+ f"- train_acc: {train_accuracy:.4f} - dev_acc: {dev_accuracy:.4f}"
612
+ )
613
+
614
+ history.append(
615
+ {
616
+ "epoch": epoch,
617
+ "loss": epoch_loss,
618
+ "train_acc": train_accuracy,
619
+ "dev_acc": dev_accuracy,
620
+ }
621
+ )
622
+
623
+ if dev_accuracy > best_dev_acc + early_stop_min_delta:
624
+ best_dev_acc = dev_accuracy
625
+ best_params = deepcopy(params)
626
+ patience_counter = 0
627
+ else:
628
+ patience_counter += 1
629
+ if patience_counter >= early_stop_patience:
630
+ print(
631
+ f"Early stopping triggered at epoch {epoch:02d}. "
632
+ f"Best dev_acc={best_dev_acc:.4f}"
633
+ )
634
+ break
635
+
636
+ return best_params, history
637
+
638
+
639
+ def evaluate(params, X, Y):
640
+ _, probs = forward_prop(X, params, training=False)
641
+ predictions = get_predictions(probs)
642
+ accuracy = np.mean(predictions == Y)
643
+ return predictions, accuracy
644
+
645
+
646
+ """
647
+ Section 14: Trains the model once
648
+ """
649
+ def train_once(
650
+ learning_rate: float,
651
+ reg_lambda: float,
652
+ *,
653
+ epochs: int = EPOCHS,
654
+ batch_size: int = BATCH_SIZE,
655
+ dropout_rate: float = DROP_RATE_FC,
656
+ history_path: Path | None = None,
657
+ ):
658
+ """
659
+ Convenience wrapper for hyperparameter sweeps. Returns trained params and dev accuracy.
660
+ """
661
+ X_train, Y_train, X_dev, Y_dev, _, _ = load_data(DATASET_PATH)
662
+ X_train, X_dev, mean, std = normalize_features(X_train, X_dev)
663
+
664
+ params, history = train_model(
665
+ X_train,
666
+ Y_train,
667
+ X_dev,
668
+ Y_dev,
669
+ epochs=epochs,
670
+ batch_size=batch_size,
671
+ learning_rate=learning_rate,
672
+ reg_lambda=reg_lambda,
673
+ dropout_rate=dropout_rate,
674
+ )
675
+
676
+ _, dev_accuracy = evaluate(params, X_dev, Y_dev)
677
+
678
+ if history_path is not None:
679
+ save_history_to_csv(history, history_path)
680
+
681
+ return params, dev_accuracy, mean, std, history
682
+
683
+ """
684
+ Section 15: Hyperparameter sweep for learning rate, regularization and dropout rate
685
+ """
686
+
687
+ def lr_sweep(
688
+ learning_rates: list[float],
689
+ *,
690
+ reg_lambda: float = REG_LAMBDA,
691
+ epochs: int = EPOCHS,
692
+ batch_size: int = BATCH_SIZE,
693
+ dropout_rate: float = DROP_RATE_FC,
694
+ history_dir: Path | None = None,
695
+ summary_path: Path | None = None,
696
+ ):
697
+ results = []
698
+ history_directory = Path(history_dir) if history_dir is not None else None
699
+ if history_directory is not None:
700
+ history_directory.mkdir(parents=True, exist_ok=True)
701
+
702
+ for lr in learning_rates:
703
+ history_path = None
704
+ if history_directory is not None:
705
+ safe_lr = f"{lr:.2e}".replace("+", "").replace("-", "m")
706
+ history_path = history_directory / f"lr_{safe_lr}.csv"
707
+ _, dev_acc, _, _, history = train_once(
708
+ lr,
709
+ reg_lambda,
710
+ epochs=epochs,
711
+ batch_size=batch_size,
712
+ dropout_rate=dropout_rate,
713
+ history_path=history_path,
714
+ )
715
+ results.append(
716
+ {
717
+ "learning_rate": float(lr),
718
+ "reg_lambda": float(reg_lambda),
719
+ "dev_acc": float(dev_acc),
720
+ "history": history,
721
+ }
722
+ )
723
+ if summary_path is not None:
724
+ save_sweep_summary(results, summary_path)
725
+ return results
726
+
727
+
728
+ def random_search_hparams(
729
+ num_trials: int,
730
+ lr_bounds: tuple[float, float],
731
+ reg_bounds: tuple[float, float],
732
+ *,
733
+ epochs: int = EPOCHS,
734
+ batch_size: int = BATCH_SIZE,
735
+ dropout_rate: float = DROP_RATE_FC,
736
+ seed: int | None = None,
737
+ history_dir: Path | None = None,
738
+ summary_path: Path | None = None,
739
+ ):
740
+ if num_trials <= 0:
741
+ raise ValueError("num_trials must be positive")
742
+
743
+ lr_min, lr_max = lr_bounds
744
+ reg_min, reg_max = reg_bounds
745
+ if lr_min <= 0 or lr_max <= 0:
746
+ raise ValueError("Learning rate bounds must be positive")
747
+ if reg_min <= 0 or reg_max <= 0:
748
+ raise ValueError("Regularization bounds must be positive")
749
+
750
+ rng = np.random.default_rng(seed)
751
+ history_directory = Path(history_dir) if history_dir is not None else None
752
+ if history_directory is not None:
753
+ history_directory.mkdir(parents=True, exist_ok=True)
754
+
755
+ results = []
756
+ log_lr_min, log_lr_max = np.log(lr_min), np.log(lr_max)
757
+ log_reg_min, log_reg_max = np.log(reg_min), np.log(reg_max)
758
+
759
+ for trial in range(1, num_trials + 1):
760
+ lr_sample = float(np.exp(rng.uniform(log_lr_min, log_lr_max)))
761
+ reg_sample = float(np.exp(rng.uniform(log_reg_min, log_reg_max)))
762
+ history_path = None
763
+ if history_directory is not None:
764
+ safe_lr = f"{lr_sample:.2e}".replace("+", "").replace("-", "m")
765
+ safe_reg = f"{reg_sample:.2e}".replace("+", "").replace("-", "m")
766
+ history_path = history_directory / f"trial_{trial:02d}_lr-{safe_lr}_reg-{safe_reg}.csv"
767
+
768
+ _, dev_acc, _, _, history = train_once(
769
+ lr_sample,
770
+ reg_sample,
771
+ epochs=epochs,
772
+ batch_size=batch_size,
773
+ dropout_rate=dropout_rate,
774
+ history_path=history_path,
775
+ )
776
+
777
+ results.append(
778
+ {
779
+ "trial": trial,
780
+ "learning_rate": lr_sample,
781
+ "reg_lambda": reg_sample,
782
+ "dev_acc": float(dev_acc),
783
+ "history": history,
784
+ }
785
+ )
786
+
787
+ results.sort(key=lambda item: item["dev_acc"], reverse=True)
788
+ if summary_path is not None:
789
+ save_sweep_summary(results, summary_path, include_trial=True)
790
+ return results
791
+
792
+
793
+ def auto_train_pipeline(
794
+ *,
795
+ trials: int,
796
+ lr_bounds: tuple[float, float],
797
+ reg_bounds: tuple[float, float],
798
+ search_epochs: int,
799
+ final_epochs: int,
800
+ batch_size: int,
801
+ dropout_rate: float,
802
+ final_batch_size: int | None,
803
+ final_dropout_rate: float | None,
804
+ history_dir: Path | None,
805
+ seed: int | None,
806
+ output_model_path: Path | None,
807
+ ):
808
+ history_directory = Path(history_dir) if history_dir is not None else None
809
+ if history_directory is not None:
810
+ history_directory.mkdir(parents=True, exist_ok=True)
811
+
812
+ search_summary_path = None
813
+ if history_directory is not None:
814
+ search_summary_path = history_directory / "random_search_summary.csv"
815
+
816
+ results = random_search_hparams(
817
+ trials,
818
+ lr_bounds,
819
+ reg_bounds,
820
+ epochs=search_epochs,
821
+ batch_size=batch_size,
822
+ dropout_rate=dropout_rate,
823
+ seed=seed,
824
+ history_dir=history_directory / "search_histories" if history_directory is not None else None,
825
+ summary_path=search_summary_path,
826
+ )
827
+ best = results[0]
828
+ print(
829
+ f"\nBest search trial -> LR={best['learning_rate']:.3e}, "
830
+ f"reg={best['reg_lambda']:.3e}, dev_acc={best['dev_acc']:.4f}"
831
+ )
832
+
833
+ final_dropout = final_dropout_rate if final_dropout_rate is not None else dropout_rate
834
+ final_history_path = None
835
+ if history_directory is not None:
836
+ final_history_path = history_directory / "final_train_history.csv"
837
+
838
+ params, final_dev_acc, mean, std, final_history = train_once(
839
+ best["learning_rate"],
840
+ best["reg_lambda"],
841
+ epochs=final_epochs,
842
+ batch_size=final_batch_size or batch_size,
843
+ dropout_rate=final_dropout,
844
+ history_path=final_history_path,
845
+ )
846
+
847
+ model_output_path = output_model_path if output_model_path is not None else ARCHIVE_DIR / "trained_model_mnist100.npz"
848
+ save_model(params, mean, std, model_output_path)
849
+
850
+ return {
851
+ "best_trial": best,
852
+ "final_dev_acc": final_dev_acc,
853
+ "model_path": Path(model_output_path),
854
+ "final_history": final_history,
855
+ }
856
+
857
+
858
+ """
859
+ Section 16: Saves the model
860
+ """
861
+ def save_model(params, mean, std, filepath=None):
862
+ target_path = Path(filepath) if filepath is not None else ARCHIVE_DIR / "trained_model_mnist100.npz"
863
+ target_path.parent.mkdir(parents=True, exist_ok=True)
864
+ print(f"\nSaving trained model to '{target_path}'...")
865
+ np.savez(target_path, **params, mean=mean, std=std)
866
+ print("Model saved successfully!")
867
+
868
+
869
+ """
870
+ Section 17: Main function
871
+ """
872
+
873
+ def main():
874
+ parser = argparse.ArgumentParser(description="MNIST-100 training and tuning utilities.")
875
+ parser.add_argument(
876
+ "--mode",
877
+ choices=("train", "lr-sweep", "random-search", "auto-train"),
878
+ default="train",
879
+ help="Select high-level action.",
880
+ )
881
+ parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, help="Base learning rate.")
882
+ parser.add_argument("--learning-rates", type=str, help="Comma-separated list for LR sweep.")
883
+ parser.add_argument("--reg-lambda", type=float, default=REG_LAMBDA, help="L2 regularization strength.")
884
+ parser.add_argument("--lr-min", type=float, default=1e-4, help="Min LR for random search (exclusive mode).")
885
+ parser.add_argument("--lr-max", type=float, default=5e-3, help="Max LR for random search.")
886
+ parser.add_argument("--reg-min", type=float, default=1e-5, help="Min lambda for random search.")
887
+ parser.add_argument("--reg-max", type=float, default=1e-3, help="Max lambda for random search.")
888
+ parser.add_argument("--trials", type=int, default=5, help="Number of random-search trials.")
889
+ parser.add_argument("--epochs", type=int, default=EPOCHS, help="Train epochs per run.")
890
+ parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Mini-batch size.")
891
+ parser.add_argument(
892
+ "--final-epochs",
893
+ type=int,
894
+ default=40,
895
+ help="Epoch budget for the final training run in auto-train mode.",
896
+ )
897
+ parser.add_argument(
898
+ "--final-batch-size",
899
+ type=int,
900
+ help="Mini-batch size for the final training run (defaults to --batch-size).",
901
+ )
902
+ parser.add_argument(
903
+ "--dropout",
904
+ type=float,
905
+ help="Override dropout rate for the fully connected layer.",
906
+ )
907
+ parser.add_argument(
908
+ "--final-dropout",
909
+ type=float,
910
+ help="Dropout rate for the final training pass in auto-train mode.",
911
+ )
912
+ parser.add_argument(
913
+ "--history-dir",
914
+ type=Path,
915
+ help="Directory for saving training histories (CSV).",
916
+ )
917
+ parser.add_argument(
918
+ "--output-model",
919
+ type=Path,
920
+ help="Path to save the trained model (.npz). Defaults to archive/trained_model_mnist100.npz.",
921
+ )
922
+ parser.add_argument("--seed", type=int, help="Random seed for random search.")
923
+ args = parser.parse_args()
924
+
925
+ dropout_rate = DROP_RATE_FC if args.dropout is None else float(args.dropout)
926
+ if not 0.0 <= dropout_rate < 1.0:
927
+ raise ValueError("Dropout rate must be in [0, 1).")
928
+
929
+ final_dropout_rate = None
930
+ if args.final_dropout is not None:
931
+ final_dropout_rate = float(args.final_dropout)
932
+ if not 0.0 <= final_dropout_rate < 1.0:
933
+ raise ValueError("Final dropout rate must be in [0, 1).")
934
+
935
+ history_dir = args.history_dir
936
+ if history_dir is not None:
937
+ history_dir = Path(history_dir)
938
+ history_dir.mkdir(parents=True, exist_ok=True)
939
+
940
+ if args.mode == "train":
941
+ print(f"Loading dataset from '{DATASET_PATH}'...")
942
+ X_train, Y_train, X_dev, Y_dev, _, _ = load_data(DATASET_PATH)
943
+ X_train, X_dev, mean, std = normalize_features(X_train, X_dev)
944
+
945
+ print(
946
+ f"Training samples: {X_train.shape[1]}, features: {X_train.shape[0]} "
947
+ f"| Dev samples: {X_dev.shape[1]}"
948
+ )
949
+
950
+ params, history = train_model(
951
+ X_train,
952
+ Y_train,
953
+ X_dev,
954
+ Y_dev,
955
+ epochs=args.epochs,
956
+ batch_size=args.batch_size,
957
+ learning_rate=args.learning_rate,
958
+ reg_lambda=args.reg_lambda,
959
+ dropout_rate=dropout_rate,
960
+ )
961
+
962
+ _, dev_accuracy = evaluate(params, X_dev, Y_dev)
963
+ print(f"\nFinal Dev Accuracy: {dev_accuracy:.4f}")
964
+
965
+ if history_dir is not None:
966
+ save_history_to_csv(history, history_dir / "train_history.csv")
967
+
968
+ save_model(params, mean, std, args.output_model or ARCHIVE_DIR / "trained_model_mnist100.npz")
969
+
970
+ elif args.mode == "lr-sweep":
971
+ if args.learning_rates is None:
972
+ raise ValueError("LR sweep mode requires --learning-rates.")
973
+ lr_values = [float(value.strip()) for value in args.learning_rates.split(",") if value.strip()]
974
+ print(f"Running LR sweep over {lr_values}...")
975
+ summary_path = history_dir / "lr_sweep_summary.csv" if history_dir is not None else None
976
+ results = lr_sweep(
977
+ lr_values,
978
+ reg_lambda=args.reg_lambda,
979
+ epochs=args.epochs,
980
+ batch_size=args.batch_size,
981
+ dropout_rate=dropout_rate,
982
+ history_dir=history_dir,
983
+ summary_path=summary_path,
984
+ )
985
+ for entry in results:
986
+ print(
987
+ f"LR={entry['learning_rate']:.3e} | reg={entry['reg_lambda']:.3e} "
988
+ f"| dev_acc={entry['dev_acc']:.4f}"
989
+ )
990
+
991
+ elif args.mode == "random-search":
992
+ print(
993
+ f"Running random search ({args.trials} trials) "
994
+ f"LR鈭圼{args.lr_min:.2e},{args.lr_max:.2e}], "
995
+ f"位鈭圼{args.reg_min:.2e},{args.reg_max:.2e}]..."
996
+ )
997
+ summary_path = history_dir / "random_search_summary.csv" if history_dir is not None else None
998
+ results = random_search_hparams(
999
+ args.trials,
1000
+ (args.lr_min, args.lr_max),
1001
+ (args.reg_min, args.reg_max),
1002
+ epochs=args.epochs,
1003
+ batch_size=args.batch_size,
1004
+ dropout_rate=dropout_rate,
1005
+ seed=args.seed,
1006
+ history_dir=history_dir,
1007
+ summary_path=summary_path,
1008
+ )
1009
+ for entry in results:
1010
+ print(
1011
+ f"Trial {entry['trial']:02d} | LR={entry['learning_rate']:.3e} "
1012
+ f"| reg={entry['reg_lambda']:.3e} | dev_acc={entry['dev_acc']:.4f}"
1013
+ )
1014
+ best = results[0]
1015
+ print(
1016
+ f"\nBest trial -> LR={best['learning_rate']:.3e}, "
1017
+ f"reg={best['reg_lambda']:.3e}, dev_acc={best['dev_acc']:.4f}"
1018
+ )
1019
+
1020
+ elif args.mode == "auto-train":
1021
+ print(
1022
+ f"Auto-train pipeline: {args.trials} search trials "
1023
+ f"(epochs={args.epochs}) followed by final training (epochs={args.final_epochs})."
1024
+ )
1025
+ results = auto_train_pipeline(
1026
+ trials=args.trials,
1027
+ lr_bounds=(args.lr_min, args.lr_max),
1028
+ reg_bounds=(args.reg_min, args.reg_max),
1029
+ search_epochs=args.epochs,
1030
+ final_epochs=args.final_epochs,
1031
+ batch_size=args.batch_size,
1032
+ dropout_rate=dropout_rate,
1033
+ final_batch_size=args.final_batch_size,
1034
+ final_dropout_rate=final_dropout_rate,
1035
+ history_dir=history_dir,
1036
+ seed=args.seed,
1037
+ output_model_path=args.output_model,
1038
+ )
1039
+ best = results["best_trial"]
1040
+ print(
1041
+ f"\nAuto-train complete. "
1042
+ f"Best trial LR={best['learning_rate']:.3e}, reg={best['reg_lambda']:.3e}. "
1043
+ f"Final dev_acc={results['final_dev_acc']:.4f}. "
1044
+ f"Model saved to '{results['model_path']}'."
1045
+ )
1046
+
1047
+
1048
+ if __name__ == "__main__":
1049
+ main()