File size: 3,613 Bytes
15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf 3c82672 e73e9bc ca2bf2a 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc 15901bf e73e9bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | ---
license: mit
library_name: onnx
pipeline_tag: image-classification
tags:
- image-classification
- mnist
- multilayer-perceptron
- onnx
- onnxruntime
- pytorch
- dlab
datasets:
- mnist
metrics:
- accuracy
---
# MNIST MLP Classifier
This repository contains a validation-selected MNIST MLP digit classifier trained with [dlab](https://github.com/tsilva/dlab).
## Architecture

## Results
10-seed confirmation sweep:
| metric | value |
|---|---:|
| validation accuracy | 99.3600% ± 0.0817 pp |
| validation loss | 0.15172 ± 0.00235 |
| test accuracy | 99.4470% ± 0.0195 pp |
| test loss | 0.14746 ± 0.00034 |
| test errors | 55.3 ± 1.95 / 10000 |
The ONNX model was exported from the best run checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B sweep [`xa56lubb`](https://wandb.ai/tsilva/dlab/sweeps/xa56lubb).
## Model Details
- Dataset: MNIST
- Architecture: MLP
- Hidden width: `1024`
- Hidden layers: `3`
- Activation: ReLU
- Batch normalization: enabled
- Dropout: `0.2`
- Optimizer: Adam
- Learning rate: `0.001`
- Weight decay: `0.0001`
- Scheduler: OneCycleLR
- Label smoothing: `0.02`
- Weight averaging: EMA
- Batch size: `512`
- Training augmentation: random affine rotation/translation/scale
- Source W&B run: [`gsuy1ifx`](https://wandb.ai/tsilva/dlab/runs/gsuy1ifx)
- Source W&B sweep: [`xa56lubb`](https://wandb.ai/tsilva/dlab/sweeps/xa56lubb)
## Input / Output
Use `model.onnx` for code-independent inference.
- Input name: `images`
- Input shape: `[batch, 1, 28, 28]`
- Input dtype: `float32`
- Output name: `logits`
- Output shape: `[batch, 10]`
Preprocessing:
- Convert image to grayscale.
- Resize to `28 x 28`.
- Scale pixel values to `[0, 1]`.
- Normalize with mean `0.1307` and standard deviation `0.3081`.
- Arrange the tensor as channels-first `[batch, 1, 28, 28]`.
## Usage
Install the runtime dependencies:
```bash
pip install huggingface_hub onnxruntime pillow numpy
```
Run inference with the ONNX model:
```python
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
LABELS = {
0: "0",
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9",
}
model_path = hf_hub_download(
repo_id="tsilva/mnist-classifier-mlp",
filename="model.onnx",
)
image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])
print(prediction, LABELS[prediction])
```
## Labels
MNIST labels:
| id | label |
|---:|---|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |
| 5 | 5 |
| 6 | 6 |
| 7 | 7 |
| 8 | 8 |
| 9 | 9 |
## Files
- `model.onnx`: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
- `model.ckpt`: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
- `config.yaml`: resolved Hydra training config.
- `metrics.csv`: training metrics from the uploaded checkpoint run.
- `metadata.json`: compact metadata for inference and provenance.
## Limitations
This MLP does not use convolutional inductive bias. It performs strongly on MNIST, but remaining errors are mostly concentrated in ambiguous or unusually written digits.
|