tsilva's picture
Use ONNX-first model card structure
bfe6018 verified
metadata
license: mit
library_name: onnx
pipeline_tag: image-classification
tags:
  - image-classification
  - mnist
  - convolutional-neural-network
  - onnx
  - onnxruntime
  - pytorch
  - dlab
datasets:
  - mnist
metrics:
  - accuracy

MNIST CNN Classifier

This repository contains a validation-selected MNIST CNN digit classifier trained with dlab.

Architecture

MNIST CNN architecture

Results

5-seed held-out test evaluation:

metric value
test accuracy 99.6140% ± 0.0802 pp
test loss 0.14677 ± 0.00272
best validation loss 0.14282 ± 0.00138

Per-seed held-out test results:

seed W&B run test accuracy test loss best validation loss
1 5um57rnu 99.6100% 0.14772 0.14124
2 23f1frqb 99.6800% 0.14437 0.14317
3 25yaaj1o 99.4800% 0.15107 0.14403
4 3rrlxghp 99.6300% 0.14573 0.14150
5 y51200ov 99.6700% 0.14494 0.14418

The ONNX model was exported from the seed-1 checkpoint, which had the best validation loss in the final 5-seed evaluation sweep. Test metrics were not used for checkpoint selection and were logged in W&B sweep ikfs5ox8.

Model Details

  • Dataset: MNIST
  • Architecture: CNN
  • Channels: [32, 64, 128]
  • Convolutions per stage: 2
  • Batch normalization: enabled
  • Dropout: 0.1
  • Optimizer: Adam
  • Learning rate: 0.001
  • Weight decay: 0.0001
  • Scheduler: OneCycleLR
  • Label smoothing: 0.02
  • Weight averaging: EMA
  • Batch size: 512
  • Training augmentation: random affine rotation/translation/scale
  • Early stopping: validation loss, patience 8, min delta 0.0005
  • Source W&B run: 5um57rnu
  • Source W&B sweep: ikfs5ox8

Input / Output

Use model.onnx for code-independent inference.

  • Input name: images
  • Input shape: [batch, 1, 28, 28]
  • Input dtype: float32
  • Output name: logits
  • Output shape: [batch, 10]

Preprocessing:

  • Convert image to grayscale.
  • Resize to 28 x 28.
  • Scale pixel values to [0, 1].
  • Normalize with mean 0.1307 and standard deviation 0.3081.
  • Arrange the tensor as channels-first [batch, 1, 28, 28].

Usage

Install the runtime dependencies:

pip install huggingface_hub onnxruntime pillow numpy

Run inference with the ONNX model:

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

LABELS = {
    0: "0",
    1: "1",
    2: "2",
    3: "3",
    4: "4",
    5: "5",
    6: "6",
    7: "7",
    8: "8",
    9: "9",
}

model_path = hf_hub_download(
    repo_id="tsilva/mnist-classifier-cnn",
    filename="model.onnx",
)

image = Image.open("example.png").convert("L").resize((28, 28))
x = np.asarray(image, dtype=np.float32) / 255.0
x = (x - 0.1307) / 0.3081
x = x[None, None, :, :].astype(np.float32)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
prediction = int(logits.argmax(axis=1)[0])

print(prediction, LABELS[prediction])

Labels

MNIST labels:

id label
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9

Files

  • model.onnx: ONNX export of the validation-selected checkpoint. Prefer this file for portable inference.
  • model.ckpt: PyTorch Lightning checkpoint for the same model. This is code-dependent and mainly useful for PyTorch-based inspection or continued experimentation.
  • config.yaml: resolved Hydra training config.
  • metrics.csv: training metrics from the uploaded checkpoint run.
  • metrics_summary.csv: compact 5-seed final evaluation summary.
  • metadata.json: compact metadata for inference and provenance.

Limitations

This compact CNN is near MNIST saturation. Remaining errors are expected to be rare and often visually ambiguous or unusually written digits.