Spaces:
Sleeping
Sleeping
| """Train the from-scratch MLP on MNIST and save the weights. | |
| Usage: python -m nn.train [--epochs 6] [--batch 128] [--lr 1e-3] | |
| Saves to weights/mnist_mlp.npz. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| from .data import load_mnist | |
| from .model import MLP, Adam | |
| def accuracy(model: MLP, x: np.ndarray, y: np.ndarray, batch: int = 1000) -> float: | |
| correct = 0 | |
| for i in range(0, len(x), batch): | |
| preds = model.predict(x[i:i + batch]) | |
| correct += int((preds == y[i:i + batch]).sum()) | |
| return correct / len(x) | |
| def train(epochs: int = 6, batch: int = 128, lr: float = 1e-3, seed: int = 0, | |
| out: str = "weights/mnist_mlp.npz") -> float: | |
| x_train, y_train, x_test, y_test = load_mnist() | |
| model = MLP(seed=seed) | |
| opt = Adam(model, lr=lr) | |
| rng = np.random.default_rng(seed) | |
| n = len(x_train) | |
| for epoch in range(1, epochs + 1): | |
| order = rng.permutation(n) | |
| x_train, y_train = x_train[order], y_train[order] | |
| t0 = time.time() | |
| running = 0.0 | |
| steps = 0 | |
| for i in range(0, n, batch): | |
| xb, yb = x_train[i:i + batch], y_train[i:i + batch] | |
| loss = model.loss_and_grad(xb, yb) | |
| opt.step() | |
| running += loss | |
| steps += 1 | |
| acc = accuracy(model, x_test, y_test) | |
| print(f"epoch {epoch} loss={running / steps:.4f} test_acc={acc:.4f} ({time.time() - t0:.1f}s)") | |
| Path(out).parent.mkdir(parents=True, exist_ok=True) | |
| np.savez(out, **model.state()) | |
| print(f"saved weights -> {out} (final test acc {acc:.4f})") | |
| return acc | |
| if __name__ == "__main__": | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--epochs", type=int, default=6) | |
| ap.add_argument("--batch", type=int, default=128) | |
| ap.add_argument("--lr", type=float, default=1e-3) | |
| ap.parse_args() | |
| args = ap.parse_args() | |
| train(epochs=args.epochs, batch=args.batch, lr=args.lr) | |