| ''' |
| Demo training script for Feature Selection Gates (FSG) with ViT on Imagenette |
| |
| This script loads the Imagenette dataset (ImageNet-mini), |
| trains a ViT model augmented with FSG, and saves the model checkpoint. |
| |
| Paper: |
| https://papers.miccai.org/miccai-2024/316-Paper0410.html |
| Code: |
| https://github.com/cosmoimd/feature-selection-gates |
| Contact: |
| giorgio.roffo@gmail.com |
| ''' |
|
|
| import os |
| import tarfile |
| import urllib.request |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import psutil |
| from tqdm import tqdm |
| from torchvision import transforms |
| from torchvision.models import vit_b_16, ViT_B_16_Weights |
| from torchvision.datasets import ImageFolder |
| from torch.utils.data import DataLoader |
| from vit_with_fsg import vit_with_fsg |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nπ₯οΈ Using device: {device}") |
| if device.type == "cuda": |
| print(f"π CUDA device: {torch.cuda.get_device_name(0)}") |
| print(f"πΎ GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB") |
| print(f"π§ System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB") |
|
|
| |
| imagenette_path = "./imagenette2-160/val" |
| if not os.path.exists(imagenette_path): |
| print("π¦ Downloading Imagenette...") |
| url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz" |
| tgz_path = "imagenette2-160.tgz" |
| urllib.request.urlretrieve(url, tgz_path) |
| print("π Extracting Imagenette dataset...") |
| with tarfile.open(tgz_path, "r:gz") as tar: |
| tar.extractall() |
| os.remove(tgz_path) |
| print("β
Dataset ready.") |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5]*3, std=[0.5]*3) |
| ]) |
|
|
| |
| dataset = ImageFolder(root=imagenette_path, transform=transform) |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True) |
|
|
| |
| print("\nπ₯ Loading pretrained ViT backbone from torchvision...") |
| backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) |
| model = vit_with_fsg(backbone).to(device) |
|
|
| |
| fsg_params, base_params = [], [] |
| for name, param in model.named_parameters(): |
| if 'fsag_rgb_ls' in name: |
| fsg_params.append(param) |
| else: |
| base_params.append(param) |
|
|
| lr_base = 1e-4 |
| lr_fsg = 5e-4 |
| print(f"\nπ§ Optimizer setup:") |
| print(f" πΉ Base ViT parameters LR: {lr_base}") |
| print(f" πΈ FSG parameters LR: {lr_fsg}") |
|
|
| optimizer = optim.AdamW([ |
| {"params": base_params, "lr": lr_base}, |
| {"params": fsg_params, "lr": lr_fsg} |
| ]) |
| criterion = nn.CrossEntropyLoss() |
|
|
| |
| epochs = 3 |
| print(f"\nπ Starting demo training for {epochs} epochs...") |
| model.train() |
| for epoch in range(epochs): |
| steps_demo = 0 |
| running_loss = 0.0 |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100) |
| for inputs, targets in pbar: |
| if steps_demo > 25: |
| break |
| steps_demo += 1 |
| inputs, targets = inputs.to(device), targets.to(device) |
| optimizer.zero_grad() |
| outputs = model(inputs) |
| loss = criterion(outputs, targets) |
| loss.backward() |
| optimizer.step() |
| running_loss += loss.item() |
| pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)}) |
|
|
| print("\nβ
Training complete.") |
|
|
| |
| ckpt_dir = "./checkpoints" |
| os.makedirs(ckpt_dir, exist_ok=True) |
| ckpt_path = os.path.join(ckpt_dir, "fsg_vit_imagenette_demo.pth") |
| torch.save(model.state_dict(), ckpt_path) |
| print(f"πΎ Checkpoint saved to: {ckpt_path}") |
|
|