MNIST RNN Classifier

This repository contains a validation-selected MNIST RNN digit classifier trained with dlab.

Architecture

MNIST RNN architecture

Results

3-seed confirmation and test audit for the selected RNN recipe:

metric value
validation accuracy 97.9167% ± 0.0491 pp
validation loss 0.07222 ± 0.00309
test accuracy 97.9733% ± 0.1370 pp
test loss 0.06857 ± 0.00544

Representative checkpoint selected by validation accuracy:

metric value
seed 9001
selected validation accuracy 97.9667%
selected validation loss 0.06796
test accuracy 97.7800%
test loss 0.07622

The ONNX model was exported from the validation-selected checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B test-audit run vdt5duxq.

Model Details

  • Dataset: MNIST
  • Architecture: RNN sequence classifier
  • Sequence axis: rows
  • Pooling: mean
  • Hidden width: 512
  • Recurrent layers: 2
  • Bidirectional: false
  • Dropout: 0
  • Optimizer: AdamW
  • Learning rate: 0.003
  • Weight decay: 0.0001
  • Scheduler: constant
  • Label smoothing: 0
  • Batch size: 256
  • Training augmentation: false
  • Checkpoint selection: max validation accuracy
  • Source W&B run: t8u8llqo

Input / Output

Use model.onnx for code-independent inference.

  • Input name: images
  • Input shape: [batch, 1, 28, 28]
  • Input dtype: float32
  • Output name: logits
  • Output shape: [batch, 10]

Preprocessing:

  • Convert image to grayscale.
  • Resize to 28 x 28.
  • Scale pixel values to [0, 1].
  • Normalize with mean 0.1307 and standard deviation 0.3081.
  • Arrange the tensor as channels-first [batch, 1, 28, 28].

Usage

Install the runtime dependencies:

pip install huggingface_hub onnxruntime pillow numpy

Run inference with the ONNX model:

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

LABELS = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9",
}

model_path = hf_hub_download(
    repo_id="tsilva/mnist-classifier-rnn",
    filename="model.onnx",
)

image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])

print(prediction, LABELS[prediction])

Labels

MNIST labels:

id label
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9

Files

  • model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
  • model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
  • config.yaml: resolved Hydra training config.
  • metrics.csv: training metrics from the uploaded checkpoint run.
  • metadata.json: compact metadata for inference and provenance.

Limitations

This RNN model treats each MNIST image as a short sequence rather than using convolutional inductive bias. It is intended for normalized 28 x 28 grayscale MNIST-style images; remaining errors are expected to concentrate in ambiguous handwritten digits and distribution shifts outside that input format.

Downloads last month
6
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tsilva/mnist-classifier-rnn