| --- |
| license: mit |
| library_name: onnx |
| pipeline_tag: image-classification |
| tags: |
| - image-classification |
| - mnist |
| - convolutional-neural-network |
| - onnx |
| - onnxruntime |
| - pytorch |
| - dlab |
| datasets: |
| - mnist |
| metrics: |
| - accuracy |
| --- |
| |
| # MNIST CNN Classifier |
|
|
| This repository contains a validation-selected MNIST CNN digit classifier trained with [dlab](https://github.com/tsilva/dlab). |
|
|
| ## Architecture |
|
|
|  |
|
|
| ## Results |
|
|
| 5-seed held-out test evaluation: |
|
|
| | metric | value | |
| |---|---:| |
| | test accuracy | 99.6140% ± 0.0802 pp | |
| | test loss | 0.14677 ± 0.00272 | |
| | best validation loss | 0.14282 ± 0.00138 | |
|
|
| Per-seed held-out test results: |
|
|
| | seed | W&B run | test accuracy | test loss | best validation loss | |
| |---:|---|---:|---:|---:| |
| | 1 | [`5um57rnu`](https://wandb.ai/tsilva/dlab/runs/5um57rnu) | 99.6100% | 0.14772 | 0.14124 | |
| | 2 | [`23f1frqb`](https://wandb.ai/tsilva/dlab/runs/23f1frqb) | 99.6800% | 0.14437 | 0.14317 | |
| | 3 | [`25yaaj1o`](https://wandb.ai/tsilva/dlab/runs/25yaaj1o) | 99.4800% | 0.15107 | 0.14403 | |
| | 4 | [`3rrlxghp`](https://wandb.ai/tsilva/dlab/runs/3rrlxghp) | 99.6300% | 0.14573 | 0.14150 | |
| | 5 | [`y51200ov`](https://wandb.ai/tsilva/dlab/runs/y51200ov) | 99.6700% | 0.14494 | 0.14418 | |
|
|
| The ONNX model was exported from the seed-1 checkpoint, which had the best validation loss in the final 5-seed evaluation sweep. Test metrics were not used for checkpoint selection and were logged in W&B sweep [`ikfs5ox8`](https://wandb.ai/tsilva/dlab/sweeps/ikfs5ox8). |
|
|
| ## Model Details |
|
|
| - Dataset: MNIST |
| - Architecture: CNN |
| - Channels: `[32, 64, 128]` |
| - Convolutions per stage: `2` |
| - Batch normalization: enabled |
| - Dropout: `0.1` |
| - Optimizer: Adam |
| - Learning rate: `0.001` |
| - Weight decay: `0.0001` |
| - Scheduler: OneCycleLR |
| - Label smoothing: `0.02` |
| - Weight averaging: EMA |
| - Batch size: `512` |
| - Training augmentation: random affine rotation/translation/scale |
| - Early stopping: validation loss, patience `8`, min delta `0.0005` |
| - Source W&B run: [`5um57rnu`](https://wandb.ai/tsilva/dlab/runs/5um57rnu) |
| - Source W&B sweep: [`ikfs5ox8`](https://wandb.ai/tsilva/dlab/sweeps/ikfs5ox8) |
|
|
| ## Input / Output |
|
|
| Use `model.onnx` for code-independent inference. |
|
|
| - Input name: `images` |
| - Input shape: `[batch, 1, 28, 28]` |
| - Input dtype: `float32` |
| - Output name: `logits` |
| - Output shape: `[batch, 10]` |
|
|
| Preprocessing: |
|
|
| - Convert image to grayscale. |
| - Resize to `28 x 28`. |
| - Scale pixel values to `[0, 1]`. |
| - Normalize with mean `0.1307` and standard deviation `0.3081`. |
| - Arrange the tensor as channels-first `[batch, 1, 28, 28]`. |
|
|
| ## Usage |
|
|
| Install the runtime dependencies: |
|
|
| ```bash |
| pip install huggingface_hub onnxruntime pillow numpy |
| ``` |
|
|
| Run inference with the ONNX model: |
|
|
| ```python |
| import numpy as np |
| import onnxruntime as ort |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| |
| LABELS = { |
| 0: "0", |
| 1: "1", |
| 2: "2", |
| 3: "3", |
| 4: "4", |
| 5: "5", |
| 6: "6", |
| 7: "7", |
| 8: "8", |
| 9: "9", |
| } |
| |
| model_path = hf_hub_download( |
| repo_id="tsilva/mnist-classifier-cnn", |
| filename="model.onnx", |
| ) |
| |
| image = Image.open("example.png").convert("L").resize((28, 28)) |
| x = np.asarray(image, dtype=np.float32) / 255.0 |
| x = (x - 0.1307) / 0.3081 |
| x = x[None, None, :, :].astype(np.float32) |
| |
| session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
| logits = session.run(["logits"], {"images": x})[0] |
| prediction = int(logits.argmax(axis=1)[0]) |
| |
| print(prediction, LABELS[prediction]) |
| ``` |
|
|
| ## Labels |
|
|
| MNIST labels: |
|
|
| | id | label | |
| |---:|---| |
| | 0 | 0 | |
| | 1 | 1 | |
| | 2 | 2 | |
| | 3 | 3 | |
| | 4 | 4 | |
| | 5 | 5 | |
| | 6 | 6 | |
| | 7 | 7 | |
| | 8 | 8 | |
| | 9 | 9 | |
|
|
| ## Files |
|
|
| - `model.onnx`: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference. |
| - `model.ckpt`: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation. |
| - `config.yaml`: resolved Hydra training config. |
| - `metrics.csv`: training metrics from the uploaded checkpoint run. |
| - `metrics_summary.csv`: compact 5-seed final evaluation summary. |
| - `metadata.json`: compact metadata for inference and provenance. |
|
|
| ## Limitations |
|
|
| This compact CNN is near MNIST saturation. Remaining errors are expected to be rare and often visually ambiguous or unusually written digits. |
|
|