| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | import matplotlib.cm as cm |
| | from torchvision.models.feature_extraction import create_feature_extractor |
| | from typing import Dict, Tuple, List, Optional |
| |
|
| |
|
| | def extract_attention_maps(model, image: torch.Tensor) -> list: |
| | """ |
| | Extrai attention maps de todas as camadas do ViT usando hooks. |
| | |
| | Implementação simplificada e robusta que calcula attention manualmente. |
| | |
| | Args: |
| | model: Modelo ViT |
| | image: Tensor de imagem [1, 3, 224, 224] |
| | |
| | Returns: |
| | attentions: lista de tensores [batch, heads, patches, patches] |
| | """ |
| | attentions = [] |
| | |
| | |
| | def make_attention_hook(): |
| | def hook(module, input, output): |
| | x = input[0] |
| | B, N, C = x.shape |
| | |
| | |
| | if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): |
| | return |
| | |
| | |
| | qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv.unbind(0) |
| | |
| | |
| | scale = (C // module.num_heads) ** -0.5 |
| | attn = (q @ k.transpose(-2, -1)) * scale |
| | attn = attn.softmax(dim=-1) |
| | |
| | |
| | attentions.append(attn.detach().cpu()) |
| | |
| | return hook |
| | |
| | |
| | hooks = [] |
| | if not hasattr(model, 'blocks'): |
| | raise ValueError("Modelo não tem atributo 'blocks'. Não é um ViT compatível.") |
| | |
| | for i, block in enumerate(model.blocks): |
| | if hasattr(block, 'attn'): |
| | hook = block.attn.register_forward_hook(make_attention_hook()) |
| | hooks.append(hook) |
| | |
| | if len(hooks) == 0: |
| | raise ValueError("Não foi possível registrar hooks. Verifique a arquitetura do modelo.") |
| | |
| | |
| | model.eval() |
| | with torch.inference_mode(): |
| | _ = model(image) |
| | |
| | |
| | for hook in hooks: |
| | hook.remove() |
| |
|
| | |
| | if len(attentions) == 0: |
| | raise ValueError( |
| | f"Nenhuma atenção capturada após registrar {len(hooks)} hooks. " |
| | f"A arquitetura do modelo pode não ser compatível." |
| | ) |
| | return attentions |
| |
|
| |
|
| | def _infer_grid_size_from_attentions(attentions_per_iter: list) -> int: |
| | """Infere o tamanho do grid a partir dos tensores de atenção.""" |
| | if not attentions_per_iter: |
| | return 14 |
| | for iter_attns in attentions_per_iter: |
| | if not iter_attns: |
| | continue |
| | for layer_tensor in iter_attns: |
| | if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4: |
| | |
| | num_tokens = layer_tensor.shape[-1] |
| | num_patches = num_tokens - 1 |
| | side = int(num_patches ** 0.5) |
| | if side * side == num_patches: |
| | return side |
| | return 14 |
| |
|
| |
|
| | def extract_layer_head_masks( |
| | attentions_per_iter: list, |
| | layer_idx: int, |
| | head_idx: int, |
| | cls_only: bool = True |
| | ) -> list: |
| | """ |
| | Extrai máscaras por iteração para uma cabeça específica de uma camada arbitrária. |
| | |
| | Args: |
| | attentions_per_iter: Lista por iteração; cada item é lista de tensores [B, H, T, T] por camada |
| | layer_idx: Índice da camada (0-based) |
| | head_idx: Índice da cabeça (0-based) |
| | cls_only: Se True, usa apenas a atenção do token CLS para os patches |
| | |
| | Returns: |
| | Lista de máscaras [grid, grid] normalizadas [0,1] |
| | """ |
| | masks = [] |
| | if attentions_per_iter is None or len(attentions_per_iter) == 0: |
| | return masks |
| | |
| | |
| | default_grid = _infer_grid_size_from_attentions(attentions_per_iter) |
| | eps = 1e-8 |
| | |
| | for iter_attns in attentions_per_iter: |
| | if not iter_attns or layer_idx < 0 or layer_idx >= len(iter_attns): |
| | masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | continue |
| | layer_tensor = iter_attns[layer_idx] |
| | if isinstance(layer_tensor, torch.Tensor): |
| | att = layer_tensor.detach().cpu() |
| | else: |
| | att = torch.as_tensor(layer_tensor) |
| | if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1): |
| | masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | continue |
| | att_head = att[0, head_idx] |
| | vec = att_head[0] if cls_only else att_head.mean(dim=0) |
| | vec_patches = vec[1:] |
| | tokens = vec_patches.numel() |
| | side = int(tokens ** 0.5) |
| | if side * side != tokens: |
| | masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | continue |
| | mask = vec_patches.reshape(side, side) |
| | mask = mask / (mask.max() + eps) |
| | masks.append(mask.numpy()) |
| | return masks |
| |
|
| |
|
| | def get_num_layers_heads_from_cached(attentions_per_iter: List[List[torch.Tensor]]) -> Tuple[int, int]: |
| | """ |
| | Inspeciona o cache de atenções para obter número de camadas e cabeças. |
| | |
| | Args: |
| | attentions_per_iter: Lista por iteração; cada item é lista por camada com tensores [B, H, T, T]. |
| | |
| | Returns: |
| | (num_layers, num_heads) |
| | """ |
| | if not attentions_per_iter: |
| | return 0, 0 |
| | first_iter = attentions_per_iter[0] |
| | if not first_iter: |
| | return 0, 0 |
| | num_layers = len(first_iter) |
| | |
| | h = first_iter[0] |
| | if isinstance(h, torch.Tensor): |
| | num_heads = int(h.shape[1]) if h.ndim == 4 else 0 |
| | else: |
| | h_t = torch.as_tensor(h) |
| | num_heads = int(h_t.shape[1]) if h_t.ndim == 4 else 0 |
| | return num_layers, num_heads |
| |
|
| |
|
| | def compute_layer_head_masks_from_cached_attns(iter_attns: List[torch.Tensor], cls_only: bool = True) -> List[List[np.ndarray]]: |
| | """ |
| | Para uma iteração, computa máscaras por camada e cabeça. |
| | |
| | Args: |
| | iter_attns: Lista por camada de tensores [B, H, T, T] |
| | cls_only: Se True, usa linha do CLS para patches |
| | |
| | Returns: |
| | Lista [layer] de listas [head] com máscaras [side, side] normalizadas. |
| | """ |
| | per_layer_head_masks: List[List[np.ndarray]] = [] |
| | eps = 1e-8 |
| | |
| | |
| | default_grid = 14 |
| | for layer_tensor in iter_attns: |
| | if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4: |
| | num_tokens = layer_tensor.shape[-1] |
| | num_patches = num_tokens - 1 |
| | side = int(num_patches ** 0.5) |
| | if side * side == num_patches: |
| | default_grid = side |
| | break |
| | |
| | for li, layer_tensor in enumerate(iter_attns): |
| | if isinstance(layer_tensor, torch.Tensor): |
| | att = layer_tensor.detach().cpu() |
| | else: |
| | att = torch.as_tensor(layer_tensor) |
| | if att.ndim != 4 or att.size(0) < 1: |
| | |
| | per_layer_head_masks.append([]) |
| | continue |
| | heads_masks: List[np.ndarray] = [] |
| | |
| | for h in range(att.size(1)): |
| | att_head = att[0, h] |
| | vec = att_head[0] if cls_only else att_head.mean(dim=0) |
| | vec_patches = vec[1:] |
| | tokens = vec_patches.numel() |
| | side = int(tokens ** 0.5) |
| | if side * side != tokens: |
| | |
| | heads_masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | continue |
| | mask = vec_patches.reshape(side, side) |
| | mmax = float(mask.max()) |
| | mask = mask / (mmax + eps) |
| | if mmax == 0: |
| | |
| | pass |
| | heads_masks.append(mask.numpy()) |
| | per_layer_head_masks.append(heads_masks) |
| | return per_layer_head_masks |
| |
|
| |
|
| | def batch_precompute_all_masks( |
| | attentions_per_iter: List[List[torch.Tensor]], |
| | discard_ratio: float = 0.9, |
| | head_fusion: str = 'max', |
| | precompute_heads: bool = True |
| | ) -> Tuple[List[np.ndarray], Optional[List[List[List[np.ndarray]]]]]: |
| | """ |
| | Pré-computa todas as máscaras de atenção: |
| | - Rollout por iteração |
| | - Opcionalmente, por camada/cabeça por iteração |
| | |
| | Args: |
| | attentions_per_iter: Lista por iteração com listas por camada [B,H,T,T] |
| | discard_ratio: parâmetro do rollout |
| | head_fusion: fusão de cabeças no rollout |
| | precompute_heads: se True, computa todas heads por camada |
| | |
| | Returns: |
| | (rollout_masks_por_iter, per_iter_layer_head_masks ou None) |
| | """ |
| | rollout_masks: List[np.ndarray] = [] |
| | per_iter_layer_head_masks: Optional[List[List[List[np.ndarray]]]] = [] if precompute_heads else None |
| |
|
| | if not attentions_per_iter: |
| | return rollout_masks, per_iter_layer_head_masks |
| |
|
| | for it_idx, iter_attns in enumerate(attentions_per_iter): |
| | |
| | attentions_cpu = [] |
| | for li, att in enumerate(iter_attns): |
| | if isinstance(att, torch.Tensor): |
| | attentions_cpu.append(att.detach().cpu()) |
| | else: |
| | attentions_cpu.append(torch.as_tensor(att)) |
| | if len(attentions_cpu) == 0: |
| | |
| | pass |
| | rollout_mask = attention_rollout( |
| | attentions_cpu, |
| | discard_ratio=discard_ratio, |
| | head_fusion=head_fusion |
| | ) |
| | rollout_masks.append(rollout_mask) |
| |
|
| | |
| | if precompute_heads: |
| | |
| | per_layer = compute_layer_head_masks_from_cached_attns(iter_attns, cls_only=True) |
| | per_iter_layer_head_masks.append(per_layer) |
| |
|
| | return rollout_masks, per_iter_layer_head_masks |
| |
|
| |
|
| | def attention_rollout(attentions: list, |
| | discard_ratio: float = 0.9, |
| | head_fusion: str = 'max') -> np.ndarray: |
| | """ |
| | Implementa Attention Rollout seguindo a implementação original. |
| | |
| | Referência: https://github.com/jacobgil/vit-explain |
| | |
| | Args: |
| | attentions: Lista de tensores [batch, heads, patches, patches] |
| | discard_ratio: Proporção de atenções mais fracas a descartar (default: 0.9) |
| | head_fusion: Como agregar múltiplas cabeças - 'mean', 'max' ou 'min' |
| | |
| | Returns: |
| | mask: Array numpy [grid_size, grid_size] com valores normalizados [0, 1] |
| | """ |
| | |
| | result = torch.eye(attentions[0].size(-1)) |
| | |
| | with torch.no_grad(): |
| | for attention in attentions: |
| | |
| | if head_fusion == 'mean': |
| | attention_heads_fused = attention.mean(axis=1) |
| | elif head_fusion == 'max': |
| | attention_heads_fused = attention.max(axis=1)[0] |
| | elif head_fusion == 'min': |
| | attention_heads_fused = attention.min(axis=1)[0] |
| | else: |
| | raise ValueError(f"head_fusion deve ser 'mean', 'max' ou 'min'") |
| | |
| | if discard_ratio > 0.0: |
| | bsz, tokens, _ = attention_heads_fused.shape |
| | flat = attention_heads_fused.view(bsz, -1) |
| | k = int(flat.size(-1) * discard_ratio) |
| | if k > 0: |
| | |
| | vals, idxs = torch.topk(flat, k, dim=-1, largest=False) |
| | for b in range(bsz): |
| | idxs_b = idxs[b] |
| | |
| | idxs_b = idxs_b[idxs_b != 0] |
| | flat[b, idxs_b] = 0 |
| | attention_heads_fused = flat.view(bsz, tokens, tokens) |
| |
|
| | |
| | I = torch.eye(attention_heads_fused.size(-1)) |
| | a = (attention_heads_fused + 1.0 * I) / 2 |
| | |
| | |
| | a = a / a.sum(dim=-1) |
| |
|
| | |
| | result = torch.matmul(a, result) |
| | |
| | |
| | mask = result[0, 0, 1:] |
| | |
| | |
| | width = int(mask.size(-1) ** 0.5) |
| | mask = mask.reshape(width, width).numpy() |
| | |
| | |
| | mask = mask / np.max(mask) |
| |
|
| | return mask |
| |
|
| |
|
| | def create_attention_overlay(original_image: Image.Image, |
| | attention_mask: np.ndarray, |
| | alpha: float = 0.5, |
| | colormap: str = 'jet') -> Image.Image: |
| | """ |
| | Cria visualização sobrepondo o mapa de atenção na imagem original. |
| | |
| | Segue implementação de referência usando OpenCV. |
| | |
| | Args: |
| | original_image: Imagem PIL original |
| | attention_mask: Máscara de atenção [H, W] normalizada [0, 1] |
| | alpha: Peso da imagem original (0.7 = 70% imagem, 30% heatmap) |
| | colormap: 'jet' (padrão OpenCV) |
| | |
| | Returns: |
| | Imagem PIL com overlay de atenção |
| | """ |
| | import cv2 |
| | |
| | |
| | img_np = np.array(original_image).astype(np.float32) / 255.0 |
| | |
| | |
| | h, w = img_np.shape[:2] |
| | mask_resized = cv2.resize(attention_mask, (w, h)) |
| | |
| | |
| | heatmap = cv2.applyColorMap(np.uint8(255 * mask_resized), cv2.COLORMAP_JET) |
| | |
| | |
| | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| | heatmap = heatmap.astype(np.float32) / 255.0 |
| | |
| | |
| | overlay = alpha * img_np + (1 - alpha) * heatmap |
| | overlay = np.clip(overlay, 0, 1) |
| | |
| | |
| | overlay_uint8 = (overlay * 255).astype(np.uint8) |
| | return Image.fromarray(overlay_uint8) |
| |
|
| |
|
| | def extract_attention_for_iterations( |
| | model, |
| | iteration_tensors: list, |
| | discard_ratio: float = 0.9, |
| | head_fusion: str = 'max' |
| | ) -> list: |
| | """ |
| | [Deprecated when cached attentions are present] |
| | Extrai mapas de atenção para cada iteração do ataque usando hooks. |
| | |
| | Args: |
| | model: Modelo ViT |
| | iteration_tensors: Lista de tensors normalizados [1, 3, 224, 224] de cada iteração |
| | discard_ratio: Proporção de atenções fracas a descartar |
| | head_fusion: Como agregar heads ('mean', 'max', 'min') |
| | |
| | Returns: |
| | Lista de máscaras de atenção [14, 14] normalizadas [0, 1] |
| | """ |
| | attention_masks = [] |
| | |
| | for tensor in iteration_tensors: |
| | |
| | attentions = extract_attention_maps(model, tensor) |
| | |
| | |
| | mask = attention_rollout( |
| | attentions, |
| | discard_ratio=discard_ratio, |
| | head_fusion=head_fusion |
| | ) |
| | |
| | attention_masks.append(mask) |
| | |
| | return attention_masks |
| |
|
| |
|
| | def rollout_from_cached_attentions( |
| | attentions_per_iter: list, |
| | discard_ratio: float = 0.9, |
| | head_fusion: str = 'max' |
| | ) -> list: |
| | """ |
| | Gera máscaras de atenção por iteração a partir de atenções já capturadas no ataque. |
| | |
| | Args: |
| | attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada |
| | discard_ratio: Proporção de atenções fracas a descartar |
| | head_fusion: Como agregar heads ('mean', 'max', 'min') |
| | |
| | Returns: |
| | Lista de máscaras de atenção [grid, grid] normalizadas [0, 1] |
| | """ |
| | attention_masks = [] |
| |
|
| | if attentions_per_iter is None or len(attentions_per_iter) == 0: |
| | return attention_masks |
| |
|
| | for layer_attns in attentions_per_iter: |
| | |
| | |
| | attentions_cpu = [] |
| | for att in layer_attns: |
| | if isinstance(att, torch.Tensor): |
| | attentions_cpu.append(att.detach().cpu()) |
| | else: |
| | |
| | attentions_cpu.append(torch.as_tensor(att)) |
| |
|
| | |
| | mask = attention_rollout( |
| | attentions_cpu, |
| | discard_ratio=discard_ratio, |
| | head_fusion=head_fusion |
| | ) |
| | attention_masks.append(mask) |
| |
|
| | return attention_masks |
| |
|
| |
|
| | def extract_last_layer_head_masks( |
| | attentions_per_iter: list, |
| | head_idx: int, |
| | cls_only: bool = True |
| | ) -> list: |
| | """ |
| | Extrai máscaras por iteração para uma única cabeça da última camada. |
| | |
| | Args: |
| | attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada |
| | head_idx: Índice da cabeça na última camada (0-based) |
| | cls_only: Se True, usa a atenção do token CLS (linha 0) para os patches |
| | |
| | Returns: |
| | Lista de máscaras [grid, grid] normalizadas [0, 1] |
| | """ |
| | masks = [] |
| | if attentions_per_iter is None or len(attentions_per_iter) == 0: |
| | return masks |
| |
|
| | |
| | default_grid = _infer_grid_size_from_attentions(attentions_per_iter) |
| | eps = 1e-8 |
| | |
| | for iter_attns in attentions_per_iter: |
| | if not iter_attns: |
| | masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | print("Atenções vazias para esta iteração.") |
| | continue |
| | |
| | last_layer = iter_attns[-1] |
| | if isinstance(last_layer, torch.Tensor): |
| | att = last_layer.detach().cpu() |
| | else: |
| | att = torch.as_tensor(last_layer) |
| |
|
| | |
| | if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1): |
| | masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | print("Atenção inválida na última camada.") |
| | continue |
| |
|
| | |
| | att_head = att[0, head_idx] |
| |
|
| | |
| | if cls_only: |
| | vec = att_head[0] |
| | else: |
| | |
| | vec = att_head.mean(dim=0) |
| |
|
| | |
| | vec_patches = vec[1:] |
| | tokens = vec_patches.numel() |
| | side = int(tokens ** 0.5) |
| | if side * side != tokens: |
| | |
| | masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) |
| | print("Número de patches não forma uma grade quadrada.") |
| | continue |
| |
|
| | mask = vec_patches.reshape(side, side) |
| | mask = mask / (mask.max() + eps) |
| | masks.append(mask.numpy()) |
| |
|
| | return masks |
| |
|
| |
|
| | def create_iteration_attention_overlays( |
| | iteration_images: list, |
| | attention_masks: list, |
| | alpha: float = 0.7 |
| | ) -> list: |
| | """ |
| | Cria overlays de atenção para cada iteração do ataque. |
| | OTIMIZADO para velocidade de renderização. |
| | |
| | Args: |
| | iteration_images: Lista de PIL Images (uma por iteração) |
| | attention_masks: Lista de máscaras de atenção [14, 14] |
| | alpha: Transparência do overlay |
| | |
| | Returns: |
| | Lista de PIL Images com heatmaps sobrepostos (comprimidas) |
| | """ |
| | overlays = [] |
| | |
| | for img, mask in zip(iteration_images, attention_masks): |
| | overlay = create_attention_overlay(img, mask, alpha=alpha) |
| | |
| | |
| | overlay = overlay.resize((224, 224), Image.LANCZOS) |
| | |
| | |
| | if overlay.mode in ('RGBA', 'LA', 'P'): |
| | background = Image.new('RGB', overlay.size, (255, 255, 255)) |
| | if overlay.mode == 'P': |
| | overlay = overlay.convert('RGBA') |
| | background.paste(overlay, mask=overlay.split()[-1] if overlay.mode == 'RGBA' else None) |
| | overlay = background |
| | |
| | overlays.append(overlay) |
| | |
| | return overlays |
| |
|
| |
|