import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import matplotlib.cm as cm from torchvision.models.feature_extraction import create_feature_extractor from typing import Dict, Tuple, List, Optional def extract_attention_maps(model, image: torch.Tensor) -> list: """ Extrai attention maps de todas as camadas do ViT usando hooks. Implementação simplificada e robusta que calcula attention manualmente. Args: model: Modelo ViT image: Tensor de imagem [1, 3, 224, 224] Returns: attentions: lista de tensores [batch, heads, patches, patches] """ attentions = [] # Função de hook simplificada que captura entrada e calcula attention def make_attention_hook(): def hook(module, input, output): x = input[0] # Input do módulo de atenção B, N, C = x.shape # Verificar se tem os componentes necessários if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): return # Calcular Q, K, V qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # Calcular attention weights scale = (C // module.num_heads) ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) # Salvar (já no CPU para não acumular na GPU) attentions.append(attn.detach().cpu()) return hook # Encontrar e registrar hooks nos módulos de atenção hooks = [] if not hasattr(model, 'blocks'): raise ValueError("Modelo não tem atributo 'blocks'. Não é um ViT compatível.") for i, block in enumerate(model.blocks): if hasattr(block, 'attn'): hook = block.attn.register_forward_hook(make_attention_hook()) hooks.append(hook) if len(hooks) == 0: raise ValueError("Não foi possível registrar hooks. Verifique a arquitetura do modelo.") # Executar forward pass model.eval() with torch.inference_mode(): _ = model(image) # Remover hooks for hook in hooks: hook.remove() # Garantir que capturamos atenções e retornar if len(attentions) == 0: raise ValueError( f"Nenhuma atenção capturada após registrar {len(hooks)} hooks. " f"A arquitetura do modelo pode não ser compatível." ) return attentions def _infer_grid_size_from_attentions(attentions_per_iter: list) -> int: """Infere o tamanho do grid a partir dos tensores de atenção.""" if not attentions_per_iter: return 14 for iter_attns in attentions_per_iter: if not iter_attns: continue for layer_tensor in iter_attns: if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4: # shape: [B, H, T, T] onde T = num_patches + 1 (CLS) num_tokens = layer_tensor.shape[-1] num_patches = num_tokens - 1 side = int(num_patches ** 0.5) if side * side == num_patches: return side return 14 # fallback def extract_layer_head_masks( attentions_per_iter: list, layer_idx: int, head_idx: int, cls_only: bool = True ) -> list: """ Extrai máscaras por iteração para uma cabeça específica de uma camada arbitrária. Args: attentions_per_iter: Lista por iteração; cada item é lista de tensores [B, H, T, T] por camada layer_idx: Índice da camada (0-based) head_idx: Índice da cabeça (0-based) cls_only: Se True, usa apenas a atenção do token CLS para os patches Returns: Lista de máscaras [grid, grid] normalizadas [0,1] """ masks = [] if attentions_per_iter is None or len(attentions_per_iter) == 0: return masks # Inferir grid_size dinamicamente default_grid = _infer_grid_size_from_attentions(attentions_per_iter) eps = 1e-8 for iter_attns in attentions_per_iter: if not iter_attns or layer_idx < 0 or layer_idx >= len(iter_attns): masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) continue layer_tensor = iter_attns[layer_idx] if isinstance(layer_tensor, torch.Tensor): att = layer_tensor.detach().cpu() else: att = torch.as_tensor(layer_tensor) if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1): masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) continue att_head = att[0, head_idx] # [T,T] vec = att_head[0] if cls_only else att_head.mean(dim=0) vec_patches = vec[1:] tokens = vec_patches.numel() side = int(tokens ** 0.5) if side * side != tokens: masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) continue mask = vec_patches.reshape(side, side) mask = mask / (mask.max() + eps) masks.append(mask.numpy()) return masks def get_num_layers_heads_from_cached(attentions_per_iter: List[List[torch.Tensor]]) -> Tuple[int, int]: """ Inspeciona o cache de atenções para obter número de camadas e cabeças. Args: attentions_per_iter: Lista por iteração; cada item é lista por camada com tensores [B, H, T, T]. Returns: (num_layers, num_heads) """ if not attentions_per_iter: return 0, 0 first_iter = attentions_per_iter[0] if not first_iter: return 0, 0 num_layers = len(first_iter) # assume cabeças constantes entre camadas h = first_iter[0] if isinstance(h, torch.Tensor): num_heads = int(h.shape[1]) if h.ndim == 4 else 0 else: h_t = torch.as_tensor(h) num_heads = int(h_t.shape[1]) if h_t.ndim == 4 else 0 return num_layers, num_heads def compute_layer_head_masks_from_cached_attns(iter_attns: List[torch.Tensor], cls_only: bool = True) -> List[List[np.ndarray]]: """ Para uma iteração, computa máscaras por camada e cabeça. Args: iter_attns: Lista por camada de tensores [B, H, T, T] cls_only: Se True, usa linha do CLS para patches Returns: Lista [layer] de listas [head] com máscaras [side, side] normalizadas. """ per_layer_head_masks: List[List[np.ndarray]] = [] eps = 1e-8 # Inferir grid_size do primeiro tensor válido default_grid = 14 for layer_tensor in iter_attns: if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4: num_tokens = layer_tensor.shape[-1] num_patches = num_tokens - 1 side = int(num_patches ** 0.5) if side * side == num_patches: default_grid = side break for li, layer_tensor in enumerate(iter_attns): if isinstance(layer_tensor, torch.Tensor): att = layer_tensor.detach().cpu() else: att = torch.as_tensor(layer_tensor) if att.ndim != 4 or att.size(0) < 1: # print(f"[ViTViz][compute_layer_head_masks] Iter layer {li}: invalid attention shape {att.shape if hasattr(att,'shape') else type(att)}") per_layer_head_masks.append([]) continue heads_masks: List[np.ndarray] = [] # print(f"[ViTViz][compute_layer_head_masks] Layer {li}: B={att.size(0)}, H={att.size(1)}, T={att.size(2)}") for h in range(att.size(1)): att_head = att[0, h] # [T, T] vec = att_head[0] if cls_only else att_head.mean(dim=0) vec_patches = vec[1:] tokens = vec_patches.numel() side = int(tokens ** 0.5) if side * side != tokens: # print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: tokens {tokens} not square -> side={side}") heads_masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) continue mask = vec_patches.reshape(side, side) mmax = float(mask.max()) mask = mask / (mmax + eps) if mmax == 0: # print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: max=0, produced zero mask") pass heads_masks.append(mask.numpy()) per_layer_head_masks.append(heads_masks) return per_layer_head_masks def batch_precompute_all_masks( attentions_per_iter: List[List[torch.Tensor]], discard_ratio: float = 0.9, head_fusion: str = 'max', precompute_heads: bool = True ) -> Tuple[List[np.ndarray], Optional[List[List[List[np.ndarray]]]]]: """ Pré-computa todas as máscaras de atenção: - Rollout por iteração - Opcionalmente, por camada/cabeça por iteração Args: attentions_per_iter: Lista por iteração com listas por camada [B,H,T,T] discard_ratio: parâmetro do rollout head_fusion: fusão de cabeças no rollout precompute_heads: se True, computa todas heads por camada Returns: (rollout_masks_por_iter, per_iter_layer_head_masks ou None) """ rollout_masks: List[np.ndarray] = [] per_iter_layer_head_masks: Optional[List[List[List[np.ndarray]]]] = [] if precompute_heads else None if not attentions_per_iter: return rollout_masks, per_iter_layer_head_masks for it_idx, iter_attns in enumerate(attentions_per_iter): # Rollout desta iteração attentions_cpu = [] for li, att in enumerate(iter_attns): if isinstance(att, torch.Tensor): attentions_cpu.append(att.detach().cpu()) else: attentions_cpu.append(torch.as_tensor(att)) if len(attentions_cpu) == 0: # print(f"[ViTViz][batch_precompute] Iter {it_idx}: empty attentions list") pass rollout_mask = attention_rollout( attentions_cpu, discard_ratio=discard_ratio, head_fusion=head_fusion ) rollout_masks.append(rollout_mask) # Heads por camada desta iteração if precompute_heads: # print(f"[ViTViz][batch_precompute] Iter {it_idx}: computing per-layer/head masks; layers={len(iter_attns)}") per_layer = compute_layer_head_masks_from_cached_attns(iter_attns, cls_only=True) per_iter_layer_head_masks.append(per_layer) return rollout_masks, per_iter_layer_head_masks def attention_rollout(attentions: list, discard_ratio: float = 0.9, head_fusion: str = 'max') -> np.ndarray: """ Implementa Attention Rollout seguindo a implementação original. Referência: https://github.com/jacobgil/vit-explain Args: attentions: Lista de tensores [batch, heads, patches, patches] discard_ratio: Proporção de atenções mais fracas a descartar (default: 0.9) head_fusion: Como agregar múltiplas cabeças - 'mean', 'max' ou 'min' Returns: mask: Array numpy [grid_size, grid_size] com valores normalizados [0, 1] """ # Inicializar com matriz identidade result = torch.eye(attentions[0].size(-1)) with torch.no_grad(): for attention in attentions: # Agregar heads if head_fusion == 'mean': attention_heads_fused = attention.mean(axis=1) elif head_fusion == 'max': attention_heads_fused = attention.max(axis=1)[0] elif head_fusion == 'min': attention_heads_fused = attention.min(axis=1)[0] else: raise ValueError(f"head_fusion deve ser 'mean', 'max' ou 'min'") # Aplicar descarte condicional das atenções fracas por amostra if discard_ratio > 0.0: bsz, tokens, _ = attention_heads_fused.shape flat = attention_heads_fused.view(bsz, -1) k = int(flat.size(-1) * discard_ratio) if k > 0: # Menores valores (largest=False) vals, idxs = torch.topk(flat, k, dim=-1, largest=False) for b in range(bsz): idxs_b = idxs[b] # proteger CLS (posição 0 nas matrizes quadradas) idxs_b = idxs_b[idxs_b != 0] flat[b, idxs_b] = 0 attention_heads_fused = flat.view(bsz, tokens, tokens) # Adicionar identidade e normalizar I = torch.eye(attention_heads_fused.size(-1)) a = (attention_heads_fused + 1.0 * I) / 2 # CORREÇÃO 3: normalizar sem keepdim a = a / a.sum(dim=-1) # Rollout recursivo result = torch.matmul(a, result) # Look at the total attention between the class token and the image patches mask = result[0, 0, 1:] # Calcular tamanho do grid width = int(mask.size(-1) ** 0.5) mask = mask.reshape(width, width).numpy() # Normalizar mask = mask / np.max(mask) return mask def create_attention_overlay(original_image: Image.Image, attention_mask: np.ndarray, alpha: float = 0.5, colormap: str = 'jet') -> Image.Image: """ Cria visualização sobrepondo o mapa de atenção na imagem original. Segue implementação de referência usando OpenCV. Args: original_image: Imagem PIL original attention_mask: Máscara de atenção [H, W] normalizada [0, 1] alpha: Peso da imagem original (0.7 = 70% imagem, 30% heatmap) colormap: 'jet' (padrão OpenCV) Returns: Imagem PIL com overlay de atenção """ import cv2 # Converter PIL para numpy array RGB img_np = np.array(original_image).astype(np.float32) / 255.0 # Redimensionar máscara para o tamanho da imagem (224x224 ou tamanho original) h, w = img_np.shape[:2] mask_resized = cv2.resize(attention_mask, (w, h)) # Aplicar colormap do OpenCV (retorna BGR!) heatmap = cv2.applyColorMap(np.uint8(255 * mask_resized), cv2.COLORMAP_JET) # CRÍTICO: Converter BGR → RGB (OpenCV usa BGR!) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = heatmap.astype(np.float32) / 255.0 # Blend: alpha * img_original + (1-alpha) * heatmap overlay = alpha * img_np + (1 - alpha) * heatmap overlay = np.clip(overlay, 0, 1) # Converter de volta para PIL overlay_uint8 = (overlay * 255).astype(np.uint8) return Image.fromarray(overlay_uint8) def extract_attention_for_iterations( model, iteration_tensors: list, discard_ratio: float = 0.9, head_fusion: str = 'max' ) -> list: """ [Deprecated when cached attentions are present] Extrai mapas de atenção para cada iteração do ataque usando hooks. Args: model: Modelo ViT iteration_tensors: Lista de tensors normalizados [1, 3, 224, 224] de cada iteração discard_ratio: Proporção de atenções fracas a descartar head_fusion: Como agregar heads ('mean', 'max', 'min') Returns: Lista de máscaras de atenção [14, 14] normalizadas [0, 1] """ attention_masks = [] for tensor in iteration_tensors: # Extrair attention maps para esta iteração attentions = extract_attention_maps(model, tensor) # Aplicar Attention Rollout mask = attention_rollout( attentions, discard_ratio=discard_ratio, head_fusion=head_fusion ) attention_masks.append(mask) return attention_masks def rollout_from_cached_attentions( attentions_per_iter: list, discard_ratio: float = 0.9, head_fusion: str = 'max' ) -> list: """ Gera máscaras de atenção por iteração a partir de atenções já capturadas no ataque. Args: attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada discard_ratio: Proporção de atenções fracas a descartar head_fusion: Como agregar heads ('mean', 'max', 'min') Returns: Lista de máscaras de atenção [grid, grid] normalizadas [0, 1] """ attention_masks = [] if attentions_per_iter is None or len(attentions_per_iter) == 0: return attention_masks for layer_attns in attentions_per_iter: # layer_attns: lista de tensores por camada [B, H, T, T] # Garantir CPU e detach attentions_cpu = [] for att in layer_attns: if isinstance(att, torch.Tensor): attentions_cpu.append(att.detach().cpu()) else: # já é CPU numpy/tensor? tentar converter via torch.as_tensor attentions_cpu.append(torch.as_tensor(att)) # Aplicar rollout padrão sobre a lista de camadas mask = attention_rollout( attentions_cpu, discard_ratio=discard_ratio, head_fusion=head_fusion ) attention_masks.append(mask) return attention_masks def extract_last_layer_head_masks( attentions_per_iter: list, head_idx: int, cls_only: bool = True ) -> list: """ Extrai máscaras por iteração para uma única cabeça da última camada. Args: attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada head_idx: Índice da cabeça na última camada (0-based) cls_only: Se True, usa a atenção do token CLS (linha 0) para os patches Returns: Lista de máscaras [grid, grid] normalizadas [0, 1] """ masks = [] if attentions_per_iter is None or len(attentions_per_iter) == 0: return masks # Inferir grid_size dinamicamente default_grid = _infer_grid_size_from_attentions(attentions_per_iter) eps = 1e-8 for iter_attns in attentions_per_iter: if not iter_attns: masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) print("Atenções vazias para esta iteração.") continue # Última camada last_layer = iter_attns[-1] if isinstance(last_layer, torch.Tensor): att = last_layer.detach().cpu() else: att = torch.as_tensor(last_layer) # Esperado: [B, H, T, T] com B=1 if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1): masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) print("Atenção inválida na última camada.") continue # Selecionar cabeça att_head = att[0, head_idx] # [T, T] # Vetor atenção CLS→tokens if cls_only: vec = att_head[0] # linha do CLS else: # média das linhas como alternativa vec = att_head.mean(dim=0) # Remover CLS e projetar para grade vec_patches = vec[1:] tokens = vec_patches.numel() side = int(tokens ** 0.5) if side * side != tokens: # fallback: normalizar e retornar zeros coerentes masks.append(np.zeros((default_grid, default_grid), dtype=np.float32)) print("Número de patches não forma uma grade quadrada.") continue mask = vec_patches.reshape(side, side) mask = mask / (mask.max() + eps) masks.append(mask.numpy()) return masks def create_iteration_attention_overlays( iteration_images: list, attention_masks: list, alpha: float = 0.7 ) -> list: """ Cria overlays de atenção para cada iteração do ataque. OTIMIZADO para velocidade de renderização. Args: iteration_images: Lista de PIL Images (uma por iteração) attention_masks: Lista de máscaras de atenção [14, 14] alpha: Transparência do overlay Returns: Lista de PIL Images com heatmaps sobrepostos (comprimidas) """ overlays = [] for img, mask in zip(iteration_images, attention_masks): overlay = create_attention_overlay(img, mask, alpha=alpha) # OTIMIZAÇÃO AGRESSIVA: reduzir para 224x224 JPEG qualidade 75 overlay = overlay.resize((224, 224), Image.LANCZOS) # Converter para RGB se necessário (JPEG não suporta RGBA) if overlay.mode in ('RGBA', 'LA', 'P'): background = Image.new('RGB', overlay.size, (255, 255, 255)) if overlay.mode == 'P': overlay = overlay.convert('RGBA') background.paste(overlay, mask=overlay.split()[-1] if overlay.mode == 'RGBA' else None) overlay = background overlays.append(overlay) return overlays