Vision Transformer for CIFAR-10
A Vision Transformer (ViT) model trained from scratch on the CIFAR-10 dataset, achieving 82.08% test accuracy.
Model Description
This model implements the Vision Transformer architecture, which processes images as sequences of patches rather than using convolutional layers. The model splits input images into fixed-size patches, linearly embeds them, adds positional encodings, and processes them through multiple transformer encoder layers.
Architecture Details:
- Image Size: 32x32 (CIFAR-10)
- Patch Size: 4x4
- Number of Patches: 64
- Embedding Dimension: 192
- Attention Heads: 3
- Transformer Layers: 12
- MLP Hidden Size: 768
- Total Parameters: 5.36M
Performance
Overall Metrics:
- Test Accuracy: 82.08%
- Test Loss: 1.0026
Per-Class Accuracy:
| Class | Accuracy |
|---|---|
| Airplane | 85.50% |
| Automobile | 92.80% |
| Bird | 77.70% |
| Cat | 65.90% |
| Deer | 78.40% |
| Dog | 73.10% |
| Frog | 85.50% |
| Horse | 86.60% |
| Ship | 88.80% |
| Truck | 86.50% |
Intended Use
This model is designed for image classification on CIFAR-10 or similar low-resolution datasets. It can be used for:
- Educational purposes to understand Vision Transformer architecture
- Research on transformer-based vision models
- Transfer learning for similar image classification tasks
- Benchmarking and comparison studies
Training Details
Dataset: CIFAR-10
- Training samples: 50,000 images
- Test samples: 10,000 images
- Image dimensions: 32x32x3
- Classes: 10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck)
Training Configuration:
- Optimizer: AdamW
- Learning rate schedule: Cosine annealing with warmup
- Mixed precision training: Enabled
- Label smoothing: 0.1
- Data augmentation: Standard CIFAR-10 augmentations
How to Use
Load the checkpoint using PyTorch:
import torch
from vision_transformers.src.models import VisionTransformer
from vision_transformers.src.config import ViTConfig
# Initialize model
config = ViTConfig()
model = VisionTransformer(config)
# Load checkpoint
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Inference
with torch.no_grad():
logits, _ = model(images)
predictions = torch.argmax(logits, dim=1)
Input: Tensor of shape (batch_size, 3, 32, 32) with values normalized to [0, 1]
Output: Logits of shape (batch_size, 10) for CIFAR-10 classes
Limitations
- Trained specifically on CIFAR-10 low-resolution images (32x32)
- Performance varies significantly by class (65.90% for cats vs 92.80% for automobiles)
- Not suitable for high-resolution images without architecture modifications
- May not generalize well to out-of-distribution images
Technical Specifications
Model Architecture:
- Patch Embedding: Conv2d projection with learnable positional embeddings
- Transformer Encoder: Pre-norm architecture with multi-head self-attention
- Classification Head: Layer normalization followed by linear projection
- Activation: GELU in MLP blocks
- Dropout: 0.1 in both attention and MLP layers
Checkpoint Contents: The .pth file contains:
model_state_dict: Model weightsepoch: Training epoch numberoptimizer_state_dict: Optimizer stateloss: Training lossaccuracy: Validation accuracy
Citation
If you use this model, please cite the original Vision Transformer paper:
@article{dosovitskiy2020vit,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
journal={ICLR},
year={2021}
}
License
MIT License