metadata
license: mit
library_name: onnx
pipeline_tag: image-classification
tags:
- image-classification
- mnist
- multilayer-perceptron
- onnx
- onnxruntime
- pytorch
- dlab
datasets:
- mnist
metrics:
- accuracy
MNIST MLP Classifier
This repository contains a validation-selected MNIST MLP digit classifier trained with dlab.
Architecture
Results
10-seed confirmation sweep:
| metric | value |
|---|---|
| validation accuracy | 99.3600% ± 0.0817 pp |
| validation loss | 0.15172 ± 0.00235 |
| test accuracy | 99.4470% ± 0.0195 pp |
| test loss | 0.14746 ± 0.00034 |
| test errors | 55.3 ± 1.95 / 10000 |
The ONNX model was exported from the best run checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B sweep xa56lubb.
Model Details
- Dataset: MNIST
- Architecture: MLP
- Hidden width:
1024 - Hidden layers:
3 - Activation: ReLU
- Batch normalization: enabled
- Dropout:
0.2 - Optimizer: Adam
- Learning rate:
0.001 - Weight decay:
0.0001 - Scheduler: OneCycleLR
- Label smoothing:
0.02 - Weight averaging: EMA
- Batch size:
512 - Training augmentation: random affine rotation/translation/scale
- Source W&B run:
gsuy1ifx - Source W&B sweep:
xa56lubb
Input / Output
Use model.onnx for code-independent inference.
- Input name:
images - Input shape:
[batch, 1, 28, 28] - Input dtype:
float32 - Output name:
logits - Output shape:
[batch, 10]
Preprocessing:
- Convert image to grayscale.
- Resize to
28 x 28. - Scale pixel values to
[0, 1]. - Normalize with mean
0.1307and standard deviation0.3081. - Arrange the tensor as channels-first
[batch, 1, 28, 28].
Usage
Install the runtime dependencies:
pip install huggingface_hub onnxruntime pillow numpy
Run inference with the ONNX model:
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
LABELS = {
0: "0",
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9",
}
model_path = hf_hub_download(
repo_id="tsilva/mnist-classifier-mlp",
filename="model.onnx",
)
image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])
print(prediction, LABELS[prediction])
Labels
MNIST labels:
| id | label |
|---|---|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |
| 5 | 5 |
| 6 | 6 |
| 7 | 7 |
| 8 | 8 |
| 9 | 9 |
Files
model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.config.yaml: resolved Hydra training config.metrics.csv: training metrics from the uploaded checkpoint run.metadata.json: compact metadata for inference and provenance.
Limitations
This MLP does not use convolutional inductive bias. It performs strongly on MNIST, but remaining errors are mostly concentrated in ambiguous or unusually written digits.
