MNIST MLP Classifier

This repository contains a validation-selected MNIST MLP digit classifier trained with dlab.

Architecture

MNIST MLP architecture

Results

10-seed confirmation sweep:

metric value
validation accuracy 99.3600% ± 0.0817 pp
validation loss 0.15172 ± 0.00235
test accuracy 99.4470% ± 0.0195 pp
test loss 0.14746 ± 0.00034
test errors 55.3 ± 1.95 / 10000

The ONNX model was exported from the best run checkpoint. Test metrics were produced after the recipe was selected and were logged in W&B sweep xa56lubb.

Model Details

  • Dataset: MNIST
  • Architecture: MLP
  • Hidden width: 1024
  • Hidden layers: 3
  • Activation: ReLU
  • Batch normalization: enabled
  • Dropout: 0.2
  • Optimizer: Adam
  • Learning rate: 0.001
  • Weight decay: 0.0001
  • Scheduler: OneCycleLR
  • Label smoothing: 0.02
  • Weight averaging: EMA
  • Batch size: 512
  • Training augmentation: random affine rotation/translation/scale
  • Source W&B run: gsuy1ifx
  • Source W&B sweep: xa56lubb

Input / Output

Use model.onnx for code-independent inference.

  • Input name: images
  • Input shape: [batch, 1, 28, 28]
  • Input dtype: float32
  • Output name: logits
  • Output shape: [batch, 10]

Preprocessing:

  • Convert image to grayscale.
  • Resize to 28 x 28.
  • Scale pixel values to [0, 1].
  • Normalize with mean 0.1307 and standard deviation 0.3081.
  • Arrange the tensor as channels-first [batch, 1, 28, 28].

Usage

Install the runtime dependencies:

pip install huggingface_hub onnxruntime pillow numpy

Run inference with the ONNX model:

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

LABELS = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9",
}

model_path = hf_hub_download(
    repo_id="tsilva/mnist-classifier-mlp",
    filename="model.onnx",
)

image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])

print(prediction, LABELS[prediction])

Labels

MNIST labels:

id label
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9

Files

  • model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
  • model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
  • config.yaml: resolved Hydra training config.
  • metrics.csv: training metrics from the uploaded checkpoint run.
  • metadata.json: compact metadata for inference and provenance.

Limitations

This MLP does not use convolutional inductive bias. It performs strongly on MNIST, but remaining errors are mostly concentrated in ambiguous or unusually written digits.

Downloads last month
40
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tsilva/mnist-classifier-mlp