MNIST LSTM Classifier

This repository contains a validation-selected MNIST LSTM digit classifier trained with dlab.

Architecture

MNIST LSTM architecture

Results

3-seed confirmation and test audit for the selected LSTM recipe:

metric value
validation accuracy 99.2500% ± 0.1163 pp
validation loss 0.09547 ± 0.00342
test accuracy 99.2433% ± 0.0573 pp
test loss 0.09554 ± 0.00185

Representative checkpoint selected by validation accuracy:

metric value
seed 9001
selected validation accuracy 99.3833%
selected validation loss 0.09157
test accuracy 99.3100%
test loss 0.09294

The ONNX model was exported from the validation-selected checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B test-audit run ntzpryf9.

Model Details

  • Dataset: MNIST
  • Architecture: LSTM sequence classifier
  • Sequence axis: columns
  • Pooling: last
  • Hidden width: 384
  • Recurrent layers: 2
  • Bidirectional: false
  • Dropout: 0.1
  • Optimizer: AdamW
  • Learning rate: 0.003
  • Weight decay: 0.001
  • Scheduler: cosine
  • Label smoothing: 0.01
  • Batch size: 512
  • Training augmentation: true
  • Checkpoint selection: max validation accuracy
  • Source W&B run: 652i33os

Input / Output

Use model.onnx for code-independent inference.

  • Input name: images
  • Input shape: [batch, 1, 28, 28]
  • Input dtype: float32
  • Output name: logits
  • Output shape: [batch, 10]

Preprocessing:

  • Convert image to grayscale.
  • Resize to 28 x 28.
  • Scale pixel values to [0, 1].
  • Normalize with mean 0.1307 and standard deviation 0.3081.
  • Arrange the tensor as channels-first [batch, 1, 28, 28].

Usage

Install the runtime dependencies:

pip install huggingface_hub onnxruntime pillow numpy

Run inference with the ONNX model:

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

LABELS = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9",
}

model_path = hf_hub_download(
    repo_id="tsilva/mnist-classifier-lstm",
    filename="model.onnx",
)

image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])

print(prediction, LABELS[prediction])

Labels

MNIST labels:

id label
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9

Files

  • model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
  • model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
  • config.yaml: resolved Hydra training config.
  • metrics.csv: training metrics from the uploaded checkpoint run.
  • metadata.json: compact metadata for inference and provenance.

Limitations

This LSTM model treats each MNIST image as a short sequence rather than using convolutional inductive bias. It is intended for normalized 28 x 28 grayscale MNIST-style images; remaining errors are expected to concentrate in ambiguous handwritten digits and distribution shifts outside that input format.

Downloads last month
6
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tsilva/mnist-classifier-lstm