uoft-cs/cifar10
Viewer • Updated • 60k • 115k • 106
ONNX CIFAR-10 image classifier trained with WideResNet WRN-28-10 + CutMix in dlab.
| Item | Value |
|---|---|
| Task | Image classification |
| Dataset | CIFAR-10 |
| Model | WRN-28-10 |
| Format | ONNX |
| Input | RGB image, 3 x 32 x 32 |
| Output | 10 CIFAR-10 logits |
| Recipe test accuracy | 0.9525 +/- 0.0062 |
| Uploaded checkpoint | seed 9001, test accuracy 0.9561 |
pip install huggingface_hub onnxruntime pillow numpy torch torchvision
Single-image inference:
import json
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
repo_id = "tsilva/cifar10-classifier-wrn28-10"
model_path = hf_hub_download(repo_id, "model.onnx")
labels_path = hf_hub_download(repo_id, "labels.json")
with open(labels_path) as f:
labels = json.load(f)
mean = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
std = np.array([0.247, 0.243, 0.261], dtype=np.float32)
image = Image.open("image.png").convert("RGB").resize((32, 32))
x = np.asarray(image).astype("float32") / 255.0
x = (x - mean) / std
x = np.transpose(x, (2, 0, 1))[None, ...]
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logits = session.run(["logits"], {"images": x})[0]
pred_id = int(logits.argmax(axis=1)[0])
print(pred_id, labels[str(pred_id)])
This evaluates model.onnx on the official CIFAR-10 test split using the same
preprocessing used during training.
import numpy as np
import onnxruntime as ort
import torch
from huggingface_hub import hf_hub_download
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
repo_id = "tsilva/cifar10-classifier-wrn28-10"
model_path = hf_hub_download(repo_id, "model.onnx")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.247, 0.243, 0.261),
),
]
)
test_set = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=transform,
)
loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
correct = 0
total = 0
loss_sum = 0.0
for images, labels in loader:
logits = session.run(["logits"], {"images": images.numpy().astype(np.float32)})[0]
logits_t = torch.from_numpy(logits)
loss_sum += float(F.cross_entropy(logits_t, labels, reduction="sum"))
preds = logits_t.argmax(dim=1)
correct += int((preds == labels).sum())
total += int(labels.numel())
print(f"test_acc={correct / total:.4f}")
print(f"test_loss={loss_sum / total:.4f}")
Expected result for the uploaded seed 9001 ONNX checkpoint:
test_acc=0.9561
test_loss=0.2244
Checkpoint selection used validation loss. Test metrics were computed after model selection and were not used to choose checkpoints.
Three-seed recipe summary:
| Split | Accuracy mean +/- std | Loss mean +/- std |
|---|---|---|
| Validation | 0.9563 +/- 0.0070 |
0.4189 +/- 0.0160 |
| Test | 0.9525 +/- 0.0062 |
0.2292 +/- 0.0119 |
The uploaded ONNX export is the seed 9001 selected checkpoint from this
WRN-28-10 + CutMix p=0.5 recipe.
imagesN x 3 x 32 x 32float32logitsN x 10float32Preprocessing:
32 x 32 RGB.float32 tensor in CHW layout with values in [0, 1].[0.4914, 0.4822, 0.4465][0.247, 0.243, 0.261]CIFAR-10 labels are stored in labels.json: airplane, automobile, bird, cat,
deer, dog, frog, horse, ship, truck.
This model is a CIFAR-style WideResNet:
3x3, 3 -> 16, stride 1, padding 1.160, stride 1.320, first block stride 2.640, first block stride 2.BatchNorm2d, ReLU, global average pool, linear 640 -> 10.0.0003, weight decay 0.01.128.N=1/M=7.0.05, MixUp alpha=0.2, CutMix
alpha=1.0 on p=0.5 of batches with MixUp fallback on the rest.model.onnx: portable ONNX inference artifact.checkpoint.ckpt: PyTorch Lightning checkpoint from the selected W&B run.config.yaml: resolved training/export configuration.labels.json: CIFAR-10 class mapping.assets/architecture.png: architecture diagram.checkpoints/063.ckptmodel.onnx for portable
inference.