resnet20-cifar10-272k

Model description

This model is a ResNet20 trained from scratch on the CIFAR-10 dataset. The architecture follows the original CIFAR‑ResNet design (3×3 convolutions, basic residual blocks, 272474 parameters). Training was performed on a consumer AMD Radeon RX 6600 GPU using PyTorch with ROCm support.

Author: sapbot from Romarchive

Intended uses

  • Image classification on 10 everyday object categories.
  • Educational purposes – shows that small models can achieve competitive accuracy without large‑scale pre‑training.

Training procedure

  • Optimizer: SGD with momentum 0.9 and weight decay 1e‑4.
  • Learning rate: 0.1, dropped by factor 0.1 at epochs 80 and 120 (MultiStepLR).
  • Batch size: 128.
  • Epochs: up to 160 (early stopping with patience 10).
  • Data augmentation: random horizontal flip, random crop with 4‑pixel padding.
  • Input normalization: mean (0.4914, 0.4822, 0.4465), std (0.2023, 0.1994, 0.2010).

Evaluation results

Metric Value
Test accuracy 86.29%
Test loss 0.4451
Full CIFAR-10 accuracy 90.17%

Per‑class performance

Class Precision Recall F1-score Support
0 (airplane) 0.7959 0.9320 0.8586 1000
1 (automobile) 0.9023 0.9510 0.9260 1000
2 (bird) 0.8176 0.8250 0.8213 1000
3 (cat) 0.7883 0.7000 0.7415 1000
4 (deer) 0.8010 0.9180 0.8555 1000
5 (dog) 0.8423 0.7850 0.8126 1000
6 (frog) 0.9310 0.8630 0.8957 1000
7 (horse) 0.9201 0.8750 0.8970 1000
8 (ship) 0.9622 0.8650 0.9110 1000
9 (truck) 0.8944 0.9150 0.9046 1000
accuracy 0.8629 10000
macro avg 0.8655 0.8629 0.8624 10000
weighted avg 0.8655 0.8629 0.8624 10000

Confusion matrix (row = true label, col = predicted label)

True \ Pred Pred 0 Pred 1 Pred 2 Pred 3 Pred 4 Pred 5 Pred 6 Pred 7 Pred 8 Pred 9
True 0 932 6 18 6 3 1 0 3 14 17
True 1 14 951 0 1 0 0 1 0 2 31
True 2 51 2 825 27 44 15 16 13 3 4
True 3 30 6 45 700 54 102 29 17 5 12
True 4 10 2 23 18 918 4 12 11 1 1
True 5 14 3 31 82 47 785 4 25 4 5
True 6 16 2 53 35 17 7 863 5 1 1
True 7 16 3 10 9 62 16 1 875 0 8
True 8 71 18 4 7 1 2 1 2 865 29
True 9 17 61 0 3 0 0 0 0 4 915

Model size

272,474 trainable parameters.

How to use

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# --------------------------------------------------
# ResNet20 for CIFAR-10 (exact same architecture)
# --------------------------------------------------
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def ResNet20():
    return ResNet(BasicBlock, [3, 3, 3])

# --------------------------------------------------
# Load the model weights
# --------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet20().to(device)
model.load_state_dict(torch.load("best_resnet20_cifar10.pth", map_location=device))
model.eval()

# --------------------------------------------------
# Preprocess an image (must be 32x32 RGB)
# --------------------------------------------------
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

def predict(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        _, predicted = output.max(1)
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    return classes[predicted.item()]

# Example usage:
# print(predict("my_cat.jpg"))

Acknowledgements

  • Original CIFAR‑10 dataset by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.
  • PyTorch team for the framework and ROCm support.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train sapbot/resnet20-cifar10-272k

Collection including sapbot/resnet20-cifar10-272k