groffo commited on
Commit ·
8573586
0
Parent(s):
Initial commit of FSG-ViT
Browse files- .idea/.gitignore +3 -0
- .idea/ViT_with_FSG.iml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +6 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- README.md +144 -0
- demo_inference_imnet.py +124 -0
- demo_inference_mnist.py +108 -0
- demo_training_imnet.py +114 -0
- demo_training_mnist.py +106 -0
- vit_with_fsg.py +109 -0
.idea/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
.idea/ViT_with_FSG.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="Python 3.10 (cvpr)" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="PLAIN" />
|
| 10 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
| 5 |
+
</profile>
|
| 6 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (cvpr)" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/ViT_with_FSG.iml" filepath="$PROJECT_DIR$/.idea/ViT_with_FSG.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
README.md
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔬 Feature Selection Gates (FSG) for Vision Transformers (ViT)
|
| 2 |
+
|
| 3 |
+
This repository provides a modular, extensible PyTorch implementation of **Feature Selection Gates (FSG)** with **Gradient Routing (GR)**, integrated into **Vision Transformers (ViTs)**. The approach is proposed in:
|
| 4 |
+
|
| 5 |
+
> **Feature Selection Gates with Gradient Routing for Endoscopic Image Computing**
|
| 6 |
+
> Giorgio Roffo, Carlo Biffi, Pietro Salvagnini, Andrea Cherubini
|
| 7 |
+
> Presented at MICCAI 2024
|
| 8 |
+
> 📄 [Paper](https://papers.miccai.org/miccai-2024/316-Paper0410.html) | 🧠 [arXiv](https://arxiv.org/abs/2407.04400) | 💻 [Code](https://github.com/cosmoimd/feature-selection-gates)
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## 📌 What Is FSG?
|
| 13 |
+
|
| 14 |
+
**FSG** introduces **learnable gates** that sparsify transformer blocks by modulating residual connections, acting as **online feature selectors**. This process encourages **sparse connectivity**, which reduces overfitting and increases generalization — especially valuable in small and imbalanced datasets.
|
| 15 |
+
|
| 16 |
+
**Gradient Routing (GR)** enables dual-phase optimization:
|
| 17 |
+
- One optimizer updates FSG parameters
|
| 18 |
+
- A second optimizer updates the base model
|
| 19 |
+
This separation allows **task-specific tuning** and ensures stable learning.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 💡 Why Use FSG?
|
| 24 |
+
|
| 25 |
+
✅ **Plug & play**: Can be integrated into **any ViT architecture**
|
| 26 |
+
✅ Works on **natural images**, **medical images**, and beyond
|
| 27 |
+
✅ Can be adapted to **NLP Transformers** like GPTs and BERT
|
| 28 |
+
✅ Lightweight and highly regularizing
|
| 29 |
+
✅ Compatible with **multi-stream CNNs** and hybrid models
|
| 30 |
+
|
| 31 |
+
⚠️ While our focus is on **endoscopic image computing**, the method has shown performance improvements on **CIFAR-100**, proving its applicability to **standard vision tasks**.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 🧪 How to Use the FSG Wrapper
|
| 36 |
+
|
| 37 |
+
Use the `vit_with_fsg.py` script to augment a pretrained ViT from `torchvision`.
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 41 |
+
from vit_with_fsg import vit_with_fsg
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
print("📥 Loading pretrained ViT_B_16...")
|
| 45 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
| 46 |
+
|
| 47 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
| 48 |
+
model = vit_with_fsg(vit_backbone=backbone)
|
| 49 |
+
|
| 50 |
+
print("🧪 Running dummy input...")
|
| 51 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
| 52 |
+
output = model(dummy_input)
|
| 53 |
+
|
| 54 |
+
print("✅ Done. Output shape:", output.shape)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## 🚀 Demo Scripts
|
| 60 |
+
|
| 61 |
+
We provide full working training and inference examples:
|
| 62 |
+
|
| 63 |
+
| Dataset | Training Script | Inference Script | Checkpoint Path |
|
| 64 |
+
|-------------|-----------------------------|------------------------------|----------------------------------------------|
|
| 65 |
+
| MNIST | `demo_training_mnist.py` | `demo_inference_mnist.py` | `./checkpoints/fsg_vit_mnist_demo.pth` |
|
| 66 |
+
| Imagenette | `demo_training_imnet.py` | `demo_inference_imnet.py` | `./checkpoints/fsg_vit_imagenette_demo.pth` |
|
| 67 |
+
|
| 68 |
+
Each demo:
|
| 69 |
+
- Trains a ViT+B16 with FSG on a reduced dataset for speed.
|
| 70 |
+
- Uses separate learning rates for FSG and base model parameters.
|
| 71 |
+
- Includes GPU-aware prints and a training progress bar.
|
| 72 |
+
- Saves checkpoints for reproducible inference.
|
| 73 |
+
|
| 74 |
+
### ▶️ Example Usage
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# Train on Imagenette
|
| 78 |
+
python demo_training_imnet.py
|
| 79 |
+
|
| 80 |
+
# Inference on Imagenette
|
| 81 |
+
python demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
# Train on MNIST
|
| 86 |
+
python demo_training_mnist.py
|
| 87 |
+
|
| 88 |
+
# Inference on MNIST
|
| 89 |
+
python demo_inference_mnist.py --checkpoint ./checkpoints/fsg_vit_mnist_demo.pth
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
> ⚠️ These demos use reduced test sets and train for few iterations to make training quick. They're not meant for benchmarking, but rather for showcasing FSG integration.
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## 🧠 Applicability Beyond Endoscopy
|
| 97 |
+
|
| 98 |
+
Although designed for **polyp size estimation in colonoscopy**, FSG is a **general mechanism** for:
|
| 99 |
+
- **Image classification**
|
| 100 |
+
- **Medical image analysis**
|
| 101 |
+
- **Multimodal fusion**
|
| 102 |
+
- **NLP Transformers** (e.g., GPTs, BERT) — apply FSG over token embeddings
|
| 103 |
+
|
| 104 |
+
We strongly encourage researchers to test FSG in **non-medical** domains.
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## 📦 Files and Structure
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
.
|
| 112 |
+
├── vit_with_fsg.py # ViT + FSG wrapper
|
| 113 |
+
├── demo_training_mnist.py
|
| 114 |
+
├── demo_inference_mnist.py
|
| 115 |
+
├── demo_training_imnet.py
|
| 116 |
+
├── demo_inference_imnet.py
|
| 117 |
+
├── checkpoints/ # Folder for .pth checkpoints
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## 📚 Citation
|
| 123 |
+
|
| 124 |
+
Please cite our work if you use this repository:
|
| 125 |
+
|
| 126 |
+
```bibtex
|
| 127 |
+
@inproceedings{roffo2024FSG,
|
| 128 |
+
title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
|
| 129 |
+
author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
|
| 130 |
+
booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
|
| 131 |
+
year={2024},
|
| 132 |
+
organization={Springer}
|
| 133 |
+
}
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## 📬 Contact
|
| 139 |
+
|
| 140 |
+
Lead Author: **Giorgio Roffo**
|
| 141 |
+
📧 giorgio.roffo@gmail.com
|
| 142 |
+
🏢 Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
|
| 143 |
+
|
| 144 |
+
For more: [github.com/cosmoimd/feature-selection-gates](https://github.com/cosmoimd/feature-selection-gates)
|
demo_inference_imnet.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
|
| 3 |
+
and running inference on the ImageNet-mini (Imagenette) validation set.
|
| 4 |
+
|
| 5 |
+
Each image is resized to 224x224 and has 3 RGB channels to be compatible with ViT.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
|
| 9 |
+
demo_inference_imnet.py --checkpoint ./checkpoints/fsg_vit_imagenette_demo.pth
|
| 10 |
+
|
| 11 |
+
Paper:
|
| 12 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
| 13 |
+
Code:
|
| 14 |
+
https://github.com/cosmoimd/feature-selection-gates
|
| 15 |
+
Contact:
|
| 16 |
+
giorgio.roffo@gmail.com
|
| 17 |
+
'''
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
warnings.filterwarnings("ignore")
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import tarfile
|
| 25 |
+
import urllib.request
|
| 26 |
+
import torch
|
| 27 |
+
import psutil
|
| 28 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 29 |
+
from vit_with_fsg import vit_with_fsg
|
| 30 |
+
from torchvision import transforms
|
| 31 |
+
from torchvision.datasets import ImageFolder
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
import argparse
|
| 38 |
+
|
| 39 |
+
parser = argparse.ArgumentParser(description="FSG-ViT inference on Imagenette")
|
| 40 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
|
| 45 |
+
wrn = False
|
| 46 |
+
print(f"\n📌 To run this script:\n"
|
| 47 |
+
f" ▶ Without checkpoint: python {os.path.basename(__file__)}\n"
|
| 48 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
| 49 |
+
|
| 50 |
+
# Device and system info
|
| 51 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 52 |
+
print(f"\n🖥️ Using device: {device}")
|
| 53 |
+
if device.type == "cuda":
|
| 54 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
| 55 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
| 56 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
| 57 |
+
|
| 58 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
| 59 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
| 60 |
+
|
| 61 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
| 62 |
+
model = vit_with_fsg(backbone).to(device)
|
| 63 |
+
|
| 64 |
+
if args.checkpoint is not None:
|
| 65 |
+
print(f"📂 Loading model weights from: {args.checkpoint}")
|
| 66 |
+
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
|
| 67 |
+
else:
|
| 68 |
+
wrn = True
|
| 69 |
+
print("\n⚠️ No checkpoint provided. Evaluating randomly initialized model! 🧪\n")
|
| 70 |
+
print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")
|
| 71 |
+
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
+
print("📚 Loading Imagenette validation set (224x224 RGB)...")
|
| 75 |
+
imagenette_path = "./imagenette2-160/val"
|
| 76 |
+
if not os.path.exists(imagenette_path):
|
| 77 |
+
print("📦 Downloading Imagenette...")
|
| 78 |
+
url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
|
| 79 |
+
tgz_path = "imagenette2-160.tgz"
|
| 80 |
+
urllib.request.urlretrieve(url, tgz_path)
|
| 81 |
+
print("📂 Extracting Imagenette dataset...")
|
| 82 |
+
with tarfile.open(tgz_path, "r:gz") as tar:
|
| 83 |
+
tar.extractall()
|
| 84 |
+
os.remove(tgz_path)
|
| 85 |
+
print("✅ Dataset ready.")
|
| 86 |
+
|
| 87 |
+
transform = transforms.Compose([
|
| 88 |
+
transforms.Resize((224, 224)),
|
| 89 |
+
transforms.ToTensor(),
|
| 90 |
+
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
dataset = ImageFolder(root=imagenette_path, transform=transform)
|
| 94 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
|
| 95 |
+
|
| 96 |
+
y_true = []
|
| 97 |
+
y_pred = []
|
| 98 |
+
|
| 99 |
+
print("🧪 Running inference on Imagenette validation set using FSG-ViT-B-16 (code by G. Roffo)...\n\n")
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
for images, labels in tqdm(dataloader, desc="🔍 Inference progress", ncols=100):
|
| 102 |
+
images = images.to(device)
|
| 103 |
+
labels = labels.to(device)
|
| 104 |
+
outputs = model(images)
|
| 105 |
+
preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
| 106 |
+
y_true.extend(labels.cpu().tolist())
|
| 107 |
+
y_pred.extend(preds.cpu().tolist())
|
| 108 |
+
|
| 109 |
+
print("✅ Inference completed.")
|
| 110 |
+
|
| 111 |
+
acc = accuracy_score(y_true, y_pred)
|
| 112 |
+
prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
| 113 |
+
rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
| 114 |
+
f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
| 115 |
+
|
| 116 |
+
if wrn == True:
|
| 117 |
+
print("\n⚠️ No checkpoint provided. Evaluated randomly initialized model! 🧪\n")
|
| 118 |
+
print(f"\n📌 To run this script:\n"
|
| 119 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
| 120 |
+
|
| 121 |
+
print(f"📊 Accuracy: {acc * 100:.2f}%")
|
| 122 |
+
print(f"📊 Precision: {prec * 100:.2f}%")
|
| 123 |
+
print(f"📊 Recall: {rec * 100:.2f}%")
|
| 124 |
+
print(f"📊 F1 Score: {f1 * 100:.2f}%")
|
demo_inference_mnist.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Demo script for applying Feature Selection Gates (FSG) to torchvision Vision Transformers
|
| 3 |
+
and running inference on the MNIST test set.
|
| 4 |
+
|
| 5 |
+
Each MNIST image is resized to 224x224 and converted to 3 channels to be compatible with ViT.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
|
| 9 |
+
demo_inference_mnist.py --checkpoint ./checkpoints/fsg_vit_mnist_demo.pth
|
| 10 |
+
|
| 11 |
+
Paper:
|
| 12 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
| 13 |
+
Code:
|
| 14 |
+
https://github.com/cosmoimd/feature-selection-gates
|
| 15 |
+
Contact:
|
| 16 |
+
giorgio.roffo@gmail.com
|
| 17 |
+
'''
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import psutil
|
| 21 |
+
import argparse
|
| 22 |
+
import warnings
|
| 23 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 24 |
+
from vit_with_fsg import vit_with_fsg
|
| 25 |
+
from torchvision.datasets import MNIST
|
| 26 |
+
from torchvision import transforms
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
warnings.filterwarnings("ignore")
|
| 34 |
+
|
| 35 |
+
parser = argparse.ArgumentParser(description="FSG-ViT inference on MNIST")
|
| 36 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Path to .pth file of trained FSG-ViT model")
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
|
| 41 |
+
wrn = False
|
| 42 |
+
print(f"\n📌 To run this script:\n"
|
| 43 |
+
f" ▶ Without checkpoint: python {os.path.basename(__file__)}\n"
|
| 44 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
| 45 |
+
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
print(f"\n🖥️ Using device: {device}")
|
| 48 |
+
if device.type == "cuda":
|
| 49 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
| 50 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
| 51 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
| 52 |
+
|
| 53 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
| 54 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
| 55 |
+
|
| 56 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
| 57 |
+
model = vit_with_fsg(backbone).to(device)
|
| 58 |
+
|
| 59 |
+
if args.checkpoint is not None:
|
| 60 |
+
print(f"📂 Loading model weights from: {args.checkpoint}")
|
| 61 |
+
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
|
| 62 |
+
else:
|
| 63 |
+
wrn = True
|
| 64 |
+
print("\n⚠️ No checkpoint provided. Evaluating randomly initialized model! 🧪\n")
|
| 65 |
+
print("❗ Note: The model has not been trained. Results will reflect a randomly initialized backbone.")
|
| 66 |
+
|
| 67 |
+
model.eval()
|
| 68 |
+
|
| 69 |
+
print("📚 Loading MNIST test set (resized to 224x224, 3-channel)...")
|
| 70 |
+
transform = transforms.Compose([
|
| 71 |
+
transforms.Resize((224, 224)),
|
| 72 |
+
transforms.Grayscale(num_output_channels=3),
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize((0.5,), (0.5,))
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
test_dataset = MNIST(root="./data", train=False, download=True, transform=transform)
|
| 78 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 79 |
+
|
| 80 |
+
y_true = []
|
| 81 |
+
y_pred = []
|
| 82 |
+
|
| 83 |
+
print("🧪 Running inference on MNIST test set using FSG-ViT-B-16 (code by G. Roffo)...")
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
for images, labels in tqdm(test_loader, desc="🔍 Inference progress", ncols=100):
|
| 86 |
+
images = images.to(device)
|
| 87 |
+
labels = labels.to(device)
|
| 88 |
+
outputs = model(images)
|
| 89 |
+
preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
| 90 |
+
y_true.extend(labels.cpu().tolist())
|
| 91 |
+
y_pred.extend(preds.cpu().tolist())
|
| 92 |
+
|
| 93 |
+
print("✅ Inference completed.")
|
| 94 |
+
|
| 95 |
+
acc = accuracy_score(y_true, y_pred)
|
| 96 |
+
prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
| 97 |
+
rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
| 98 |
+
f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
| 99 |
+
|
| 100 |
+
if wrn == True:
|
| 101 |
+
print("\n⚠️ No checkpoint provided. Evaluated randomly initialized model! 🧪\n")
|
| 102 |
+
print(f"\n📌 To run this script:\n"
|
| 103 |
+
f" ▶ With checkpoint: python {os.path.basename(__file__)} --checkpoint path/to/model.pth\n")
|
| 104 |
+
|
| 105 |
+
print(f"📊 Accuracy: {acc * 100:.2f}%")
|
| 106 |
+
print(f"📊 Precision: {prec * 100:.2f}%")
|
| 107 |
+
print(f"📊 Recall: {rec * 100:.2f}%")
|
| 108 |
+
print(f"📊 F1 Score: {f1 * 100:.2f}%")
|
demo_training_imnet.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Demo training script for Feature Selection Gates (FSG) with ViT on Imagenette
|
| 3 |
+
|
| 4 |
+
This script loads the Imagenette dataset (ImageNet-mini),
|
| 5 |
+
trains a ViT model augmented with FSG, and saves the model checkpoint.
|
| 6 |
+
|
| 7 |
+
Paper:
|
| 8 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
| 9 |
+
Code:
|
| 10 |
+
https://github.com/cosmoimd/feature-selection-gates
|
| 11 |
+
Contact:
|
| 12 |
+
giorgio.roffo@gmail.com
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import tarfile
|
| 17 |
+
import urllib.request
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
import psutil
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 25 |
+
from torchvision.datasets import ImageFolder
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
from vit_with_fsg import vit_with_fsg
|
| 28 |
+
|
| 29 |
+
# System info
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
print(f"\n🖥️ Using device: {device}")
|
| 32 |
+
if device.type == "cuda":
|
| 33 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
| 34 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
| 35 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
| 36 |
+
|
| 37 |
+
# Dataset path
|
| 38 |
+
imagenette_path = "./imagenette2-160/val"
|
| 39 |
+
if not os.path.exists(imagenette_path):
|
| 40 |
+
print("📦 Downloading Imagenette...")
|
| 41 |
+
url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
|
| 42 |
+
tgz_path = "imagenette2-160.tgz"
|
| 43 |
+
urllib.request.urlretrieve(url, tgz_path)
|
| 44 |
+
print("📂 Extracting Imagenette dataset...")
|
| 45 |
+
with tarfile.open(tgz_path, "r:gz") as tar:
|
| 46 |
+
tar.extractall()
|
| 47 |
+
os.remove(tgz_path)
|
| 48 |
+
print("✅ Dataset ready.")
|
| 49 |
+
|
| 50 |
+
# Transforms
|
| 51 |
+
transform = transforms.Compose([
|
| 52 |
+
transforms.Resize((224, 224)),
|
| 53 |
+
transforms.ToTensor(),
|
| 54 |
+
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
# Dataset and loader
|
| 58 |
+
dataset = ImageFolder(root=imagenette_path, transform=transform)
|
| 59 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| 60 |
+
|
| 61 |
+
# Model setup
|
| 62 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
| 63 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
| 64 |
+
model = vit_with_fsg(backbone).to(device)
|
| 65 |
+
|
| 66 |
+
# Optimizer with separate LRs for FSG and base ViT
|
| 67 |
+
fsg_params, base_params = [], []
|
| 68 |
+
for name, param in model.named_parameters():
|
| 69 |
+
if 'fsag_rgb_ls' in name:
|
| 70 |
+
fsg_params.append(param)
|
| 71 |
+
else:
|
| 72 |
+
base_params.append(param)
|
| 73 |
+
|
| 74 |
+
lr_base = 1e-4
|
| 75 |
+
lr_fsg = 5e-4
|
| 76 |
+
print(f"\n🔧 Optimizer setup:")
|
| 77 |
+
print(f" 🔹 Base ViT parameters LR: {lr_base}")
|
| 78 |
+
print(f" 🔸 FSG parameters LR: {lr_fsg}")
|
| 79 |
+
|
| 80 |
+
optimizer = optim.AdamW([
|
| 81 |
+
{"params": base_params, "lr": lr_base},
|
| 82 |
+
{"params": fsg_params, "lr": lr_fsg}
|
| 83 |
+
])
|
| 84 |
+
criterion = nn.CrossEntropyLoss()
|
| 85 |
+
|
| 86 |
+
# Training loop
|
| 87 |
+
epochs = 3
|
| 88 |
+
print(f"\n🚀 Starting demo training for {epochs} epochs...")
|
| 89 |
+
model.train()
|
| 90 |
+
for epoch in range(epochs):
|
| 91 |
+
steps_demo = 0 # to remove: for demo only
|
| 92 |
+
running_loss = 0.0
|
| 93 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
|
| 94 |
+
for inputs, targets in pbar:
|
| 95 |
+
if steps_demo > 25: # to remove: for demo only
|
| 96 |
+
break # to remove: for demo only
|
| 97 |
+
steps_demo += 1 # to remove: for demo only
|
| 98 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 99 |
+
optimizer.zero_grad()
|
| 100 |
+
outputs = model(inputs)
|
| 101 |
+
loss = criterion(outputs, targets)
|
| 102 |
+
loss.backward()
|
| 103 |
+
optimizer.step()
|
| 104 |
+
running_loss += loss.item()
|
| 105 |
+
pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})
|
| 106 |
+
|
| 107 |
+
print("\n✅ Training complete.")
|
| 108 |
+
|
| 109 |
+
# Save checkpoint
|
| 110 |
+
ckpt_dir = "./checkpoints"
|
| 111 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 112 |
+
ckpt_path = os.path.join(ckpt_dir, "fsg_vit_imagenette_demo.pth")
|
| 113 |
+
torch.save(model.state_dict(), ckpt_path)
|
| 114 |
+
print(f"💾 Checkpoint saved to: {ckpt_path}")
|
demo_training_mnist.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Demo training script for Feature Selection Gates (FSG) with ViT on MNIST test set
|
| 3 |
+
|
| 4 |
+
This is a minimal demo: we train only on the MNIST test set (resized and converted to 3-channel)
|
| 5 |
+
for a few epochs to simulate training, save the checkpoint, and allow downstream inference.
|
| 6 |
+
|
| 7 |
+
Paper:
|
| 8 |
+
https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
| 9 |
+
Code:
|
| 10 |
+
https://github.com/cosmoimd/feature-selection-gates
|
| 11 |
+
Contact:
|
| 12 |
+
giorgio.roffo@gmail.com
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import warnings
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.optim as optim
|
| 20 |
+
import psutil
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
from torchvision import transforms
|
| 23 |
+
from torchvision.datasets import MNIST
|
| 24 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
from vit_with_fsg import vit_with_fsg
|
| 27 |
+
|
| 28 |
+
warnings.filterwarnings("ignore")
|
| 29 |
+
|
| 30 |
+
# Device info
|
| 31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
print(f"\n🖥️ Using device: {device}")
|
| 33 |
+
if device.type == "cuda":
|
| 34 |
+
print(f"🚀 CUDA device: {torch.cuda.get_device_name(0)}")
|
| 35 |
+
print(f"💾 GPU memory total: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
|
| 36 |
+
print(f"🧠 System RAM: {psutil.virtual_memory().total / (1024 ** 3):.2f} GB")
|
| 37 |
+
|
| 38 |
+
# Dataset loading
|
| 39 |
+
print("\n📚 Loading MNIST demo set for demo training (resized to 224x224, 3-channel)...")
|
| 40 |
+
transform = transforms.Compose([
|
| 41 |
+
transforms.Resize((224, 224)),
|
| 42 |
+
transforms.Grayscale(num_output_channels=3),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
transforms.Normalize((0.5,), (0.5,))
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
dataset = MNIST(root="./data", train=False, download=True, transform=transform)
|
| 48 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| 49 |
+
|
| 50 |
+
# Load ViT backbone and wrap with FSG
|
| 51 |
+
print("\n📥 Loading pretrained ViT backbone from torchvision...")
|
| 52 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
| 53 |
+
model = vit_with_fsg(backbone).to(device)
|
| 54 |
+
|
| 55 |
+
# Prepare optimizer with different LRs for FSG parameters and base model
|
| 56 |
+
fsg_params = []
|
| 57 |
+
base_params = []
|
| 58 |
+
for name, param in model.named_parameters():
|
| 59 |
+
if 'fsag_rgb_ls' in name:
|
| 60 |
+
fsg_params.append(param)
|
| 61 |
+
else:
|
| 62 |
+
base_params.append(param)
|
| 63 |
+
|
| 64 |
+
# Assign a higher LR to FSG parameters, lower to base ViT params
|
| 65 |
+
lr_base = 1e-4
|
| 66 |
+
lr_fsg = 5e-4
|
| 67 |
+
print(f"\n🔧 Optimizer setup:")
|
| 68 |
+
print(f" 🔹 Base ViT parameters LR: {lr_base}")
|
| 69 |
+
print(f" 🔸 FSG parameters LR: {lr_fsg}")
|
| 70 |
+
|
| 71 |
+
optimizer = optim.AdamW([
|
| 72 |
+
{"params": base_params, "lr": lr_base},
|
| 73 |
+
{"params": fsg_params, "lr": lr_fsg}
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
criterion = nn.CrossEntropyLoss()
|
| 77 |
+
epochs = 3
|
| 78 |
+
print(f"\n🚀 Starting demo training for {epochs} epochs...")
|
| 79 |
+
|
| 80 |
+
model.train()
|
| 81 |
+
for epoch in range(epochs):
|
| 82 |
+
steps_demo = 0 # to remove: for demo only
|
| 83 |
+
running_loss = 0.0
|
| 84 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
|
| 85 |
+
for inputs, targets in pbar:
|
| 86 |
+
if steps_demo > 25: # to remove: for demo only
|
| 87 |
+
break # to remove: for demo only
|
| 88 |
+
steps_demo += 1 # to remove: for demo only
|
| 89 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 90 |
+
optimizer.zero_grad()
|
| 91 |
+
outputs = model(inputs)
|
| 92 |
+
loss = criterion(outputs, targets)
|
| 93 |
+
loss.backward()
|
| 94 |
+
optimizer.step()
|
| 95 |
+
|
| 96 |
+
running_loss += loss.item()
|
| 97 |
+
pbar.set_postfix({"loss": running_loss / (pbar.n + 1e-8)})
|
| 98 |
+
|
| 99 |
+
print("\n✅ Training complete.")
|
| 100 |
+
|
| 101 |
+
# Save checkpoint
|
| 102 |
+
ckpt_dir = "./checkpoints"
|
| 103 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 104 |
+
ckpt_path = os.path.join(ckpt_dir, "fsg_vit_mnist_demo.pth")
|
| 105 |
+
torch.save(model.state_dict(), ckpt_path)
|
| 106 |
+
print(f"💾 Checkpoint saved to: {ckpt_path}")
|
vit_with_fsg.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
ViTwithFSG: Vision Transformer wrapper with Feature Selection Gates (FSG)
|
| 3 |
+
|
| 4 |
+
This script defines a wrapper class to apply Feature Selection Gates (FSG) to a Vision Transformer (ViT) model.
|
| 5 |
+
FSG enhances model generalization by introducing sparse, learnable gates on the residual paths of attention and MLP blocks.
|
| 6 |
+
It is a form of architectural regularization designed for vision tasks and applicable to NLP tasks.
|
| 7 |
+
|
| 8 |
+
The method is introduced in:
|
| 9 |
+
|
| 10 |
+
@inproceedings{roffo2024FSG,
|
| 11 |
+
title={Feature Selection Gates with Gradient Routing for Endoscopic Image Computing},
|
| 12 |
+
author={Giorgio Roffo and Carlo Biffi and Pietro Salvagnini and Andrea Cherubini},
|
| 13 |
+
booktitle={MICCAI 2024, the 27th International Conference on Medical Image Computing and Computer Assisted Intervention, Marrakech, Morocco, October 2024.},
|
| 14 |
+
year={2024},
|
| 15 |
+
organization={Springer}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
- Publication: https://papers.miccai.org/miccai-2024/316-Paper0410.html
|
| 19 |
+
- Code: https://github.com/cosmoimd/feature-selection-gates
|
| 20 |
+
- Contact: giorgio.roffo@gmail.com
|
| 21 |
+
- Affiliation: Cosmo Intelligent Medical Devices (IMD), Lainate, Italy
|
| 22 |
+
'''
|
| 23 |
+
|
| 24 |
+
# imports
|
| 25 |
+
import warnings
|
| 26 |
+
warnings.filterwarnings("ignore")
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
| 31 |
+
|
| 32 |
+
class FSGBlock(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
A Transformer encoder block augmented with Feature Selection Gates (FSG).
|
| 35 |
+
Each residual path (attention and MLP) is weighted element-wise by a learnable sigmoid gate.
|
| 36 |
+
This promotes sparse activation and serves as a regularization mechanism to avoid overfitting.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, original_block):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.self_attention = original_block.self_attention # Multi-head self-attention module
|
| 41 |
+
self.mlp = original_block.mlp # Feedforward network (2-layer MLP)
|
| 42 |
+
self.ln_1 = original_block.ln_1 # LayerNorm before attention
|
| 43 |
+
self.ln_2 = original_block.ln_2 # LayerNorm before MLP
|
| 44 |
+
self.dropout = original_block.dropout # Dropout after attention
|
| 45 |
+
|
| 46 |
+
dim = self.ln_1.normalized_shape[0] # Dimensionality of the model
|
| 47 |
+
|
| 48 |
+
# FSG: learnable gates (one per channel), initialized with Xavier normal
|
| 49 |
+
self.fsg_rectifier = nn.Sigmoid()
|
| 50 |
+
self.fsg_rgb_ls1 = nn.Parameter(torch.empty(dim)) # Gate for attention path
|
| 51 |
+
self.fsg_rgb_ls2 = nn.Parameter(torch.empty(dim)) # Gate for MLP path
|
| 52 |
+
nn.init.xavier_normal_(self.fsg_rgb_ls1.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
|
| 53 |
+
nn.init.xavier_normal_(self.fsg_rgb_ls2.unsqueeze(0), gain=nn.init.calculate_gain('sigmoid'))
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
# Self-attention + gate
|
| 57 |
+
x_norm = self.ln_1(x)
|
| 58 |
+
attn_output, _ = self.self_attention(x_norm, x_norm, x_norm, need_weights=False)
|
| 59 |
+
attn_output = self.dropout(attn_output)
|
| 60 |
+
fsg_scores_1 = self.fsg_rectifier(self.fsg_rgb_ls1)
|
| 61 |
+
x = x + attn_output * fsg_scores_1 # Residual connection weighted by gate
|
| 62 |
+
|
| 63 |
+
# MLP + gate
|
| 64 |
+
x_norm = self.ln_2(x)
|
| 65 |
+
mlp_output = self.mlp(x_norm)
|
| 66 |
+
fsg_scores_2 = self.fsg_rectifier(self.fsg_rgb_ls2)
|
| 67 |
+
x = x + mlp_output * fsg_scores_2 # Residual connection weighted by gate
|
| 68 |
+
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
class ViTwithFSG(nn.Module):
|
| 72 |
+
"""
|
| 73 |
+
Wrapper module that injects FSGBlocks into each Transformer encoder block of a given ViT model.
|
| 74 |
+
"""
|
| 75 |
+
def __init__(self, vit_backbone: VisionTransformer):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.vit = vit_backbone
|
| 78 |
+
for i, blk in enumerate(self.vit.encoder.layers):
|
| 79 |
+
self.vit.encoder.layers[i] = FSGBlock(blk) # Replace original block with FSGBlock
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
return self.vit(x)
|
| 83 |
+
|
| 84 |
+
def vit_with_fsg(vit_backbone: VisionTransformer):
|
| 85 |
+
"""
|
| 86 |
+
Factory function that wraps a torchvision VisionTransformer with FSG-enhanced encoder blocks.
|
| 87 |
+
"""
|
| 88 |
+
return ViTwithFSG(vit_backbone)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# === Example Usage ===
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
import warnings
|
| 94 |
+
warnings.filterwarnings("ignore", message="Failed to load image Python extension*")
|
| 95 |
+
|
| 96 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 97 |
+
|
| 98 |
+
print("\n📥 Loading pretrained ViT_B_16 backbone from torchvision...")
|
| 99 |
+
backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
| 100 |
+
|
| 101 |
+
print("🔧 Wrapping with Feature Selection Gates (FSG)...")
|
| 102 |
+
model = vit_with_fsg(vit_backbone=backbone)
|
| 103 |
+
|
| 104 |
+
print("🧪 Running dummy input through FSG-augmented ViT...")
|
| 105 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
| 106 |
+
output = model(dummy_input)
|
| 107 |
+
|
| 108 |
+
print("✅ Inference completed.")
|
| 109 |
+
print("📐 Output shape:", output.shape)
|