File size: 2,935 Bytes
a3355a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image

def show_final_layer_attention_maps(

    outputs,

    processed_input,

    device,

    discard_ratio=0.6,

    head_fusion="max",

    only_last_layer=False,

):
    """

    Generates the attention heatmap for the input image based on model outputs.

    """
    with torch.no_grad():
        # Get the pixel values from the processed input
        # Note: processed_input can be a BatchFeature (behaving like dict) or direct tensor
        if hasattr(processed_input, "keys") and "pixel_values" in processed_input.keys():
             image_tensor = processed_input["pixel_values"].to(device)
        elif isinstance(processed_input, dict) and "pixel_values" in processed_input:
             image_tensor = processed_input["pixel_values"].to(device)
        else:
             # Assuming it's already the tensor if not dict-like with pixel_values
             image_tensor = processed_input

        image = image_tensor.squeeze(0)
        image = image - image.min()
        image = image / image.max()

        result = torch.eye(outputs.attentions[0].size(-1)).to(device)
        if only_last_layer:
            attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
        else:
            attention_list = outputs.attentions

        for attention in attention_list:
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=1)[0]
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=1)[0]

            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
            indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1)).to(device)
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1)

            result = torch.matmul(a, result)

        mask = result[0, 0, 1:]
        width = int(mask.size(-1) ** 0.5)
        mask = mask.reshape(width, width).cpu().numpy()
        mask = mask / np.max(mask)

        mask = cv2.resize(mask, (224, 224))

        mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
        heatmap = plt.cm.jet(mask)[:, :, :3]

        showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
        showed_img = (showed_img - np.min(showed_img)) / (
            np.max(showed_img) - np.min(showed_img)
        )
        superimposed_img = heatmap * 0.4 + showed_img * 0.6

        superimposed_img_pil = Image.fromarray(
            (superimposed_img * 255).astype(np.uint8)
        )

        return superimposed_img_pil