shahidul034's picture
Add files using upload-large-folder tool
a16c07b verified
"""
Vision Transformer (ViT) training script for CIFAR-10.
Reference:
- Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
Transformers for Image Recognition at Scale. ICLR 2021.
https://arxiv.org/abs/2010.11929
This script covers:
1) Loading CIFAR-10
2) Resizing images (default: 64x64)
3) Normalizing pixel values to [-1, 1]
4) Creating batched DataLoaders
5) Building a ViT encoder + classification head
6) Training with CrossEntropy + AdamW + LR scheduler
7) Evaluation accuracy + misclassification visualization
"""
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
# ---------------------------------------------------------------------------
# Dataset metadata
# ---------------------------------------------------------------------------
CLASS_NAMES: Tuple[str, ...] = (
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
def get_cifar10_dataloaders(
data_root: str = "./data",
image_size: int = 64,
batch_size: int = 128,
num_workers: int = 2,
pin_memory: bool = True,
val_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Build CIFAR-10 train/val/test dataloaders with resize + normalization.
Uses CIFAR-10's official split:
- train=True -> 50,000 images
- train=False -> 10,000 images
Data source:
- https://www.cs.toronto.edu/~kriz/cifar.html
Args:
data_root: Directory to download/store CIFAR-10.
image_size: Target image size after resizing (square).
batch_size: Number of samples per batch.
num_workers: Number of subprocesses for data loading.
pin_memory: Pin memory for faster host-to-device transfer on CUDA.
val_ratio: Fraction of official train split reserved for validation.
seed: Random seed for deterministic train/val split.
Returns:
train_loader, val_loader, test_loader
"""
if not 0.0 < val_ratio < 1.0:
raise ValueError("val_ratio must be between 0 and 1.")
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(), # Scales uint8 [0,255] -> float [0,1]
transforms.Normalize(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
), # [0,1] -> [-1,1]
]
)
data_root_path = Path(data_root)
data_root_path.mkdir(parents=True, exist_ok=True)
full_train_dataset = datasets.CIFAR10(
root=str(data_root_path),
train=True,
download=True,
transform=transform,
)
test_dataset = datasets.CIFAR10(
root=str(data_root_path),
train=False,
download=True,
transform=transform,
)
# pin_memory is useful only when CUDA is available.
use_pin_memory = pin_memory and torch.cuda.is_available()
val_size = int(len(full_train_dataset) * val_ratio)
train_size = len(full_train_dataset) - val_size
generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset = random_split(
full_train_dataset, [train_size, val_size], generator=generator
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=use_pin_memory,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=use_pin_memory,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=use_pin_memory,
)
return train_loader, val_loader, test_loader
class PatchifyEmbedding(nn.Module):
"""
Step 2:
- Divide image into PxP patches
- Flatten each patch
- Project flattened patches to hidden dim D
"""
def __init__(
self,
image_size: int = 64,
patch_size: int = 4,
in_channels: int = 3,
embed_dim: int = 256,
) -> None:
super().__init__()
if image_size % patch_size != 0:
raise ValueError("image_size must be divisible by patch_size.")
self.image_size = image_size
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.num_patches_per_side = image_size // patch_size
self.num_patches = self.num_patches_per_side * self.num_patches_per_side
patch_dim = in_channels * patch_size * patch_size
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
self.proj = nn.Linear(patch_dim, embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, C, H, W)
returns: (B, N, D), where N=num_patches, D=embed_dim
"""
patches = self.unfold(x) # (B, patch_dim, N)
patches = patches.transpose(1, 2) # (B, N, patch_dim)
embeddings = self.proj(patches) # (B, N, D)
return embeddings
class TransformerEncoderBlock(nn.Module):
"""
Step 4 single block:
LayerNorm -> MSA -> residual -> LayerNorm -> MLP -> residual
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
) -> None:
super().__init__()
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# MSA block + residual
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
x = x + attn_out
# MLP block + residual
x = x + self.mlp(self.norm2(x))
return x
class ViTEncoder(nn.Module):
"""
Steps 2-4:
- Patchify + projection
- Learnable CLS token + learnable positional embeddings
- Stacked Transformer encoder blocks
"""
def __init__(
self,
image_size: int = 64,
patch_size: int = 4,
in_channels: int = 3,
embed_dim: int = 256,
depth: int = 6,
num_heads: int = 8,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
) -> None:
super().__init__()
self.patch_embed = PatchifyEmbedding(
image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
# Step 3: CLS token + positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
TransformerEncoderBlock(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
)
for _ in range(depth)
]
)
self.norm = nn.LayerNorm(embed_dim)
self._init_parameters()
def _init_parameters(self) -> None:
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, C, H, W)
returns: (B, D) CLS representation after encoder
"""
x = self.patch_embed(x) # (B, N, D)
batch_size = x.size(0)
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (B, 1, D)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, D)
x = self.pos_drop(x + self.pos_embed) # add positional information
for block in self.blocks:
x = block(x)
x = self.norm(x)
cls_representation = x[:, 0] # (B, D)
return cls_representation
class ViTClassifier(nn.Module):
"""
Step 5:
- Extract CLS representation from encoder
- Map to class logits with a Linear layer
"""
def __init__(
self,
image_size: int = 64,
patch_size: int = 4,
in_channels: int = 3,
embed_dim: int = 256,
depth: int = 6,
num_heads: int = 8,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
num_classes: int = 10,
) -> None:
super().__init__()
self.encoder = ViTEncoder(
image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
cls_features = self.encoder(x) # (B, D)
logits = self.head(cls_features) # (B, num_classes)
return logits
# ---------------------------------------------------------------------------
# Training and evaluation helpers
# ---------------------------------------------------------------------------
def train_one_epoch(
model: nn.Module,
dataloader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> Tuple[float, float]:
"""
Run one optimization epoch over the training set.
Args:
model: Classifier to optimize.
dataloader: Training mini-batches.
criterion: Loss function (typically CrossEntropyLoss for CIFAR-10).
optimizer: Parameter optimizer (AdamW in this project).
device: CPU or CUDA device.
Returns:
(avg_loss, avg_accuracy) over all training samples in this epoch.
"""
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = running_loss / total
avg_acc = correct / total
return avg_loss, avg_acc
@torch.no_grad()
def evaluate(
model: nn.Module,
dataloader: DataLoader,
criterion: nn.Module,
device: torch.device,
) -> Tuple[float, float]:
"""
Evaluate model performance without gradient updates.
Args:
model: Classifier to evaluate.
dataloader: Validation or test mini-batches.
criterion: Loss function used for reporting.
device: CPU or CUDA device.
Returns:
(avg_loss, avg_accuracy) over all samples from `dataloader`.
"""
model.eval()
running_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
running_loss += loss.item() * images.size(0)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = running_loss / total
avg_acc = correct / total
return avg_loss, avg_acc
def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
num_epochs: int = 10,
lr: float = 3e-4,
weight_decay: float = 1e-4,
save_dir: str = "./saved_model",
checkpoint_name: str = "vit_cifar10_best.pt",
model_config: Dict[str, Any] | None = None,
early_stopping_patience: int = 5,
) -> Tuple[Dict[str, List[float]], str]:
"""
Step 6:
- Loss: CrossEntropy
- Optimizer: AdamW
- LR scheduler: StepLR decay
- Validation each epoch
- Early stopping on validation accuracy
Hyperparameters:
- num_epochs: Max number of epochs before early stopping.
- lr: Initial learning rate for AdamW updates.
- weight_decay: L2-style regularization term in AdamW.
- early_stopping_patience: Number of non-improving epochs allowed.
This limits overfitting and unnecessary computation.
"""
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
history: Dict[str, List[float]] = {
"train_loss": [],
"train_acc": [],
"val_loss": [],
"val_acc": [],
}
best_val_acc = 0.0
epochs_without_improvement = 0
save_dir_path = Path(save_dir)
save_dir_path.mkdir(parents=True, exist_ok=True)
best_checkpoint_path = str(save_dir_path / checkpoint_name)
model.to(device)
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(
model=model,
dataloader=train_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
)
val_loss, val_acc = evaluate(
model=model,
dataloader=val_loader,
criterion=criterion,
device=device,
)
scheduler.step()
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
current_lr = optimizer.param_groups[0]["lr"]
print(
f"Epoch [{epoch + 1}/{num_epochs}] | "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}% | "
f"LR: {current_lr:.6f}"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
epochs_without_improvement = 0
checkpoint = {
"epoch": epoch + 1,
"best_val_acc": best_val_acc,
"model_state_dict": model.state_dict(),
"model_config": model_config or {},
}
torch.save(checkpoint, best_checkpoint_path)
print(f"Saved best checkpoint to: {best_checkpoint_path}")
else:
epochs_without_improvement += 1
if epochs_without_improvement >= early_stopping_patience:
print(
"Early stopping triggered "
f"(no validation improvement for {early_stopping_patience} epochs)."
)
break
final_checkpoint_path = str(save_dir_path / "vit_cifar10_last.pt")
torch.save(
{
"epoch": num_epochs,
"best_val_acc": best_val_acc,
"model_state_dict": model.state_dict(),
"model_config": model_config or {},
},
final_checkpoint_path,
)
print(f"Saved last checkpoint to: {final_checkpoint_path}")
return history, best_checkpoint_path
# ---------------------------------------------------------------------------
# Error analysis and visualization
# ---------------------------------------------------------------------------
@torch.no_grad()
def collect_misclassified(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
max_samples: int = 16,
) -> List[Tuple[torch.Tensor, int, int]]:
"""
Step 7 (Error analysis helper):
Collect misclassified samples: (image_tensor, true_label, pred_label).
"""
model.eval()
misclassified: List[Tuple[torch.Tensor, int, int]] = []
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
preds = logits.argmax(dim=1)
wrong_mask = preds != labels
wrong_images = images[wrong_mask]
wrong_labels = labels[wrong_mask]
wrong_preds = preds[wrong_mask]
for i in range(wrong_images.size(0)):
misclassified.append(
(
wrong_images[i].detach().cpu(),
int(wrong_labels[i].item()),
int(wrong_preds[i].item()),
)
)
if len(misclassified) >= max_samples:
return misclassified
return misclassified
def denormalize_image(img: torch.Tensor) -> torch.Tensor:
"""
Convert image from normalized [-1, 1] back to [0, 1] for visualization.
"""
return (img * 0.5 + 0.5).clamp(0.0, 1.0)
def visualize_misclassified(
samples: List[Tuple[torch.Tensor, int, int]],
class_names: Tuple[str, ...],
save_path: str = "misclassified_examples.png",
) -> None:
"""
Visualize wrongly predicted images and save to disk.
"""
if len(samples) == 0:
print("No misclassified samples to visualize.")
return
try:
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib is not installed. Skipping visualization.")
return
n = len(samples)
cols = min(4, n)
rows = (n + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
if rows == 1 and cols == 1:
axes = [axes]
elif rows == 1 or cols == 1:
axes = list(axes)
else:
axes = axes.flatten()
for idx, ax in enumerate(axes):
if idx < n:
img, true_lbl, pred_lbl = samples[idx]
img = denormalize_image(img).permute(1, 2, 0).numpy()
ax.imshow(img)
ax.set_title(f"True: {class_names[true_lbl]}\nPred: {class_names[pred_lbl]}")
ax.axis("off")
else:
ax.axis("off")
fig.tight_layout()
fig.savefig(save_path, dpi=150)
print(f"Saved misclassified visualization to: {save_path}")
if __name__ == "__main__":
# -----------------------------------------------------------------------
# Data preprocessing and loader setup
# -----------------------------------------------------------------------
train_loader, val_loader, test_loader = get_cifar10_dataloaders(
data_root="./data",
image_size=64,
batch_size=128,
val_ratio=0.1,
)
train_images, train_labels = next(iter(train_loader))
val_images, val_labels = next(iter(val_loader))
test_images, test_labels = next(iter(test_loader))
print(f"Train batch images shape: {train_images.shape}")
print(f"Train batch labels shape: {train_labels.shape}")
print(f"Val batch images shape: {val_images.shape}")
print(f"Val batch labels shape: {val_labels.shape}")
print(f"Test batch images shape: {test_images.shape}")
print(f"Test batch labels shape: {test_labels.shape}")
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Val dataset size: {len(val_loader.dataset)}")
print(f"Test dataset size: {len(test_loader.dataset)}")
print(
"Image value range (approx after normalization): "
f"[{train_images.min().item():.3f}, {train_images.max().item():.3f}]"
)
# -----------------------------------------------------------------------
# Model architecture configuration (key hyperparameters)
# -----------------------------------------------------------------------
# image_size=64: upscales CIFAR-10 from 32x32 for a richer patch grid.
# patch_size=4: produces (64/4)^2 = 256 image patches per sample.
# embed_dim=256: token representation size in attention/MLP blocks.
# depth=6, num_heads=8: transformer depth and multi-head attention width.
# mlp_ratio=4.0: hidden size in feed-forward layer = 4 * embed_dim.
# dropout=0.1: regularization inside encoder blocks.
model_kwargs: Dict[str, Any] = {
"image_size": 64,
"patch_size": 4,
"in_channels": 3,
"embed_dim": 256,
"depth": 6,
"num_heads": 8,
"mlp_ratio": 4.0,
"dropout": 0.1,
"num_classes": 10,
}
model = ViTClassifier(
image_size=64,
patch_size=4,
in_channels=3,
embed_dim=256,
depth=6,
num_heads=8,
mlp_ratio=4.0,
dropout=0.1,
num_classes=10,
)
patch_embeddings = model.encoder.patch_embed(train_images)
cls_features = model.encoder(train_images)
logits = model(train_images)
print(f"Patch embeddings shape (B, N, D): {patch_embeddings.shape}")
print(f"CLS feature shape (B, D): {cls_features.shape}")
print(f"Logits shape (B, num_classes): {logits.shape}")
# -----------------------------------------------------------------------
# Training setup hyperparameters
# -----------------------------------------------------------------------
# lr=3e-4: base AdamW learning rate.
# weight_decay=1e-4: regularization to improve generalization.
# num_epochs=10 and early_stopping_patience=5: train up to 10 epochs,
# but stop if validation accuracy does not improve for 5 epochs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
history, best_ckpt_path = train_model(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
num_epochs=10,
lr=3e-4,
weight_decay=1e-4,
save_dir="./saved_model",
checkpoint_name="vit_cifar10_best.pt",
model_config=model_kwargs,
early_stopping_patience=5,
)
final_val_acc = history["val_acc"][-1] * 100 if history["val_acc"] else 0.0
print(f"Final validation accuracy: {final_val_acc:.2f}%")
print(f"Best model checkpoint: {best_ckpt_path}")
# -----------------------------------------------------------------------
# Final test evaluation and qualitative error analysis
# -----------------------------------------------------------------------
# Evaluate on the held-out test set (not used for training/checkpointing).
best_checkpoint = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(best_checkpoint["model_state_dict"])
test_criterion = nn.CrossEntropyLoss()
test_loss, test_acc = evaluate(
model=model,
dataloader=test_loader,
criterion=test_criterion,
device=device,
)
print(f"Final test loss (best checkpoint): {test_loss:.4f}")
print(f"Final test accuracy (best checkpoint): {test_acc * 100:.2f}%")
wrong_samples = collect_misclassified(
model=model,
dataloader=test_loader,
device=device,
max_samples=12,
)
visualize_misclassified(
samples=wrong_samples,
class_names=CLASS_NAMES,
save_path="./results/misclassified_examples.png",
)