tsilva's picture
Use ONNX-first model card structure
e73e9bc verified
---
license: mit
library_name: onnx
pipeline_tag: image-classification
tags:
- image-classification
- mnist
- multilayer-perceptron
- onnx
- onnxruntime
- pytorch
- dlab
datasets:
- mnist
metrics:
- accuracy
---
# MNIST MLP Classifier
This repository contains a validation-selected MNIST MLP digit classifier trained with [dlab](https://github.com/tsilva/dlab).
## Architecture
![MNIST MLP architecture](assets/architecture.png)
## Results
10-seed confirmation sweep:
| metric | value |
|---|---:|
| validation accuracy | 99.3600% ± 0.0817 pp |
| validation loss | 0.15172 ± 0.00235 |
| test accuracy | 99.4470% ± 0.0195 pp |
| test loss | 0.14746 ± 0.00034 |
| test errors | 55.3 ± 1.95 / 10000 |
The ONNX model was exported from the best run checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B sweep [`xa56lubb`](https://wandb.ai/tsilva/dlab/sweeps/xa56lubb).
## Model Details
- Dataset: MNIST
- Architecture: MLP
- Hidden width: `1024`
- Hidden layers: `3`
- Activation: ReLU
- Batch normalization: enabled
- Dropout: `0.2`
- Optimizer: Adam
- Learning rate: `0.001`
- Weight decay: `0.0001`
- Scheduler: OneCycleLR
- Label smoothing: `0.02`
- Weight averaging: EMA
- Batch size: `512`
- Training augmentation: random affine rotation/translation/scale
- Source W&B run: [`gsuy1ifx`](https://wandb.ai/tsilva/dlab/runs/gsuy1ifx)
- Source W&B sweep: [`xa56lubb`](https://wandb.ai/tsilva/dlab/sweeps/xa56lubb)
## Input / Output
Use `model.onnx` for code-independent inference.
- Input name: `images`
- Input shape: `[batch, 1, 28, 28]`
- Input dtype: `float32`
- Output name: `logits`
- Output shape: `[batch, 10]`
Preprocessing:
- Convert image to grayscale.
- Resize to `28 x 28`.
- Scale pixel values to `[0, 1]`.
- Normalize with mean `0.1307` and standard deviation `0.3081`.
- Arrange the tensor as channels-first `[batch, 1, 28, 28]`.
## Usage
Install the runtime dependencies:
```bash
pip install huggingface_hub onnxruntime pillow numpy
```
Run inference with the ONNX model:
```python
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
LABELS = {
0: "0",
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9",
}
model_path = hf_hub_download(
repo_id="tsilva/mnist-classifier-mlp",
filename="model.onnx",
)
image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])
print(prediction, LABELS[prediction])
```
## Labels
MNIST labels:
| id | label |
|---:|---|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |
| 5 | 5 |
| 6 | 6 |
| 7 | 7 |
| 8 | 8 |
| 9 | 9 |
## Files
- `model.onnx`: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
- `model.ckpt`: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
- `config.yaml`: resolved Hydra training config.
- `metrics.csv`: training metrics from the uploaded checkpoint run.
- `metadata.json`: compact metadata for inference and provenance.
## Limitations
This MLP does not use convolutional inductive bias. It performs strongly on MNIST, but remaining errors are mostly concentrated in ambiguous or unusually written digits.