| | import torch |
| | import torchattacks |
| | from PIL import Image |
| | from typing import List, Tuple, Optional |
| | import numpy as np |
| | import warnings |
| | from pathlib import Path |
| | import types |
| | import os |
| |
|
| | try: |
| | import torchvision.models as tv_models |
| | except Exception: |
| | tv_models = None |
| |
|
| | try: |
| | import timm |
| | except Exception: |
| | timm = None |
| |
|
| | try: |
| | from huggingface_hub import hf_hub_download |
| | except Exception: |
| | hf_hub_download = None |
| |
|
| | def capture_outputs_and_attentions(model, x_norm: torch.Tensor): |
| | """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT. |
| | Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada. |
| | Funciona para modelos do timm com atributo 'blocks' e submódulo 'attn'. |
| | """ |
| | |
| | attentions: List[torch.Tensor] = [] |
| |
|
| | def make_attention_hook(): |
| | def hook(module, input, output): |
| | x = input[0] |
| | B, N, C = x.shape |
| | if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): |
| | return |
| | qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv.unbind(0) |
| | scale = (C // module.num_heads) ** -0.5 |
| | attn = (q @ k.transpose(-2, -1)) * scale |
| | attn = attn.softmax(dim=-1) |
| | attentions.append(attn.detach()) |
| | return hook |
| |
|
| | hooks = [] |
| | if hasattr(model, 'blocks'): |
| | for block in model.blocks: |
| | if hasattr(block, 'attn'): |
| | hooks.append(block.attn.register_forward_hook(make_attention_hook())) |
| |
|
| | model.eval() |
| | outputs = model(x_norm) |
| |
|
| | for h in hooks: |
| | h.remove() |
| |
|
| | attentions = [a.cpu() for a in attentions] |
| | return outputs, attentions |
| |
|
| | def denormalize_imagenet(tensor: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Reverte a normalização ImageNet de um tensor. |
| | |
| | Args: |
| | tensor: Tensor normalizado (CxHxW) com mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| | |
| | Returns: |
| | Tensor desnormalizado com valores em [0, 1] |
| | """ |
| | mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device) |
| | std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device) |
| | |
| | |
| | denorm = tensor * std + mean |
| | |
| | |
| | return torch.clamp(denorm, 0, 1) |
| |
|
| | def tensor_to_pil(tensor: torch.Tensor, denormalize: bool = True) -> Image.Image: |
| | """ |
| | Converte tensor (CxHxW) para PIL Image RGB. |
| | |
| | Args: |
| | tensor: Tensor com shape (C, H, W) |
| | denormalize: Se True, aplica desnormalização ImageNet antes da conversão |
| | |
| | Returns: |
| | PIL Image no espaço RGB [0, 255] |
| | """ |
| | if denormalize: |
| | tensor = denormalize_imagenet(tensor) |
| | |
| | |
| | img_np = tensor.cpu().detach().numpy() |
| | img_np = np.transpose(img_np, (1, 2, 0)) |
| | img_np = (img_np * 255).clip(0, 255).astype(np.uint8) |
| | return Image.fromarray(img_np, mode='RGB') |
| |
|
| | class FGSM(torchattacks.FGSM): |
| | """ |
| | Extensão do ataque FGSM (Fast Gradient Sign Method) que captura |
| | a imagem original e a imagem adversarial final. |
| | |
| | FGSM é um ataque de 1 única iteração (non-iterative). |
| | """ |
| | def __init__(self, model, eps=0.03): |
| | super().__init__(model, eps=eps) |
| | self.iteration_images: List[Image.Image] = [] |
| | self.iteration_tensors: List[torch.Tensor] = [] |
| | |
| | self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| | |
| | def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| | """ |
| | Executa o ataque FGSM e retorna: |
| | - adv_images: tensor adversarial final |
| | - iteration_images: lista com [imagem_original, imagem_adversarial] |
| | """ |
| | images = images.clone().detach().to(self.device) |
| | labels = labels.clone().detach().to(self.device) |
| | |
| | loss = torch.nn.CrossEntropyLoss() |
| | |
| | |
| | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| | |
| | images_denorm = images * std + mean |
| | |
| | self.iteration_images = [] |
| | self.iteration_tensors = [] |
| | self.attentions_per_iter = [] |
| | |
| | |
| | pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img_orig) |
| | self.iteration_tensors.append(images.clone().detach()) |
| | |
| | |
| | images.requires_grad = True |
| | |
| | outputs, attentions0 = capture_outputs_and_attentions(self.model, images) |
| | self.attentions_per_iter.append([att for att in attentions0]) |
| | |
| | if self.targeted: |
| | target_labels = self.get_target_label(images, labels) |
| | cost = -loss(outputs, target_labels) |
| | else: |
| | cost = loss(outputs, labels) |
| | |
| | grad = torch.autograd.grad(cost, images, retain_graph=False, create_graph=False)[0] |
| | |
| | |
| | |
| | adv_images_denorm = images_denorm + self.eps * grad.sign() |
| | adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach() |
| | |
| | |
| | adv_images = (adv_images_denorm - mean) / std |
| | |
| | |
| | pil_img_adv = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img_adv) |
| | self.iteration_tensors.append(adv_images.clone().detach()) |
| |
|
| | |
| | outputs_adv, attentions1 = capture_outputs_and_attentions(self.model, adv_images) |
| | self.attentions_per_iter.append([att for att in attentions1]) |
| | |
| | return adv_images, self.iteration_images |
| |
|
| | class PGDIterations(torchattacks.PGD): |
| | """ |
| | Extensão do ataque PGD padrão que captura e retorna |
| | as imagens adversariais de cada iteração como lista de PIL Images. |
| | """ |
| | def __init__(self, model, eps=0.05, alpha=0.005, steps=10, random_start=True): |
| | |
| | super().__init__(model, eps=eps, alpha=alpha, steps=steps, random_start=random_start) |
| | self.iteration_images: List[Image.Image] = [] |
| | self.iteration_tensors: List[torch.Tensor] = [] |
| | self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| |
|
| | def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| | """ |
| | Executa o ataque PGD e retorna: |
| | - adv_images: tensor adversarial final |
| | - iteration_images: lista de PIL Images (uma por iteração do ataque) |
| | |
| | Implementação adaptada para trabalhar com imagens normalizadas ImageNet. |
| | """ |
| | images = images.clone().detach().to(self.device) |
| | labels = labels.clone().detach().to(self.device) |
| | |
| | |
| | if self.targeted: |
| | target_labels = self.get_target_label(images, labels) |
| | |
| | loss = torch.nn.CrossEntropyLoss() |
| | adv_images = images.clone().detach() |
| | |
| | |
| | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| | |
| | |
| | images_denorm = images * std + mean |
| | adv_images_denorm = images_denorm.clone().detach() |
| | |
| | if self.random_start: |
| | |
| | adv_images_denorm = adv_images_denorm + torch.empty_like(adv_images_denorm).uniform_(-self.eps, self.eps) |
| | adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach() |
| |
|
| | self.iteration_images = [] |
| | self.iteration_tensors = [] |
| | self.attentions_per_iter = [] |
| | |
| | |
| | pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img_orig) |
| | self.iteration_tensors.append(images.clone().detach()) |
| | |
| | outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) |
| | self.attentions_per_iter.append([att for att in attentions0]) |
| |
|
| | for step_idx in range(self.steps): |
| | |
| | adv_images = (adv_images_denorm - mean) / std |
| | adv_images.requires_grad = True |
| | outputs, attentions = capture_outputs_and_attentions(self.model, adv_images) |
| |
|
| | |
| | if self.targeted: |
| | cost = -loss(outputs, target_labels) |
| | else: |
| | cost = loss(outputs, labels) |
| |
|
| | |
| | grad = torch.autograd.grad(cost, adv_images, |
| | retain_graph=False, create_graph=False)[0] |
| |
|
| | |
| | |
| | adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad.sign() |
| | delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) |
| | adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() |
| |
|
| | |
| | adv_images_normalized = (adv_images_denorm - mean) / std |
| | |
| | |
| | pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img) |
| | self.iteration_tensors.append(adv_images_normalized.clone().detach()) |
| | |
| | self.attentions_per_iter.append([att for att in attentions]) |
| |
|
| | |
| | adv_images = (adv_images_denorm - mean) / std |
| | return adv_images, self.iteration_images |
| | |
| | class SAGA(torch.nn.Module): |
| | """ |
| | SAGA: Self-Attention Gradient Attack |
| | |
| | Ataque adversarial específico para Vision Transformers que multiplica |
| | o gradiente FGSM pelo mapa de atenção do modelo, focando perturbações |
| | nas regiões que o modelo considera importantes. |
| | |
| | Baseado em: https://github.com/MetaMain/ViTRobust |
| | Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021) |
| | """ |
| |
|
| | def __init__(self, model, eps=8/255, steps=10, discard_ratio: float = 0.0, |
| | head_fusion: str = "mean", use_resnet: bool = False, |
| | cnn_checkpoint_path: str = "resnet.pth", vit_weight=0.5): |
| | """Implementação correta do SAGA baseada no código original (SelfAttentionGradientAttack). |
| | |
| | Parâmetros: |
| | - model: Vision Transformer (deve expor atenções via forward ou função auxiliar em visualization utils) |
| | - eps: orçamento L_inf máximo (em pixel space [0,1]) |
| | - steps: número de iterações (FGSM iterativo) |
| | - discard_ratio: razão de descarte usada no attention rollout |
| | - head_fusion: estratégia de fusão de heads ('mean','max','min') |
| | - use_resnet: se True, acumula gradiente de um backbone CNN externo e o mistura ao gradiente ponderado pela atenção |
| | - cnn_checkpoint_path: caminho padrão do backbone CNN auxiliar (será carregado sob demanda) |
| | """ |
| | super().__init__() |
| | self.model = model |
| | self.eps = eps |
| | self.steps = steps |
| | self.eps_step = self.eps / max(1, steps) |
| | self.discard_ratio = discard_ratio |
| | self.head_fusion = head_fusion |
| | self.use_resnet = use_resnet |
| | |
| | |
| | |
| | self.cnn_checkpoint_spec = cnn_checkpoint_path |
| | self.cnn_model: Optional[torch.nn.Module] = None |
| | self.vit_weight = vit_weight |
| | self.device = next(model.parameters()).device |
| | self.iteration_images: List[Image.Image] = [] |
| | self.iteration_tensors: List[torch.Tensor] = [] |
| | self.attention_masks_cache: List[np.ndarray] = [] |
| | |
| | |
| | self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| | self.loss_fn = torch.nn.CrossEntropyLoss() |
| |
|
| | @staticmethod |
| | def _resolve_checkpoint_path(spec: object) -> Optional[Path]: |
| | """Resolve um checkpoint local ou no Hugging Face Hub para um Path local. |
| | |
| | Formato suportado (HF): |
| | - hf://owner/repo/path/to/file.pth |
| | - hf://owner/repo@revision/path/to/file.pth |
| | |
| | Retorna None se não conseguir resolver. |
| | """ |
| | if spec is None: |
| | return None |
| |
|
| | if isinstance(spec, Path): |
| | return spec |
| |
|
| | if not isinstance(spec, str): |
| | return None |
| |
|
| | s = spec.strip() |
| | if s.startswith("hf://") or s.startswith("hf:"): |
| | if hf_hub_download is None: |
| | warnings.warn("huggingface_hub não está instalado; não é possível baixar checkpoint via HF Hub.") |
| | return None |
| |
|
| | rest = s[len("hf://"):] if s.startswith("hf://") else s[len("hf:"):] |
| | rest = rest.lstrip("/") |
| | parts = [p for p in rest.split("/") if p] |
| | if len(parts) < 3: |
| | raise ValueError( |
| | "Formato inválido para checkpoint HF. Use hf://owner/repo/path/to/file.pth" |
| | ) |
| |
|
| | repo_part = "/".join(parts[:2]) |
| | filename = "/".join(parts[2:]) |
| | revision = None |
| | if "@" in repo_part: |
| | repo_id, revision = repo_part.split("@", 1) |
| | else: |
| | repo_id = repo_part |
| |
|
| | cache_dir = os.getenv("HF_HOME") or None |
| | local_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir) |
| | return Path(local_path) |
| |
|
| | return Path(s) |
| |
|
| | def _attention_map(self, images_norm: torch.Tensor, save: bool = False) -> torch.Tensor: |
| | """Extrai mapa de atenção (rollout) e retorna tensor expandido [B,3,H,W] em [0,1]. |
| | images_norm: imagens já normalizadas para o forward do modelo. |
| | """ |
| | |
| | |
| | raise RuntimeError("_attention_map should not be called directly; use integrated forward attention capture.") |
| |
|
| | def _capture_outputs_and_attentions(self, x_norm: torch.Tensor): |
| | """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT. |
| | Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada. |
| | """ |
| | attentions: List[torch.Tensor] = [] |
| |
|
| | def make_attention_hook(): |
| | def hook(module, input, output): |
| | |
| | x = input[0] |
| | B, N, C = x.shape |
| | if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): |
| | return |
| | qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv.unbind(0) |
| | scale = (C // module.num_heads) ** -0.5 |
| | attn = (q @ k.transpose(-2, -1)) * scale |
| | attn = attn.softmax(dim=-1) |
| | attentions.append(attn.detach()) |
| | return hook |
| |
|
| | hooks = [] |
| | if not hasattr(self.model, 'blocks'): |
| | outputs = self.model(x_norm) |
| | return outputs, [] |
| | for block in self.model.blocks: |
| | if hasattr(block, 'attn'): |
| | hooks.append(block.attn.register_forward_hook(make_attention_hook())) |
| |
|
| | self.model.eval() |
| | outputs = self.model(x_norm) |
| |
|
| | for h in hooks: |
| | h.remove() |
| |
|
| | |
| | attentions = [a.cpu() for a in attentions] |
| | return outputs, attentions |
| |
|
| | def _load_cnn_backbone(self) -> Optional[torch.nn.Module]: |
| | """Carrega (lazy) o backbone CNN auxiliar usado quando use_resnet=True.""" |
| | if not self.use_resnet: |
| | return None |
| | if self.cnn_model is not None: |
| | return self.cnn_model |
| | if tv_models is None: |
| | warnings.warn("torchvision não disponível; desabilitando modo CNN do SAGA.") |
| | return None |
| |
|
| | model: Optional[torch.nn.Module] = None |
| | checkpoint_model_name = "resnetv2_101x1_bit.goog_in21k_ft_in1k" |
| |
|
| | resolved_ckpt_path = None |
| | try: |
| | resolved_ckpt_path = self._resolve_checkpoint_path(self.cnn_checkpoint_spec) |
| | except Exception as exc: |
| | warnings.warn(f"Falha ao resolver cnn_checkpoint_path='{self.cnn_checkpoint_spec}': {exc}") |
| |
|
| | if resolved_ckpt_path and resolved_ckpt_path.exists(): |
| | try: |
| | checkpoint = torch.load(resolved_ckpt_path, map_location=self.device) |
| | if isinstance(checkpoint, torch.nn.Module): |
| | model = checkpoint |
| | elif isinstance(checkpoint, dict): |
| | state_dict = checkpoint.get('model_state_dict') or checkpoint.get('state_dict') or checkpoint |
| | if timm is not None and any(key.startswith("stem.") for key in state_dict.keys()): |
| | num_classes = None |
| | head_bias = state_dict.get('head.fc.bias') |
| | if isinstance(head_bias, torch.Tensor): |
| | num_classes = head_bias.shape[0] |
| | model = timm.create_model( |
| | checkpoint.get("model_name", checkpoint_model_name), |
| | pretrained=False, |
| | num_classes=num_classes or 1000 |
| | ) |
| | load_result = model.load_state_dict(state_dict, strict=False) |
| | else: |
| | model = tv_models.resnet101(weights=None) |
| | load_result = model.load_state_dict(state_dict, strict=False) |
| | missing = load_result.missing_keys |
| | unexpected = load_result.unexpected_keys |
| | if missing or unexpected: |
| | warn_msg = "[SAGA] ResNet checkpoint keys mismatch." |
| | if missing: |
| | warn_msg += f" Missing: {missing[:5]}{'...' if len(missing) > 5 else ''}." |
| | if unexpected: |
| | warn_msg += f" Unexpected: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}." |
| | warnings.warn(warn_msg + " Using available weights (strict=False).") |
| | else: |
| | warnings.warn(f"Formato de checkpoint desconhecido em {resolved_ckpt_path}; utilizando pesos padrão.") |
| | except Exception as exc: |
| | warnings.warn(f"Falha ao carregar {resolved_ckpt_path}: {exc}. Usando ResNet padrão.") |
| |
|
| | if model is None: |
| | if timm is not None: |
| | try: |
| | model = timm.create_model(checkpoint_model_name, pretrained=True) |
| | except Exception: |
| | model = None |
| | if model is None and tv_models is not None: |
| | try: |
| | model = tv_models.resnet101(weights="IMAGENET1K_V2") |
| | except Exception: |
| | model = tv_models.resnet101(pretrained=True) |
| |
|
| | model = model.to(self.device) |
| | model.eval() |
| | self.cnn_model = model |
| | return self.cnn_model |
| |
|
| | def _compute_cnn_gradient(self, images_norm: torch.Tensor, labels: torch.Tensor) -> Optional[torch.Tensor]: |
| | """Obtém gradientes do backbone CNN auxiliar para a mesma imagem normalizada.""" |
| | cnn_model = self._load_cnn_backbone() |
| | if cnn_model is None: |
| | return None |
| |
|
| | cnn_input = images_norm.detach().clone().requires_grad_(True) |
| | outputs = cnn_model(cnn_input) |
| | loss = self.loss_fn(outputs, labels) |
| | grad = torch.autograd.grad(loss, cnn_input, retain_graph=False, create_graph=False)[0] |
| | return grad |
| |
|
| | def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| | """Executa o ataque SAGA (FGSM iterativo com ponderação por atenção). |
| | |
| | Fluxo por iteração: |
| | 1. Normaliza a imagem adversarial atual. |
| | 2. Calcula loss e gradiente. |
| | 3. Extrai mapa de atenção da imagem atual e pondera gradiente. |
| | 4. Aplica passo FGSM (sign) em pixel space [0,1]. |
| | 5. Projeta em L_inf (clamp delta) e clip final para [0,1]. |
| | 6. Salva imagem e tensor normalizado. |
| | """ |
| | images = images.clone().detach().to(self.device) |
| | labels = labels.clone().detach().to(self.device) |
| |
|
| | |
| | mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) |
| | std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) |
| |
|
| | |
| | images_denorm = images * std + mean |
| | adv_denorm = images_denorm.clone().detach() |
| |
|
| | |
| | self.iteration_images = [] |
| | self.iteration_tensors = [] |
| | self.attention_masks_cache = [] |
| | self.attentions_per_iter = [] |
| |
|
| | |
| | self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False)) |
| | self.iteration_tensors.append(images.clone().detach()) |
| | |
| | outputs0, attentions0 = self._capture_outputs_and_attentions(images) |
| | |
| | self.attentions_per_iter.append([att for att in attentions0]) |
| | |
| | from utils.visualization import attention_rollout |
| | import cv2 |
| | b, _, h, w = images.shape |
| | mask0 = attention_rollout(attentions0, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion) |
| | mask0_resized = cv2.resize(mask0, (w, h)) |
| | self.attention_masks_cache.append(mask0.copy()) |
| |
|
| | for step_idx in range(self.steps): |
| | |
| | adv_norm = (adv_denorm - mean) / std |
| | adv_norm.requires_grad = True |
| | outputs, attentions = self._capture_outputs_and_attentions(adv_norm) |
| | if isinstance(outputs, tuple): |
| | outputs = outputs[0] |
| | loss = self.loss_fn(outputs, labels) |
| | grad = torch.autograd.grad(loss, adv_norm, retain_graph=False, create_graph=False)[0] |
| |
|
| | |
| | |
| | self.attentions_per_iter.append([att for att in attentions]) |
| | |
| | mask = attention_rollout(attentions, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion) |
| | mask_resized = cv2.resize(mask, (adv_norm.shape[-1], adv_norm.shape[-2])) |
| | mmax = mask_resized.max() if mask_resized.max() > 0 else 1.0 |
| | mask_resized = (mask_resized / mmax).astype('float32') |
| | att_map = torch.from_numpy(mask_resized).to(self.device).unsqueeze(0).unsqueeze(0).repeat(adv_norm.size(0), 3, 1, 1) |
| | |
| | self.attention_masks_cache.append(mask.copy()) |
| | grad_weighted = grad * att_map |
| |
|
| | grad_final = grad_weighted |
| | if self.use_resnet: |
| | cnn_grad = self._compute_cnn_gradient(adv_norm, labels) |
| | if cnn_grad is not None: |
| | vit_contrib = grad_weighted.detach().abs().mean().item() |
| | cnn_contrib = cnn_grad.detach().abs().mean().item() |
| | grad_final = self.vit_weight * grad_weighted + (1 - self.vit_weight) * cnn_grad |
| | blended_contrib = grad_final.detach().abs().mean().item() |
| |
|
| | |
| | adv_denorm = adv_denorm.detach() + self.eps_step * grad_final.sign() |
| |
|
| | |
| | delta = torch.clamp(adv_denorm - images_denorm, min=-self.eps, max=self.eps) |
| | adv_denorm = torch.clamp(images_denorm + delta, 0.0, 1.0).detach() |
| |
|
| | |
| | self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False)) |
| | self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach()) |
| |
|
| | |
| | adv_final = (adv_denorm - mean) / std |
| | return adv_final, self.iteration_images |
| |
|
| | class AttentionWeightedPGD(torch.nn.Module): |
| | """ |
| | [Deprecated] |
| | Implementação errada do ataque SAGA, mas que consegue fazer ataques |
| | adversariais eficazes em ViTs usando mapas de atenção para pesar o gradiente. |
| | """ |
| | def __init__(self, model, eps=0.03, steps=10): |
| | super().__init__() |
| | self.model = model |
| | self.eps = eps |
| | self.steps = steps |
| | self.eps_step = self.eps / self.steps |
| | self.device = next(model.parameters()).device |
| | self.iteration_images: List[Image.Image] = [] |
| | self.iteration_tensors: List[torch.Tensor] = [] |
| | self.attention_masks_cache: List[np.ndarray] = [] |
| | |
| | def get_attention_map(self, images: torch.Tensor, save_for_viz: bool = False) -> tuple: |
| | """ |
| | Extrai mapa de atenção do ViT usando attention rollout. |
| | Retorna: |
| | - mask_tensor: [B, C, H, W] para uso no ataque |
| | - mask_np: [H, W] numpy array para visualização (se save_for_viz=True) |
| | """ |
| | from utils.visualization import extract_attention_maps, attention_rollout |
| | import cv2 |
| | |
| | batch_size = images.shape[0] |
| | img_size = images.shape[2] |
| | |
| | |
| | attentions = extract_attention_maps(self.model, images) |
| | |
| | |
| | mask = attention_rollout(attentions, discard_ratio=0.9, head_fusion='max') |
| | |
| | |
| | if save_for_viz: |
| | self.attention_masks_cache.append(mask.copy()) |
| | |
| | |
| | mask_resized = cv2.resize(mask, (img_size, img_size)) |
| | |
| | |
| | mask_tensor = torch.from_numpy(mask_resized).float().to(self.device) |
| | mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) |
| | mask_tensor = mask_tensor.repeat(batch_size, 3, 1, 1) |
| | |
| | return mask_tensor, mask if save_for_viz else None |
| | |
| | def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| | """ |
| | Executa ataque SAGA e retorna: |
| | - adv_images: tensor adversarial final |
| | - iteration_images: lista de PIL Images de cada iteração |
| | """ |
| | images = images.clone().detach().to(self.device) |
| | labels = labels.clone().detach().to(self.device) |
| | |
| | loss_fn = torch.nn.CrossEntropyLoss() |
| | |
| | |
| | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| | |
| | images_denorm = images * std + mean |
| | adv_images_denorm = images_denorm.clone().detach() |
| | |
| | self.iteration_images = [] |
| | self.iteration_tensors = [] |
| | self.attention_masks_cache = [] |
| | |
| | |
| | pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img_orig) |
| | self.iteration_tensors.append(images.clone().detach()) |
| | |
| | |
| | attention_map, _ = self.get_attention_map(images, save_for_viz=True) |
| | |
| | for step in range(self.steps): |
| | |
| | adv_images = (adv_images_denorm - mean) / std |
| | adv_images.requires_grad = True |
| | |
| | |
| | outputs = self.model(adv_images) |
| | |
| | |
| | cost = loss_fn(outputs, labels) |
| | |
| | |
| | grad = torch.autograd.grad(cost, adv_images, |
| | retain_graph=False, |
| | create_graph=False)[0] |
| | |
| | |
| | attention_map, _ = self.get_attention_map(adv_images.detach(), save_for_viz=True) |
| | |
| | |
| | grad_weighted = grad * attention_map |
| | |
| | |
| | adv_images_denorm = adv_images_denorm.detach() + self.eps_step * grad_weighted.sign() |
| | delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) |
| | adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() |
| | |
| | |
| | adv_images_normalized = (adv_images_denorm - mean) / std |
| | |
| | |
| | pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img) |
| | self.iteration_tensors.append(adv_images_normalized.clone().detach()) |
| | |
| | |
| | adv_images = (adv_images_denorm - mean) / std |
| | return adv_images, self.iteration_images |
| |
|
| | class MIFGSM(torchattacks.MIFGSM): |
| | """ |
| | MI-FGSM: Momentum Iterative Fast Gradient Sign Method |
| | |
| | Extensão do ataque MIFGSM que captura imagens e atenção de cada iteração. |
| | Usa momentum para estabilizar direção do gradiente e melhorar transferabilidade. |
| | |
| | Paper: "Boosting Adversarial Attacks with Momentum" (2017) |
| | https://arxiv.org/abs/1710.06081 |
| | """ |
| | def __init__(self, model, eps=8/255, alpha=2/255, steps=10, decay=1.0): |
| | super().__init__(model, eps=eps, alpha=alpha, steps=steps, decay=decay) |
| | self.iteration_images: List[Image.Image] = [] |
| | self.iteration_tensors: List[torch.Tensor] = [] |
| | self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| | |
| | def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| | """ |
| | Executa o ataque MI-FGSM e retorna: |
| | - adv_images: tensor adversarial final |
| | - iteration_images: lista de PIL Images (uma por iteração) |
| | |
| | Implementação adaptada para trabalhar com imagens normalizadas ImageNet |
| | e capturar todas as iterações. |
| | """ |
| | images = images.clone().detach().to(self.device) |
| | labels = labels.clone().detach().to(self.device) |
| | if self.targeted: |
| | target_labels = self.get_target_label(images, labels) |
| | |
| | loss = torch.nn.CrossEntropyLoss() |
| | |
| | |
| | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| | |
| | images_denorm = images * std + mean |
| | adv_images_denorm = images_denorm.clone().detach() |
| | |
| | |
| | momentum = torch.zeros_like(images_denorm).detach().to(self.device) |
| | self.iteration_images = [] |
| | self.iteration_tensors = [] |
| | self.attentions_per_iter = [] |
| |
|
| | |
| | pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_img_orig) |
| | self.iteration_tensors.append(images.clone().detach()) |
| |
|
| | |
| | outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) |
| | self.attentions_per_iter.append([att for att in attentions0]) |
| |
|
| | for step in range(self.steps): |
| | |
| | adv_images = (adv_images_denorm - mean) / std |
| | adv_images.requires_grad = True |
| | outputs, attentions = capture_outputs_and_attentions(self.model, adv_images) |
| |
|
| | |
| | if self.targeted: |
| | cost = -loss(outputs, target_labels) |
| | else: |
| | cost = loss(outputs, labels) |
| | |
| | |
| | grad = torch.autograd.grad(cost, adv_images, |
| | retain_graph=False, create_graph=False)[0] |
| |
|
| | |
| | self.attentions_per_iter.append([att for att in attentions]) |
| |
|
| | |
| | grad_denorm = grad * std |
| |
|
| | |
| | grad_denorm = grad_denorm / torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True) |
| | |
| | grad_denorm = grad_denorm + momentum * self.decay |
| | momentum = grad_denorm |
| | |
| | adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad_denorm.sign() |
| | delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) |
| | adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() |
| |
|
| | |
| | adv_images_normalized = (adv_images_denorm - mean) / std |
| | self.iteration_tensors.append(adv_images_normalized.clone().detach()) |
| | pil_iter = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| | self.iteration_images.append(pil_iter) |
| |
|
| | adv_images = (adv_images_denorm - mean) / std |
| |
|
| | return adv_images, self.iteration_images |
| |
|
| |
|
| | class TGR(torch.nn.Module): |
| | """TGR: Token Gradient Regularization attack. |
| | |
| | Ataque iterativo untargeted, white-box, no estilo MI-FGSM, que aplica |
| | regularização de gradiente em módulos internos do transformer via |
| | backward hooks (Attention map, QKV, MLP). |
| | |
| | Diferenças-chave vs. MI-FGSM: |
| | - Attention: zera LINHAS e COLUNAS inteiras do mapa N×N (pares extremos) |
| | - QKV/MLP: zera TOKENS INTEIROS (todas as features de tokens extremos) |
| | - Escala por componente (código oficial): s_attn=0.25, s_qkv=0.75, s_mlp=0.5 |
| | |
| | O ataque trabalha em pixel space [0,1], respeitando orçamento L_inf. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model: torch.nn.Module, |
| | eps: float = 16 / 255, |
| | steps: int = 10, |
| | decay: float = 1.0, |
| | k: int = 1, |
| | gamma_attn: float = 0.25, |
| | gamma_qkv: float = 0.75, |
| | gamma_mlp: float = 0.5, |
| | debug_shapes: bool = False, |
| | enable_attn_hook: bool = True, |
| | enable_qkv_hook: bool = True, |
| | enable_mlp_hook: bool = True, |
| | debug_stats: bool = False, |
| | protect_cls_token: bool = True, |
| | debug_progress: bool = False, |
| | ) -> None: |
| | super().__init__() |
| | self.model = model |
| | self.eps = float(eps) |
| | self.steps = int(steps) |
| | self.decay = float(decay) |
| | self.k = int(k) |
| | self.eps_step = self.eps / max(1, self.steps) |
| | self.gamma_attn = float(gamma_attn) |
| | self.gamma_qkv = float(gamma_qkv) |
| | self.gamma_mlp = float(gamma_mlp) |
| | self.debug_shapes = bool(debug_shapes) |
| | self.enable_attn_hook = bool(enable_attn_hook) |
| | self.enable_qkv_hook = bool(enable_qkv_hook) |
| | self.enable_mlp_hook = bool(enable_mlp_hook) |
| | self.debug_stats = bool(debug_stats) |
| | self.protect_cls_token = bool(protect_cls_token) |
| | self.debug_progress = bool(debug_progress) |
| |
|
| | self.device = next(model.parameters()).device |
| | self.loss_fn = torch.nn.CrossEntropyLoss() |
| |
|
| | self.iteration_images: List[Image.Image] = [] |
| | self.iteration_tensors: List[torch.Tensor] = [] |
| | self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| |
|
| | self.debug_last: dict = {} |
| | self.debug_progress_log: List[dict] = [] |
| | self._patched_attn_forwards: dict = {} |
| |
|
| | |
| |
|
| | def _patch_attention_forward(self, attn_module: torch.nn.Module) -> None: |
| | """Monkeypatch do forward do Attention para anexar hook no mapa de atenção. |
| | |
| | Isso permite aplicar o Algoritmo 1 de forma paper-faithful em timm ViTs, |
| | porque o tensor de atenção [B,H,N,N] não é exposto diretamente como |
| | saída de um submódulo. |
| | """ |
| | if attn_module in self._patched_attn_forwards: |
| | return |
| |
|
| | orig_forward = attn_module.forward |
| | self._patched_attn_forwards[attn_module] = orig_forward |
| |
|
| | def forward_patched(this, x, attn_mask=None, **kwargs): |
| | B, N, C = x.shape |
| | num_heads = getattr(this, "num_heads") |
| |
|
| | qkv = this.qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv.unbind(0) |
| | scale = getattr(this, "scale", (C // num_heads) ** -0.5) |
| | attn = (q @ k.transpose(-2, -1)) * scale |
| | |
| | |
| | if attn_mask is not None: |
| | |
| | attn = attn + attn_mask |
| | attn = attn.softmax(dim=-1) |
| |
|
| | if self.debug_shapes and not getattr(self, "_debug_attn_map_printed", False): |
| | print(f"[TGR DEBUG] attn_map tensor shape (patched): {attn.shape}") |
| | print(f"[TGR DEBUG] attn_map tensor ndim (patched): {attn.ndim}") |
| | self._debug_attn_map_printed = True |
| |
|
| | def grad_hook(grad): |
| | return self._tgr_process_grad_attention(grad, self.gamma_attn) |
| |
|
| | attn.register_hook(grad_hook) |
| |
|
| | attn = this.attn_drop(attn) |
| | x_out = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x_out = this.proj(x_out) |
| | proj_drop = getattr(this, "proj_drop", None) |
| | if proj_drop is not None: |
| | x_out = proj_drop(x_out) |
| | return x_out |
| |
|
| | attn_module.forward = types.MethodType(forward_patched, attn_module) |
| |
|
| | def _tgr_process_grad_attention(self, grad: torch.Tensor, gamma: float) -> torch.Tensor: |
| | """Regularização TGR para componente Attention. |
| | |
| | Paper-faithful (Algoritmo 1): atua no gradiente do mapa de atenção |
| | com shape [B, H, N, N] (H=heads). Para cada head (canal de saída), |
| | seleciona 2k posições extremas e zera a linha e a coluna correspondentes. |
| | |
| | Mantemos também suportes legados: |
| | - [B, N, C] (tokens): fallback para arquiteturas onde só há gradiente token-wise. |
| | - [B, C, H, W] (CNN): fallback histórico. |
| | |
| | Args: |
| | grad: gradiente [B,H,N,N] (atenção) ou [B,N,C] ou [B,C,H,W] |
| | gamma: fator de escala (paper usa 0.25). Se gamma=1.0, retorna sem modificação. |
| | """ |
| | if grad is None: |
| | return grad |
| | |
| | |
| | if abs(gamma - 1.0) < 1e-6: |
| | return grad |
| | |
| | g = grad * gamma |
| |
|
| | |
| | if g.ndim == 4 and g.shape[-1] == g.shape[-2] and g.shape[1] <= 64: |
| | try: |
| | B, Hh, N, _ = g.shape |
| | k_actual = min(self.k, N * N) |
| | if k_actual <= 0: |
| | return g |
| |
|
| | for b in range(B): |
| | for h in range(Hh): |
| | gh = g[b, h] |
| | flat = gh.reshape(-1) |
| | _, idx_max = torch.topk(flat, k_actual, largest=True) |
| | _, idx_min = torch.topk(flat, k_actual, largest=False) |
| | idxs = torch.cat([idx_max, idx_min], dim=0) |
| |
|
| | removed_cls = False |
| | for idx in idxs.tolist(): |
| | r = idx // N |
| | c = idx % N |
| | if self.protect_cls_token and (r == 0 or c == 0): |
| | removed_cls = True |
| | continue |
| | g[b, h, r, :] = 0.0 |
| | g[b, h, :, c] = 0.0 |
| |
|
| | if self.debug_shapes and b == 0 and h == 0: |
| | extra = " (CLS protegido)" if removed_cls else "" |
| | print( |
| | f"[TGR DEBUG] AttentionMap: head0 zerou linhas/cols por 2k={2*k_actual} entradas{extra}" |
| | ) |
| |
|
| | return g |
| | except Exception as e: |
| | warnings.warn(f"[TGR] AttentionMap ([B,H,N,N]): fallback ({e})") |
| | return g |
| | |
| | |
| | if g.ndim == 3: |
| | |
| | try: |
| | B, N, C = g.shape |
| | for b in range(B): |
| | |
| | token_ids = set() |
| | for c in range(C): |
| | v = g[b, :, c] |
| | k_actual = min(self.k, N) |
| | |
| | if k_actual > 0: |
| | _, idx_max = torch.topk(v, k_actual, largest=True) |
| | _, idx_min = torch.topk(v, k_actual, largest=False) |
| | token_ids.update(idx_max.tolist()) |
| | token_ids.update(idx_min.tolist()) |
| |
|
| | removed_cls = False |
| | if self.protect_cls_token and 0 in token_ids: |
| | token_ids.discard(0) |
| | removed_cls = True |
| | |
| | |
| | if self.debug_shapes and b == 0: |
| | extra = " (CLS protegido)" if removed_cls else "" |
| | print(f"[TGR DEBUG] AttentionTokens: zerando {len(token_ids)}/{N} tokens (k={self.k}, C={C}){extra}") |
| | |
| | |
| | for t in token_ids: |
| | g[b, t, :] = 0.0 |
| | |
| | return g |
| | except Exception as e: |
| | warnings.warn(f"[TGR] Atenção ([B,N,C]): fallback ({e})") |
| | return g |
| | |
| | |
| | elif g.ndim == 4: |
| | B, C, H, W = g.shape |
| | |
| | |
| | if H * W >= C: |
| | try: |
| | g_flat = g[0].reshape(C, H * W) |
| | max_idx = g_flat.argmax(dim=1) |
| | min_idx = g_flat.argmin(dim=1) |
| | |
| | max_h = max_idx // W |
| | max_w = max_idx % W |
| | min_h = min_idx // W |
| | min_w = min_idx % W |
| | |
| | c_range = torch.arange(C, device=g.device) |
| | g[:, c_range, max_h, :] = 0.0 |
| | g[:, c_range, :, max_w] = 0.0 |
| | g[:, c_range, min_h, :] = 0.0 |
| | g[:, c_range, :, min_w] = 0.0 |
| | |
| | return g |
| | except Exception as e: |
| | warnings.warn(f"[TGR] Atenção ([B,C,H,W]): fallback ({e})") |
| | return g |
| | |
| | |
| | return g |
| | |
| | def _tgr_process_grad_tokens(self, grad: torch.Tensor, gamma: float) -> torch.Tensor: |
| | """Regularização TGR para componentes QKV/MLP (conforme implementação original do paper). |
| | |
| | Para gradiente shape [B, N, C] (entrada do QKV/MLP): |
| | - Escala por gamma |
| | - Para cada canal c, encontra top-k e bottom-k tokens (por valor) |
| | - Zera as ENTRADAS extremas (token, canal), isto é: g[b, token, c] = 0 |
| | |
| | Observação: isso difere de "zerar token inteiro". É o que o código oficial |
| | faz quando executa: out_grad[:, max_all, range(c)] = 0.0. |
| | """ |
| | if grad is None: |
| | return grad |
| | |
| | |
| | if abs(gamma - 1.0) < 1e-6: |
| | return grad |
| | |
| | g = grad * gamma |
| | |
| | try: |
| | if g.ndim == 3: |
| | B, N, C = g.shape |
| | for b in range(B): |
| | |
| | k_actual = min(self.k, N) |
| | zeroed = 0 |
| | for c in range(C): |
| | v = g[b, :, c] |
| | if k_actual <= 0: |
| | continue |
| | _, idx_max = torch.topk(v, k_actual, largest=True) |
| | _, idx_min = torch.topk(v, k_actual, largest=False) |
| | for t in idx_max.tolist() + idx_min.tolist(): |
| | if self.protect_cls_token and t == 0: |
| | continue |
| | g[b, t, c] = 0.0 |
| | zeroed += 1 |
| |
|
| | if self.debug_shapes and b == 0: |
| | if not hasattr(self, "_debug_token_zero_counts"): |
| | self._debug_token_zero_counts = {} |
| | key = f"gamma={gamma:.3f}" |
| | count = self._debug_token_zero_counts.get(key, 0) |
| | if count < 3: |
| | |
| | print( |
| | f"[TGR DEBUG] Tokens ({key}): zerando ~{zeroed} entradas (2*k*C={2*k_actual*C}, ataque em [token,canal])" |
| | ) |
| | self._debug_token_zero_counts[key] = count + 1 |
| | |
| | except Exception as e: |
| | warnings.warn(f"[TGR] Tokens: fallback no processo de QKV/MLP ({e})") |
| | g = grad * gamma |
| | |
| | return g |
| |
|
| | def _make_attention_hook(self): |
| | raise RuntimeError("_make_attention_hook não é mais usado; use _patch_attention_forward") |
| | |
| | def _make_qkv_hook(self): |
| | """Hook para componente QKV.""" |
| | def hook(module, grad_input, grad_output): |
| | if not grad_input or grad_input[0] is None: |
| | return grad_input |
| | g0_new = self._tgr_process_grad_tokens(grad_input[0], self.gamma_qkv) |
| | return (g0_new,) + tuple(grad_input[1:]) |
| | return hook |
| | |
| | def _make_mlp_hook(self): |
| | """Hook para componente MLP.""" |
| | def hook(module, grad_input, grad_output): |
| | if not grad_input or grad_input[0] is None: |
| | return grad_input |
| | g0_new = self._tgr_process_grad_tokens(grad_input[0], self.gamma_mlp) |
| | return (g0_new,) + tuple(grad_input[1:]) |
| | return hook |
| |
|
| | def _register_tgr_hooks(self) -> List[torch.utils.hooks.RemovableHandle]: |
| | """Registra hooks conforme Algoritmo 1 do paper TGR. |
| | |
| | Implementação alinhada ao código oficial: |
| | - Attention: aplica TGR no gradiente do mapa de atenção [B,H,N,N] |
| | (monkeypatch no forward do módulo de atenção para anexar hook no tensor `attn`) |
| | - QKV: hook em `attn.qkv` para regularizar grad_input[0] ([B,N,C]) |
| | - MLP: hook no `mlp` para regularizar grad_input[0] ([B,N,C]) |
| | |
| | Se não encontrar nenhum módulo compatível, não registra nada; o ataque |
| | ainda funciona (equivale a um MI-FGSM), apenas sem regularização TGR. |
| | """ |
| | handles: List[torch.utils.hooks.RemovableHandle] = [] |
| | warned_attn = False |
| |
|
| | |
| | if hasattr(self.model, "blocks"): |
| | for block in self.model.blocks: |
| | attn_module = getattr(block, "attn", None) |
| | if attn_module is not None: |
| | |
| | |
| | if self.enable_attn_hook: |
| | if hasattr(attn_module, "qkv") and hasattr(attn_module, "num_heads") and hasattr(attn_module, "proj"): |
| | self._patch_attention_forward(attn_module) |
| | elif not warned_attn: |
| | warnings.warn( |
| | "[TGR] Nenhum módulo de atenção compatível encontrado (qkv/num_heads/proj); " |
| | "pulando regularização TGR-Attention. Apenas QKV/MLP serão regularizados." |
| | ) |
| | warned_attn = True |
| | |
| | |
| | if self.enable_qkv_hook and hasattr(attn_module, "qkv"): |
| | handles.append( |
| | attn_module.qkv.register_full_backward_hook(self._make_qkv_hook()) |
| | ) |
| |
|
| | |
| | mlp = getattr(block, "mlp", None) |
| | if self.enable_mlp_hook and mlp is not None: |
| | handles.append(mlp.register_full_backward_hook(self._make_mlp_hook())) |
| |
|
| | if not handles: |
| | warnings.warn( |
| | "[TGR] Nenhum módulo compatível encontrado para hooks; " |
| | "executando como MI-FGSM (sem regularização interna)." |
| | ) |
| | elif self.debug_shapes: |
| | print(f"[TGR DEBUG] Registrados {len(handles)} hooks") |
| |
|
| | return handles |
| |
|
| | |
| |
|
| | def forward(self, images: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, List[Image.Image]]: |
| | """Executa o ataque TGR. |
| | |
| | Retorna: |
| | - adv_images: tensor adversarial final (normalizado) |
| | - iteration_images: lista de PIL Images (uma por iteração, incluindo original) |
| | """ |
| | images = images.clone().detach().to(self.device) |
| | labels = labels.clone().detach().to(self.device) |
| |
|
| | |
| | mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) |
| | std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) |
| |
|
| | |
| | images_denorm = images * std + mean |
| | unnorm_inps = images_denorm.clone().detach() |
| |
|
| | |
| | perts = torch.zeros_like(unnorm_inps).detach() |
| |
|
| | |
| | self.iteration_images = [] |
| | self.iteration_tensors = [] |
| | self.attentions_per_iter = [] |
| | self.debug_progress_log = [] |
| |
|
| | |
| | self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False)) |
| | self.iteration_tensors.append(images.clone().detach()) |
| | |
| | |
| | was_training = self.model.training |
| | self.model.eval() |
| | |
| | |
| | outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) |
| | self.attentions_per_iter.append([att.detach().cpu() for att in attentions0]) |
| |
|
| | momentum = torch.zeros_like(perts).detach().to(self.device) |
| |
|
| | handles: List[torch.utils.hooks.RemovableHandle] = [] |
| | try: |
| | handles = self._register_tgr_hooks() |
| |
|
| | self.debug_last = {} |
| | for step_idx in range(self.steps): |
| | |
| | perts = perts.detach().requires_grad_(True) |
| | adv_norm = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - mean) / std |
| |
|
| | outputs, attentions = capture_outputs_and_attentions(self.model, adv_norm) |
| | if isinstance(outputs, tuple): |
| | outputs = outputs[0] |
| |
|
| | loss = self.loss_fn(outputs, labels) |
| |
|
| | if self.debug_progress: |
| | with torch.no_grad(): |
| | probs = torch.softmax(outputs, dim=1) |
| | pred = probs.argmax(dim=1) |
| | conf_pred = probs.gather(1, pred.view(-1, 1)).squeeze(1) |
| | conf_label = probs.gather(1, labels.view(-1, 1)).squeeze(1) |
| | delta_now = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps).detach() |
| | dmax = float(delta_now.abs().max().item()) |
| | dmean = float(delta_now.abs().mean().item()) |
| | changed = float((delta_now.abs() > 1e-6).float().mean().item()) |
| | self.debug_progress_log.append( |
| | { |
| | "iter": int(step_idx), |
| | "loss": float(loss.detach().item()), |
| | "pred": pred.detach().cpu().tolist(), |
| | "label": labels.detach().cpu().tolist(), |
| | "conf_pred": conf_pred.detach().cpu().tolist(), |
| | "conf_label": conf_label.detach().cpu().tolist(), |
| | "delta_linf": dmax, |
| | "delta_mean": dmean, |
| | "pixels_changed_ratio": changed, |
| | } |
| | ) |
| | |
| | print( |
| | f"[TGR PROGRESS] it={step_idx} loss={loss.item():.4f} " |
| | f"pred={int(pred[0])} conf_pred={conf_pred[0].item():.4f} " |
| | f"label={int(labels[0])} conf_label={conf_label[0].item():.4f} " |
| | f"dLinf={dmax:.6f} dMean={dmean:.6f} changed={changed*100:.1f}%" |
| | ) |
| | grad_norm = torch.autograd.grad( |
| | loss, |
| | perts, |
| | retain_graph=False, |
| | create_graph=False, |
| | )[0] |
| |
|
| | |
| | self.attentions_per_iter.append([att.detach().cpu() for att in attentions]) |
| |
|
| | |
| | grad_denorm = grad_norm |
| |
|
| | if self.debug_stats: |
| | |
| | self.debug_last[f"iter_{step_idx}"] = { |
| | "loss": float(loss.detach().item()), |
| | "grad_norm_abs_mean": float(grad_norm.detach().abs().mean().item()), |
| | "grad_norm_abs_max": float(grad_norm.detach().abs().max().item()), |
| | "grad_denorm_abs_mean_pre_norm": float(grad_denorm.detach().abs().mean().item()), |
| | "grad_denorm_abs_max_pre_norm": float(grad_denorm.detach().abs().max().item()), |
| | } |
| |
|
| | |
| | denom = torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True) + 1e-12 |
| | grad_denorm = grad_denorm / denom |
| | |
| | grad_denorm = grad_denorm + momentum * self.decay |
| | momentum = grad_denorm |
| |
|
| | |
| | perts = perts.detach() + self.eps_step * grad_denorm.sign() |
| | perts = torch.clamp(perts, -self.eps, self.eps) |
| | |
| | perts = torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps |
| |
|
| | if self.debug_shapes: |
| | step_size = (self.eps_step * grad_denorm.sign()).abs().max().item() |
| | grad_sign_nonzero = (grad_denorm.sign().abs() > 0).float().mean().item() |
| | print(f"[TGR DEBUG] Step size: {step_size:.6f}, grad_sign non-zero: {grad_sign_nonzero*100:.1f}%") |
| |
|
| | if self.debug_stats: |
| | |
| | iter_stats = self.debug_last.get(f"iter_{step_idx}", {}) |
| | iter_stats.update( |
| | { |
| | "denom_abs_mean": float(denom.detach().mean().item()), |
| | "grad_denorm_abs_mean_post_norm": float(grad_denorm.detach().abs().mean().item()), |
| | "grad_denorm_abs_max_post_norm": float(grad_denorm.detach().abs().max().item()), |
| | "grad_sign_nonzero_ratio": float( |
| | (grad_denorm.detach().sign().abs() > 0).float().mean().item() |
| | ), |
| | "step_size": float((self.eps_step * grad_denorm.detach().sign()).abs().max().item()), |
| | } |
| | ) |
| | self.debug_last[f"iter_{step_idx}"] = iter_stats |
| |
|
| | if self.debug_shapes: |
| | actual_delta = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps).abs().max().item() |
| | print(f"[TGR DEBUG] Iteration delta: {actual_delta:.6f} (eps={self.eps:.6f}, eps_step={self.eps_step:.6f})") |
| |
|
| | |
| | adv_denorm = torch.clamp(unnorm_inps + perts, 0.0, 1.0).detach() |
| | self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False)) |
| | self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach()) |
| |
|
| | finally: |
| | for h in handles: |
| | h.remove() |
| | |
| | if self._patched_attn_forwards: |
| | for attn_module, orig_forward in list(self._patched_attn_forwards.items()): |
| | try: |
| | attn_module.forward = orig_forward |
| | except Exception: |
| | pass |
| | self._patched_attn_forwards.clear() |
| | if hasattr(self, "_debug_attn_map_printed"): |
| | delattr(self, "_debug_attn_map_printed") |
| | |
| | if was_training: |
| | self.model.train() |
| |
|
| | adv_final = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - mean) / std |
| | return adv_final, self.iteration_images |