""" Vision Transformer (ViT) training script for CIFAR-10. Reference: - Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. https://arxiv.org/abs/2010.11929 This script covers: 1) Loading CIFAR-10 2) Resizing images (default: 64x64) 3) Normalizing pixel values to [-1, 1] 4) Creating batched DataLoaders 5) Building a ViT encoder + classification head 6) Training with CrossEntropy + AdamW + LR scheduler 7) Evaluation accuracy + misclassification visualization """ import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" from pathlib import Path from typing import Any, Dict, List, Tuple import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms # --------------------------------------------------------------------------- # Dataset metadata # --------------------------------------------------------------------------- CLASS_NAMES: Tuple[str, ...] = ( "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) def get_cifar10_dataloaders( data_root: str = "./data", image_size: int = 64, batch_size: int = 128, num_workers: int = 2, pin_memory: bool = True, val_ratio: float = 0.1, seed: int = 42, ) -> Tuple[DataLoader, DataLoader, DataLoader]: """ Build CIFAR-10 train/val/test dataloaders with resize + normalization. Uses CIFAR-10's official split: - train=True -> 50,000 images - train=False -> 10,000 images Data source: - https://www.cs.toronto.edu/~kriz/cifar.html Args: data_root: Directory to download/store CIFAR-10. image_size: Target image size after resizing (square). batch_size: Number of samples per batch. num_workers: Number of subprocesses for data loading. pin_memory: Pin memory for faster host-to-device transfer on CUDA. val_ratio: Fraction of official train split reserved for validation. seed: Random seed for deterministic train/val split. Returns: train_loader, val_loader, test_loader """ if not 0.0 < val_ratio < 1.0: raise ValueError("val_ratio must be between 0 and 1.") transform = transforms.Compose( [ transforms.Resize((image_size, image_size)), transforms.ToTensor(), # Scales uint8 [0,255] -> float [0,1] transforms.Normalize( mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), # [0,1] -> [-1,1] ] ) data_root_path = Path(data_root) data_root_path.mkdir(parents=True, exist_ok=True) full_train_dataset = datasets.CIFAR10( root=str(data_root_path), train=True, download=True, transform=transform, ) test_dataset = datasets.CIFAR10( root=str(data_root_path), train=False, download=True, transform=transform, ) # pin_memory is useful only when CUDA is available. use_pin_memory = pin_memory and torch.cuda.is_available() val_size = int(len(full_train_dataset) * val_ratio) train_size = len(full_train_dataset) - val_size generator = torch.Generator().manual_seed(seed) train_dataset, val_dataset = random_split( full_train_dataset, [train_size, val_size], generator=generator ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=use_pin_memory, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=use_pin_memory, ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=use_pin_memory, ) return train_loader, val_loader, test_loader class PatchifyEmbedding(nn.Module): """ Step 2: - Divide image into PxP patches - Flatten each patch - Project flattened patches to hidden dim D """ def __init__( self, image_size: int = 64, patch_size: int = 4, in_channels: int = 3, embed_dim: int = 256, ) -> None: super().__init__() if image_size % patch_size != 0: raise ValueError("image_size must be divisible by patch_size.") self.image_size = image_size self.patch_size = patch_size self.in_channels = in_channels self.embed_dim = embed_dim self.num_patches_per_side = image_size // patch_size self.num_patches = self.num_patches_per_side * self.num_patches_per_side patch_dim = in_channels * patch_size * patch_size self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) self.proj = nn.Linear(patch_dim, embed_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, C, H, W) returns: (B, N, D), where N=num_patches, D=embed_dim """ patches = self.unfold(x) # (B, patch_dim, N) patches = patches.transpose(1, 2) # (B, N, patch_dim) embeddings = self.proj(patches) # (B, N, D) return embeddings class TransformerEncoderBlock(nn.Module): """ Step 4 single block: LayerNorm -> MSA -> residual -> LayerNorm -> MLP -> residual """ def __init__( self, embed_dim: int, num_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0, ) -> None: super().__init__() mlp_hidden_dim = int(embed_dim * mlp_ratio) self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, embed_dim), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # MSA block + residual x_norm = self.norm1(x) attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) x = x + attn_out # MLP block + residual x = x + self.mlp(self.norm2(x)) return x class ViTEncoder(nn.Module): """ Steps 2-4: - Patchify + projection - Learnable CLS token + learnable positional embeddings - Stacked Transformer encoder blocks """ def __init__( self, image_size: int = 64, patch_size: int = 4, in_channels: int = 3, embed_dim: int = 256, depth: int = 6, num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.0, ) -> None: super().__init__() self.patch_embed = PatchifyEmbedding( image_size=image_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches # Step 3: CLS token + positional embedding self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(dropout) self.blocks = nn.ModuleList( [ TransformerEncoderBlock( embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout, ) for _ in range(depth) ] ) self.norm = nn.LayerNorm(embed_dim) self._init_parameters() def _init_parameters(self) -> None: nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, C, H, W) returns: (B, D) CLS representation after encoder """ x = self.patch_embed(x) # (B, N, D) batch_size = x.size(0) cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (B, 1, D) x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, D) x = self.pos_drop(x + self.pos_embed) # add positional information for block in self.blocks: x = block(x) x = self.norm(x) cls_representation = x[:, 0] # (B, D) return cls_representation class ViTClassifier(nn.Module): """ Step 5: - Extract CLS representation from encoder - Map to class logits with a Linear layer """ def __init__( self, image_size: int = 64, patch_size: int = 4, in_channels: int = 3, embed_dim: int = 256, depth: int = 6, num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.1, num_classes: int = 10, ) -> None: super().__init__() self.encoder = ViTEncoder( image_size=image_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout, ) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: cls_features = self.encoder(x) # (B, D) logits = self.head(cls_features) # (B, num_classes) return logits # --------------------------------------------------------------------------- # Training and evaluation helpers # --------------------------------------------------------------------------- def train_one_epoch( model: nn.Module, dataloader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, ) -> Tuple[float, float]: """ Run one optimization epoch over the training set. Args: model: Classifier to optimize. dataloader: Training mini-batches. criterion: Loss function (typically CrossEntropyLoss for CIFAR-10). optimizer: Parameter optimizer (AdamW in this project). device: CPU or CUDA device. Returns: (avg_loss, avg_accuracy) over all training samples in this epoch. """ model.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in dataloader: images = images.to(device) labels = labels.to(device) optimizer.zero_grad() logits = model(images) loss = criterion(logits, labels) loss.backward() optimizer.step() running_loss += loss.item() * images.size(0) preds = logits.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = running_loss / total avg_acc = correct / total return avg_loss, avg_acc @torch.no_grad() def evaluate( model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device, ) -> Tuple[float, float]: """ Evaluate model performance without gradient updates. Args: model: Classifier to evaluate. dataloader: Validation or test mini-batches. criterion: Loss function used for reporting. device: CPU or CUDA device. Returns: (avg_loss, avg_accuracy) over all samples from `dataloader`. """ model.eval() running_loss = 0.0 correct = 0 total = 0 for images, labels in dataloader: images = images.to(device) labels = labels.to(device) logits = model(images) loss = criterion(logits, labels) running_loss += loss.item() * images.size(0) preds = logits.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = running_loss / total avg_acc = correct / total return avg_loss, avg_acc def train_model( model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, num_epochs: int = 10, lr: float = 3e-4, weight_decay: float = 1e-4, save_dir: str = "./saved_model", checkpoint_name: str = "vit_cifar10_best.pt", model_config: Dict[str, Any] | None = None, early_stopping_patience: int = 5, ) -> Tuple[Dict[str, List[float]], str]: """ Step 6: - Loss: CrossEntropy - Optimizer: AdamW - LR scheduler: StepLR decay - Validation each epoch - Early stopping on validation accuracy Hyperparameters: - num_epochs: Max number of epochs before early stopping. - lr: Initial learning rate for AdamW updates. - weight_decay: L2-style regularization term in AdamW. - early_stopping_patience: Number of non-improving epochs allowed. This limits overfitting and unnecessary computation. """ criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) history: Dict[str, List[float]] = { "train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], } best_val_acc = 0.0 epochs_without_improvement = 0 save_dir_path = Path(save_dir) save_dir_path.mkdir(parents=True, exist_ok=True) best_checkpoint_path = str(save_dir_path / checkpoint_name) model.to(device) for epoch in range(num_epochs): train_loss, train_acc = train_one_epoch( model=model, dataloader=train_loader, criterion=criterion, optimizer=optimizer, device=device, ) val_loss, val_acc = evaluate( model=model, dataloader=val_loader, criterion=criterion, device=device, ) scheduler.step() history["train_loss"].append(train_loss) history["train_acc"].append(train_acc) history["val_loss"].append(val_loss) history["val_acc"].append(val_acc) current_lr = optimizer.param_groups[0]["lr"] print( f"Epoch [{epoch + 1}/{num_epochs}] | " f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}% | " f"LR: {current_lr:.6f}" ) if val_acc > best_val_acc: best_val_acc = val_acc epochs_without_improvement = 0 checkpoint = { "epoch": epoch + 1, "best_val_acc": best_val_acc, "model_state_dict": model.state_dict(), "model_config": model_config or {}, } torch.save(checkpoint, best_checkpoint_path) print(f"Saved best checkpoint to: {best_checkpoint_path}") else: epochs_without_improvement += 1 if epochs_without_improvement >= early_stopping_patience: print( "Early stopping triggered " f"(no validation improvement for {early_stopping_patience} epochs)." ) break final_checkpoint_path = str(save_dir_path / "vit_cifar10_last.pt") torch.save( { "epoch": num_epochs, "best_val_acc": best_val_acc, "model_state_dict": model.state_dict(), "model_config": model_config or {}, }, final_checkpoint_path, ) print(f"Saved last checkpoint to: {final_checkpoint_path}") return history, best_checkpoint_path # --------------------------------------------------------------------------- # Error analysis and visualization # --------------------------------------------------------------------------- @torch.no_grad() def collect_misclassified( model: nn.Module, dataloader: DataLoader, device: torch.device, max_samples: int = 16, ) -> List[Tuple[torch.Tensor, int, int]]: """ Step 7 (Error analysis helper): Collect misclassified samples: (image_tensor, true_label, pred_label). """ model.eval() misclassified: List[Tuple[torch.Tensor, int, int]] = [] for images, labels in dataloader: images = images.to(device) labels = labels.to(device) logits = model(images) preds = logits.argmax(dim=1) wrong_mask = preds != labels wrong_images = images[wrong_mask] wrong_labels = labels[wrong_mask] wrong_preds = preds[wrong_mask] for i in range(wrong_images.size(0)): misclassified.append( ( wrong_images[i].detach().cpu(), int(wrong_labels[i].item()), int(wrong_preds[i].item()), ) ) if len(misclassified) >= max_samples: return misclassified return misclassified def denormalize_image(img: torch.Tensor) -> torch.Tensor: """ Convert image from normalized [-1, 1] back to [0, 1] for visualization. """ return (img * 0.5 + 0.5).clamp(0.0, 1.0) def visualize_misclassified( samples: List[Tuple[torch.Tensor, int, int]], class_names: Tuple[str, ...], save_path: str = "misclassified_examples.png", ) -> None: """ Visualize wrongly predicted images and save to disk. """ if len(samples) == 0: print("No misclassified samples to visualize.") return try: import matplotlib.pyplot as plt except ImportError: print("matplotlib is not installed. Skipping visualization.") return n = len(samples) cols = min(4, n) rows = (n + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows)) if rows == 1 and cols == 1: axes = [axes] elif rows == 1 or cols == 1: axes = list(axes) else: axes = axes.flatten() for idx, ax in enumerate(axes): if idx < n: img, true_lbl, pred_lbl = samples[idx] img = denormalize_image(img).permute(1, 2, 0).numpy() ax.imshow(img) ax.set_title(f"True: {class_names[true_lbl]}\nPred: {class_names[pred_lbl]}") ax.axis("off") else: ax.axis("off") fig.tight_layout() fig.savefig(save_path, dpi=150) print(f"Saved misclassified visualization to: {save_path}") if __name__ == "__main__": # ----------------------------------------------------------------------- # Data preprocessing and loader setup # ----------------------------------------------------------------------- train_loader, val_loader, test_loader = get_cifar10_dataloaders( data_root="./data", image_size=64, batch_size=128, val_ratio=0.1, ) train_images, train_labels = next(iter(train_loader)) val_images, val_labels = next(iter(val_loader)) test_images, test_labels = next(iter(test_loader)) print(f"Train batch images shape: {train_images.shape}") print(f"Train batch labels shape: {train_labels.shape}") print(f"Val batch images shape: {val_images.shape}") print(f"Val batch labels shape: {val_labels.shape}") print(f"Test batch images shape: {test_images.shape}") print(f"Test batch labels shape: {test_labels.shape}") print(f"Train dataset size: {len(train_loader.dataset)}") print(f"Val dataset size: {len(val_loader.dataset)}") print(f"Test dataset size: {len(test_loader.dataset)}") print( "Image value range (approx after normalization): " f"[{train_images.min().item():.3f}, {train_images.max().item():.3f}]" ) # ----------------------------------------------------------------------- # Model architecture configuration (key hyperparameters) # ----------------------------------------------------------------------- # image_size=64: upscales CIFAR-10 from 32x32 for a richer patch grid. # patch_size=4: produces (64/4)^2 = 256 image patches per sample. # embed_dim=256: token representation size in attention/MLP blocks. # depth=6, num_heads=8: transformer depth and multi-head attention width. # mlp_ratio=4.0: hidden size in feed-forward layer = 4 * embed_dim. # dropout=0.1: regularization inside encoder blocks. model_kwargs: Dict[str, Any] = { "image_size": 64, "patch_size": 4, "in_channels": 3, "embed_dim": 256, "depth": 6, "num_heads": 8, "mlp_ratio": 4.0, "dropout": 0.1, "num_classes": 10, } model = ViTClassifier( image_size=64, patch_size=4, in_channels=3, embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1, num_classes=10, ) patch_embeddings = model.encoder.patch_embed(train_images) cls_features = model.encoder(train_images) logits = model(train_images) print(f"Patch embeddings shape (B, N, D): {patch_embeddings.shape}") print(f"CLS feature shape (B, D): {cls_features.shape}") print(f"Logits shape (B, num_classes): {logits.shape}") # ----------------------------------------------------------------------- # Training setup hyperparameters # ----------------------------------------------------------------------- # lr=3e-4: base AdamW learning rate. # weight_decay=1e-4: regularization to improve generalization. # num_epochs=10 and early_stopping_patience=5: train up to 10 epochs, # but stop if validation accuracy does not improve for 5 epochs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") history, best_ckpt_path = train_model( model=model, train_loader=train_loader, val_loader=val_loader, device=device, num_epochs=10, lr=3e-4, weight_decay=1e-4, save_dir="./saved_model", checkpoint_name="vit_cifar10_best.pt", model_config=model_kwargs, early_stopping_patience=5, ) final_val_acc = history["val_acc"][-1] * 100 if history["val_acc"] else 0.0 print(f"Final validation accuracy: {final_val_acc:.2f}%") print(f"Best model checkpoint: {best_ckpt_path}") # ----------------------------------------------------------------------- # Final test evaluation and qualitative error analysis # ----------------------------------------------------------------------- # Evaluate on the held-out test set (not used for training/checkpointing). best_checkpoint = torch.load(best_ckpt_path, map_location=device) model.load_state_dict(best_checkpoint["model_state_dict"]) test_criterion = nn.CrossEntropyLoss() test_loss, test_acc = evaluate( model=model, dataloader=test_loader, criterion=test_criterion, device=device, ) print(f"Final test loss (best checkpoint): {test_loss:.4f}") print(f"Final test accuracy (best checkpoint): {test_acc * 100:.2f}%") wrong_samples = collect_misclassified( model=model, dataloader=test_loader, device=device, max_samples=12, ) visualize_misclassified( samples=wrong_samples, class_names=CLASS_NAMES, save_path="./results/misclassified_examples.png", )