import torch import numpy as np import cv2 import matplotlib.pyplot as plt from PIL import Image def show_final_layer_attention_maps( outputs, processed_input, device, discard_ratio=0.6, head_fusion="max", only_last_layer=False, ): """ Generates the attention heatmap for the input image based on model outputs. """ with torch.no_grad(): # Get the pixel values from the processed input # Note: processed_input can be a BatchFeature (behaving like dict) or direct tensor if hasattr(processed_input, "keys") and "pixel_values" in processed_input.keys(): image_tensor = processed_input["pixel_values"].to(device) elif isinstance(processed_input, dict) and "pixel_values" in processed_input: image_tensor = processed_input["pixel_values"].to(device) else: # Assuming it's already the tensor if not dict-like with pixel_values image_tensor = processed_input image = image_tensor.squeeze(0) image = image - image.min() image = image / image.max() result = torch.eye(outputs.attentions[0].size(-1)).to(device) if only_last_layer: attention_list = outputs.attentions[-1].unsqueeze(0).to(device) else: attention_list = outputs.attentions for attention in attention_list: if head_fusion == "mean": attention_heads_fused = attention.mean(axis=1) elif head_fusion == "max": attention_heads_fused = attention.max(axis=1)[0] elif head_fusion == "min": attention_heads_fused = attention.min(axis=1)[0] flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False) indices = indices[indices != 0] flat[0, indices] = 0 I = torch.eye(attention_heads_fused.size(-1)).to(device) a = (attention_heads_fused + 1.0 * I) / 2 a = a / a.sum(dim=-1) result = torch.matmul(a, result) mask = result[0, 0, 1:] width = int(mask.size(-1) ** 0.5) mask = mask.reshape(width, width).cpu().numpy() mask = mask / np.max(mask) mask = cv2.resize(mask, (224, 224)) mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask)) heatmap = plt.cm.jet(mask)[:, :, :3] showed_img = image.permute(1, 2, 0).detach().cpu().numpy() showed_img = (showed_img - np.min(showed_img)) / ( np.max(showed_img) - np.min(showed_img) ) superimposed_img = heatmap * 0.4 + showed_img * 0.6 superimposed_img_pil = Image.fromarray( (superimposed_img * 255).astype(np.uint8) ) return superimposed_img_pil