File size: 8,045 Bytes
a16c07b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | # Assignment Image: Vision Transformer Documentation
## 1) Overview
This documentation explains two scripts:
- `assignment_image/code/c1.py`: end-to-end **training pipeline** for a custom Vision Transformer (ViT) on CIFAR-10.
- `assignment_image/code/c1_test.py`: **evaluation and analysis pipeline** for saved checkpoints, with an optional transfer-learning experiment using a pre-trained torchvision ViT.
Together, these scripts cover:
1. Data preprocessing and DataLoader creation
2. ViT architecture definition
3. Training, validation, checkpointing, and early stopping
4. Final test evaluation
5. Error analysis (per-class accuracy + confusion patterns + misclassified images)
---
## 2) Project Organization
Logical separation in the codebase:
- **Data preprocessing**
- `get_cifar10_dataloaders()` in `c1.py`
- `get_imagenet_style_cifar10_dataloaders()` in `c1_test.py` (for pre-trained ViT)
- **Model architecture**
- `PatchifyEmbedding`, `TransformerEncoderBlock`, `ViTEncoder`, `ViTClassifier` in `c1.py`
- **Training loop**
- `train_one_epoch()`, `train_model()` in `c1.py`
- **Evaluation**
- `evaluate()` in `c1.py`
- `evaluate_model()` in `c1_test.py`
- **Error analysis and visualization**
- `collect_misclassified()`, `visualize_misclassified()` in `c1.py`
- `collect_predictions()`, `build_confusion_matrix()`, `format_error_analysis()` in `c1_test.py`
---
## 3) `c1.py` (Training Script) Documentation
### Purpose
`c1.py` trains a custom ViT classifier on CIFAR-10 and saves:
- best checkpoint by validation accuracy: `vit_cifar10_best.pt`
- final checkpoint after training ends: `vit_cifar10_last.pt`
- optional misclassification visualization image
### Data Pipeline
`get_cifar10_dataloaders()` performs:
- resize CIFAR-10 images to `image_size x image_size` (default `64x64`)
- convert to tensor (`[0, 255] -> [0, 1]`)
- normalize channels from `[0,1]` to `[-1,1]` using mean/std `(0.5, 0.5, 0.5)`
- split official training set into train/validation by `val_ratio`
- build train/val/test DataLoaders with configurable batch size and workers
### Model Architecture
The custom ViT follows standard encoder-style design:
1. **Patchify + Projection**
`PatchifyEmbedding` creates non-overlapping patches and projects each patch to `embed_dim`.
2. **Token + Position Encoding**
`ViTEncoder` prepends a learnable CLS token and adds learnable positional embeddings.
3. **Transformer Blocks**
`TransformerEncoderBlock` applies:
- LayerNorm -> Multi-Head Self-Attention -> Residual
- LayerNorm -> MLP (GELU + Dropout) -> Residual
4. **Classification Head**
`ViTClassifier` extracts CLS representation and maps it to 10 class logits.
### Training and Validation
`train_model()` uses:
- loss: `CrossEntropyLoss`
- optimizer: `AdamW`
- scheduler: `StepLR(step_size=5, gamma=0.5)`
- early stopping: stop when validation accuracy does not improve for `early_stopping_patience` epochs
### Main Outputs
During training:
- epoch-wise train/validation loss, accuracy, and learning rate logs
- checkpoint files saved in `save_dir`
After training:
- final validation summary
- test loss/accuracy using best checkpoint
- optional plot of misclassified examples
---
## 4) `c1_test.py` (Evaluation + Analysis Script) Documentation
### Purpose
`c1_test.py` is a separate script for:
- loading a trained checkpoint
- evaluating on test data
- generating error analysis reports
- optionally running transfer learning with pre-trained ViT-B/16
### Baseline Evaluation Flow
1. Load checkpoint with `load_model_from_checkpoint()`
2. Recreate test DataLoader with same preprocessing used during training
3. Run `evaluate_model()` for test loss and accuracy
4. Collect predictions via `collect_predictions()`
5. Generate:
- per-class accuracy
- top confusion pairs (true -> predicted)
6. Save analysis text report and misclassified image grid
### Optional Transfer-Learning Experiment
When `--run-pretrained-experiment` is enabled:
- build pre-trained `vit_b_16` from torchvision
- replace classification head for 10 CIFAR-10 classes
- preprocess data with ImageNet normalization and `224x224` resize
- fine-tune with `fine_tune_pretrained()`
- evaluate and save separate analysis artifacts
### Baseline vs Pre-trained Comparison (Recorded Result)
From `results/comparison_report.txt`:
| Model | Test Loss | Test Accuracy |
|---|---:|---:|
| Baseline ViT (custom checkpoint) | 0.8916 | 68.57% |
| Pre-trained ViT-B/16 | 0.1495 | 95.15% |
Key comparison metrics:
- Accuracy gain (pre-trained - baseline): **+26.58 percentage points**
- Loss delta (pre-trained - baseline): **-0.7420**
Interpretation: transfer learning with pre-trained ViT-B/16 provides a large performance improvement over the baseline custom-trained ViT in this run.
---
## 5) Hyperparameters and Their Significance
### Core model hyperparameters (`c1.py`)
- `image_size=64`
Upscales CIFAR-10 images from `32x32` to allow richer patch tokenization.
- `patch_size=4`
Number of patches per image becomes `(64/4)^2 = 256`.
- `embed_dim=256`
Dimensionality of token embeddings; larger values increase representation capacity and compute cost.
- `depth=6`
Number of transformer encoder blocks; deeper models can learn more complex patterns but train slower.
- `num_heads=8`
Attention heads per block; controls multi-view attention decomposition.
- `mlp_ratio=4.0`
Hidden size of feed-forward block equals `4 * embed_dim`.
- `dropout=0.1`
Regularization in transformer blocks to reduce overfitting risk.
### Training hyperparameters (`c1.py`)
- `batch_size=128`
Balance between gradient stability, memory use, and throughput.
- `num_epochs=10`
Maximum training epochs before early stopping triggers.
- `lr=3e-4`
Initial learning rate for AdamW.
- `weight_decay=1e-4`
L2-style regularization used by AdamW.
- `early_stopping_patience=5`
Stops training if validation accuracy does not improve for 5 epochs.
- `StepLR(step_size=5, gamma=0.5)`
Learning rate decays by half every 5 epochs.
### Transfer-learning hyperparameters (`c1_test.py`)
- `pretrained_epochs=2` (default)
Short fine-tuning schedule for quick comparison against baseline.
- `lr=1e-4`, `weight_decay=1e-4`
Conservative adaptation from ImageNet features to CIFAR-10.
- ImageNet transform: `Resize(224,224)` + ImageNet mean/std
Matches input assumptions of pre-trained ViT-B/16.
---
## 6) CLI Usage
### Train custom ViT
From `assignment_image/code`:
```bash
python c1.py
```
### Evaluate custom checkpoint
```bash
python c1_test.py --checkpoint-path /path/to/vit_cifar10_best.pt
```
### Evaluate + run pre-trained ViT transfer experiment
```bash
python c1_test.py \
--checkpoint-path /path/to/vit_cifar10_best.pt \
--run-pretrained-experiment \
--pretrained-epochs 2
```
---
## 7) Generated Artifacts
Common artifacts produced by the scripts:
- `saved_model/vit_cifar10_best.pt`
- `saved_model/vit_cifar10_last.pt`
- `misclassified_examples.png` (training script visualization)
- `results/baseline_analysis.txt`
- `results/misclassified_examples_test.png`
- `results/pretrained_vit_analysis.txt` (if transfer experiment runs)
- `results/misclassified_examples_pretrained_vit.png` (if transfer experiment runs)
---
## 8) Notes and Best Practices
- Keep training and evaluation preprocessing consistent when testing custom checkpoints.
- Do not use test set for model selection; use validation split for checkpoint selection.
- Use error analysis outputs (per-class and confusion pairs) to guide augmentation or architecture tuning.
- If GPU memory is limited, reduce `batch_size` or `image_size`.
---
## 9) References
- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. (2021). *An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale*. ICLR 2021. https://arxiv.org/abs/2010.11929
|