| """ |
| Vision Transformer (ViT) training script for CIFAR-10. |
| |
| Reference: |
| - Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: |
| Transformers for Image Recognition at Scale. ICLR 2021. |
| https://arxiv.org/abs/2010.11929 |
| |
| This script covers: |
| 1) Loading CIFAR-10 |
| 2) Resizing images (default: 64x64) |
| 3) Normalizing pixel values to [-1, 1] |
| 4) Creating batched DataLoaders |
| 5) Building a ViT encoder + classification head |
| 6) Training with CrossEntropy + AdamW + LR scheduler |
| 7) Evaluation accuracy + misclassification visualization |
| """ |
| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| from pathlib import Path |
| from typing import Any, Dict, List, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, random_split |
| from torchvision import datasets, transforms |
|
|
| |
| |
| |
| CLASS_NAMES: Tuple[str, ...] = ( |
| "airplane", |
| "automobile", |
| "bird", |
| "cat", |
| "deer", |
| "dog", |
| "frog", |
| "horse", |
| "ship", |
| "truck", |
| ) |
|
|
|
|
| def get_cifar10_dataloaders( |
| data_root: str = "./data", |
| image_size: int = 64, |
| batch_size: int = 128, |
| num_workers: int = 2, |
| pin_memory: bool = True, |
| val_ratio: float = 0.1, |
| seed: int = 42, |
| ) -> Tuple[DataLoader, DataLoader, DataLoader]: |
| """ |
| Build CIFAR-10 train/val/test dataloaders with resize + normalization. |
| Uses CIFAR-10's official split: |
| - train=True -> 50,000 images |
| - train=False -> 10,000 images |
| |
| Data source: |
| - https://www.cs.toronto.edu/~kriz/cifar.html |
| |
| Args: |
| data_root: Directory to download/store CIFAR-10. |
| image_size: Target image size after resizing (square). |
| batch_size: Number of samples per batch. |
| num_workers: Number of subprocesses for data loading. |
| pin_memory: Pin memory for faster host-to-device transfer on CUDA. |
| val_ratio: Fraction of official train split reserved for validation. |
| seed: Random seed for deterministic train/val split. |
| |
| Returns: |
| train_loader, val_loader, test_loader |
| """ |
| if not 0.0 < val_ratio < 1.0: |
| raise ValueError("val_ratio must be between 0 and 1.") |
|
|
| transform = transforms.Compose( |
| [ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) |
| ), |
| ] |
| ) |
|
|
| data_root_path = Path(data_root) |
| data_root_path.mkdir(parents=True, exist_ok=True) |
|
|
| full_train_dataset = datasets.CIFAR10( |
| root=str(data_root_path), |
| train=True, |
| download=True, |
| transform=transform, |
| ) |
| test_dataset = datasets.CIFAR10( |
| root=str(data_root_path), |
| train=False, |
| download=True, |
| transform=transform, |
| ) |
|
|
| |
| use_pin_memory = pin_memory and torch.cuda.is_available() |
|
|
| val_size = int(len(full_train_dataset) * val_ratio) |
| train_size = len(full_train_dataset) - val_size |
| generator = torch.Generator().manual_seed(seed) |
| train_dataset, val_dataset = random_split( |
| full_train_dataset, [train_size, val_size], generator=generator |
| ) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| pin_memory=use_pin_memory, |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=use_pin_memory, |
| ) |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=use_pin_memory, |
| ) |
|
|
| return train_loader, val_loader, test_loader |
|
|
|
|
| class PatchifyEmbedding(nn.Module): |
| """ |
| Step 2: |
| - Divide image into PxP patches |
| - Flatten each patch |
| - Project flattened patches to hidden dim D |
| """ |
|
|
| def __init__( |
| self, |
| image_size: int = 64, |
| patch_size: int = 4, |
| in_channels: int = 3, |
| embed_dim: int = 256, |
| ) -> None: |
| super().__init__() |
| if image_size % patch_size != 0: |
| raise ValueError("image_size must be divisible by patch_size.") |
|
|
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.in_channels = in_channels |
| self.embed_dim = embed_dim |
|
|
| self.num_patches_per_side = image_size // patch_size |
| self.num_patches = self.num_patches_per_side * self.num_patches_per_side |
| patch_dim = in_channels * patch_size * patch_size |
|
|
| self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) |
| self.proj = nn.Linear(patch_dim, embed_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (B, C, H, W) |
| returns: (B, N, D), where N=num_patches, D=embed_dim |
| """ |
| patches = self.unfold(x) |
| patches = patches.transpose(1, 2) |
| embeddings = self.proj(patches) |
| return embeddings |
|
|
|
|
| class TransformerEncoderBlock(nn.Module): |
| """ |
| Step 4 single block: |
| LayerNorm -> MSA -> residual -> LayerNorm -> MLP -> residual |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| mlp_hidden_dim = int(embed_dim * mlp_ratio) |
|
|
| self.norm1 = nn.LayerNorm(embed_dim) |
| self.attn = nn.MultiheadAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ) |
| self.norm2 = nn.LayerNorm(embed_dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(embed_dim, mlp_hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_hidden_dim, embed_dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x_norm = self.norm1(x) |
| attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) |
| x = x + attn_out |
|
|
| |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class ViTEncoder(nn.Module): |
| """ |
| Steps 2-4: |
| - Patchify + projection |
| - Learnable CLS token + learnable positional embeddings |
| - Stacked Transformer encoder blocks |
| """ |
|
|
| def __init__( |
| self, |
| image_size: int = 64, |
| patch_size: int = 4, |
| in_channels: int = 3, |
| embed_dim: int = 256, |
| depth: int = 6, |
| num_heads: int = 8, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.patch_embed = PatchifyEmbedding( |
| image_size=image_size, |
| patch_size=patch_size, |
| in_channels=in_channels, |
| embed_dim=embed_dim, |
| ) |
| num_patches = self.patch_embed.num_patches |
|
|
| |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
| self.pos_drop = nn.Dropout(dropout) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| TransformerEncoderBlock( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| dropout=dropout, |
| ) |
| for _ in range(depth) |
| ] |
| ) |
| self.norm = nn.LayerNorm(embed_dim) |
|
|
| self._init_parameters() |
|
|
| def _init_parameters(self) -> None: |
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (B, C, H, W) |
| returns: (B, D) CLS representation after encoder |
| """ |
| x = self.patch_embed(x) |
| batch_size = x.size(0) |
|
|
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
| x = self.pos_drop(x + self.pos_embed) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.norm(x) |
| cls_representation = x[:, 0] |
| return cls_representation |
|
|
|
|
| class ViTClassifier(nn.Module): |
| """ |
| Step 5: |
| - Extract CLS representation from encoder |
| - Map to class logits with a Linear layer |
| """ |
|
|
| def __init__( |
| self, |
| image_size: int = 64, |
| patch_size: int = 4, |
| in_channels: int = 3, |
| embed_dim: int = 256, |
| depth: int = 6, |
| num_heads: int = 8, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.1, |
| num_classes: int = 10, |
| ) -> None: |
| super().__init__() |
| self.encoder = ViTEncoder( |
| image_size=image_size, |
| patch_size=patch_size, |
| in_channels=in_channels, |
| embed_dim=embed_dim, |
| depth=depth, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| dropout=dropout, |
| ) |
| self.head = nn.Linear(embed_dim, num_classes) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| cls_features = self.encoder(x) |
| logits = self.head(cls_features) |
| return logits |
|
|
|
|
| |
| |
| |
| def train_one_epoch( |
| model: nn.Module, |
| dataloader: DataLoader, |
| criterion: nn.Module, |
| optimizer: torch.optim.Optimizer, |
| device: torch.device, |
| ) -> Tuple[float, float]: |
| """ |
| Run one optimization epoch over the training set. |
| |
| Args: |
| model: Classifier to optimize. |
| dataloader: Training mini-batches. |
| criterion: Loss function (typically CrossEntropyLoss for CIFAR-10). |
| optimizer: Parameter optimizer (AdamW in this project). |
| device: CPU or CUDA device. |
| |
| Returns: |
| (avg_loss, avg_accuracy) over all training samples in this epoch. |
| """ |
| model.train() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| for images, labels in dataloader: |
| images = images.to(device) |
| labels = labels.to(device) |
|
|
| optimizer.zero_grad() |
| logits = model(images) |
| loss = criterion(logits, labels) |
| loss.backward() |
| optimizer.step() |
|
|
| running_loss += loss.item() * images.size(0) |
| preds = logits.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| avg_loss = running_loss / total |
| avg_acc = correct / total |
| return avg_loss, avg_acc |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| model: nn.Module, |
| dataloader: DataLoader, |
| criterion: nn.Module, |
| device: torch.device, |
| ) -> Tuple[float, float]: |
| """ |
| Evaluate model performance without gradient updates. |
| |
| Args: |
| model: Classifier to evaluate. |
| dataloader: Validation or test mini-batches. |
| criterion: Loss function used for reporting. |
| device: CPU or CUDA device. |
| |
| Returns: |
| (avg_loss, avg_accuracy) over all samples from `dataloader`. |
| """ |
| model.eval() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| for images, labels in dataloader: |
| images = images.to(device) |
| labels = labels.to(device) |
|
|
| logits = model(images) |
| loss = criterion(logits, labels) |
|
|
| running_loss += loss.item() * images.size(0) |
| preds = logits.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| avg_loss = running_loss / total |
| avg_acc = correct / total |
| return avg_loss, avg_acc |
|
|
|
|
| def train_model( |
| model: nn.Module, |
| train_loader: DataLoader, |
| val_loader: DataLoader, |
| device: torch.device, |
| num_epochs: int = 10, |
| lr: float = 3e-4, |
| weight_decay: float = 1e-4, |
| save_dir: str = "./saved_model", |
| checkpoint_name: str = "vit_cifar10_best.pt", |
| model_config: Dict[str, Any] | None = None, |
| early_stopping_patience: int = 5, |
| ) -> Tuple[Dict[str, List[float]], str]: |
| """ |
| Step 6: |
| - Loss: CrossEntropy |
| - Optimizer: AdamW |
| - LR scheduler: StepLR decay |
| - Validation each epoch |
| - Early stopping on validation accuracy |
| |
| Hyperparameters: |
| - num_epochs: Max number of epochs before early stopping. |
| - lr: Initial learning rate for AdamW updates. |
| - weight_decay: L2-style regularization term in AdamW. |
| - early_stopping_patience: Number of non-improving epochs allowed. |
| This limits overfitting and unnecessary computation. |
| """ |
| criterion = nn.CrossEntropyLoss() |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) |
|
|
| history: Dict[str, List[float]] = { |
| "train_loss": [], |
| "train_acc": [], |
| "val_loss": [], |
| "val_acc": [], |
| } |
| best_val_acc = 0.0 |
| epochs_without_improvement = 0 |
| save_dir_path = Path(save_dir) |
| save_dir_path.mkdir(parents=True, exist_ok=True) |
| best_checkpoint_path = str(save_dir_path / checkpoint_name) |
|
|
| model.to(device) |
|
|
| for epoch in range(num_epochs): |
| train_loss, train_acc = train_one_epoch( |
| model=model, |
| dataloader=train_loader, |
| criterion=criterion, |
| optimizer=optimizer, |
| device=device, |
| ) |
| val_loss, val_acc = evaluate( |
| model=model, |
| dataloader=val_loader, |
| criterion=criterion, |
| device=device, |
| ) |
| scheduler.step() |
|
|
| history["train_loss"].append(train_loss) |
| history["train_acc"].append(train_acc) |
| history["val_loss"].append(val_loss) |
| history["val_acc"].append(val_acc) |
|
|
| current_lr = optimizer.param_groups[0]["lr"] |
| print( |
| f"Epoch [{epoch + 1}/{num_epochs}] | " |
| f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | " |
| f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}% | " |
| f"LR: {current_lr:.6f}" |
| ) |
|
|
| if val_acc > best_val_acc: |
| best_val_acc = val_acc |
| epochs_without_improvement = 0 |
| checkpoint = { |
| "epoch": epoch + 1, |
| "best_val_acc": best_val_acc, |
| "model_state_dict": model.state_dict(), |
| "model_config": model_config or {}, |
| } |
| torch.save(checkpoint, best_checkpoint_path) |
| print(f"Saved best checkpoint to: {best_checkpoint_path}") |
| else: |
| epochs_without_improvement += 1 |
| if epochs_without_improvement >= early_stopping_patience: |
| print( |
| "Early stopping triggered " |
| f"(no validation improvement for {early_stopping_patience} epochs)." |
| ) |
| break |
|
|
| final_checkpoint_path = str(save_dir_path / "vit_cifar10_last.pt") |
| torch.save( |
| { |
| "epoch": num_epochs, |
| "best_val_acc": best_val_acc, |
| "model_state_dict": model.state_dict(), |
| "model_config": model_config or {}, |
| }, |
| final_checkpoint_path, |
| ) |
| print(f"Saved last checkpoint to: {final_checkpoint_path}") |
|
|
| return history, best_checkpoint_path |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def collect_misclassified( |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| max_samples: int = 16, |
| ) -> List[Tuple[torch.Tensor, int, int]]: |
| """ |
| Step 7 (Error analysis helper): |
| Collect misclassified samples: (image_tensor, true_label, pred_label). |
| """ |
| model.eval() |
| misclassified: List[Tuple[torch.Tensor, int, int]] = [] |
|
|
| for images, labels in dataloader: |
| images = images.to(device) |
| labels = labels.to(device) |
| logits = model(images) |
| preds = logits.argmax(dim=1) |
| wrong_mask = preds != labels |
|
|
| wrong_images = images[wrong_mask] |
| wrong_labels = labels[wrong_mask] |
| wrong_preds = preds[wrong_mask] |
|
|
| for i in range(wrong_images.size(0)): |
| misclassified.append( |
| ( |
| wrong_images[i].detach().cpu(), |
| int(wrong_labels[i].item()), |
| int(wrong_preds[i].item()), |
| ) |
| ) |
| if len(misclassified) >= max_samples: |
| return misclassified |
|
|
| return misclassified |
|
|
|
|
| def denormalize_image(img: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert image from normalized [-1, 1] back to [0, 1] for visualization. |
| """ |
| return (img * 0.5 + 0.5).clamp(0.0, 1.0) |
|
|
|
|
| def visualize_misclassified( |
| samples: List[Tuple[torch.Tensor, int, int]], |
| class_names: Tuple[str, ...], |
| save_path: str = "misclassified_examples.png", |
| ) -> None: |
| """ |
| Visualize wrongly predicted images and save to disk. |
| """ |
| if len(samples) == 0: |
| print("No misclassified samples to visualize.") |
| return |
|
|
| try: |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("matplotlib is not installed. Skipping visualization.") |
| return |
|
|
| n = len(samples) |
| cols = min(4, n) |
| rows = (n + cols - 1) // cols |
| fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows)) |
|
|
| if rows == 1 and cols == 1: |
| axes = [axes] |
| elif rows == 1 or cols == 1: |
| axes = list(axes) |
| else: |
| axes = axes.flatten() |
|
|
| for idx, ax in enumerate(axes): |
| if idx < n: |
| img, true_lbl, pred_lbl = samples[idx] |
| img = denormalize_image(img).permute(1, 2, 0).numpy() |
| ax.imshow(img) |
| ax.set_title(f"True: {class_names[true_lbl]}\nPred: {class_names[pred_lbl]}") |
| ax.axis("off") |
| else: |
| ax.axis("off") |
|
|
| fig.tight_layout() |
| fig.savefig(save_path, dpi=150) |
| print(f"Saved misclassified visualization to: {save_path}") |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| train_loader, val_loader, test_loader = get_cifar10_dataloaders( |
| data_root="./data", |
| image_size=64, |
| batch_size=128, |
| val_ratio=0.1, |
| ) |
|
|
| train_images, train_labels = next(iter(train_loader)) |
| val_images, val_labels = next(iter(val_loader)) |
| test_images, test_labels = next(iter(test_loader)) |
|
|
| print(f"Train batch images shape: {train_images.shape}") |
| print(f"Train batch labels shape: {train_labels.shape}") |
| print(f"Val batch images shape: {val_images.shape}") |
| print(f"Val batch labels shape: {val_labels.shape}") |
| print(f"Test batch images shape: {test_images.shape}") |
| print(f"Test batch labels shape: {test_labels.shape}") |
| print(f"Train dataset size: {len(train_loader.dataset)}") |
| print(f"Val dataset size: {len(val_loader.dataset)}") |
| print(f"Test dataset size: {len(test_loader.dataset)}") |
| print( |
| "Image value range (approx after normalization): " |
| f"[{train_images.min().item():.3f}, {train_images.max().item():.3f}]" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| model_kwargs: Dict[str, Any] = { |
| "image_size": 64, |
| "patch_size": 4, |
| "in_channels": 3, |
| "embed_dim": 256, |
| "depth": 6, |
| "num_heads": 8, |
| "mlp_ratio": 4.0, |
| "dropout": 0.1, |
| "num_classes": 10, |
| } |
| model = ViTClassifier( |
| image_size=64, |
| patch_size=4, |
| in_channels=3, |
| embed_dim=256, |
| depth=6, |
| num_heads=8, |
| mlp_ratio=4.0, |
| dropout=0.1, |
| num_classes=10, |
| ) |
|
|
| patch_embeddings = model.encoder.patch_embed(train_images) |
| cls_features = model.encoder(train_images) |
| logits = model(train_images) |
|
|
| print(f"Patch embeddings shape (B, N, D): {patch_embeddings.shape}") |
| print(f"CLS feature shape (B, D): {cls_features.shape}") |
| print(f"Logits shape (B, num_classes): {logits.shape}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| history, best_ckpt_path = train_model( |
| model=model, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| device=device, |
| num_epochs=10, |
| lr=3e-4, |
| weight_decay=1e-4, |
| save_dir="./saved_model", |
| checkpoint_name="vit_cifar10_best.pt", |
| model_config=model_kwargs, |
| early_stopping_patience=5, |
| ) |
|
|
| final_val_acc = history["val_acc"][-1] * 100 if history["val_acc"] else 0.0 |
| print(f"Final validation accuracy: {final_val_acc:.2f}%") |
| print(f"Best model checkpoint: {best_ckpt_path}") |
|
|
| |
| |
| |
| |
| best_checkpoint = torch.load(best_ckpt_path, map_location=device) |
| model.load_state_dict(best_checkpoint["model_state_dict"]) |
| test_criterion = nn.CrossEntropyLoss() |
| test_loss, test_acc = evaluate( |
| model=model, |
| dataloader=test_loader, |
| criterion=test_criterion, |
| device=device, |
| ) |
| print(f"Final test loss (best checkpoint): {test_loss:.4f}") |
| print(f"Final test accuracy (best checkpoint): {test_acc * 100:.2f}%") |
|
|
| wrong_samples = collect_misclassified( |
| model=model, |
| dataloader=test_loader, |
| device=device, |
| max_samples=12, |
| ) |
| visualize_misclassified( |
| samples=wrong_samples, |
| class_names=CLASS_NAMES, |
| save_path="./results/misclassified_examples.png", |
| ) |
|
|