ViTViz / utils /visualization.py
lucasddmc's picture
feat: adds different models capability
98cf39b
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from torchvision.models.feature_extraction import create_feature_extractor
from typing import Dict, Tuple, List, Optional
def extract_attention_maps(model, image: torch.Tensor) -> list:
"""
Extrai attention maps de todas as camadas do ViT usando hooks.
Implementação simplificada e robusta que calcula attention manualmente.
Args:
model: Modelo ViT
image: Tensor de imagem [1, 3, 224, 224]
Returns:
attentions: lista de tensores [batch, heads, patches, patches]
"""
attentions = []
# Função de hook simplificada que captura entrada e calcula attention
def make_attention_hook():
def hook(module, input, output):
x = input[0] # Input do módulo de atenção
B, N, C = x.shape
# Verificar se tem os componentes necessários
if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')):
return
# Calcular Q, K, V
qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# Calcular attention weights
scale = (C // module.num_heads) ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
# Salvar (já no CPU para não acumular na GPU)
attentions.append(attn.detach().cpu())
return hook
# Encontrar e registrar hooks nos módulos de atenção
hooks = []
if not hasattr(model, 'blocks'):
raise ValueError("Modelo não tem atributo 'blocks'. Não é um ViT compatível.")
for i, block in enumerate(model.blocks):
if hasattr(block, 'attn'):
hook = block.attn.register_forward_hook(make_attention_hook())
hooks.append(hook)
if len(hooks) == 0:
raise ValueError("Não foi possível registrar hooks. Verifique a arquitetura do modelo.")
# Executar forward pass
model.eval()
with torch.inference_mode():
_ = model(image)
# Remover hooks
for hook in hooks:
hook.remove()
# Garantir que capturamos atenções e retornar
if len(attentions) == 0:
raise ValueError(
f"Nenhuma atenção capturada após registrar {len(hooks)} hooks. "
f"A arquitetura do modelo pode não ser compatível."
)
return attentions
def _infer_grid_size_from_attentions(attentions_per_iter: list) -> int:
"""Infere o tamanho do grid a partir dos tensores de atenção."""
if not attentions_per_iter:
return 14
for iter_attns in attentions_per_iter:
if not iter_attns:
continue
for layer_tensor in iter_attns:
if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4:
# shape: [B, H, T, T] onde T = num_patches + 1 (CLS)
num_tokens = layer_tensor.shape[-1]
num_patches = num_tokens - 1
side = int(num_patches ** 0.5)
if side * side == num_patches:
return side
return 14 # fallback
def extract_layer_head_masks(
attentions_per_iter: list,
layer_idx: int,
head_idx: int,
cls_only: bool = True
) -> list:
"""
Extrai máscaras por iteração para uma cabeça específica de uma camada arbitrária.
Args:
attentions_per_iter: Lista por iteração; cada item é lista de tensores [B, H, T, T] por camada
layer_idx: Índice da camada (0-based)
head_idx: Índice da cabeça (0-based)
cls_only: Se True, usa apenas a atenção do token CLS para os patches
Returns:
Lista de máscaras [grid, grid] normalizadas [0,1]
"""
masks = []
if attentions_per_iter is None or len(attentions_per_iter) == 0:
return masks
# Inferir grid_size dinamicamente
default_grid = _infer_grid_size_from_attentions(attentions_per_iter)
eps = 1e-8
for iter_attns in attentions_per_iter:
if not iter_attns or layer_idx < 0 or layer_idx >= len(iter_attns):
masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
continue
layer_tensor = iter_attns[layer_idx]
if isinstance(layer_tensor, torch.Tensor):
att = layer_tensor.detach().cpu()
else:
att = torch.as_tensor(layer_tensor)
if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
continue
att_head = att[0, head_idx] # [T,T]
vec = att_head[0] if cls_only else att_head.mean(dim=0)
vec_patches = vec[1:]
tokens = vec_patches.numel()
side = int(tokens ** 0.5)
if side * side != tokens:
masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
continue
mask = vec_patches.reshape(side, side)
mask = mask / (mask.max() + eps)
masks.append(mask.numpy())
return masks
def get_num_layers_heads_from_cached(attentions_per_iter: List[List[torch.Tensor]]) -> Tuple[int, int]:
"""
Inspeciona o cache de atenções para obter número de camadas e cabeças.
Args:
attentions_per_iter: Lista por iteração; cada item é lista por camada com tensores [B, H, T, T].
Returns:
(num_layers, num_heads)
"""
if not attentions_per_iter:
return 0, 0
first_iter = attentions_per_iter[0]
if not first_iter:
return 0, 0
num_layers = len(first_iter)
# assume cabeças constantes entre camadas
h = first_iter[0]
if isinstance(h, torch.Tensor):
num_heads = int(h.shape[1]) if h.ndim == 4 else 0
else:
h_t = torch.as_tensor(h)
num_heads = int(h_t.shape[1]) if h_t.ndim == 4 else 0
return num_layers, num_heads
def compute_layer_head_masks_from_cached_attns(iter_attns: List[torch.Tensor], cls_only: bool = True) -> List[List[np.ndarray]]:
"""
Para uma iteração, computa máscaras por camada e cabeça.
Args:
iter_attns: Lista por camada de tensores [B, H, T, T]
cls_only: Se True, usa linha do CLS para patches
Returns:
Lista [layer] de listas [head] com máscaras [side, side] normalizadas.
"""
per_layer_head_masks: List[List[np.ndarray]] = []
eps = 1e-8
# Inferir grid_size do primeiro tensor válido
default_grid = 14
for layer_tensor in iter_attns:
if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4:
num_tokens = layer_tensor.shape[-1]
num_patches = num_tokens - 1
side = int(num_patches ** 0.5)
if side * side == num_patches:
default_grid = side
break
for li, layer_tensor in enumerate(iter_attns):
if isinstance(layer_tensor, torch.Tensor):
att = layer_tensor.detach().cpu()
else:
att = torch.as_tensor(layer_tensor)
if att.ndim != 4 or att.size(0) < 1:
# print(f"[ViTViz][compute_layer_head_masks] Iter layer {li}: invalid attention shape {att.shape if hasattr(att,'shape') else type(att)}")
per_layer_head_masks.append([])
continue
heads_masks: List[np.ndarray] = []
# print(f"[ViTViz][compute_layer_head_masks] Layer {li}: B={att.size(0)}, H={att.size(1)}, T={att.size(2)}")
for h in range(att.size(1)):
att_head = att[0, h] # [T, T]
vec = att_head[0] if cls_only else att_head.mean(dim=0)
vec_patches = vec[1:]
tokens = vec_patches.numel()
side = int(tokens ** 0.5)
if side * side != tokens:
# print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: tokens {tokens} not square -> side={side}")
heads_masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
continue
mask = vec_patches.reshape(side, side)
mmax = float(mask.max())
mask = mask / (mmax + eps)
if mmax == 0:
# print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: max=0, produced zero mask")
pass
heads_masks.append(mask.numpy())
per_layer_head_masks.append(heads_masks)
return per_layer_head_masks
def batch_precompute_all_masks(
attentions_per_iter: List[List[torch.Tensor]],
discard_ratio: float = 0.9,
head_fusion: str = 'max',
precompute_heads: bool = True
) -> Tuple[List[np.ndarray], Optional[List[List[List[np.ndarray]]]]]:
"""
Pré-computa todas as máscaras de atenção:
- Rollout por iteração
- Opcionalmente, por camada/cabeça por iteração
Args:
attentions_per_iter: Lista por iteração com listas por camada [B,H,T,T]
discard_ratio: parâmetro do rollout
head_fusion: fusão de cabeças no rollout
precompute_heads: se True, computa todas heads por camada
Returns:
(rollout_masks_por_iter, per_iter_layer_head_masks ou None)
"""
rollout_masks: List[np.ndarray] = []
per_iter_layer_head_masks: Optional[List[List[List[np.ndarray]]]] = [] if precompute_heads else None
if not attentions_per_iter:
return rollout_masks, per_iter_layer_head_masks
for it_idx, iter_attns in enumerate(attentions_per_iter):
# Rollout desta iteração
attentions_cpu = []
for li, att in enumerate(iter_attns):
if isinstance(att, torch.Tensor):
attentions_cpu.append(att.detach().cpu())
else:
attentions_cpu.append(torch.as_tensor(att))
if len(attentions_cpu) == 0:
# print(f"[ViTViz][batch_precompute] Iter {it_idx}: empty attentions list")
pass
rollout_mask = attention_rollout(
attentions_cpu,
discard_ratio=discard_ratio,
head_fusion=head_fusion
)
rollout_masks.append(rollout_mask)
# Heads por camada desta iteração
if precompute_heads:
# print(f"[ViTViz][batch_precompute] Iter {it_idx}: computing per-layer/head masks; layers={len(iter_attns)}")
per_layer = compute_layer_head_masks_from_cached_attns(iter_attns, cls_only=True)
per_iter_layer_head_masks.append(per_layer)
return rollout_masks, per_iter_layer_head_masks
def attention_rollout(attentions: list,
discard_ratio: float = 0.9,
head_fusion: str = 'max') -> np.ndarray:
"""
Implementa Attention Rollout seguindo a implementação original.
Referência: https://github.com/jacobgil/vit-explain
Args:
attentions: Lista de tensores [batch, heads, patches, patches]
discard_ratio: Proporção de atenções mais fracas a descartar (default: 0.9)
head_fusion: Como agregar múltiplas cabeças - 'mean', 'max' ou 'min'
Returns:
mask: Array numpy [grid_size, grid_size] com valores normalizados [0, 1]
"""
# Inicializar com matriz identidade
result = torch.eye(attentions[0].size(-1))
with torch.no_grad():
for attention in attentions:
# Agregar heads
if head_fusion == 'mean':
attention_heads_fused = attention.mean(axis=1)
elif head_fusion == 'max':
attention_heads_fused = attention.max(axis=1)[0]
elif head_fusion == 'min':
attention_heads_fused = attention.min(axis=1)[0]
else:
raise ValueError(f"head_fusion deve ser 'mean', 'max' ou 'min'")
# Aplicar descarte condicional das atenções fracas por amostra
if discard_ratio > 0.0:
bsz, tokens, _ = attention_heads_fused.shape
flat = attention_heads_fused.view(bsz, -1)
k = int(flat.size(-1) * discard_ratio)
if k > 0:
# Menores valores (largest=False)
vals, idxs = torch.topk(flat, k, dim=-1, largest=False)
for b in range(bsz):
idxs_b = idxs[b]
# proteger CLS (posição 0 nas matrizes quadradas)
idxs_b = idxs_b[idxs_b != 0]
flat[b, idxs_b] = 0
attention_heads_fused = flat.view(bsz, tokens, tokens)
# Adicionar identidade e normalizar
I = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + 1.0 * I) / 2
# CORREÇÃO 3: normalizar sem keepdim
a = a / a.sum(dim=-1)
# Rollout recursivo
result = torch.matmul(a, result)
# Look at the total attention between the class token and the image patches
mask = result[0, 0, 1:]
# Calcular tamanho do grid
width = int(mask.size(-1) ** 0.5)
mask = mask.reshape(width, width).numpy()
# Normalizar
mask = mask / np.max(mask)
return mask
def create_attention_overlay(original_image: Image.Image,
attention_mask: np.ndarray,
alpha: float = 0.5,
colormap: str = 'jet') -> Image.Image:
"""
Cria visualização sobrepondo o mapa de atenção na imagem original.
Segue implementação de referência usando OpenCV.
Args:
original_image: Imagem PIL original
attention_mask: Máscara de atenção [H, W] normalizada [0, 1]
alpha: Peso da imagem original (0.7 = 70% imagem, 30% heatmap)
colormap: 'jet' (padrão OpenCV)
Returns:
Imagem PIL com overlay de atenção
"""
import cv2
# Converter PIL para numpy array RGB
img_np = np.array(original_image).astype(np.float32) / 255.0
# Redimensionar máscara para o tamanho da imagem (224x224 ou tamanho original)
h, w = img_np.shape[:2]
mask_resized = cv2.resize(attention_mask, (w, h))
# Aplicar colormap do OpenCV (retorna BGR!)
heatmap = cv2.applyColorMap(np.uint8(255 * mask_resized), cv2.COLORMAP_JET)
# CRÍTICO: Converter BGR → RGB (OpenCV usa BGR!)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = heatmap.astype(np.float32) / 255.0
# Blend: alpha * img_original + (1-alpha) * heatmap
overlay = alpha * img_np + (1 - alpha) * heatmap
overlay = np.clip(overlay, 0, 1)
# Converter de volta para PIL
overlay_uint8 = (overlay * 255).astype(np.uint8)
return Image.fromarray(overlay_uint8)
def extract_attention_for_iterations(
model,
iteration_tensors: list,
discard_ratio: float = 0.9,
head_fusion: str = 'max'
) -> list:
"""
[Deprecated when cached attentions are present]
Extrai mapas de atenção para cada iteração do ataque usando hooks.
Args:
model: Modelo ViT
iteration_tensors: Lista de tensors normalizados [1, 3, 224, 224] de cada iteração
discard_ratio: Proporção de atenções fracas a descartar
head_fusion: Como agregar heads ('mean', 'max', 'min')
Returns:
Lista de máscaras de atenção [14, 14] normalizadas [0, 1]
"""
attention_masks = []
for tensor in iteration_tensors:
# Extrair attention maps para esta iteração
attentions = extract_attention_maps(model, tensor)
# Aplicar Attention Rollout
mask = attention_rollout(
attentions,
discard_ratio=discard_ratio,
head_fusion=head_fusion
)
attention_masks.append(mask)
return attention_masks
def rollout_from_cached_attentions(
attentions_per_iter: list,
discard_ratio: float = 0.9,
head_fusion: str = 'max'
) -> list:
"""
Gera máscaras de atenção por iteração a partir de atenções já capturadas no ataque.
Args:
attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
discard_ratio: Proporção de atenções fracas a descartar
head_fusion: Como agregar heads ('mean', 'max', 'min')
Returns:
Lista de máscaras de atenção [grid, grid] normalizadas [0, 1]
"""
attention_masks = []
if attentions_per_iter is None or len(attentions_per_iter) == 0:
return attention_masks
for layer_attns in attentions_per_iter:
# layer_attns: lista de tensores por camada [B, H, T, T]
# Garantir CPU e detach
attentions_cpu = []
for att in layer_attns:
if isinstance(att, torch.Tensor):
attentions_cpu.append(att.detach().cpu())
else:
# já é CPU numpy/tensor? tentar converter via torch.as_tensor
attentions_cpu.append(torch.as_tensor(att))
# Aplicar rollout padrão sobre a lista de camadas
mask = attention_rollout(
attentions_cpu,
discard_ratio=discard_ratio,
head_fusion=head_fusion
)
attention_masks.append(mask)
return attention_masks
def extract_last_layer_head_masks(
attentions_per_iter: list,
head_idx: int,
cls_only: bool = True
) -> list:
"""
Extrai máscaras por iteração para uma única cabeça da última camada.
Args:
attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
head_idx: Índice da cabeça na última camada (0-based)
cls_only: Se True, usa a atenção do token CLS (linha 0) para os patches
Returns:
Lista de máscaras [grid, grid] normalizadas [0, 1]
"""
masks = []
if attentions_per_iter is None or len(attentions_per_iter) == 0:
return masks
# Inferir grid_size dinamicamente
default_grid = _infer_grid_size_from_attentions(attentions_per_iter)
eps = 1e-8
for iter_attns in attentions_per_iter:
if not iter_attns:
masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
print("Atenções vazias para esta iteração.")
continue
# Última camada
last_layer = iter_attns[-1]
if isinstance(last_layer, torch.Tensor):
att = last_layer.detach().cpu()
else:
att = torch.as_tensor(last_layer)
# Esperado: [B, H, T, T] com B=1
if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
print("Atenção inválida na última camada.")
continue
# Selecionar cabeça
att_head = att[0, head_idx] # [T, T]
# Vetor atenção CLS→tokens
if cls_only:
vec = att_head[0] # linha do CLS
else:
# média das linhas como alternativa
vec = att_head.mean(dim=0)
# Remover CLS e projetar para grade
vec_patches = vec[1:]
tokens = vec_patches.numel()
side = int(tokens ** 0.5)
if side * side != tokens:
# fallback: normalizar e retornar zeros coerentes
masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
print("Número de patches não forma uma grade quadrada.")
continue
mask = vec_patches.reshape(side, side)
mask = mask / (mask.max() + eps)
masks.append(mask.numpy())
return masks
def create_iteration_attention_overlays(
iteration_images: list,
attention_masks: list,
alpha: float = 0.7
) -> list:
"""
Cria overlays de atenção para cada iteração do ataque.
OTIMIZADO para velocidade de renderização.
Args:
iteration_images: Lista de PIL Images (uma por iteração)
attention_masks: Lista de máscaras de atenção [14, 14]
alpha: Transparência do overlay
Returns:
Lista de PIL Images com heatmaps sobrepostos (comprimidas)
"""
overlays = []
for img, mask in zip(iteration_images, attention_masks):
overlay = create_attention_overlay(img, mask, alpha=alpha)
# OTIMIZAÇÃO AGRESSIVA: reduzir para 224x224 JPEG qualidade 75
overlay = overlay.resize((224, 224), Image.LANCZOS)
# Converter para RGB se necessário (JPEG não suporta RGBA)
if overlay.mode in ('RGBA', 'LA', 'P'):
background = Image.new('RGB', overlay.size, (255, 255, 255))
if overlay.mode == 'P':
overlay = overlay.convert('RGBA')
background.paste(overlay, mask=overlay.split()[-1] if overlay.mode == 'RGBA' else None)
overlay = background
overlays.append(overlay)
return overlays