File size: 3,008 Bytes
6a00010 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | # LeGrad + ImageBind Notebook Usage
This folder includes `legrad_imagebind.ipynb`, a notebook that demonstrates **LeGrad** explanations for the **ImageBind** model using `ImageBindLeWrapper` (`imagebind/legrad_wrapper.py`).
## 1. Environment and installation
From the `ImageBind` repo root:
```bash
pip install -e .
```
Install **LeGrad** and make sure it is on `PYTHONPATH`:
```bash
pip install -e /path/to/LeGrad
```
or set `PYTHONPATH` so that the `legrad` package is importable.
You also need a working CUDA‑enabled PyTorch environment for GPU execution.
## 2. Open the notebook
From this folder:
```bash
cd xai/ImageBind
jupyter lab legrad_imagebind.ipynb
```
## 3. What the notebook does
The notebook walks through:
1. **Importing ImageBind and LeGrad**
- Adds the local `LeGrad` repo to `sys.path`.
- Imports `imagebind_model`, `ModalityType`, and utility functions from `imagebind.data`.
2. **Loading text and image inputs**
- Defines `text_list` (e.g. `"A dog."`, `"A car"`, `"A bird"`).
- Loads and preprocesses images from `.assets/*.jpg` using `data.load_and_transform_vision_data`.
3. **Running the base ImageBind model**
- Creates `model = imagebind_model.imagebind_huge(pretrained=True)`.
- Moves it to CUDA and evaluates vanilla similarity scores `embeddings[VISION] @ embeddings[TEXT].T`.
4. **Wrapping with LeGrad**
- Uses `ImageBindLeWrapper` from `imagebind/legrad_wrapper.py`.
- Hooks transformer blocks and `nn.MultiheadAttention` to keep attention probabilities (`attention_maps`) in the autograd graph.
5. **Computing explanations**
- Vision: spatial heatmaps over image patches (`compute_legrad_image(...)` / `compute_legrad_image_one_layer(...)`).
- Text: per‑token relevance scores (`compute_legrad_text(...)` / `compute_legrad_text_one_layer(...)`).
6. **Visualizing results**
- Converts heatmaps to numpy and overlays them on the input images; plots token relevance over the prompt.
## 4. Minimal code sketch (inside the notebook)
The core pattern is:
```python
import torch
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.legrad_wrapper import ImageBindLeWrapper
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True).to(device).eval()
wrapper = ImageBindLeWrapper(model, layer_index=-2, trunk_key=ModalityType.VISION)
text_list = ["A dog.", "A car", "A bird"]
image_paths = [".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
```
You can then call the wrapper’s `compute_legrad_*` helpers in the notebook to obtain:
- **Vision heatmaps** for each (image, text) pair.
- **Text token relevance** for a chosen image or batch of images.
|