| ''' |
| Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers |
| and running inference on the ImageNet-mini (Imagenette) validation set. |
| |
| Each image is resized to 224x224 and has 3 RGB channels to be compatible with ViT. |
| |
| Usage: |
| |
| demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth |
| |
| Paper: |
| https://papers.miccai.org/miccai-2024/316-Paper0410.html |
| Code: |
| https://github.com/cosmoimd/feature-selection-gates |
| Contact: |
| giorgio.roffo@gmail.com |
| ''' |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import os |
| import sys |
| import tarfile |
| import urllib.request |
| import torch |
| import psutil |
| from torchvision.models import vit_b_16, ViT_B_16_Weights |
| from vit_with_fsg import vit_with_fsg |
| from torchvision import transforms |
| from torchvision.datasets import ImageFolder |
| from torch.utils.data import DataLoader |
| import torch.nn.functional as F |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
| from tqdm import tqdm |
|
|
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="FSG-ViT inference on Imagenette") |
| parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model") |
| args = parser.parse_args() |
|
|
| if __name__ == "__main__": |
| warnings.filterwarnings("ignore", message="Failed to load image Python extension*") |
| wrn = False |
| print(f"\nπ To run this script:\n" |
| f" βΆ Without checkpoint: python {os.path.basename(__file__)}\n" |
| f" βΆ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n") |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nπ₯οΈ Using device: {device}") |
| if device.type == "cuda": |
| print(f"π CUDA device: {torch.cuda.get_device_name(0)}") |
| print(f"πΎ GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB") |
| print(f"π§ System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB") |
|
|
| print("\nπ₯ Loading pretrained ViT backbone from torchvision...") |
| backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) |
|
|
| print("π§ Wrapping with Feature Selection Gates (FSG)...") |
| model = vit_with_fsg(backbone).to(device) |
|
|
| if args.checkpoint is not None: |
| print(f"π Loading model weights from: {args.checkpoint}") |
| model.load_state_dict(torch.load(args.checkpoint, map_location=device)) |
| else: |
| wrn = True |
| print("\nβ οΈ No checkpoint provided. Evaluating randomly initialized model! π§ͺ\n") |
| print("β Note: The model has not been trained. Results will reflect a randomly initialized backbone.") |
|
|
| model.eval() |
|
|
| print("π Loading Imagenette validation set (224x224 RGB)...") |
| imagenette_path = "./imagenette2-160/val" |
| if not os.path.exists(imagenette_path): |
| print("π¦ Downloading Imagenette...") |
| url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz" |
| tgz_path = "imagenette2-160.tgz" |
| urllib.request.urlretrieve(url, tgz_path) |
| print("π Extracting Imagenette dataset...") |
| with tarfile.open(tgz_path, "r:gz") as tar: |
| tar.extractall() |
| os.remove(tgz_path) |
| print("β
Dataset ready.") |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5]*3, std=[0.5]*3) |
| ]) |
|
|
| dataset = ImageFolder(root=imagenette_path, transform=transform) |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=False) |
|
|
| y_true = [] |
| y_pred = [] |
|
|
| print("π§ͺ Running inference on Imagenette validation set using FSG-ViT-B-16 (code by G. Roffo)...\n\n") |
| with torch.no_grad(): |
| for images, labels in tqdm(dataloader, desc="π Inference progress", ncols=100): |
| images = images.to(device) |
| labels = labels.to(device) |
| outputs = model(images) |
| preds = torch.argmax(F.softmax(outputs, dim=1), dim=1) |
| y_true.extend(labels.cpu().tolist()) |
| y_pred.extend(preds.cpu().tolist()) |
|
|
| print("β
Inference completed.") |
|
|
| acc = accuracy_score(y_true, y_pred) |
| prec = precision_score(y_true, y_pred, average='macro', zero_division=0) |
| rec = recall_score(y_true, y_pred, average='macro', zero_division=0) |
| f1 = f1_score(y_true, y_pred, average='macro', zero_division=0) |
|
|
| if wrn == True: |
| print("\nβ οΈ No checkpoint provided. Evaluated randomly initialized model! π§ͺ\n") |
| print(f"\nπ To run this script:\n" |
| f" βΆ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n") |
|
|
| print(f"π Accuracy: {acc * 100:.2f}%") |
| print(f"π Precision: {prec * 100:.2f}%") |
| print(f"π Recall: {rec * 100:.2f}%") |
| print(f"π F1 Score: {f1 * 100:.2f}%") |
|
|