siddharth-magesh commited on
Commit
419d77f
·
verified ·
1 Parent(s): 90624be

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +141 -3
README.md CHANGED
@@ -1,3 +1,141 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - uoft-cs/cifar10
5
+ language:
6
+ - en
7
+ metrics:
8
+ - accuracy
9
+ tags:
10
+ - code
11
+ ---
12
+
13
+ # Vision Transformer for CIFAR-10
14
+
15
+ A Vision Transformer (ViT) model trained from scratch on the CIFAR-10 dataset, achieving 82.08% test accuracy.
16
+
17
+ ## Model Description
18
+
19
+ This model implements the Vision Transformer architecture, which processes images as sequences of patches rather than using convolutional layers. The model splits input images into fixed-size patches, linearly embeds them, adds positional encodings, and processes them through multiple transformer encoder layers.
20
+
21
+ **Architecture Details:**
22
+ - Image Size: 32x32 (CIFAR-10)
23
+ - Patch Size: 4x4
24
+ - Number of Patches: 64
25
+ - Embedding Dimension: 192
26
+ - Attention Heads: 3
27
+ - Transformer Layers: 12
28
+ - MLP Hidden Size: 768
29
+ - Total Parameters: 5.36M
30
+
31
+ ## Performance
32
+
33
+ **Overall Metrics:**
34
+ - Test Accuracy: 82.08%
35
+ - Test Loss: 1.0026
36
+
37
+ **Per-Class Accuracy:**
38
+ | Class | Accuracy |
39
+ |------------|----------|
40
+ | Airplane | 85.50% |
41
+ | Automobile | 92.80% |
42
+ | Bird | 77.70% |
43
+ | Cat | 65.90% |
44
+ | Deer | 78.40% |
45
+ | Dog | 73.10% |
46
+ | Frog | 85.50% |
47
+ | Horse | 86.60% |
48
+ | Ship | 88.80% |
49
+ | Truck | 86.50% |
50
+
51
+ ## Intended Use
52
+
53
+ This model is designed for image classification on CIFAR-10 or similar low-resolution datasets. It can be used for:
54
+ - Educational purposes to understand Vision Transformer architecture
55
+ - Research on transformer-based vision models
56
+ - Transfer learning for similar image classification tasks
57
+ - Benchmarking and comparison studies
58
+
59
+ ## Training Details
60
+
61
+ **Dataset:** CIFAR-10
62
+ - Training samples: 50,000 images
63
+ - Test samples: 10,000 images
64
+ - Image dimensions: 32x32x3
65
+ - Classes: 10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck)
66
+
67
+ **Training Configuration:**
68
+ - Optimizer: AdamW
69
+ - Learning rate schedule: Cosine annealing with warmup
70
+ - Mixed precision training: Enabled
71
+ - Label smoothing: 0.1
72
+ - Data augmentation: Standard CIFAR-10 augmentations
73
+
74
+ ## How to Use
75
+
76
+ Load the checkpoint using PyTorch:
77
+
78
+ ```python
79
+ import torch
80
+ from vision_transformers.src.models import VisionTransformer
81
+ from vision_transformers.src.config import ViTConfig
82
+
83
+ # Initialize model
84
+ config = ViTConfig()
85
+ model = VisionTransformer(config)
86
+
87
+ # Load checkpoint
88
+ checkpoint = torch.load('best_model.pth')
89
+ model.load_state_dict(checkpoint['model_state_dict'])
90
+ model.eval()
91
+
92
+ # Inference
93
+ with torch.no_grad():
94
+ logits, _ = model(images)
95
+ predictions = torch.argmax(logits, dim=1)
96
+ ```
97
+
98
+ **Input:** Tensor of shape (batch_size, 3, 32, 32) with values normalized to [0, 1]
99
+
100
+ **Output:** Logits of shape (batch_size, 10) for CIFAR-10 classes
101
+
102
+ ## Limitations
103
+
104
+ - Trained specifically on CIFAR-10 low-resolution images (32x32)
105
+ - Performance varies significantly by class (65.90% for cats vs 92.80% for automobiles)
106
+ - Not suitable for high-resolution images without architecture modifications
107
+ - May not generalize well to out-of-distribution images
108
+
109
+ ## Technical Specifications
110
+
111
+ **Model Architecture:**
112
+ - Patch Embedding: Conv2d projection with learnable positional embeddings
113
+ - Transformer Encoder: Pre-norm architecture with multi-head self-attention
114
+ - Classification Head: Layer normalization followed by linear projection
115
+ - Activation: GELU in MLP blocks
116
+ - Dropout: 0.1 in both attention and MLP layers
117
+
118
+ **Checkpoint Contents:**
119
+ The .pth file contains:
120
+ - `model_state_dict`: Model weights
121
+ - `epoch`: Training epoch number
122
+ - `optimizer_state_dict`: Optimizer state
123
+ - `loss`: Training loss
124
+ - `accuracy`: Validation accuracy
125
+
126
+ ## Citation
127
+
128
+ If you use this model, please cite the original Vision Transformer paper:
129
+
130
+ ```bibtex
131
+ @article{dosovitskiy2020vit,
132
+ title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
133
+ author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
134
+ journal={ICLR},
135
+ year={2021}
136
+ }
137
+ ```
138
+
139
+ ## License
140
+
141
+ MIT License