LeGrad + ImageBind Notebook Usage
This folder includes legrad_imagebind.ipynb, a notebook that demonstrates LeGrad explanations for the ImageBind model using ImageBindLeWrapper (imagebind/legrad_wrapper.py).
1. Environment and installation
From the ImageBind repo root:
pip install -e .
Install LeGrad and make sure it is on PYTHONPATH:
pip install -e /path/to/LeGrad
or set PYTHONPATH so that the legrad package is importable.
You also need a working CUDA‑enabled PyTorch environment for GPU execution.
2. Open the notebook
From this folder:
cd xai/ImageBind
jupyter lab legrad_imagebind.ipynb
3. What the notebook does
The notebook walks through:
- Importing ImageBind and LeGrad
- Adds the local
LeGradrepo tosys.path. - Imports
imagebind_model,ModalityType, and utility functions fromimagebind.data.
- Adds the local
- Loading text and image inputs
- Defines
text_list(e.g."A dog.","A car","A bird"). - Loads and preprocesses images from
.assets/*.jpgusingdata.load_and_transform_vision_data.
- Defines
- Running the base ImageBind model
- Creates
model = imagebind_model.imagebind_huge(pretrained=True). - Moves it to CUDA and evaluates vanilla similarity scores
embeddings[VISION] @ embeddings[TEXT].T.
- Creates
- Wrapping with LeGrad
- Uses
ImageBindLeWrapperfromimagebind/legrad_wrapper.py. - Hooks transformer blocks and
nn.MultiheadAttentionto keep attention probabilities (attention_maps) in the autograd graph.
- Uses
- Computing explanations
- Vision: spatial heatmaps over image patches (
compute_legrad_image(...)/compute_legrad_image_one_layer(...)). - Text: per‑token relevance scores (
compute_legrad_text(...)/compute_legrad_text_one_layer(...)).
- Vision: spatial heatmaps over image patches (
- Visualizing results
- Converts heatmaps to numpy and overlays them on the input images; plots token relevance over the prompt.
4. Minimal code sketch (inside the notebook)
The core pattern is:
import torch
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.legrad_wrapper import ImageBindLeWrapper
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True).to(device).eval()
wrapper = ImageBindLeWrapper(model, layer_index=-2, trunk_key=ModalityType.VISION)
text_list = ["A dog.", "A car", "A bird"]
image_paths = [".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
You can then call the wrapper’s compute_legrad_* helpers in the notebook to obtain:
- Vision heatmaps for each (image, text) pair.
- Text token relevance for a chosen image or batch of images.