File size: 3,008 Bytes
6a00010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# LeGrad + ImageBind Notebook Usage

This folder includes `legrad_imagebind.ipynb`, a notebook that demonstrates **LeGrad** explanations for the **ImageBind** model using `ImageBindLeWrapper` (`imagebind/legrad_wrapper.py`).

## 1. Environment and installation

From the `ImageBind` repo root:

```bash
pip install -e .
```

Install **LeGrad** and make sure it is on `PYTHONPATH`:

```bash
pip install -e /path/to/LeGrad
```

or set `PYTHONPATH` so that the `legrad` package is importable.

You also need a working CUDA‑enabled PyTorch environment for GPU execution.

## 2. Open the notebook

From this folder:

```bash
cd xai/ImageBind
jupyter lab legrad_imagebind.ipynb
```

## 3. What the notebook does

The notebook walks through:

1. **Importing ImageBind and LeGrad**  
   - Adds the local `LeGrad` repo to `sys.path`.  
   - Imports `imagebind_model`, `ModalityType`, and utility functions from `imagebind.data`.  
2. **Loading text and image inputs**  
   - Defines `text_list` (e.g. `"A dog."`, `"A car"`, `"A bird"`).  
   - Loads and preprocesses images from `.assets/*.jpg` using `data.load_and_transform_vision_data`.  
3. **Running the base ImageBind model**  
   - Creates `model = imagebind_model.imagebind_huge(pretrained=True)`.  
   - Moves it to CUDA and evaluates vanilla similarity scores `embeddings[VISION] @ embeddings[TEXT].T`.  
4. **Wrapping with LeGrad**  
   - Uses `ImageBindLeWrapper` from `imagebind/legrad_wrapper.py`.  
   - Hooks transformer blocks and `nn.MultiheadAttention` to keep attention probabilities (`attention_maps`) in the autograd graph.  
5. **Computing explanations**  
   - Vision: spatial heatmaps over image patches (`compute_legrad_image(...)` / `compute_legrad_image_one_layer(...)`).  
   - Text: per‑token relevance scores (`compute_legrad_text(...)` / `compute_legrad_text_one_layer(...)`).  
6. **Visualizing results**  
   - Converts heatmaps to numpy and overlays them on the input images; plots token relevance over the prompt.

## 4. Minimal code sketch (inside the notebook)

The core pattern is:

```python
import torch
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.legrad_wrapper import ImageBindLeWrapper

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True).to(device).eval()

wrapper = ImageBindLeWrapper(model, layer_index=-2, trunk_key=ModalityType.VISION)

text_list = ["A dog.", "A car", "A bird"]
image_paths = [".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]

inputs = {
    ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
```

You can then call the wrapper’s `compute_legrad_*` helpers in the notebook to obtain:

- **Vision heatmaps** for each (image, text) pair.
- **Text token relevance** for a chosen image or batch of images.