File size: 4,309 Bytes
21f2675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from PIL import Image
from src.interpretability import cross_attention_to_image
import numpy as np
import matplotlib.cm as cm


def resize_for_display(pil_img, max_dim=5000):
    w, h = pil_img.size
    if max(w, h) <= max_dim:
        return pil_img
    scale = max_dim / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    return pil_img.resize((new_w, new_h), Image.LANCZOS)



@torch.no_grad()
def generate_rollout_for_demo(model, tokenizer, img, preprocess,
                              device="cuda", max_new_tokens=32, alpha=0.45):

    model.eval()

    img_tensor = preprocess(img).unsqueeze(0).to(device)

    vision_out = model.vision_encoder(img_tensor)
    img_embeds = vision_out["image_embeds"]

    if img_embeds.dim() == 2:
        img_embeds = img_embeds.unsqueeze(1)

    projected = model.projector(img_embeds)

    decoder_input_ids = torch.tensor(
        [[model.t5.config.decoder_start_token_id]], device=device
    )

    generated_ids = []
    avg_frames = []
    labels = []
    per_head_frames = []

    num_heads = None

    # Decode token-by-token
    for step in range(max_new_tokens):

        outputs = model.t5(
            encoder_outputs=(projected,),
            decoder_input_ids=decoder_input_ids,
            output_attentions=True,
            return_dict=True,
        )

        # Cross-attention from last decoder layer
        last_cross = outputs.cross_attentions[-1][0]   # (heads, tgt, src)
        num_heads = last_cross.size(0)

        # average over heads (tgt, src)
        attn_avg = last_cross.mean(dim=0)

        # Get attention for the last generated token (tgt index = -1)
        attn_vec = attn_avg[-1]     # shape: (src_len,)
        heat_avg = cross_attention_to_image(attn_vec)

        if isinstance(heat_avg, tuple):
            heat_avg = heat_avg[0]
        if isinstance(heat_avg, np.ndarray):
            heat_avg = Image.fromarray((heat_avg * 255).astype("uint8"))

        avg_frames.append(
            overlay_attention_for_demo(img_tensor, heat_avg, alpha=alpha)
        )

        head_overlays = []
        for h in range(num_heads):
            attn_vec_h = last_cross[h][-1]        # (src_len,)
            hmap = cross_attention_to_image(attn_vec_h)

            if isinstance(hmap, tuple):
                hmap = hmap[0]
            if isinstance(hmap, np.ndarray):
                hmap = Image.fromarray((hmap * 255).astype("uint8"))

            head_overlays.append(
                overlay_attention_for_demo(img_tensor, hmap, alpha=alpha)
            )

        per_head_frames.append(head_overlays)

        # Decode next token
        next_token = outputs.logits[:, -1, :].argmax(-1)
        token_str = tokenizer.decode(next_token, skip_special_tokens=True)
        labels.append(f"Token #{step}: \"{token_str}\"")

        generated_ids.append(int(next_token))

        if next_token.item() == tokenizer.eos_token_id:
            break

        decoder_input_ids = torch.cat(
            [decoder_input_ids, next_token.unsqueeze(0)], dim=1
        )

    # Caption
    caption = tokenizer.decode(generated_ids, skip_special_tokens=True)

    # Return structured dict for Gradio
    return {
        "caption": caption,
        "avg": {
            "frames": avg_frames,
            "labels": labels
        },
        "heads": {
            "frames": per_head_frames,   # list[step][head] = PIL image
            "labels": labels,
            "num_heads": num_heads
        }
    }


def overlay_attention_for_demo(image_tensor, heatmap, alpha=0.45):

    img = image_tensor[0].detach().cpu().permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # normalize
    img_uint8 = (img * 255).astype("uint8")

    heatmap = heatmap.resize((img_uint8.shape[1], img_uint8.shape[0]), Image.BILINEAR)
    heat_np = np.asarray(heatmap).astype("float32") / 255.0

    base = Image.fromarray(img_uint8).convert("RGBA")

    colored = cm.inferno(heat_np)  # returns RGBA float array

    colored_uint8 = (colored * 255).astype("uint8")
    heat = Image.fromarray(colored_uint8).convert("RGBA")

    heat.putalpha(int(alpha * 255))

    blended = Image.alpha_composite(base, heat)
    blended = blended.convert("RGB")
    return blended #resize_for_display(blended, max_dim=500)