Vision Transformer for CIFAR-10

A Vision Transformer (ViT) model trained from scratch on the CIFAR-10 dataset, achieving 82.08% test accuracy.

Model Description

This model implements the Vision Transformer architecture, which processes images as sequences of patches rather than using convolutional layers. The model splits input images into fixed-size patches, linearly embeds them, adds positional encodings, and processes them through multiple transformer encoder layers.

Architecture Details:

  • Image Size: 32x32 (CIFAR-10)
  • Patch Size: 4x4
  • Number of Patches: 64
  • Embedding Dimension: 192
  • Attention Heads: 3
  • Transformer Layers: 12
  • MLP Hidden Size: 768
  • Total Parameters: 5.36M

Performance

Overall Metrics:

  • Test Accuracy: 82.08%
  • Test Loss: 1.0026

Per-Class Accuracy:

Class Accuracy
Airplane 85.50%
Automobile 92.80%
Bird 77.70%
Cat 65.90%
Deer 78.40%
Dog 73.10%
Frog 85.50%
Horse 86.60%
Ship 88.80%
Truck 86.50%

Intended Use

This model is designed for image classification on CIFAR-10 or similar low-resolution datasets. It can be used for:

  • Educational purposes to understand Vision Transformer architecture
  • Research on transformer-based vision models
  • Transfer learning for similar image classification tasks
  • Benchmarking and comparison studies

Training Details

Dataset: CIFAR-10

  • Training samples: 50,000 images
  • Test samples: 10,000 images
  • Image dimensions: 32x32x3
  • Classes: 10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck)

Training Configuration:

  • Optimizer: AdamW
  • Learning rate schedule: Cosine annealing with warmup
  • Mixed precision training: Enabled
  • Label smoothing: 0.1
  • Data augmentation: Standard CIFAR-10 augmentations

How to Use

Load the checkpoint using PyTorch:

import torch
from vision_transformers.src.models import VisionTransformer
from vision_transformers.src.config import ViTConfig

# Initialize model
config = ViTConfig()
model = VisionTransformer(config)

# Load checkpoint
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Inference
with torch.no_grad():
    logits, _ = model(images)
    predictions = torch.argmax(logits, dim=1)

Input: Tensor of shape (batch_size, 3, 32, 32) with values normalized to [0, 1]

Output: Logits of shape (batch_size, 10) for CIFAR-10 classes

Limitations

  • Trained specifically on CIFAR-10 low-resolution images (32x32)
  • Performance varies significantly by class (65.90% for cats vs 92.80% for automobiles)
  • Not suitable for high-resolution images without architecture modifications
  • May not generalize well to out-of-distribution images

Technical Specifications

Model Architecture:

  • Patch Embedding: Conv2d projection with learnable positional embeddings
  • Transformer Encoder: Pre-norm architecture with multi-head self-attention
  • Classification Head: Layer normalization followed by linear projection
  • Activation: GELU in MLP blocks
  • Dropout: 0.1 in both attention and MLP layers

Checkpoint Contents: The .pth file contains:

  • model_state_dict: Model weights
  • epoch: Training epoch number
  • optimizer_state_dict: Optimizer state
  • loss: Training loss
  • accuracy: Validation accuracy

Citation

If you use this model, please cite the original Vision Transformer paper:

@article{dosovitskiy2020vit,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
  journal={ICLR},
  year={2021}
}

License

MIT License

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train siddharth-magesh/VisionTransformers-CIFAR10