| | """ |
| | Separate testing/inference script for CIFAR-10 ViT model. |
| | |
| | Loads a saved checkpoint, runs inference on the test set, |
| | prints final performance, and saves misclassification analysis. |
| | Also supports an optional transfer-learning experiment with |
| | a pre-trained torchvision ViT model. |
| | |
| | Experiment with pre-trained models: Consider fine-tuning pre-trained |
| | Transformer models (e.g., ViT) on your task and evaluate their |
| | performance to understand the impact of transfer learning. |
| | """ |
| |
|
| | import argparse |
| | from pathlib import Path |
| | from typing import List |
| | from typing import Any, Dict, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader, random_split |
| | from torchvision import datasets, models, transforms |
| | from torchvision.models import ViT_B_16_Weights |
| |
|
| | from c1 import ( |
| | CLASS_NAMES, |
| | ViTClassifier, |
| | collect_misclassified, |
| | visualize_misclassified, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def evaluate_model( |
| | model: nn.Module, |
| | dataloader: torch.utils.data.DataLoader, |
| | device: torch.device, |
| | ) -> Tuple[float, float]: |
| | """ |
| | Compute average loss and accuracy for a model on a dataset split. |
| | |
| | Args: |
| | model: Trained model to evaluate. |
| | dataloader: Batches from validation or test split. |
| | device: CPU or CUDA device for inference. |
| | |
| | Returns: |
| | (avg_loss, accuracy) aggregated over all samples in `dataloader`. |
| | """ |
| | model.eval() |
| | criterion = nn.CrossEntropyLoss() |
| | total_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| |
|
| | for images, labels in dataloader: |
| | images = images.to(device) |
| | labels = labels.to(device) |
| |
|
| | logits = model(images) |
| | loss = criterion(logits, labels) |
| | preds = logits.argmax(dim=1) |
| |
|
| | total_loss += loss.item() * images.size(0) |
| | correct += (preds == labels).sum().item() |
| | total += labels.size(0) |
| |
|
| | avg_loss = total_loss / total |
| | acc = correct / total |
| | return avg_loss, acc |
| |
|
| |
|
| | def load_model_from_checkpoint( |
| | checkpoint_path: str, |
| | device: torch.device, |
| | ) -> ViTClassifier: |
| | """ |
| | Restore `ViTClassifier` from a saved checkpoint. |
| | |
| | The checkpoint is expected to include: |
| | - `model_state_dict` containing learned parameters |
| | - optional `model_config` with architecture hyperparameters |
| | |
| | If `model_config` is missing, the function falls back to the training |
| | defaults used in `c1.py`. |
| | """ |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | model_config: Dict[str, Any] = checkpoint.get("model_config", {}) |
| |
|
| | if not model_config: |
| | model_config = { |
| | "image_size": 64, |
| | "patch_size": 4, |
| | "in_channels": 3, |
| | "embed_dim": 256, |
| | "depth": 6, |
| | "num_heads": 8, |
| | "mlp_ratio": 4.0, |
| | "dropout": 0.1, |
| | "num_classes": 10, |
| | } |
| |
|
| | model = ViTClassifier(**model_config) |
| | state_dict = checkpoint.get("model_state_dict", checkpoint) |
| | model.load_state_dict(state_dict) |
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | @torch.no_grad() |
| | def collect_predictions( |
| | model: nn.Module, |
| | dataloader: DataLoader, |
| | device: torch.device, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Collect predicted labels and ground-truth labels for full-dataset analysis. |
| | """ |
| | model.eval() |
| | all_preds: List[torch.Tensor] = [] |
| | all_labels: List[torch.Tensor] = [] |
| | for images, labels in dataloader: |
| | images = images.to(device) |
| | labels = labels.to(device) |
| | logits = model(images) |
| | preds = logits.argmax(dim=1) |
| | all_preds.append(preds.cpu()) |
| | all_labels.append(labels.cpu()) |
| | return torch.cat(all_preds), torch.cat(all_labels) |
| |
|
| |
|
| | def build_confusion_matrix( |
| | preds: torch.Tensor, |
| | labels: torch.Tensor, |
| | num_classes: int, |
| | ) -> torch.Tensor: |
| | """ |
| | Build confusion matrix where rows=true class and cols=predicted class. |
| | """ |
| | confusion = torch.zeros((num_classes, num_classes), dtype=torch.int64) |
| | for true_label, pred_label in zip(labels, preds): |
| | confusion[int(true_label), int(pred_label)] += 1 |
| | return confusion |
| |
|
| |
|
| | def format_error_analysis( |
| | preds: torch.Tensor, |
| | labels: torch.Tensor, |
| | class_names: Tuple[str, ...], |
| | ) -> str: |
| | """ |
| | Create a readable report with per-class accuracy and top confusion pairs. |
| | """ |
| | num_classes = len(class_names) |
| | confusion = build_confusion_matrix(preds=preds, labels=labels, num_classes=num_classes) |
| | class_totals = confusion.sum(dim=1) |
| | class_correct = confusion.diag() |
| |
|
| | lines: List[str] = [] |
| | lines.append("Per-class accuracy (lower = harder classes):") |
| | per_class_scores = [] |
| | for idx, class_name in enumerate(class_names): |
| | total = int(class_totals[idx].item()) |
| | correct = int(class_correct[idx].item()) |
| | acc = (correct / total) if total > 0 else 0.0 |
| | per_class_scores.append((acc, class_name, total)) |
| | per_class_scores.sort(key=lambda x: x[0]) |
| | for acc, class_name, total in per_class_scores: |
| | lines.append(f" {class_name:<10} | acc={acc * 100:6.2f}% | n={total}") |
| |
|
| | lines.append("") |
| | lines.append("Top confusion pairs (true -> predicted):") |
| | confusions = [] |
| | for true_idx in range(num_classes): |
| | for pred_idx in range(num_classes): |
| | if true_idx == pred_idx: |
| | continue |
| | count = int(confusion[true_idx, pred_idx].item()) |
| | if count > 0: |
| | confusions.append((count, true_idx, pred_idx)) |
| | confusions.sort(reverse=True, key=lambda x: x[0]) |
| | top_k = min(8, len(confusions)) |
| | if top_k == 0: |
| | lines.append(" No confusions found (perfect classification).") |
| | else: |
| | for count, true_idx, pred_idx in confusions[:top_k]: |
| | lines.append( |
| | f" {class_names[true_idx]} -> {class_names[pred_idx]}: {count} samples" |
| | ) |
| | return "\n".join(lines) |
| |
|
| |
|
| | def print_error_analysis( |
| | preds: torch.Tensor, |
| | labels: torch.Tensor, |
| | class_names: Tuple[str, ...], |
| | ) -> str: |
| | """ |
| | Print and return error-analysis summary. |
| | """ |
| | report = format_error_analysis(preds=preds, labels=labels, class_names=class_names) |
| | print(f"\n{report}") |
| | return report |
| |
|
| |
|
| | def get_imagenet_style_cifar10_dataloaders( |
| | data_root: str = "./data", |
| | batch_size: int = 128, |
| | num_workers: int = 2, |
| | pin_memory: bool = True, |
| | val_ratio: float = 0.1, |
| | seed: int = 42, |
| | ) -> Tuple[DataLoader, DataLoader, DataLoader]: |
| | """ |
| | Build CIFAR-10 DataLoaders with ImageNet preprocessing for ViT-B/16. |
| | |
| | Why this preprocessing: |
| | - Resize to 224x224 because torchvision ViT-B/16 expects ImageNet-sized input. |
| | - Use ImageNet mean/std so input statistics align with pre-training. |
| | """ |
| | if not 0.0 < val_ratio < 1.0: |
| | raise ValueError("val_ratio must be between 0 and 1.") |
| |
|
| | transform = transforms.Compose( |
| | [ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=(0.485, 0.456, 0.406), |
| | std=(0.229, 0.224, 0.225), |
| | ), |
| | ] |
| | ) |
| |
|
| | data_root_path = Path(data_root) |
| | data_root_path.mkdir(parents=True, exist_ok=True) |
| | full_train_dataset = datasets.CIFAR10( |
| | root=str(data_root_path), |
| | train=True, |
| | download=True, |
| | transform=transform, |
| | ) |
| | test_dataset = datasets.CIFAR10( |
| | root=str(data_root_path), |
| | train=False, |
| | download=True, |
| | transform=transform, |
| | ) |
| |
|
| | use_pin_memory = pin_memory and torch.cuda.is_available() |
| | val_size = int(len(full_train_dataset) * val_ratio) |
| | train_size = len(full_train_dataset) - val_size |
| | generator = torch.Generator().manual_seed(seed) |
| | train_dataset, val_dataset = random_split( |
| | full_train_dataset, [train_size, val_size], generator=generator |
| | ) |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | num_workers=num_workers, |
| | pin_memory=use_pin_memory, |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=use_pin_memory, |
| | ) |
| | test_loader = DataLoader( |
| | test_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=use_pin_memory, |
| | ) |
| | return train_loader, val_loader, test_loader |
| |
|
| |
|
| | def build_pretrained_vit_classifier(num_classes: int = 10) -> nn.Module: |
| | """ |
| | Load torchvision ViT-B/16 with ImageNet weights and replace classifier head. |
| | """ |
| | weights = ViT_B_16_Weights.IMAGENET1K_V1 |
| | model = models.vit_b_16(weights=weights) |
| | in_features = model.heads.head.in_features |
| | model.heads.head = nn.Linear(in_features, num_classes) |
| | return model |
| |
|
| |
|
| | def fine_tune_pretrained( |
| | model: nn.Module, |
| | train_loader: DataLoader, |
| | val_loader: DataLoader, |
| | device: torch.device, |
| | epochs: int = 2, |
| | lr: float = 1e-4, |
| | weight_decay: float = 1e-4, |
| | ) -> None: |
| | """ |
| | Fine-tune a pre-trained ViT on CIFAR-10 and print epoch-level metrics. |
| | |
| | Hyperparameters: |
| | - epochs: Number of fine-tuning passes over training data. |
| | - lr: AdamW learning rate for adaptation from ImageNet to CIFAR-10. |
| | - weight_decay: Regularization to reduce overfitting. |
| | """ |
| | criterion = nn.CrossEntropyLoss() |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
| | model.to(device) |
| |
|
| | for epoch in range(epochs): |
| | model.train() |
| | running_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| |
|
| | for images, labels in train_loader: |
| | images = images.to(device) |
| | labels = labels.to(device) |
| | optimizer.zero_grad() |
| | logits = model(images) |
| | loss = criterion(logits, labels) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | running_loss += loss.item() * images.size(0) |
| | preds = logits.argmax(dim=1) |
| | correct += (preds == labels).sum().item() |
| | total += labels.size(0) |
| |
|
| | train_loss = running_loss / total |
| | train_acc = correct / total |
| | val_loss, val_acc = evaluate_model(model=model, dataloader=val_loader, device=device) |
| | print( |
| | f"[Pretrained ViT] Epoch {epoch + 1}/{epochs} | " |
| | f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | " |
| | f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}%" |
| | ) |
| |
|
| |
|
| | def build_comparison_report( |
| | baseline_loss: float, |
| | baseline_acc: float, |
| | pretrained_loss: float, |
| | pretrained_acc: float, |
| | ) -> str: |
| | """ |
| | Build a compact side-by-side comparison report for baseline vs pre-trained ViT. |
| | """ |
| | acc_delta = (pretrained_acc - baseline_acc) * 100.0 |
| | loss_delta = pretrained_loss - baseline_loss |
| |
|
| | lines = [ |
| | "Model comparison (baseline vs transfer learning)", |
| | "-" * 56, |
| | f"{'Model':<28}{'Test Loss':>12}{'Test Acc':>14}", |
| | f"{'Baseline ViT (custom checkpoint)':<28}{baseline_loss:>12.4f}{baseline_acc * 100:>13.2f}%", |
| | f"{'Pre-trained ViT-B/16':<28}{pretrained_loss:>12.4f}{pretrained_acc * 100:>13.2f}%", |
| | "-" * 56, |
| | f"Accuracy gain (pretrained - baseline): {acc_delta:+.2f} percentage points", |
| | f"Loss delta (pretrained - baseline): {loss_delta:+.4f}", |
| | "", |
| | ] |
| | return "\n".join(lines) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | parser = argparse.ArgumentParser(description="Evaluate baseline ViT and run analysis.") |
| | parser.add_argument( |
| | "--checkpoint-path", |
| | type=str, |
| | default="./saved_model/vit_cifar10_best.pt", |
| | help="Path to custom ViT checkpoint.", |
| | ) |
| | parser.add_argument( |
| | "--batch-size", |
| | type=int, |
| | default=128, |
| | help="Evaluation batch size.", |
| | ) |
| | parser.add_argument( |
| | "--run-pretrained-experiment", |
| | action="store_true", |
| | help="If set, fine-tune a pre-trained ViT-B/16 on CIFAR-10 and compare.", |
| | ) |
| | parser.add_argument( |
| | "--results-dir", |
| | type=str, |
| | default="./results", |
| | help="Directory to save plots and analysis reports.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | checkpoint_path = args.checkpoint_path |
| | results_dir = Path(args.results_dir) |
| | results_dir.mkdir(parents=True, exist_ok=True) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| | print(f"Loading checkpoint: {checkpoint_path}") |
| | print(f"Saving results to: {results_dir}") |
| |
|
| | |
| | |
| | |
| | |
| | from c1 import get_cifar10_dataloaders |
| |
|
| | _, _, test_loader = get_cifar10_dataloaders( |
| | data_root="./data", |
| | image_size=64, |
| | batch_size=args.batch_size, |
| | val_ratio=0.1, |
| | ) |
| |
|
| | model = load_model_from_checkpoint(checkpoint_path=checkpoint_path, device=device) |
| | test_loss, test_acc = evaluate_model(model=model, dataloader=test_loader, device=device) |
| |
|
| | print(f"Test Loss: {test_loss:.4f}") |
| | print(f"Test Accuracy: {test_acc * 100:.2f}%") |
| |
|
| | preds, labels = collect_predictions(model=model, dataloader=test_loader, device=device) |
| | baseline_analysis = print_error_analysis( |
| | preds=preds, labels=labels, class_names=CLASS_NAMES |
| | ) |
| |
|
| | wrong_samples = collect_misclassified( |
| | model=model, |
| | dataloader=test_loader, |
| | device=device, |
| | max_samples=24, |
| | ) |
| | visualize_misclassified( |
| | samples=wrong_samples, |
| | class_names=CLASS_NAMES, |
| | save_path=str(results_dir / "misclassified_examples_test.png"), |
| | ) |
| | baseline_report_path = results_dir / "baseline_analysis.txt" |
| | baseline_report_path.write_text( |
| | "\n".join( |
| | [ |
| | "Baseline ViT (custom checkpoint) results", |
| | f"Checkpoint: {checkpoint_path}", |
| | f"Test Loss: {test_loss:.4f}", |
| | f"Test Accuracy: {test_acc * 100:.2f}%", |
| | "", |
| | baseline_analysis, |
| | "", |
| | ] |
| | ), |
| | encoding="utf-8", |
| | ) |
| | print(f"Saved baseline analysis to: {baseline_report_path}") |
| |
|
| | if args.run_pretrained_experiment: |
| | |
| | |
| | |
| | |
| | |
| | pretrained_epochs = 2 |
| | print("\nRunning transfer-learning experiment with pre-trained ViT-B/16...") |
| | train_loader_pt, val_loader_pt, test_loader_pt = ( |
| | get_imagenet_style_cifar10_dataloaders( |
| | data_root="./data", |
| | batch_size=args.batch_size, |
| | val_ratio=0.1, |
| | ) |
| | ) |
| |
|
| | pretrained_model = build_pretrained_vit_classifier(num_classes=len(CLASS_NAMES)) |
| | fine_tune_pretrained( |
| | model=pretrained_model, |
| | train_loader=train_loader_pt, |
| | val_loader=val_loader_pt, |
| | device=device, |
| | epochs=pretrained_epochs, |
| | lr=1e-4, |
| | weight_decay=1e-4, |
| | ) |
| | pt_test_loss, pt_test_acc = evaluate_model( |
| | model=pretrained_model, |
| | dataloader=test_loader_pt, |
| | device=device, |
| | ) |
| | print(f"[Pretrained ViT] Test Loss: {pt_test_loss:.4f}") |
| | print(f"[Pretrained ViT] Test Accuracy: {pt_test_acc * 100:.2f}%") |
| | comparison_report = build_comparison_report( |
| | baseline_loss=test_loss, |
| | baseline_acc=test_acc, |
| | pretrained_loss=pt_test_loss, |
| | pretrained_acc=pt_test_acc, |
| | ) |
| | print("\n" + comparison_report) |
| |
|
| | pt_preds, pt_labels = collect_predictions( |
| | model=pretrained_model, dataloader=test_loader_pt, device=device |
| | ) |
| | pretrained_analysis = print_error_analysis( |
| | preds=pt_preds, labels=pt_labels, class_names=CLASS_NAMES |
| | ) |
| |
|
| | pt_wrong_samples = collect_misclassified( |
| | model=pretrained_model, |
| | dataloader=test_loader_pt, |
| | device=device, |
| | max_samples=24, |
| | ) |
| | visualize_misclassified( |
| | samples=pt_wrong_samples, |
| | class_names=CLASS_NAMES, |
| | save_path=str(results_dir / "misclassified_examples_pretrained_vit.png"), |
| | ) |
| | pretrained_report_path = results_dir / "pretrained_vit_analysis.txt" |
| | pretrained_report_path.write_text( |
| | "\n".join( |
| | [ |
| | "Pre-trained ViT-B/16 transfer learning results", |
| | f"Fine-tuning epochs: {pretrained_epochs}", |
| | f"Test Loss: {pt_test_loss:.4f}", |
| | f"Test Accuracy: {pt_test_acc * 100:.2f}%", |
| | "", |
| | pretrained_analysis, |
| | "", |
| | ] |
| | ), |
| | encoding="utf-8", |
| | ) |
| | print(f"Saved pre-trained analysis to: {pretrained_report_path}") |
| |
|
| | comparison_report_path = results_dir / "comparison_report.txt" |
| | comparison_report_path.write_text(comparison_report, encoding="utf-8") |
| | print(f"Saved model comparison report to: {comparison_report_path}") |
| |
|