Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- 2.CNN/trained_model_mnist100.npz +3 -0
- 2.CNN/training-100.py +1049 -0
2.CNN/trained_model_mnist100.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3c08456f9f9d0b8ff934e8df4f78966a77c97f7a39b19a786df621dfae42347
|
| 3 |
+
size 3349284
|
2.CNN/training-100.py
ADDED
|
@@ -0,0 +1,1049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Section 1: Imports and network configurations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
import csv
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 16 |
+
ARCHIVE_DIR = BASE_DIR / "archive"
|
| 17 |
+
DATASET_PATH = ARCHIVE_DIR / "mnist_compressed.npz"
|
| 18 |
+
|
| 19 |
+
np.random.seed(42)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Network configuration
|
| 23 |
+
IMAGE_CHANNELS = 1
|
| 24 |
+
IMAGE_HEIGHT = 28
|
| 25 |
+
IMAGE_WIDTH = 56
|
| 26 |
+
INPUT_DIM = IMAGE_HEIGHT * IMAGE_WIDTH # flattened input for compatibility
|
| 27 |
+
CONV_FILTERS = (16, 32)
|
| 28 |
+
KERNEL_SIZE = 3
|
| 29 |
+
POOL_SIZE = 2
|
| 30 |
+
FC_HIDDEN_DIM = 256
|
| 31 |
+
OUTPUT_DIM = 100
|
| 32 |
+
EPOCHS = 20
|
| 33 |
+
BATCH_SIZE = 256
|
| 34 |
+
LEARNING_RATE = 1e-3
|
| 35 |
+
REG_LAMBDA = 1e-4
|
| 36 |
+
DROP_RATE_FC = 0.4
|
| 37 |
+
EARLY_STOP_PATIENCE = 5
|
| 38 |
+
EARLY_STOP_MIN_DELTA = 1e-3
|
| 39 |
+
MAX_SHIFT_PIXELS = 2
|
| 40 |
+
CONTRAST_JITTER_STD = 0.1
|
| 41 |
+
BETA1 = 0.9
|
| 42 |
+
BETA2 = 0.999
|
| 43 |
+
EPSILON = 1e-8
|
| 44 |
+
DEV_SIZE = 10_000 # held-out validation set size
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def save_history_to_csv(history, filepath):
|
| 48 |
+
target_path = Path(filepath)
|
| 49 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
with target_path.open("w", newline="") as csvfile:
|
| 51 |
+
writer = csv.DictWriter(csvfile, fieldnames=("epoch", "loss", "train_acc", "dev_acc"))
|
| 52 |
+
writer.writeheader()
|
| 53 |
+
for row in history:
|
| 54 |
+
writer.writerow(row)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def save_sweep_summary(results, filepath, *, include_trial=False):
|
| 58 |
+
target_path = Path(filepath)
|
| 59 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
fieldnames = ["learning_rate", "reg_lambda", "dev_acc"]
|
| 61 |
+
if include_trial:
|
| 62 |
+
fieldnames.insert(0, "trial")
|
| 63 |
+
with target_path.open("w", newline="") as csvfile:
|
| 64 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
| 65 |
+
writer.writeheader()
|
| 66 |
+
for entry in results:
|
| 67 |
+
row = {
|
| 68 |
+
"learning_rate": float(entry["learning_rate"]),
|
| 69 |
+
"reg_lambda": float(entry["reg_lambda"]),
|
| 70 |
+
"dev_acc": float(entry["dev_acc"]),
|
| 71 |
+
}
|
| 72 |
+
if include_trial:
|
| 73 |
+
row["trial"] = int(entry["trial"])
|
| 74 |
+
writer.writerow(row)
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
Section 2: Loads the input data, transposes (so arrays are feature x samples) and normalises it (scales features to 0-1)
|
| 78 |
+
"""
|
| 79 |
+
def load_data(path: Path, dev_size: int = DEV_SIZE):
|
| 80 |
+
"""
|
| 81 |
+
Load the MNIST-100 dataset from the compressed archive and return
|
| 82 |
+
training / validation splits flattened to (features, samples).
|
| 83 |
+
"""
|
| 84 |
+
path = Path(path)
|
| 85 |
+
if not path.exists():
|
| 86 |
+
raise FileNotFoundError(f"Dataset not found at '{path}'")
|
| 87 |
+
|
| 88 |
+
with np.load(path) as data:
|
| 89 |
+
train_images = data["train_images"].astype(np.float32)
|
| 90 |
+
train_labels = data["train_labels"].astype(np.int64)
|
| 91 |
+
test_images = data["test_images"].astype(np.float32)
|
| 92 |
+
test_labels = data["test_labels"].astype(np.int64)
|
| 93 |
+
|
| 94 |
+
# Flatten images to column-major format (features, samples)
|
| 95 |
+
X_full = train_images.reshape(train_images.shape[0], -1).T # (input_dim, m)
|
| 96 |
+
Y_full = train_labels
|
| 97 |
+
|
| 98 |
+
# Shuffle before splitting to validation
|
| 99 |
+
permutation = np.random.permutation(X_full.shape[1])
|
| 100 |
+
X_full = X_full[:, permutation]
|
| 101 |
+
Y_full = Y_full[permutation]
|
| 102 |
+
|
| 103 |
+
X_dev = X_full[:, :dev_size]
|
| 104 |
+
Y_dev = Y_full[:dev_size]
|
| 105 |
+
X_train = X_full[:, dev_size:]
|
| 106 |
+
Y_train = Y_full[dev_size:]
|
| 107 |
+
|
| 108 |
+
# Also flatten the test set for later reuse if needed.
|
| 109 |
+
X_test = test_images.reshape(test_images.shape[0], -1).T
|
| 110 |
+
|
| 111 |
+
return X_train, Y_train, X_dev, Y_dev, X_test, test_labels
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
"""
|
| 115 |
+
Section 3: Normalises the features [(0, 255)] to [(0, 1)]
|
| 116 |
+
"""
|
| 117 |
+
def normalize_features(X_train, X_dev):
|
| 118 |
+
"""
|
| 119 |
+
Normalize features to zero mean and unit variance using the training set.
|
| 120 |
+
"""
|
| 121 |
+
X_train /= 255.0
|
| 122 |
+
X_dev /= 255.0
|
| 123 |
+
|
| 124 |
+
mean = np.mean(X_train, axis=1, keepdims=True)
|
| 125 |
+
std = np.std(X_train, axis=1, keepdims=True) + 1e-8
|
| 126 |
+
|
| 127 |
+
X_train = (X_train - mean) / std
|
| 128 |
+
X_dev = (X_dev - mean) / std
|
| 129 |
+
|
| 130 |
+
return X_train, X_dev, mean, std
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
"""
|
| 134 |
+
Section 4: Initialises the parameters (layers, weights and biases) and adam optimizer
|
| 135 |
+
"""
|
| 136 |
+
def init_params():
|
| 137 |
+
params = {}
|
| 138 |
+
conv1_fan_in = IMAGE_CHANNELS * KERNEL_SIZE * KERNEL_SIZE
|
| 139 |
+
params["conv1_W"] = (
|
| 140 |
+
np.random.randn(CONV_FILTERS[0], IMAGE_CHANNELS, KERNEL_SIZE, KERNEL_SIZE) * np.sqrt(2.0 / conv1_fan_in)
|
| 141 |
+
).astype(np.float32)
|
| 142 |
+
params["conv1_b"] = np.zeros((CONV_FILTERS[0], 1), dtype=np.float32)
|
| 143 |
+
|
| 144 |
+
conv2_fan_in = CONV_FILTERS[0] * KERNEL_SIZE * KERNEL_SIZE
|
| 145 |
+
params["conv2_W"] = (
|
| 146 |
+
np.random.randn(CONV_FILTERS[1], CONV_FILTERS[0], KERNEL_SIZE, KERNEL_SIZE) * np.sqrt(2.0 / conv2_fan_in)
|
| 147 |
+
).astype(np.float32)
|
| 148 |
+
params["conv2_b"] = np.zeros((CONV_FILTERS[1], 1), dtype=np.float32)
|
| 149 |
+
|
| 150 |
+
height_after_pool1 = IMAGE_HEIGHT // POOL_SIZE
|
| 151 |
+
width_after_pool1 = IMAGE_WIDTH // POOL_SIZE
|
| 152 |
+
height_after_pool2 = height_after_pool1 // POOL_SIZE
|
| 153 |
+
width_after_pool2 = width_after_pool1 // POOL_SIZE
|
| 154 |
+
flattened_dim = CONV_FILTERS[1] * height_after_pool2 * width_after_pool2
|
| 155 |
+
|
| 156 |
+
params["fc1_W"] = (
|
| 157 |
+
np.random.randn(FC_HIDDEN_DIM, flattened_dim) * np.sqrt(2.0 / flattened_dim)
|
| 158 |
+
).astype(np.float32)
|
| 159 |
+
params["fc1_b"] = np.zeros((FC_HIDDEN_DIM, 1), dtype=np.float32)
|
| 160 |
+
|
| 161 |
+
params["fc2_W"] = (
|
| 162 |
+
np.random.randn(OUTPUT_DIM, FC_HIDDEN_DIM) * np.sqrt(2.0 / FC_HIDDEN_DIM)
|
| 163 |
+
).astype(np.float32)
|
| 164 |
+
params["fc2_b"] = np.zeros((OUTPUT_DIM, 1), dtype=np.float32)
|
| 165 |
+
|
| 166 |
+
return params
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def init_adam(params):
|
| 170 |
+
v = {}
|
| 171 |
+
s = {}
|
| 172 |
+
for key, value in params.items():
|
| 173 |
+
v[key] = np.zeros_like(value)
|
| 174 |
+
s[key] = np.zeros_like(value)
|
| 175 |
+
return v, s
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
"""
|
| 179 |
+
Section 5: ReLu activation function and backward ReLu function
|
| 180 |
+
"""
|
| 181 |
+
def relu(Z):
|
| 182 |
+
return np.maximum(0.0, Z)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def relu_backward(Z):
|
| 186 |
+
return (Z > 0).astype(np.float32)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
"""
|
| 190 |
+
Section 6: Reshapes the flattened input to 4D tensors (batch, channels, height, width) for the convolutional layers
|
| 191 |
+
"""
|
| 192 |
+
def reshape_flat_to_images(X: np.ndarray, *, batch_size: int | None = None):
|
| 193 |
+
"""
|
| 194 |
+
Convert flattened columns (features, batch) into 4D tensors (batch, channels, height, width).
|
| 195 |
+
"""
|
| 196 |
+
_, m = X.shape
|
| 197 |
+
if batch_size is not None and m != batch_size:
|
| 198 |
+
raise ValueError(f"Expected batch size {batch_size}, got {m}")
|
| 199 |
+
images = X.T.reshape(m, IMAGE_HEIGHT, IMAGE_WIDTH)
|
| 200 |
+
return images[:, None, :, :] # add channel dim
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
"""
|
| 204 |
+
Section 7: Convolutional layer forward pass and backward pass
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def im2col(X, kernel_h, kernel_w, stride, padding):
|
| 208 |
+
X_padded = np.pad(
|
| 209 |
+
X,
|
| 210 |
+
((0, 0), (0, 0), (padding, padding), (padding, padding)),
|
| 211 |
+
mode="constant",
|
| 212 |
+
)
|
| 213 |
+
windows = sliding_window_view(X_padded, (kernel_h, kernel_w), axis=(2, 3))
|
| 214 |
+
# windows shape: (batch, channels, out_height, out_width, kernel_h, kernel_w)
|
| 215 |
+
batch_size, channels, out_height, out_width, _, _ = windows.shape
|
| 216 |
+
cols = windows.transpose(0, 2, 3, 1, 4, 5).reshape(batch_size * out_height * out_width, channels * kernel_h * kernel_w)
|
| 217 |
+
return X_padded, cols, out_height, out_width
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def col2im(cols, X_shape, kernel_h, kernel_w, stride, padding, out_height, out_width):
|
| 221 |
+
batch_size, channels, height, width = X_shape
|
| 222 |
+
cols_reshaped = cols.reshape(batch_size, out_height, out_width, channels, kernel_h, kernel_w)
|
| 223 |
+
cols_reshaped = cols_reshaped.transpose(0, 3, 1, 2, 4, 5)
|
| 224 |
+
X_padded = np.zeros((batch_size, channels, height + 2 * padding, width + 2 * padding), dtype=np.float32)
|
| 225 |
+
|
| 226 |
+
for h_idx in range(out_height):
|
| 227 |
+
h_start = h_idx * stride
|
| 228 |
+
h_end = h_start + kernel_h
|
| 229 |
+
for w_idx in range(out_width):
|
| 230 |
+
w_start = w_idx * stride
|
| 231 |
+
w_end = w_start + kernel_w
|
| 232 |
+
X_padded[:, :, h_start:h_end, w_start:w_end] += cols_reshaped[:, :, h_idx, w_idx, :, :]
|
| 233 |
+
|
| 234 |
+
if padding > 0:
|
| 235 |
+
return X_padded[:, :, padding:-padding, padding:-padding]
|
| 236 |
+
return X_padded
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def conv_forward(X, W, b, *, stride: int = 1, padding: int = 0):
|
| 240 |
+
batch_size, in_channels, height, width = X.shape
|
| 241 |
+
num_filters, _, kernel_h, kernel_w = W.shape
|
| 242 |
+
|
| 243 |
+
X_padded, cols, out_height, out_width = im2col(X, kernel_h, kernel_w, stride, padding)
|
| 244 |
+
W_col = W.reshape(num_filters, -1)
|
| 245 |
+
out_cols = cols @ W_col.T # (batch*out_height*out_width, num_filters)
|
| 246 |
+
out = out_cols.reshape(batch_size, out_height, out_width, num_filters).transpose(0, 3, 1, 2)
|
| 247 |
+
out = out.astype(np.float32, copy=False)
|
| 248 |
+
out += b.reshape(1, num_filters, 1, 1)
|
| 249 |
+
|
| 250 |
+
cache = {
|
| 251 |
+
"X": X,
|
| 252 |
+
"X_padded": X_padded,
|
| 253 |
+
"W": W,
|
| 254 |
+
"stride": stride,
|
| 255 |
+
"padding": padding,
|
| 256 |
+
"kernel_h": kernel_h,
|
| 257 |
+
"kernel_w": kernel_w,
|
| 258 |
+
"out_height": out_height,
|
| 259 |
+
"out_width": out_width,
|
| 260 |
+
"cols": cols,
|
| 261 |
+
"W_col": W_col,
|
| 262 |
+
"output_shape": out.shape,
|
| 263 |
+
}
|
| 264 |
+
return out, cache
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def conv_backward(dout, cache):
|
| 268 |
+
X = cache["X"]
|
| 269 |
+
W = cache["W"]
|
| 270 |
+
stride = cache["stride"]
|
| 271 |
+
padding = cache["padding"]
|
| 272 |
+
kernel_h = cache["kernel_h"]
|
| 273 |
+
kernel_w = cache["kernel_w"]
|
| 274 |
+
out_height = cache["out_height"]
|
| 275 |
+
out_width = cache["out_width"]
|
| 276 |
+
cols = cache["cols"]
|
| 277 |
+
W_col = cache["W_col"]
|
| 278 |
+
|
| 279 |
+
batch_size, _, _, _ = X.shape
|
| 280 |
+
num_filters = W.shape[0]
|
| 281 |
+
|
| 282 |
+
dout_cols = dout.transpose(0, 2, 3, 1).reshape(batch_size * out_height * out_width, num_filters)
|
| 283 |
+
dW_col = dout_cols.T @ cols
|
| 284 |
+
dW = dW_col.reshape(W.shape)
|
| 285 |
+
db = np.sum(dout, axis=(0, 2, 3)).reshape(num_filters, 1)
|
| 286 |
+
|
| 287 |
+
dcols = dout_cols @ W_col
|
| 288 |
+
dX = col2im(dcols, X.shape, kernel_h, kernel_w, stride, padding, out_height, out_width)
|
| 289 |
+
|
| 290 |
+
return dX, dW, db
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
"""
|
| 295 |
+
Section 8: Max pooling layer forward pass and backward pass
|
| 296 |
+
"""
|
| 297 |
+
def maxpool_forward(X, *, pool_size: int = 2, stride: int = 2):
|
| 298 |
+
batch_size, channels, height, width = X.shape
|
| 299 |
+
out_height = (height - pool_size) // stride + 1
|
| 300 |
+
out_width = (width - pool_size) // stride + 1
|
| 301 |
+
|
| 302 |
+
out = np.zeros((batch_size, channels, out_height, out_width), dtype=np.float32)
|
| 303 |
+
|
| 304 |
+
for h_idx in range(out_height):
|
| 305 |
+
h_start = h_idx * stride
|
| 306 |
+
h_end = h_start + pool_size
|
| 307 |
+
for w_idx in range(out_width):
|
| 308 |
+
w_start = w_idx * stride
|
| 309 |
+
w_end = w_start + pool_size
|
| 310 |
+
window = X[:, :, h_start:h_end, w_start:w_end]
|
| 311 |
+
max_vals = np.max(window, axis=(2, 3))
|
| 312 |
+
out[:, :, h_idx, w_idx] = max_vals
|
| 313 |
+
|
| 314 |
+
cache = {
|
| 315 |
+
"X": X,
|
| 316 |
+
"pool_size": pool_size,
|
| 317 |
+
"stride": stride,
|
| 318 |
+
"output_shape": out.shape,
|
| 319 |
+
}
|
| 320 |
+
return out, cache
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def maxpool_backward(dout, cache):
|
| 324 |
+
X = cache["X"]
|
| 325 |
+
pool_size = cache["pool_size"]
|
| 326 |
+
stride = cache["stride"]
|
| 327 |
+
batch_size, channels, out_height, out_width = dout.shape
|
| 328 |
+
|
| 329 |
+
dX = np.zeros_like(X)
|
| 330 |
+
for h_idx in range(out_height):
|
| 331 |
+
h_start = h_idx * stride
|
| 332 |
+
h_end = h_start + pool_size
|
| 333 |
+
for w_idx in range(out_width):
|
| 334 |
+
w_start = w_idx * stride
|
| 335 |
+
w_end = w_start + pool_size
|
| 336 |
+
window = X[:, :, h_start:h_end, w_start:w_end]
|
| 337 |
+
max_vals = np.max(window, axis=(2, 3), keepdims=True)
|
| 338 |
+
mask = (window == max_vals).astype(np.float32)
|
| 339 |
+
mask_sum = np.sum(mask, axis=(2, 3), keepdims=True)
|
| 340 |
+
mask /= np.maximum(mask_sum, 1.0)
|
| 341 |
+
grad_slice = dout[:, :, h_idx, w_idx][:, :, None, None]
|
| 342 |
+
dX[:, :, h_start:h_end, w_start:w_end] += mask * grad_slice
|
| 343 |
+
return dX
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def softmax(Z):
|
| 347 |
+
Z_shift = Z - np.max(Z, axis=0, keepdims=True)
|
| 348 |
+
expZ = np.exp(Z_shift)
|
| 349 |
+
return expZ / np.sum(expZ, axis=0, keepdims=True)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def one_hot(Y, num_classes=OUTPUT_DIM):
|
| 353 |
+
one_hot_y = np.zeros((num_classes, Y.size), dtype=np.float32)
|
| 354 |
+
one_hot_y[Y, np.arange(Y.size)] = 1.0
|
| 355 |
+
return one_hot_y
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
"""
|
| 360 |
+
Section 9: Forward propagation and comptutes for loss
|
| 361 |
+
"""
|
| 362 |
+
def forward_prop(
|
| 363 |
+
X,
|
| 364 |
+
params,
|
| 365 |
+
*,
|
| 366 |
+
training: bool = False,
|
| 367 |
+
dropout_rate: float = DROP_RATE_FC,
|
| 368 |
+
):
|
| 369 |
+
batch_size = X.shape[1]
|
| 370 |
+
images = reshape_flat_to_images(X, batch_size=batch_size)
|
| 371 |
+
padding = KERNEL_SIZE // 2
|
| 372 |
+
|
| 373 |
+
conv1_out, conv1_cache = conv_forward(images, params["conv1_W"], params["conv1_b"], stride=1, padding=padding)
|
| 374 |
+
relu1 = relu(conv1_out)
|
| 375 |
+
pool1_out, pool1_cache = maxpool_forward(relu1, pool_size=POOL_SIZE, stride=POOL_SIZE)
|
| 376 |
+
|
| 377 |
+
conv2_out, conv2_cache = conv_forward(pool1_out, params["conv2_W"], params["conv2_b"], stride=1, padding=padding)
|
| 378 |
+
relu2 = relu(conv2_out)
|
| 379 |
+
pool2_out, pool2_cache = maxpool_forward(relu2, pool_size=POOL_SIZE, stride=POOL_SIZE)
|
| 380 |
+
|
| 381 |
+
flattened = pool2_out.reshape(batch_size, -1).T # (features_flat, batch)
|
| 382 |
+
|
| 383 |
+
Z_fc1 = params["fc1_W"] @ flattened + params["fc1_b"]
|
| 384 |
+
A_fc1 = relu(Z_fc1)
|
| 385 |
+
|
| 386 |
+
dropout_mask = None
|
| 387 |
+
keep_prob = 1.0 - dropout_rate
|
| 388 |
+
if training and dropout_rate > 0.0:
|
| 389 |
+
dropout_mask = (np.random.rand(*A_fc1.shape) >= dropout_rate).astype(np.float32)
|
| 390 |
+
A_fc1 = (A_fc1 * dropout_mask) / keep_prob
|
| 391 |
+
|
| 392 |
+
Z_fc2 = params["fc2_W"] @ A_fc1 + params["fc2_b"]
|
| 393 |
+
probs = softmax(Z_fc2)
|
| 394 |
+
|
| 395 |
+
cache = {
|
| 396 |
+
"X": X,
|
| 397 |
+
"images": images,
|
| 398 |
+
"conv1_out": conv1_out,
|
| 399 |
+
"conv1_cache": conv1_cache,
|
| 400 |
+
"pool1_cache": pool1_cache,
|
| 401 |
+
"conv2_out": conv2_out,
|
| 402 |
+
"conv2_cache": conv2_cache,
|
| 403 |
+
"pool2_cache": pool2_cache,
|
| 404 |
+
"flattened": flattened,
|
| 405 |
+
"Z_fc1": Z_fc1,
|
| 406 |
+
"A_fc1": A_fc1,
|
| 407 |
+
"dropout_mask": dropout_mask,
|
| 408 |
+
"keep_prob": keep_prob,
|
| 409 |
+
"dropout_rate": dropout_rate,
|
| 410 |
+
"Z_fc2": Z_fc2,
|
| 411 |
+
"probs": probs,
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
return cache, probs
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def compute_loss(probs, Y_batch, params, reg_lambda):
|
| 418 |
+
m = Y_batch.shape[1]
|
| 419 |
+
log_likelihood = -np.log(probs + 1e-9) * Y_batch
|
| 420 |
+
data_loss = np.sum(log_likelihood) / m
|
| 421 |
+
|
| 422 |
+
l2_penalty = 0.0
|
| 423 |
+
for key in ("conv1_W", "conv2_W", "fc1_W", "fc2_W"):
|
| 424 |
+
l2_penalty += np.sum(np.square(params[key]))
|
| 425 |
+
l2_loss = (reg_lambda / (2 * m)) * l2_penalty
|
| 426 |
+
|
| 427 |
+
return data_loss + l2_loss
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
"""
|
| 431 |
+
Section 10: Back propagation for the CNN model
|
| 432 |
+
"""
|
| 433 |
+
def back_prop(cache, Y_batch, params, reg_lambda, dropout_rate):
|
| 434 |
+
m = Y_batch.shape[1]
|
| 435 |
+
grads = {}
|
| 436 |
+
|
| 437 |
+
probs = cache["probs"]
|
| 438 |
+
A_fc1 = cache["A_fc1"]
|
| 439 |
+
Z_fc1 = cache["Z_fc1"]
|
| 440 |
+
flattened = cache["flattened"]
|
| 441 |
+
dropout_mask = cache["dropout_mask"]
|
| 442 |
+
keep_prob = cache["keep_prob"]
|
| 443 |
+
|
| 444 |
+
dZ_fc2 = probs - Y_batch
|
| 445 |
+
grads["fc2_W"] = (dZ_fc2 @ A_fc1.T) / m + (reg_lambda / m) * params["fc2_W"]
|
| 446 |
+
grads["fc2_b"] = np.sum(dZ_fc2, axis=1, keepdims=True) / m
|
| 447 |
+
|
| 448 |
+
dA_fc1 = params["fc2_W"].T @ dZ_fc2
|
| 449 |
+
if dropout_mask is not None:
|
| 450 |
+
dA_fc1 = (dA_fc1 * dropout_mask) / keep_prob
|
| 451 |
+
dZ_fc1 = dA_fc1 * relu_backward(Z_fc1)
|
| 452 |
+
grads["fc1_W"] = (dZ_fc1 @ flattened.T) / m + (reg_lambda / m) * params["fc1_W"]
|
| 453 |
+
grads["fc1_b"] = np.sum(dZ_fc1, axis=1, keepdims=True) / m
|
| 454 |
+
|
| 455 |
+
dFlatten = params["fc1_W"].T @ dZ_fc1 # (flatten_dim, batch)
|
| 456 |
+
pool2_shape = cache["pool2_cache"]["output_shape"]
|
| 457 |
+
dPool2 = dFlatten.T.reshape(pool2_shape)
|
| 458 |
+
|
| 459 |
+
dRelu2_input = maxpool_backward(dPool2, cache["pool2_cache"])
|
| 460 |
+
dConv2 = dRelu2_input * relu_backward(cache["conv2_out"])
|
| 461 |
+
dPool1_input, dConv2_W, dConv2_b = conv_backward(dConv2, cache["conv2_cache"])
|
| 462 |
+
grads["conv2_W"] = dConv2_W / m + (reg_lambda / m) * params["conv2_W"]
|
| 463 |
+
grads["conv2_b"] = dConv2_b / m
|
| 464 |
+
|
| 465 |
+
dRelu1_input = maxpool_backward(dPool1_input, cache["pool1_cache"])
|
| 466 |
+
dConv1 = dRelu1_input * relu_backward(cache["conv1_out"])
|
| 467 |
+
_, dConv1_W, dConv1_b = conv_backward(dConv1, cache["conv1_cache"])
|
| 468 |
+
grads["conv1_W"] = dConv1_W / m + (reg_lambda / m) * params["conv1_W"]
|
| 469 |
+
grads["conv1_b"] = dConv1_b / m
|
| 470 |
+
|
| 471 |
+
return grads
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
"""
|
| 475 |
+
Section 11: Updates the parameters using the adam optimizer
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
def update_params_adam(params, grads, v, s, t, learning_rate):
|
| 479 |
+
updated_params = {}
|
| 480 |
+
for key in params:
|
| 481 |
+
v[key] = BETA1 * v[key] + (1 - BETA1) * grads[key]
|
| 482 |
+
s[key] = BETA2 * s[key] + (1 - BETA2) * (grads[key] ** 2)
|
| 483 |
+
|
| 484 |
+
v_corrected = v[key] / (1 - BETA1 ** t)
|
| 485 |
+
s_corrected = s[key] / (1 - BETA2 ** t)
|
| 486 |
+
|
| 487 |
+
updated_params[key] = params[key] - learning_rate * v_corrected / (np.sqrt(s_corrected) + EPSILON)
|
| 488 |
+
|
| 489 |
+
return updated_params, v, s
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def get_predictions(probs):
|
| 493 |
+
return np.argmax(probs, axis=0)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def get_accuracy(probs, labels):
|
| 497 |
+
predictions = get_predictions(probs)
|
| 498 |
+
return np.mean(predictions == labels)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
"""
|
| 502 |
+
Section 12: Augments the batch with horizontal shifts and contrast/brightness jitter
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
def augment_batch(
|
| 506 |
+
X_batch,
|
| 507 |
+
*,
|
| 508 |
+
image_shape: tuple[int, int] = (28, 56),
|
| 509 |
+
max_shift: int = MAX_SHIFT_PIXELS,
|
| 510 |
+
contrast_jitter_std: float = CONTRAST_JITTER_STD,
|
| 511 |
+
):
|
| 512 |
+
"""
|
| 513 |
+
Apply lightweight augmentation: horizontal shifts and contrast/brightness jitter.
|
| 514 |
+
"""
|
| 515 |
+
if max_shift <= 0 and contrast_jitter_std <= 0.0:
|
| 516 |
+
return X_batch
|
| 517 |
+
|
| 518 |
+
batch_size = X_batch.shape[1]
|
| 519 |
+
images = X_batch.T.reshape(batch_size, *image_shape)
|
| 520 |
+
|
| 521 |
+
if max_shift > 0:
|
| 522 |
+
shifts = np.random.randint(-max_shift, max_shift + 1, size=batch_size)
|
| 523 |
+
for idx, shift in enumerate(shifts):
|
| 524 |
+
if shift > 0:
|
| 525 |
+
shifted = np.roll(images[idx], shift, axis=1)
|
| 526 |
+
shifted[:, :shift] = 0.0
|
| 527 |
+
images[idx] = shifted
|
| 528 |
+
elif shift < 0:
|
| 529 |
+
shift = -shift
|
| 530 |
+
shifted = np.roll(images[idx], -shift, axis=1)
|
| 531 |
+
shifted[:, -shift:] = 0.0
|
| 532 |
+
images[idx] = shifted
|
| 533 |
+
|
| 534 |
+
if contrast_jitter_std > 0.0:
|
| 535 |
+
scale = 1.0 + np.random.normal(0.0, contrast_jitter_std, size=batch_size)
|
| 536 |
+
bias = np.random.normal(0.0, contrast_jitter_std, size=batch_size)
|
| 537 |
+
images *= scale[:, None, None]
|
| 538 |
+
images += bias[:, None, None]
|
| 539 |
+
np.clip(images, -3.0, 3.0, out=images)
|
| 540 |
+
|
| 541 |
+
return images.reshape(batch_size, -1).T
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
"""
|
| 545 |
+
Section 13: Trains the model + evaluates the model
|
| 546 |
+
"""
|
| 547 |
+
def train_model(
|
| 548 |
+
X_train,
|
| 549 |
+
Y_train,
|
| 550 |
+
X_dev,
|
| 551 |
+
Y_dev,
|
| 552 |
+
*,
|
| 553 |
+
epochs: int = EPOCHS,
|
| 554 |
+
batch_size: int = BATCH_SIZE,
|
| 555 |
+
learning_rate: float = LEARNING_RATE,
|
| 556 |
+
reg_lambda: float = REG_LAMBDA,
|
| 557 |
+
dropout_rate: float = DROP_RATE_FC,
|
| 558 |
+
early_stop_patience: int = EARLY_STOP_PATIENCE,
|
| 559 |
+
early_stop_min_delta: float = EARLY_STOP_MIN_DELTA,
|
| 560 |
+
use_augmentation: bool = True,
|
| 561 |
+
):
|
| 562 |
+
params = init_params()
|
| 563 |
+
v, s = init_adam(params)
|
| 564 |
+
m_train = X_train.shape[1]
|
| 565 |
+
global_step = 0
|
| 566 |
+
best_dev_acc = -np.inf
|
| 567 |
+
best_params = deepcopy(params)
|
| 568 |
+
patience_counter = 0
|
| 569 |
+
history = []
|
| 570 |
+
|
| 571 |
+
for epoch in range(1, epochs + 1):
|
| 572 |
+
permutation = np.random.permutation(m_train)
|
| 573 |
+
X_shuffled = X_train[:, permutation]
|
| 574 |
+
Y_shuffled = Y_train[permutation]
|
| 575 |
+
|
| 576 |
+
epoch_loss = 0.0
|
| 577 |
+
|
| 578 |
+
for start in range(0, m_train, batch_size):
|
| 579 |
+
end = min(start + batch_size, m_train)
|
| 580 |
+
X_batch = X_shuffled[:, start:end]
|
| 581 |
+
Y_batch_indices = Y_shuffled[start:end]
|
| 582 |
+
Y_batch = one_hot(Y_batch_indices)
|
| 583 |
+
|
| 584 |
+
if use_augmentation:
|
| 585 |
+
X_batch = augment_batch(X_batch.copy())
|
| 586 |
+
|
| 587 |
+
cache, probs = forward_prop(
|
| 588 |
+
X_batch,
|
| 589 |
+
params,
|
| 590 |
+
training=True,
|
| 591 |
+
dropout_rate=dropout_rate,
|
| 592 |
+
)
|
| 593 |
+
loss = compute_loss(probs, Y_batch, params, reg_lambda)
|
| 594 |
+
grads = back_prop(cache, Y_batch, params, reg_lambda, dropout_rate)
|
| 595 |
+
|
| 596 |
+
global_step += 1
|
| 597 |
+
params, v, s = update_params_adam(params, grads, v, s, global_step, learning_rate)
|
| 598 |
+
|
| 599 |
+
epoch_loss += loss * (end - start)
|
| 600 |
+
|
| 601 |
+
epoch_loss /= m_train
|
| 602 |
+
|
| 603 |
+
_, train_probs = forward_prop(X_train, params, training=False, dropout_rate=dropout_rate)
|
| 604 |
+
train_accuracy = get_accuracy(train_probs, Y_train)
|
| 605 |
+
|
| 606 |
+
_, dev_probs = forward_prop(X_dev, params, training=False, dropout_rate=dropout_rate)
|
| 607 |
+
dev_accuracy = get_accuracy(dev_probs, Y_dev)
|
| 608 |
+
|
| 609 |
+
print(
|
| 610 |
+
f"Epoch {epoch:02d} - loss: {epoch_loss:.4f} "
|
| 611 |
+
f"- train_acc: {train_accuracy:.4f} - dev_acc: {dev_accuracy:.4f}"
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
history.append(
|
| 615 |
+
{
|
| 616 |
+
"epoch": epoch,
|
| 617 |
+
"loss": epoch_loss,
|
| 618 |
+
"train_acc": train_accuracy,
|
| 619 |
+
"dev_acc": dev_accuracy,
|
| 620 |
+
}
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
if dev_accuracy > best_dev_acc + early_stop_min_delta:
|
| 624 |
+
best_dev_acc = dev_accuracy
|
| 625 |
+
best_params = deepcopy(params)
|
| 626 |
+
patience_counter = 0
|
| 627 |
+
else:
|
| 628 |
+
patience_counter += 1
|
| 629 |
+
if patience_counter >= early_stop_patience:
|
| 630 |
+
print(
|
| 631 |
+
f"Early stopping triggered at epoch {epoch:02d}. "
|
| 632 |
+
f"Best dev_acc={best_dev_acc:.4f}"
|
| 633 |
+
)
|
| 634 |
+
break
|
| 635 |
+
|
| 636 |
+
return best_params, history
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def evaluate(params, X, Y):
|
| 640 |
+
_, probs = forward_prop(X, params, training=False)
|
| 641 |
+
predictions = get_predictions(probs)
|
| 642 |
+
accuracy = np.mean(predictions == Y)
|
| 643 |
+
return predictions, accuracy
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
"""
|
| 647 |
+
Section 14: Trains the model once
|
| 648 |
+
"""
|
| 649 |
+
def train_once(
|
| 650 |
+
learning_rate: float,
|
| 651 |
+
reg_lambda: float,
|
| 652 |
+
*,
|
| 653 |
+
epochs: int = EPOCHS,
|
| 654 |
+
batch_size: int = BATCH_SIZE,
|
| 655 |
+
dropout_rate: float = DROP_RATE_FC,
|
| 656 |
+
history_path: Path | None = None,
|
| 657 |
+
):
|
| 658 |
+
"""
|
| 659 |
+
Convenience wrapper for hyperparameter sweeps. Returns trained params and dev accuracy.
|
| 660 |
+
"""
|
| 661 |
+
X_train, Y_train, X_dev, Y_dev, _, _ = load_data(DATASET_PATH)
|
| 662 |
+
X_train, X_dev, mean, std = normalize_features(X_train, X_dev)
|
| 663 |
+
|
| 664 |
+
params, history = train_model(
|
| 665 |
+
X_train,
|
| 666 |
+
Y_train,
|
| 667 |
+
X_dev,
|
| 668 |
+
Y_dev,
|
| 669 |
+
epochs=epochs,
|
| 670 |
+
batch_size=batch_size,
|
| 671 |
+
learning_rate=learning_rate,
|
| 672 |
+
reg_lambda=reg_lambda,
|
| 673 |
+
dropout_rate=dropout_rate,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
_, dev_accuracy = evaluate(params, X_dev, Y_dev)
|
| 677 |
+
|
| 678 |
+
if history_path is not None:
|
| 679 |
+
save_history_to_csv(history, history_path)
|
| 680 |
+
|
| 681 |
+
return params, dev_accuracy, mean, std, history
|
| 682 |
+
|
| 683 |
+
"""
|
| 684 |
+
Section 15: Hyperparameter sweep for learning rate, regularization and dropout rate
|
| 685 |
+
"""
|
| 686 |
+
|
| 687 |
+
def lr_sweep(
|
| 688 |
+
learning_rates: list[float],
|
| 689 |
+
*,
|
| 690 |
+
reg_lambda: float = REG_LAMBDA,
|
| 691 |
+
epochs: int = EPOCHS,
|
| 692 |
+
batch_size: int = BATCH_SIZE,
|
| 693 |
+
dropout_rate: float = DROP_RATE_FC,
|
| 694 |
+
history_dir: Path | None = None,
|
| 695 |
+
summary_path: Path | None = None,
|
| 696 |
+
):
|
| 697 |
+
results = []
|
| 698 |
+
history_directory = Path(history_dir) if history_dir is not None else None
|
| 699 |
+
if history_directory is not None:
|
| 700 |
+
history_directory.mkdir(parents=True, exist_ok=True)
|
| 701 |
+
|
| 702 |
+
for lr in learning_rates:
|
| 703 |
+
history_path = None
|
| 704 |
+
if history_directory is not None:
|
| 705 |
+
safe_lr = f"{lr:.2e}".replace("+", "").replace("-", "m")
|
| 706 |
+
history_path = history_directory / f"lr_{safe_lr}.csv"
|
| 707 |
+
_, dev_acc, _, _, history = train_once(
|
| 708 |
+
lr,
|
| 709 |
+
reg_lambda,
|
| 710 |
+
epochs=epochs,
|
| 711 |
+
batch_size=batch_size,
|
| 712 |
+
dropout_rate=dropout_rate,
|
| 713 |
+
history_path=history_path,
|
| 714 |
+
)
|
| 715 |
+
results.append(
|
| 716 |
+
{
|
| 717 |
+
"learning_rate": float(lr),
|
| 718 |
+
"reg_lambda": float(reg_lambda),
|
| 719 |
+
"dev_acc": float(dev_acc),
|
| 720 |
+
"history": history,
|
| 721 |
+
}
|
| 722 |
+
)
|
| 723 |
+
if summary_path is not None:
|
| 724 |
+
save_sweep_summary(results, summary_path)
|
| 725 |
+
return results
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def random_search_hparams(
|
| 729 |
+
num_trials: int,
|
| 730 |
+
lr_bounds: tuple[float, float],
|
| 731 |
+
reg_bounds: tuple[float, float],
|
| 732 |
+
*,
|
| 733 |
+
epochs: int = EPOCHS,
|
| 734 |
+
batch_size: int = BATCH_SIZE,
|
| 735 |
+
dropout_rate: float = DROP_RATE_FC,
|
| 736 |
+
seed: int | None = None,
|
| 737 |
+
history_dir: Path | None = None,
|
| 738 |
+
summary_path: Path | None = None,
|
| 739 |
+
):
|
| 740 |
+
if num_trials <= 0:
|
| 741 |
+
raise ValueError("num_trials must be positive")
|
| 742 |
+
|
| 743 |
+
lr_min, lr_max = lr_bounds
|
| 744 |
+
reg_min, reg_max = reg_bounds
|
| 745 |
+
if lr_min <= 0 or lr_max <= 0:
|
| 746 |
+
raise ValueError("Learning rate bounds must be positive")
|
| 747 |
+
if reg_min <= 0 or reg_max <= 0:
|
| 748 |
+
raise ValueError("Regularization bounds must be positive")
|
| 749 |
+
|
| 750 |
+
rng = np.random.default_rng(seed)
|
| 751 |
+
history_directory = Path(history_dir) if history_dir is not None else None
|
| 752 |
+
if history_directory is not None:
|
| 753 |
+
history_directory.mkdir(parents=True, exist_ok=True)
|
| 754 |
+
|
| 755 |
+
results = []
|
| 756 |
+
log_lr_min, log_lr_max = np.log(lr_min), np.log(lr_max)
|
| 757 |
+
log_reg_min, log_reg_max = np.log(reg_min), np.log(reg_max)
|
| 758 |
+
|
| 759 |
+
for trial in range(1, num_trials + 1):
|
| 760 |
+
lr_sample = float(np.exp(rng.uniform(log_lr_min, log_lr_max)))
|
| 761 |
+
reg_sample = float(np.exp(rng.uniform(log_reg_min, log_reg_max)))
|
| 762 |
+
history_path = None
|
| 763 |
+
if history_directory is not None:
|
| 764 |
+
safe_lr = f"{lr_sample:.2e}".replace("+", "").replace("-", "m")
|
| 765 |
+
safe_reg = f"{reg_sample:.2e}".replace("+", "").replace("-", "m")
|
| 766 |
+
history_path = history_directory / f"trial_{trial:02d}_lr-{safe_lr}_reg-{safe_reg}.csv"
|
| 767 |
+
|
| 768 |
+
_, dev_acc, _, _, history = train_once(
|
| 769 |
+
lr_sample,
|
| 770 |
+
reg_sample,
|
| 771 |
+
epochs=epochs,
|
| 772 |
+
batch_size=batch_size,
|
| 773 |
+
dropout_rate=dropout_rate,
|
| 774 |
+
history_path=history_path,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
results.append(
|
| 778 |
+
{
|
| 779 |
+
"trial": trial,
|
| 780 |
+
"learning_rate": lr_sample,
|
| 781 |
+
"reg_lambda": reg_sample,
|
| 782 |
+
"dev_acc": float(dev_acc),
|
| 783 |
+
"history": history,
|
| 784 |
+
}
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
results.sort(key=lambda item: item["dev_acc"], reverse=True)
|
| 788 |
+
if summary_path is not None:
|
| 789 |
+
save_sweep_summary(results, summary_path, include_trial=True)
|
| 790 |
+
return results
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def auto_train_pipeline(
|
| 794 |
+
*,
|
| 795 |
+
trials: int,
|
| 796 |
+
lr_bounds: tuple[float, float],
|
| 797 |
+
reg_bounds: tuple[float, float],
|
| 798 |
+
search_epochs: int,
|
| 799 |
+
final_epochs: int,
|
| 800 |
+
batch_size: int,
|
| 801 |
+
dropout_rate: float,
|
| 802 |
+
final_batch_size: int | None,
|
| 803 |
+
final_dropout_rate: float | None,
|
| 804 |
+
history_dir: Path | None,
|
| 805 |
+
seed: int | None,
|
| 806 |
+
output_model_path: Path | None,
|
| 807 |
+
):
|
| 808 |
+
history_directory = Path(history_dir) if history_dir is not None else None
|
| 809 |
+
if history_directory is not None:
|
| 810 |
+
history_directory.mkdir(parents=True, exist_ok=True)
|
| 811 |
+
|
| 812 |
+
search_summary_path = None
|
| 813 |
+
if history_directory is not None:
|
| 814 |
+
search_summary_path = history_directory / "random_search_summary.csv"
|
| 815 |
+
|
| 816 |
+
results = random_search_hparams(
|
| 817 |
+
trials,
|
| 818 |
+
lr_bounds,
|
| 819 |
+
reg_bounds,
|
| 820 |
+
epochs=search_epochs,
|
| 821 |
+
batch_size=batch_size,
|
| 822 |
+
dropout_rate=dropout_rate,
|
| 823 |
+
seed=seed,
|
| 824 |
+
history_dir=history_directory / "search_histories" if history_directory is not None else None,
|
| 825 |
+
summary_path=search_summary_path,
|
| 826 |
+
)
|
| 827 |
+
best = results[0]
|
| 828 |
+
print(
|
| 829 |
+
f"\nBest search trial -> LR={best['learning_rate']:.3e}, "
|
| 830 |
+
f"reg={best['reg_lambda']:.3e}, dev_acc={best['dev_acc']:.4f}"
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
final_dropout = final_dropout_rate if final_dropout_rate is not None else dropout_rate
|
| 834 |
+
final_history_path = None
|
| 835 |
+
if history_directory is not None:
|
| 836 |
+
final_history_path = history_directory / "final_train_history.csv"
|
| 837 |
+
|
| 838 |
+
params, final_dev_acc, mean, std, final_history = train_once(
|
| 839 |
+
best["learning_rate"],
|
| 840 |
+
best["reg_lambda"],
|
| 841 |
+
epochs=final_epochs,
|
| 842 |
+
batch_size=final_batch_size or batch_size,
|
| 843 |
+
dropout_rate=final_dropout,
|
| 844 |
+
history_path=final_history_path,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
model_output_path = output_model_path if output_model_path is not None else ARCHIVE_DIR / "trained_model_mnist100.npz"
|
| 848 |
+
save_model(params, mean, std, model_output_path)
|
| 849 |
+
|
| 850 |
+
return {
|
| 851 |
+
"best_trial": best,
|
| 852 |
+
"final_dev_acc": final_dev_acc,
|
| 853 |
+
"model_path": Path(model_output_path),
|
| 854 |
+
"final_history": final_history,
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
"""
|
| 859 |
+
Section 16: Saves the model
|
| 860 |
+
"""
|
| 861 |
+
def save_model(params, mean, std, filepath=None):
|
| 862 |
+
target_path = Path(filepath) if filepath is not None else ARCHIVE_DIR / "trained_model_mnist100.npz"
|
| 863 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 864 |
+
print(f"\nSaving trained model to '{target_path}'...")
|
| 865 |
+
np.savez(target_path, **params, mean=mean, std=std)
|
| 866 |
+
print("Model saved successfully!")
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
"""
|
| 870 |
+
Section 17: Main function
|
| 871 |
+
"""
|
| 872 |
+
|
| 873 |
+
def main():
|
| 874 |
+
parser = argparse.ArgumentParser(description="MNIST-100 training and tuning utilities.")
|
| 875 |
+
parser.add_argument(
|
| 876 |
+
"--mode",
|
| 877 |
+
choices=("train", "lr-sweep", "random-search", "auto-train"),
|
| 878 |
+
default="train",
|
| 879 |
+
help="Select high-level action.",
|
| 880 |
+
)
|
| 881 |
+
parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, help="Base learning rate.")
|
| 882 |
+
parser.add_argument("--learning-rates", type=str, help="Comma-separated list for LR sweep.")
|
| 883 |
+
parser.add_argument("--reg-lambda", type=float, default=REG_LAMBDA, help="L2 regularization strength.")
|
| 884 |
+
parser.add_argument("--lr-min", type=float, default=1e-4, help="Min LR for random search (exclusive mode).")
|
| 885 |
+
parser.add_argument("--lr-max", type=float, default=5e-3, help="Max LR for random search.")
|
| 886 |
+
parser.add_argument("--reg-min", type=float, default=1e-5, help="Min lambda for random search.")
|
| 887 |
+
parser.add_argument("--reg-max", type=float, default=1e-3, help="Max lambda for random search.")
|
| 888 |
+
parser.add_argument("--trials", type=int, default=5, help="Number of random-search trials.")
|
| 889 |
+
parser.add_argument("--epochs", type=int, default=EPOCHS, help="Train epochs per run.")
|
| 890 |
+
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Mini-batch size.")
|
| 891 |
+
parser.add_argument(
|
| 892 |
+
"--final-epochs",
|
| 893 |
+
type=int,
|
| 894 |
+
default=40,
|
| 895 |
+
help="Epoch budget for the final training run in auto-train mode.",
|
| 896 |
+
)
|
| 897 |
+
parser.add_argument(
|
| 898 |
+
"--final-batch-size",
|
| 899 |
+
type=int,
|
| 900 |
+
help="Mini-batch size for the final training run (defaults to --batch-size).",
|
| 901 |
+
)
|
| 902 |
+
parser.add_argument(
|
| 903 |
+
"--dropout",
|
| 904 |
+
type=float,
|
| 905 |
+
help="Override dropout rate for the fully connected layer.",
|
| 906 |
+
)
|
| 907 |
+
parser.add_argument(
|
| 908 |
+
"--final-dropout",
|
| 909 |
+
type=float,
|
| 910 |
+
help="Dropout rate for the final training pass in auto-train mode.",
|
| 911 |
+
)
|
| 912 |
+
parser.add_argument(
|
| 913 |
+
"--history-dir",
|
| 914 |
+
type=Path,
|
| 915 |
+
help="Directory for saving training histories (CSV).",
|
| 916 |
+
)
|
| 917 |
+
parser.add_argument(
|
| 918 |
+
"--output-model",
|
| 919 |
+
type=Path,
|
| 920 |
+
help="Path to save the trained model (.npz). Defaults to archive/trained_model_mnist100.npz.",
|
| 921 |
+
)
|
| 922 |
+
parser.add_argument("--seed", type=int, help="Random seed for random search.")
|
| 923 |
+
args = parser.parse_args()
|
| 924 |
+
|
| 925 |
+
dropout_rate = DROP_RATE_FC if args.dropout is None else float(args.dropout)
|
| 926 |
+
if not 0.0 <= dropout_rate < 1.0:
|
| 927 |
+
raise ValueError("Dropout rate must be in [0, 1).")
|
| 928 |
+
|
| 929 |
+
final_dropout_rate = None
|
| 930 |
+
if args.final_dropout is not None:
|
| 931 |
+
final_dropout_rate = float(args.final_dropout)
|
| 932 |
+
if not 0.0 <= final_dropout_rate < 1.0:
|
| 933 |
+
raise ValueError("Final dropout rate must be in [0, 1).")
|
| 934 |
+
|
| 935 |
+
history_dir = args.history_dir
|
| 936 |
+
if history_dir is not None:
|
| 937 |
+
history_dir = Path(history_dir)
|
| 938 |
+
history_dir.mkdir(parents=True, exist_ok=True)
|
| 939 |
+
|
| 940 |
+
if args.mode == "train":
|
| 941 |
+
print(f"Loading dataset from '{DATASET_PATH}'...")
|
| 942 |
+
X_train, Y_train, X_dev, Y_dev, _, _ = load_data(DATASET_PATH)
|
| 943 |
+
X_train, X_dev, mean, std = normalize_features(X_train, X_dev)
|
| 944 |
+
|
| 945 |
+
print(
|
| 946 |
+
f"Training samples: {X_train.shape[1]}, features: {X_train.shape[0]} "
|
| 947 |
+
f"| Dev samples: {X_dev.shape[1]}"
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
params, history = train_model(
|
| 951 |
+
X_train,
|
| 952 |
+
Y_train,
|
| 953 |
+
X_dev,
|
| 954 |
+
Y_dev,
|
| 955 |
+
epochs=args.epochs,
|
| 956 |
+
batch_size=args.batch_size,
|
| 957 |
+
learning_rate=args.learning_rate,
|
| 958 |
+
reg_lambda=args.reg_lambda,
|
| 959 |
+
dropout_rate=dropout_rate,
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
_, dev_accuracy = evaluate(params, X_dev, Y_dev)
|
| 963 |
+
print(f"\nFinal Dev Accuracy: {dev_accuracy:.4f}")
|
| 964 |
+
|
| 965 |
+
if history_dir is not None:
|
| 966 |
+
save_history_to_csv(history, history_dir / "train_history.csv")
|
| 967 |
+
|
| 968 |
+
save_model(params, mean, std, args.output_model or ARCHIVE_DIR / "trained_model_mnist100.npz")
|
| 969 |
+
|
| 970 |
+
elif args.mode == "lr-sweep":
|
| 971 |
+
if args.learning_rates is None:
|
| 972 |
+
raise ValueError("LR sweep mode requires --learning-rates.")
|
| 973 |
+
lr_values = [float(value.strip()) for value in args.learning_rates.split(",") if value.strip()]
|
| 974 |
+
print(f"Running LR sweep over {lr_values}...")
|
| 975 |
+
summary_path = history_dir / "lr_sweep_summary.csv" if history_dir is not None else None
|
| 976 |
+
results = lr_sweep(
|
| 977 |
+
lr_values,
|
| 978 |
+
reg_lambda=args.reg_lambda,
|
| 979 |
+
epochs=args.epochs,
|
| 980 |
+
batch_size=args.batch_size,
|
| 981 |
+
dropout_rate=dropout_rate,
|
| 982 |
+
history_dir=history_dir,
|
| 983 |
+
summary_path=summary_path,
|
| 984 |
+
)
|
| 985 |
+
for entry in results:
|
| 986 |
+
print(
|
| 987 |
+
f"LR={entry['learning_rate']:.3e} | reg={entry['reg_lambda']:.3e} "
|
| 988 |
+
f"| dev_acc={entry['dev_acc']:.4f}"
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
elif args.mode == "random-search":
|
| 992 |
+
print(
|
| 993 |
+
f"Running random search ({args.trials} trials) "
|
| 994 |
+
f"LR鈭圼{args.lr_min:.2e},{args.lr_max:.2e}], "
|
| 995 |
+
f"位鈭圼{args.reg_min:.2e},{args.reg_max:.2e}]..."
|
| 996 |
+
)
|
| 997 |
+
summary_path = history_dir / "random_search_summary.csv" if history_dir is not None else None
|
| 998 |
+
results = random_search_hparams(
|
| 999 |
+
args.trials,
|
| 1000 |
+
(args.lr_min, args.lr_max),
|
| 1001 |
+
(args.reg_min, args.reg_max),
|
| 1002 |
+
epochs=args.epochs,
|
| 1003 |
+
batch_size=args.batch_size,
|
| 1004 |
+
dropout_rate=dropout_rate,
|
| 1005 |
+
seed=args.seed,
|
| 1006 |
+
history_dir=history_dir,
|
| 1007 |
+
summary_path=summary_path,
|
| 1008 |
+
)
|
| 1009 |
+
for entry in results:
|
| 1010 |
+
print(
|
| 1011 |
+
f"Trial {entry['trial']:02d} | LR={entry['learning_rate']:.3e} "
|
| 1012 |
+
f"| reg={entry['reg_lambda']:.3e} | dev_acc={entry['dev_acc']:.4f}"
|
| 1013 |
+
)
|
| 1014 |
+
best = results[0]
|
| 1015 |
+
print(
|
| 1016 |
+
f"\nBest trial -> LR={best['learning_rate']:.3e}, "
|
| 1017 |
+
f"reg={best['reg_lambda']:.3e}, dev_acc={best['dev_acc']:.4f}"
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
elif args.mode == "auto-train":
|
| 1021 |
+
print(
|
| 1022 |
+
f"Auto-train pipeline: {args.trials} search trials "
|
| 1023 |
+
f"(epochs={args.epochs}) followed by final training (epochs={args.final_epochs})."
|
| 1024 |
+
)
|
| 1025 |
+
results = auto_train_pipeline(
|
| 1026 |
+
trials=args.trials,
|
| 1027 |
+
lr_bounds=(args.lr_min, args.lr_max),
|
| 1028 |
+
reg_bounds=(args.reg_min, args.reg_max),
|
| 1029 |
+
search_epochs=args.epochs,
|
| 1030 |
+
final_epochs=args.final_epochs,
|
| 1031 |
+
batch_size=args.batch_size,
|
| 1032 |
+
dropout_rate=dropout_rate,
|
| 1033 |
+
final_batch_size=args.final_batch_size,
|
| 1034 |
+
final_dropout_rate=final_dropout_rate,
|
| 1035 |
+
history_dir=history_dir,
|
| 1036 |
+
seed=args.seed,
|
| 1037 |
+
output_model_path=args.output_model,
|
| 1038 |
+
)
|
| 1039 |
+
best = results["best_trial"]
|
| 1040 |
+
print(
|
| 1041 |
+
f"\nAuto-train complete. "
|
| 1042 |
+
f"Best trial LR={best['learning_rate']:.3e}, reg={best['reg_lambda']:.3e}. "
|
| 1043 |
+
f"Final dev_acc={results['final_dev_acc']:.4f}. "
|
| 1044 |
+
f"Model saved to '{results['model_path']}'."
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
if __name__ == "__main__":
|
| 1049 |
+
main()
|