CIFAR-10 Classifier WRN-28-10

ONNX CIFAR-10 image classifier trained with WideResNet WRN-28-10 + CutMix in dlab.

At a Glance

Item Value
Task Image classification
Dataset CIFAR-10
Model WRN-28-10
Format ONNX
Input RGB image, 3 x 32 x 32
Output 10 CIFAR-10 logits
Recipe test accuracy 0.9525 +/- 0.0062
Uploaded checkpoint seed 9001, test accuracy 0.9561

Quick Start

pip install huggingface_hub onnxruntime pillow numpy torch torchvision

Single-image inference:

import json
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image

repo_id = "tsilva/cifar10-classifier-wrn28-10"
model_path = hf_hub_download(repo_id, "model.onnx")
labels_path = hf_hub_download(repo_id, "labels.json")
with open(labels_path) as f:
    labels = json.load(f)

mean = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
std = np.array([0.247, 0.243, 0.261], dtype=np.float32)

image = Image.open("image.png").convert("RGB").resize((32, 32))
x = np.asarray(image).astype("float32") / 255.0
x = (x - mean) / std
x = np.transpose(x, (2, 0, 1))[None, ...]

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
pred_id = int(logits.argmax(axis=1)[0])
print(pred_id, labels[str(pred_id)])

Validate on CIFAR-10 Test Set

This evaluates model.onnx on the official CIFAR-10 test split using the same preprocessing used during training.

import numpy as np
import onnxruntime as ort
import torch
from huggingface_hub import hf_hub_download
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

repo_id = "tsilva/cifar10-classifier-wrn28-10"
model_path = hf_hub_download(repo_id, "model.onnx")

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.4914, 0.4822, 0.4465),
            std=(0.247, 0.243, 0.261),
        ),
    ]
)
test_set = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform,
)
loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)

session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])

correct = 0
total = 0
loss_sum = 0.0
for images, labels in loader:
    logits = session.run(["logits"], {"images": images.numpy().astype(np.float32)})[0]
    logits_t = torch.from_numpy(logits)
    loss_sum += float(F.cross_entropy(logits_t, labels, reduction="sum"))
    preds = logits_t.argmax(dim=1)
    correct += int((preds == labels).sum())
    total += int(labels.numel())

print(f"test_acc={correct / total:.4f}")
print(f"test_loss={loss_sum / total:.4f}")

Expected result for the uploaded seed 9001 ONNX checkpoint:

test_acc=0.9561
test_loss=0.2244

Results

Checkpoint selection used validation loss. Test metrics were computed after model selection and were not used to choose checkpoints.

Three-seed recipe summary:

Split Accuracy mean +/- std Loss mean +/- std
Validation 0.9563 +/- 0.0070 0.4189 +/- 0.0160
Test 0.9525 +/- 0.0062 0.2292 +/- 0.0119

The uploaded ONNX export is the seed 9001 selected checkpoint from this WRN-28-10 + CutMix p=0.5 recipe.

Input / Output

  • Input name: images
  • Input shape: dynamic batch, N x 3 x 32 x 32
  • Input dtype: float32
  • Output name: logits
  • Output shape: N x 10
  • Output dtype: float32

Preprocessing:

  1. Resize/crop input to 32 x 32 RGB.
  2. Convert to float32 tensor in CHW layout with values in [0, 1].
  3. Normalize with CIFAR-10 statistics:
    • mean: [0.4914, 0.4822, 0.4465]
    • std: [0.247, 0.243, 0.261]

CIFAR-10 labels are stored in labels.json: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.

Architecture

WRN-28-10 architecture

This model is a CIFAR-style WideResNet:

  • Initial convolution: 3x3, 3 -> 16, stride 1, padding 1.
  • Stage 1: 4 pre-activation residual blocks, width 160, stride 1.
  • Stage 2: 4 pre-activation residual blocks, width 320, first block stride 2.
  • Stage 3: 4 pre-activation residual blocks, width 640, first block stride 2.
  • Head: BatchNorm2d, ReLU, global average pool, linear 640 -> 10.

Training Recipe

  • Optimizer: AdamW, learning rate 0.0003, weight decay 0.01.
  • Schedule: cosine.
  • Batch size: 128.
  • Training augmentation: random crop, horizontal flip, medium ColorJitter, RandAugment N=1/M=7.
  • Loss regularization: label smoothing 0.05, MixUp alpha=0.2, CutMix alpha=1.0 on p=0.5 of batches with MixUp fallback on the rest.

Files

  • model.onnx: portable ONNX inference artifact.
  • checkpoint.ckpt: PyTorch Lightning checkpoint from the selected W&B run.
  • config.yaml: resolved training/export configuration.
  • labels.json: CIFAR-10 class mapping.
  • assets/architecture.png: architecture diagram.

Provenance

Limitations

  • Checkpoints were selected using validation loss, not test-set performance.
  • Inputs outside the CIFAR-10 image distribution may be unreliable.
  • The PyTorch checkpoint is code-dependent; use model.onnx for portable inference.
Downloads last month
30
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tsilva/cifar10-classifier-wrn28-10