| """ |
| Separate testing/inference script for CIFAR-10 ViT model. |
| |
| Loads a saved checkpoint, runs inference on the test set, |
| prints final performance, and saves misclassification analysis. |
| Also supports an optional transfer-learning experiment with |
| a pre-trained torchvision ViT model. |
| |
| Experiment with pre-trained models: Consider fine-tuning pre-trained |
| Transformer models (e.g., ViT) on your task and evaluate their |
| performance to understand the impact of transfer learning. |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| from typing import List |
| from typing import Any, Dict, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, random_split |
| from torchvision import datasets, models, transforms |
| from torchvision.models import ViT_B_16_Weights |
|
|
| from c1 import ( |
| CLASS_NAMES, |
| ViTClassifier, |
| collect_misclassified, |
| visualize_misclassified, |
| ) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate_model( |
| model: nn.Module, |
| dataloader: torch.utils.data.DataLoader, |
| device: torch.device, |
| ) -> Tuple[float, float]: |
| """ |
| Compute average loss and accuracy for a model on a dataset split. |
| |
| Args: |
| model: Trained model to evaluate. |
| dataloader: Batches from validation or test split. |
| device: CPU or CUDA device for inference. |
| |
| Returns: |
| (avg_loss, accuracy) aggregated over all samples in `dataloader`. |
| """ |
| model.eval() |
| criterion = nn.CrossEntropyLoss() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| for images, labels in dataloader: |
| images = images.to(device) |
| labels = labels.to(device) |
|
|
| logits = model(images) |
| loss = criterion(logits, labels) |
| preds = logits.argmax(dim=1) |
|
|
| total_loss += loss.item() * images.size(0) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| avg_loss = total_loss / total |
| acc = correct / total |
| return avg_loss, acc |
|
|
|
|
| def load_model_from_checkpoint( |
| checkpoint_path: str, |
| device: torch.device, |
| ) -> ViTClassifier: |
| """ |
| Restore `ViTClassifier` from a saved checkpoint. |
| |
| The checkpoint is expected to include: |
| - `model_state_dict` containing learned parameters |
| - optional `model_config` with architecture hyperparameters |
| |
| If `model_config` is missing, the function falls back to the training |
| defaults used in `c1.py`. |
| """ |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| model_config: Dict[str, Any] = checkpoint.get("model_config", {}) |
|
|
| if not model_config: |
| model_config = { |
| "image_size": 64, |
| "patch_size": 4, |
| "in_channels": 3, |
| "embed_dim": 256, |
| "depth": 6, |
| "num_heads": 8, |
| "mlp_ratio": 4.0, |
| "dropout": 0.1, |
| "num_classes": 10, |
| } |
|
|
| model = ViTClassifier(**model_config) |
| state_dict = checkpoint.get("model_state_dict", checkpoint) |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| @torch.no_grad() |
| def collect_predictions( |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Collect predicted labels and ground-truth labels for full-dataset analysis. |
| """ |
| model.eval() |
| all_preds: List[torch.Tensor] = [] |
| all_labels: List[torch.Tensor] = [] |
| for images, labels in dataloader: |
| images = images.to(device) |
| labels = labels.to(device) |
| logits = model(images) |
| preds = logits.argmax(dim=1) |
| all_preds.append(preds.cpu()) |
| all_labels.append(labels.cpu()) |
| return torch.cat(all_preds), torch.cat(all_labels) |
|
|
|
|
| def build_confusion_matrix( |
| preds: torch.Tensor, |
| labels: torch.Tensor, |
| num_classes: int, |
| ) -> torch.Tensor: |
| """ |
| Build confusion matrix where rows=true class and cols=predicted class. |
| """ |
| confusion = torch.zeros((num_classes, num_classes), dtype=torch.int64) |
| for true_label, pred_label in zip(labels, preds): |
| confusion[int(true_label), int(pred_label)] += 1 |
| return confusion |
|
|
|
|
| def format_error_analysis( |
| preds: torch.Tensor, |
| labels: torch.Tensor, |
| class_names: Tuple[str, ...], |
| ) -> str: |
| """ |
| Create a readable report with per-class accuracy and top confusion pairs. |
| """ |
| num_classes = len(class_names) |
| confusion = build_confusion_matrix(preds=preds, labels=labels, num_classes=num_classes) |
| class_totals = confusion.sum(dim=1) |
| class_correct = confusion.diag() |
|
|
| lines: List[str] = [] |
| lines.append("Per-class accuracy (lower = harder classes):") |
| per_class_scores = [] |
| for idx, class_name in enumerate(class_names): |
| total = int(class_totals[idx].item()) |
| correct = int(class_correct[idx].item()) |
| acc = (correct / total) if total > 0 else 0.0 |
| per_class_scores.append((acc, class_name, total)) |
| per_class_scores.sort(key=lambda x: x[0]) |
| for acc, class_name, total in per_class_scores: |
| lines.append(f" {class_name:<10} | acc={acc * 100:6.2f}% | n={total}") |
|
|
| lines.append("") |
| lines.append("Top confusion pairs (true -> predicted):") |
| confusions = [] |
| for true_idx in range(num_classes): |
| for pred_idx in range(num_classes): |
| if true_idx == pred_idx: |
| continue |
| count = int(confusion[true_idx, pred_idx].item()) |
| if count > 0: |
| confusions.append((count, true_idx, pred_idx)) |
| confusions.sort(reverse=True, key=lambda x: x[0]) |
| top_k = min(8, len(confusions)) |
| if top_k == 0: |
| lines.append(" No confusions found (perfect classification).") |
| else: |
| for count, true_idx, pred_idx in confusions[:top_k]: |
| lines.append( |
| f" {class_names[true_idx]} -> {class_names[pred_idx]}: {count} samples" |
| ) |
| return "\n".join(lines) |
|
|
|
|
| def print_error_analysis( |
| preds: torch.Tensor, |
| labels: torch.Tensor, |
| class_names: Tuple[str, ...], |
| ) -> str: |
| """ |
| Print and return error-analysis summary. |
| """ |
| report = format_error_analysis(preds=preds, labels=labels, class_names=class_names) |
| print(f"\n{report}") |
| return report |
|
|
|
|
| def get_imagenet_style_cifar10_dataloaders( |
| data_root: str = "./data", |
| batch_size: int = 128, |
| num_workers: int = 2, |
| pin_memory: bool = True, |
| val_ratio: float = 0.1, |
| seed: int = 42, |
| ) -> Tuple[DataLoader, DataLoader, DataLoader]: |
| """ |
| Build CIFAR-10 DataLoaders with ImageNet preprocessing for ViT-B/16. |
| |
| Why this preprocessing: |
| - Resize to 224x224 because torchvision ViT-B/16 expects ImageNet-sized input. |
| - Use ImageNet mean/std so input statistics align with pre-training. |
| """ |
| if not 0.0 < val_ratio < 1.0: |
| raise ValueError("val_ratio must be between 0 and 1.") |
|
|
| transform = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.485, 0.456, 0.406), |
| std=(0.229, 0.224, 0.225), |
| ), |
| ] |
| ) |
|
|
| data_root_path = Path(data_root) |
| data_root_path.mkdir(parents=True, exist_ok=True) |
| full_train_dataset = datasets.CIFAR10( |
| root=str(data_root_path), |
| train=True, |
| download=True, |
| transform=transform, |
| ) |
| test_dataset = datasets.CIFAR10( |
| root=str(data_root_path), |
| train=False, |
| download=True, |
| transform=transform, |
| ) |
|
|
| use_pin_memory = pin_memory and torch.cuda.is_available() |
| val_size = int(len(full_train_dataset) * val_ratio) |
| train_size = len(full_train_dataset) - val_size |
| generator = torch.Generator().manual_seed(seed) |
| train_dataset, val_dataset = random_split( |
| full_train_dataset, [train_size, val_size], generator=generator |
| ) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| pin_memory=use_pin_memory, |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=use_pin_memory, |
| ) |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=use_pin_memory, |
| ) |
| return train_loader, val_loader, test_loader |
|
|
|
|
| def build_pretrained_vit_classifier(num_classes: int = 10) -> nn.Module: |
| """ |
| Load torchvision ViT-B/16 with ImageNet weights and replace classifier head. |
| """ |
| weights = ViT_B_16_Weights.IMAGENET1K_V1 |
| model = models.vit_b_16(weights=weights) |
| in_features = model.heads.head.in_features |
| model.heads.head = nn.Linear(in_features, num_classes) |
| return model |
|
|
|
|
| def fine_tune_pretrained( |
| model: nn.Module, |
| train_loader: DataLoader, |
| val_loader: DataLoader, |
| device: torch.device, |
| epochs: int = 2, |
| lr: float = 1e-4, |
| weight_decay: float = 1e-4, |
| ) -> None: |
| """ |
| Fine-tune a pre-trained ViT on CIFAR-10 and print epoch-level metrics. |
| |
| Hyperparameters: |
| - epochs: Number of fine-tuning passes over training data. |
| - lr: AdamW learning rate for adaptation from ImageNet to CIFAR-10. |
| - weight_decay: Regularization to reduce overfitting. |
| """ |
| criterion = nn.CrossEntropyLoss() |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
| model.to(device) |
|
|
| for epoch in range(epochs): |
| model.train() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| for images, labels in train_loader: |
| images = images.to(device) |
| labels = labels.to(device) |
| optimizer.zero_grad() |
| logits = model(images) |
| loss = criterion(logits, labels) |
| loss.backward() |
| optimizer.step() |
|
|
| running_loss += loss.item() * images.size(0) |
| preds = logits.argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
|
|
| train_loss = running_loss / total |
| train_acc = correct / total |
| val_loss, val_acc = evaluate_model(model=model, dataloader=val_loader, device=device) |
| print( |
| f"[Pretrained ViT] Epoch {epoch + 1}/{epochs} | " |
| f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | " |
| f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}%" |
| ) |
|
|
|
|
| def build_comparison_report( |
| baseline_loss: float, |
| baseline_acc: float, |
| pretrained_loss: float, |
| pretrained_acc: float, |
| ) -> str: |
| """ |
| Build a compact side-by-side comparison report for baseline vs pre-trained ViT. |
| """ |
| acc_delta = (pretrained_acc - baseline_acc) * 100.0 |
| loss_delta = pretrained_loss - baseline_loss |
|
|
| lines = [ |
| "Model comparison (baseline vs transfer learning)", |
| "-" * 56, |
| f"{'Model':<28}{'Test Loss':>12}{'Test Acc':>14}", |
| f"{'Baseline ViT (custom checkpoint)':<28}{baseline_loss:>12.4f}{baseline_acc * 100:>13.2f}%", |
| f"{'Pre-trained ViT-B/16':<28}{pretrained_loss:>12.4f}{pretrained_acc * 100:>13.2f}%", |
| "-" * 56, |
| f"Accuracy gain (pretrained - baseline): {acc_delta:+.2f} percentage points", |
| f"Loss delta (pretrained - baseline): {loss_delta:+.4f}", |
| "", |
| ] |
| return "\n".join(lines) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| parser = argparse.ArgumentParser(description="Evaluate baseline ViT and run analysis.") |
| parser.add_argument( |
| "--checkpoint-path", |
| type=str, |
| default="./saved_model/vit_cifar10_best.pt", |
| help="Path to custom ViT checkpoint.", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=128, |
| help="Evaluation batch size.", |
| ) |
| parser.add_argument( |
| "--run-pretrained-experiment", |
| action="store_true", |
| help="If set, fine-tune a pre-trained ViT-B/16 on CIFAR-10 and compare.", |
| ) |
| parser.add_argument( |
| "--results-dir", |
| type=str, |
| default="./results", |
| help="Directory to save plots and analysis reports.", |
| ) |
| args = parser.parse_args() |
|
|
| checkpoint_path = args.checkpoint_path |
| results_dir = Path(args.results_dir) |
| results_dir.mkdir(parents=True, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| print(f"Loading checkpoint: {checkpoint_path}") |
| print(f"Saving results to: {results_dir}") |
|
|
| |
| |
| |
| |
| from c1 import get_cifar10_dataloaders |
|
|
| _, _, test_loader = get_cifar10_dataloaders( |
| data_root="./data", |
| image_size=64, |
| batch_size=args.batch_size, |
| val_ratio=0.1, |
| ) |
|
|
| model = load_model_from_checkpoint(checkpoint_path=checkpoint_path, device=device) |
| test_loss, test_acc = evaluate_model(model=model, dataloader=test_loader, device=device) |
|
|
| print(f"Test Loss: {test_loss:.4f}") |
| print(f"Test Accuracy: {test_acc * 100:.2f}%") |
|
|
| preds, labels = collect_predictions(model=model, dataloader=test_loader, device=device) |
| baseline_analysis = print_error_analysis( |
| preds=preds, labels=labels, class_names=CLASS_NAMES |
| ) |
|
|
| wrong_samples = collect_misclassified( |
| model=model, |
| dataloader=test_loader, |
| device=device, |
| max_samples=24, |
| ) |
| visualize_misclassified( |
| samples=wrong_samples, |
| class_names=CLASS_NAMES, |
| save_path=str(results_dir / "misclassified_examples_test.png"), |
| ) |
| baseline_report_path = results_dir / "baseline_analysis.txt" |
| baseline_report_path.write_text( |
| "\n".join( |
| [ |
| "Baseline ViT (custom checkpoint) results", |
| f"Checkpoint: {checkpoint_path}", |
| f"Test Loss: {test_loss:.4f}", |
| f"Test Accuracy: {test_acc * 100:.2f}%", |
| "", |
| baseline_analysis, |
| "", |
| ] |
| ), |
| encoding="utf-8", |
| ) |
| print(f"Saved baseline analysis to: {baseline_report_path}") |
|
|
| if args.run_pretrained_experiment: |
| |
| |
| |
| |
| |
| pretrained_epochs = 2 |
| print("\nRunning transfer-learning experiment with pre-trained ViT-B/16...") |
| train_loader_pt, val_loader_pt, test_loader_pt = ( |
| get_imagenet_style_cifar10_dataloaders( |
| data_root="./data", |
| batch_size=args.batch_size, |
| val_ratio=0.1, |
| ) |
| ) |
|
|
| pretrained_model = build_pretrained_vit_classifier(num_classes=len(CLASS_NAMES)) |
| fine_tune_pretrained( |
| model=pretrained_model, |
| train_loader=train_loader_pt, |
| val_loader=val_loader_pt, |
| device=device, |
| epochs=pretrained_epochs, |
| lr=1e-4, |
| weight_decay=1e-4, |
| ) |
| pt_test_loss, pt_test_acc = evaluate_model( |
| model=pretrained_model, |
| dataloader=test_loader_pt, |
| device=device, |
| ) |
| print(f"[Pretrained ViT] Test Loss: {pt_test_loss:.4f}") |
| print(f"[Pretrained ViT] Test Accuracy: {pt_test_acc * 100:.2f}%") |
| comparison_report = build_comparison_report( |
| baseline_loss=test_loss, |
| baseline_acc=test_acc, |
| pretrained_loss=pt_test_loss, |
| pretrained_acc=pt_test_acc, |
| ) |
| print("\n" + comparison_report) |
|
|
| pt_preds, pt_labels = collect_predictions( |
| model=pretrained_model, dataloader=test_loader_pt, device=device |
| ) |
| pretrained_analysis = print_error_analysis( |
| preds=pt_preds, labels=pt_labels, class_names=CLASS_NAMES |
| ) |
|
|
| pt_wrong_samples = collect_misclassified( |
| model=pretrained_model, |
| dataloader=test_loader_pt, |
| device=device, |
| max_samples=24, |
| ) |
| visualize_misclassified( |
| samples=pt_wrong_samples, |
| class_names=CLASS_NAMES, |
| save_path=str(results_dir / "misclassified_examples_pretrained_vit.png"), |
| ) |
| pretrained_report_path = results_dir / "pretrained_vit_analysis.txt" |
| pretrained_report_path.write_text( |
| "\n".join( |
| [ |
| "Pre-trained ViT-B/16 transfer learning results", |
| f"Fine-tuning epochs: {pretrained_epochs}", |
| f"Test Loss: {pt_test_loss:.4f}", |
| f"Test Accuracy: {pt_test_acc * 100:.2f}%", |
| "", |
| pretrained_analysis, |
| "", |
| ] |
| ), |
| encoding="utf-8", |
| ) |
| print(f"Saved pre-trained analysis to: {pretrained_report_path}") |
|
|
| comparison_report_path = results_dir / "comparison_report.txt" |
| comparison_report_path.write_text(comparison_report, encoding="utf-8") |
| print(f"Saved model comparison report to: {comparison_report_path}") |
|
|