--- license: mit library_name: onnx pipeline_tag: image-classification tags: - image-classification - mnist - convolutional-neural-network - onnx - onnxruntime - pytorch - dlab datasets: - mnist metrics: - accuracy --- # MNIST CNN Classifier This repository contains a validation-selected MNIST CNN digit classifier trained with [dlab](https://github.com/tsilva/dlab). ## Architecture ![MNIST CNN architecture](assets/architecture.png) ## Results 5-seed held-out test evaluation: | metric | value | |---|---:| | test accuracy | 99.6140% ± 0.0802 pp | | test loss | 0.14677 ± 0.00272 | | best validation loss | 0.14282 ± 0.00138 | Per-seed held-out test results: | seed | W&B run | test accuracy | test loss | best validation loss | |---:|---|---:|---:|---:| | 1 | [`5um57rnu`](https://wandb.ai/tsilva/dlab/runs/5um57rnu) | 99.6100% | 0.14772 | 0.14124 | | 2 | [`23f1frqb`](https://wandb.ai/tsilva/dlab/runs/23f1frqb) | 99.6800% | 0.14437 | 0.14317 | | 3 | [`25yaaj1o`](https://wandb.ai/tsilva/dlab/runs/25yaaj1o) | 99.4800% | 0.15107 | 0.14403 | | 4 | [`3rrlxghp`](https://wandb.ai/tsilva/dlab/runs/3rrlxghp) | 99.6300% | 0.14573 | 0.14150 | | 5 | [`y51200ov`](https://wandb.ai/tsilva/dlab/runs/y51200ov) | 99.6700% | 0.14494 | 0.14418 | The ONNX model was exported from the seed-1 checkpoint, which had the best validation loss in the final 5-seed evaluation sweep. Test metrics were not used for checkpoint selection and were logged in W&B sweep [`ikfs5ox8`](https://wandb.ai/tsilva/dlab/sweeps/ikfs5ox8). ## Model Details - Dataset: MNIST - Architecture: CNN - Channels: `[32, 64, 128]` - Convolutions per stage: `2` - Batch normalization: enabled - Dropout: `0.1` - Optimizer: Adam - Learning rate: `0.001` - Weight decay: `0.0001` - Scheduler: OneCycleLR - Label smoothing: `0.02` - Weight averaging: EMA - Batch size: `512` - Training augmentation: random affine rotation/translation/scale - Early stopping: validation loss, patience `8`, min delta `0.0005` - Source W&B run: [`5um57rnu`](https://wandb.ai/tsilva/dlab/runs/5um57rnu) - Source W&B sweep: [`ikfs5ox8`](https://wandb.ai/tsilva/dlab/sweeps/ikfs5ox8) ## Input / Output Use `model.onnx` for code-independent inference. - Input name: `images` - Input shape: `[batch, 1, 28, 28]` - Input dtype: `float32` - Output name: `logits` - Output shape: `[batch, 10]` Preprocessing: - Convert image to grayscale. - Resize to `28 x 28`. - Scale pixel values to `[0, 1]`. - Normalize with mean `0.1307` and standard deviation `0.3081`. - Arrange the tensor as channels-first `[batch, 1, 28, 28]`. ## Usage Install the runtime dependencies: ```bash pip install huggingface_hub onnxruntime pillow numpy ``` Run inference with the ONNX model: ```python import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download from PIL import Image LABELS = { 0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7: "7", 8: "8", 9: "9", } model_path = hf_hub_download( repo_id="tsilva/mnist-classifier-cnn", filename="model.onnx", ) image = Image.open("example.png").convert("L").resize((28, 28)) x = np.asarray(image, dtype=np.float32) / 255.0 x = (x - 0.1307) / 0.3081 x = x[None, None, :, :].astype(np.float32) session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) logits = session.run(["logits"], {"images": x})[0] prediction = int(logits.argmax(axis=1)[0]) print(prediction, LABELS[prediction]) ``` ## Labels MNIST labels: | id | label | |---:|---| | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 3 | | 4 | 4 | | 5 | 5 | | 6 | 6 | | 7 | 7 | | 8 | 8 | | 9 | 9 | ## Files - `model.onnx`: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference. - `model.ckpt`: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation. - `config.yaml`: resolved Hydra training config. - `metrics.csv`: training metrics from the uploaded checkpoint run. - `metrics_summary.csv`: compact 5-seed final evaluation summary. - `metadata.json`: compact metadata for inference and provenance. ## Limitations This compact CNN is near MNIST saturation. Remaining errors are expected to be rare and often visually ambiguous or unusually written digits.