File size: 3,613 Bytes
15901bf
e73e9bc
 
 
15901bf
e73e9bc
 
 
 
 
 
 
15901bf
e73e9bc
15901bf
e73e9bc
15901bf
 
 
 
e73e9bc
 
 
15901bf
e73e9bc
15901bf
 
 
3c82672
 
e73e9bc
 
ca2bf2a
 
 
 
 
15901bf
e73e9bc
 
 
15901bf
 
e73e9bc
 
 
15901bf
 
e73e9bc
15901bf
e73e9bc
 
15901bf
e73e9bc
15901bf
e73e9bc
 
 
 
15901bf
e73e9bc
15901bf
e73e9bc
 
 
 
 
 
 
 
 
15901bf
e73e9bc
 
 
 
 
15901bf
e73e9bc
 
 
 
 
 
 
 
 
15901bf
 
e73e9bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15901bf
 
e73e9bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15901bf
e73e9bc
15901bf
e73e9bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
---
license: mit
library_name: onnx
pipeline_tag: image-classification
tags:
- image-classification
- mnist
- multilayer-perceptron
- onnx
- onnxruntime
- pytorch
- dlab
datasets:
- mnist
metrics:
- accuracy
---

# MNIST MLP Classifier

This repository contains a validation-selected MNIST MLP digit classifier trained with [dlab](https://github.com/tsilva/dlab).

## Architecture

![MNIST MLP architecture](assets/architecture.png)

## Results

10-seed confirmation sweep:

| metric | value |
|---|---:|
| validation accuracy | 99.3600% ± 0.0817 pp |
| validation loss | 0.15172 ± 0.00235 |
| test accuracy | 99.4470% ± 0.0195 pp |
| test loss | 0.14746 ± 0.00034 |
| test errors | 55.3 ± 1.95 / 10000 |

The ONNX model was exported from the best run checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B sweep [`xa56lubb`](https://wandb.ai/tsilva/dlab/sweeps/xa56lubb).

## Model Details

- Dataset: MNIST
- Architecture: MLP
- Hidden width: `1024`
- Hidden layers: `3`
- Activation: ReLU
- Batch normalization: enabled
- Dropout: `0.2`
- Optimizer: Adam
- Learning rate: `0.001`
- Weight decay: `0.0001`
- Scheduler: OneCycleLR
- Label smoothing: `0.02`
- Weight averaging: EMA
- Batch size: `512`
- Training augmentation: random affine rotation/translation/scale
- Source W&B run: [`gsuy1ifx`](https://wandb.ai/tsilva/dlab/runs/gsuy1ifx)
- Source W&B sweep: [`xa56lubb`](https://wandb.ai/tsilva/dlab/sweeps/xa56lubb)

## Input / Output

Use `model.onnx` for code-independent inference.

- Input name: `images`
- Input shape: `[batch, 1, 28, 28]`
- Input dtype: `float32`
- Output name: `logits`
- Output shape: `[batch, 10]`

Preprocessing:

- Convert image to grayscale.
- Resize to `28 x 28`.
- Scale pixel values to `[0, 1]`.
- Normalize with mean `0.1307` and standard deviation `0.3081`.
- Arrange the tensor as channels-first `[batch, 1, 28, 28]`.

## Usage

Install the runtime dependencies:

```bash
pip install huggingface_hub onnxruntime pillow numpy
```

Run inference with the ONNX model:

```python
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

LABELS = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9",
}

model_path = hf_hub_download(
    repo_id="tsilva/mnist-classifier-mlp",
    filename="model.onnx",
)

image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])

print(prediction, LABELS[prediction])
```

## Labels

MNIST labels:

| id | label |
|---:|---|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |
| 5 | 5 |
| 6 | 6 |
| 7 | 7 |
| 8 | 8 |
| 9 | 9 |

## Files

- `model.onnx`: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
- `model.ckpt`: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
- `config.yaml`: resolved Hydra training config.
- `metrics.csv`: training metrics from the uploaded checkpoint run.
- `metadata.json`: compact metadata for inference and provenance.

## Limitations

This MLP does not use convolutional inductive bias. It performs strongly on MNIST, but remaining errors are mostly concentrated in ambiguous or unusually written digits.