MNIST LSTM Classifier
This repository contains a validation-selected MNIST LSTM digit classifier trained with dlab.
Architecture
Results
3-seed confirmation and test audit for the selected LSTM recipe:
| metric | value |
|---|---|
| validation accuracy | 99.2500% ± 0.1163 pp |
| validation loss | 0.09547 ± 0.00342 |
| test accuracy | 99.2433% ± 0.0573 pp |
| test loss | 0.09554 ± 0.00185 |
Representative checkpoint selected by validation accuracy:
| metric | value |
|---|---|
| seed | 9001 |
| selected validation accuracy | 99.3833% |
| selected validation loss | 0.09157 |
| test accuracy | 99.3100% |
| test loss | 0.09294 |
The ONNX model was exported from the validation-selected checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B test-audit run ntzpryf9.
Model Details
- Dataset: MNIST
- Architecture: LSTM sequence classifier
- Sequence axis:
columns - Pooling:
last - Hidden width:
384 - Recurrent layers:
2 - Bidirectional:
false - Dropout:
0.1 - Optimizer: AdamW
- Learning rate:
0.003 - Weight decay:
0.001 - Scheduler: cosine
- Label smoothing:
0.01 - Batch size:
512 - Training augmentation:
true - Checkpoint selection: max validation accuracy
- Source W&B run:
652i33os
Input / Output
Use model.onnx for code-independent inference.
- Input name:
images - Input shape:
[batch, 1, 28, 28] - Input dtype:
float32 - Output name:
logits - Output shape:
[batch, 10]
Preprocessing:
- Convert image to grayscale.
- Resize to
28 x 28. - Scale pixel values to
[0, 1]. - Normalize with mean
0.1307and standard deviation0.3081. - Arrange the tensor as channels-first
[batch, 1, 28, 28].
Usage
Install the runtime dependencies:
pip install huggingface_hub onnxruntime pillow numpy
Run inference with the ONNX model:
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
LABELS = {
0: "0",
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9",
}
model_path = hf_hub_download(
repo_id="tsilva/mnist-classifier-lstm",
filename="model.onnx",
)
image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])
print(prediction, LABELS[prediction])
Labels
MNIST labels:
| id | label |
|---|---|
| 0 | 0 |
| 1 | 1 |
| 2 | 2 |
| 3 | 3 |
| 4 | 4 |
| 5 | 5 |
| 6 | 6 |
| 7 | 7 |
| 8 | 8 |
| 9 | 9 |
Files
model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.config.yaml: resolved Hydra training config.metrics.csv: training metrics from the uploaded checkpoint run.metadata.json: compact metadata for inference and provenance.
Limitations
This LSTM model treats each MNIST image as a short sequence rather than using convolutional inductive bias. It is intended for normalized 28 x 28 grayscale MNIST-style images; remaining errors are expected to concentrate in ambiguous handwritten digits and distribution shifts outside that input format.
- Downloads last month
- 6
