import torch import torchattacks from PIL import Image from typing import List, Tuple, Optional import numpy as np import warnings from pathlib import Path import types import os try: import torchvision.models as tv_models except Exception: # pragma: no cover - torchvision is optional for ViT-only mode tv_models = None try: import timm except Exception: # pragma: no cover - timm is optional for CNN blending timm = None try: from huggingface_hub import hf_hub_download except Exception: # pragma: no cover - optional dependency hf_hub_download = None def capture_outputs_and_attentions(model, x_norm: torch.Tensor): """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT. Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada. Funciona para modelos do timm com atributo 'blocks' e submódulo 'attn'. """ # TODO: adaptar para pytorch também, além de timm attentions: List[torch.Tensor] = [] def make_attention_hook(): def hook(module, input, output): x = input[0] B, N, C = x.shape if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): return qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) scale = (C // module.num_heads) ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) attentions.append(attn.detach()) return hook hooks = [] if hasattr(model, 'blocks'): for block in model.blocks: if hasattr(block, 'attn'): hooks.append(block.attn.register_forward_hook(make_attention_hook())) model.eval() outputs = model(x_norm) for h in hooks: h.remove() attentions = [a.cpu() for a in attentions] return outputs, attentions def denormalize_imagenet(tensor: torch.Tensor) -> torch.Tensor: """ Reverte a normalização ImageNet de um tensor. Args: tensor: Tensor normalizado (CxHxW) com mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] Returns: Tensor desnormalizado com valores em [0, 1] """ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device) # Inverte: x_norm = (x - mean) / std => x = x_norm * std + mean denorm = tensor * std + mean # Clip para garantir [0, 1] return torch.clamp(denorm, 0, 1) def tensor_to_pil(tensor: torch.Tensor, denormalize: bool = True) -> Image.Image: """ Converte tensor (CxHxW) para PIL Image RGB. Args: tensor: Tensor com shape (C, H, W) denormalize: Se True, aplica desnormalização ImageNet antes da conversão Returns: PIL Image no espaço RGB [0, 255] """ if denormalize: tensor = denormalize_imagenet(tensor) # tensor shape: (C, H, W) com valores [0, 1] img_np = tensor.cpu().detach().numpy() img_np = np.transpose(img_np, (1, 2, 0)) # HxWxC img_np = (img_np * 255).clip(0, 255).astype(np.uint8) return Image.fromarray(img_np, mode='RGB') class FGSM(torchattacks.FGSM): """ Extensão do ataque FGSM (Fast Gradient Sign Method) que captura a imagem original e a imagem adversarial final. FGSM é um ataque de 1 única iteração (non-iterative). """ def __init__(self, model, eps=0.03): super().__init__(model, eps=eps) self.iteration_images: List[Image.Image] = [] self.iteration_tensors: List[torch.Tensor] = [] # Atenções por iteração (iteração 0: original, iteração 1: adversarial) self.attentions_per_iter: List[List[torch.Tensor]] = [] def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: """ Executa o ataque FGSM e retorna: - adv_images: tensor adversarial final - iteration_images: lista com [imagem_original, imagem_adversarial] """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) loss = torch.nn.CrossEntropyLoss() # Desnormalizar para trabalhar no espaço [0,1] mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) images_denorm = images * std + mean self.iteration_images = [] self.iteration_tensors = [] self.attentions_per_iter = [] # Salvar imagem original pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) self.iteration_images.append(pil_img_orig) self.iteration_tensors.append(images.clone().detach()) # Calcular gradiente images.requires_grad = True # Capturar atenções e logits para imagem original outputs, attentions0 = capture_outputs_and_attentions(self.model, images) self.attentions_per_iter.append([att for att in attentions0]) if self.targeted: target_labels = self.get_target_label(images, labels) cost = -loss(outputs, target_labels) else: cost = loss(outputs, labels) grad = torch.autograd.grad(cost, images, retain_graph=False, create_graph=False)[0] # Aplicar perturbação no espaço desnormalizado [0,1] # sign(grad) dá a direção, eps é a magnitude no pixel space adv_images_denorm = images_denorm + self.eps * grad.sign() adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach() # Normalizar de volta adv_images = (adv_images_denorm - mean) / std # Salvar imagem adversarial pil_img_adv = tensor_to_pil(adv_images_denorm[0], denormalize=False) self.iteration_images.append(pil_img_adv) self.iteration_tensors.append(adv_images.clone().detach()) # Capturar atenções para imagem adversarial final outputs_adv, attentions1 = capture_outputs_and_attentions(self.model, adv_images) self.attentions_per_iter.append([att for att in attentions1]) return adv_images, self.iteration_images class PGDIterations(torchattacks.PGD): """ Extensão do ataque PGD padrão que captura e retorna as imagens adversariais de cada iteração como lista de PIL Images. """ def __init__(self, model, eps=0.05, alpha=0.005, steps=10, random_start=True): # Inicializa PGD padrão com os parâmetros super().__init__(model, eps=eps, alpha=alpha, steps=steps, random_start=random_start) self.iteration_images: List[Image.Image] = [] self.iteration_tensors: List[torch.Tensor] = [] self.attentions_per_iter: List[List[torch.Tensor]] = [] def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: """ Executa o ataque PGD e retorna: - adv_images: tensor adversarial final - iteration_images: lista de PIL Images (uma por iteração do ataque) Implementação adaptada para trabalhar com imagens normalizadas ImageNet. """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) # Para targeted attack (se implementarmos no futuro) if self.targeted: target_labels = self.get_target_label(images, labels) loss = torch.nn.CrossEntropyLoss() adv_images = images.clone().detach() # Desnormalizar para aplicar eps e clipping no espaço correto [0,1] mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) # Converter para espaço [0,1] images_denorm = images * std + mean adv_images_denorm = images_denorm.clone().detach() if self.random_start: # Starting at a uniformly random point no espaço [0,1] adv_images_denorm = adv_images_denorm + torch.empty_like(adv_images_denorm).uniform_(-self.eps, self.eps) adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach() self.iteration_images = [] self.iteration_tensors = [] self.attentions_per_iter = [] # Salvar iteração 0 (imagem original) pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) self.iteration_images.append(pil_img_orig) self.iteration_tensors.append(images.clone().detach()) # Atenções da imagem original outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) self.attentions_per_iter.append([att for att in attentions0]) for step_idx in range(self.steps): # Normalizar para passar pelo modelo adv_images = (adv_images_denorm - mean) / std adv_images.requires_grad = True outputs, attentions = capture_outputs_and_attentions(self.model, adv_images) # Calculate loss if self.targeted: cost = -loss(outputs, target_labels) else: cost = loss(outputs, labels) # Update adversarial images grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] # Aplicar perturbação no espaço desnormalizado [0,1] # sign(grad) dá a direção, alpha é o step size no pixel space adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad.sign() delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() # Normalizar para salvar tensor adv_images_normalized = (adv_images_denorm - mean) / std # Capturar imagem e tensor desta iteração pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False) self.iteration_images.append(pil_img) self.iteration_tensors.append(adv_images_normalized.clone().detach()) # Atenções desta iteração self.attentions_per_iter.append([att for att in attentions]) # Retornar imagem normalizada para o modelo adv_images = (adv_images_denorm - mean) / std return adv_images, self.iteration_images class SAGA(torch.nn.Module): """ SAGA: Self-Attention Gradient Attack Ataque adversarial específico para Vision Transformers que multiplica o gradiente FGSM pelo mapa de atenção do modelo, focando perturbações nas regiões que o modelo considera importantes. Baseado em: https://github.com/MetaMain/ViTRobust Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021) """ def __init__(self, model, eps=8/255, steps=10, discard_ratio: float = 0.0, head_fusion: str = "mean", use_resnet: bool = False, cnn_checkpoint_path: str = "resnet.pth", vit_weight=0.5): """Implementação correta do SAGA baseada no código original (SelfAttentionGradientAttack). Parâmetros: - model: Vision Transformer (deve expor atenções via forward ou função auxiliar em visualization utils) - eps: orçamento L_inf máximo (em pixel space [0,1]) - steps: número de iterações (FGSM iterativo) - discard_ratio: razão de descarte usada no attention rollout - head_fusion: estratégia de fusão de heads ('mean','max','min') - use_resnet: se True, acumula gradiente de um backbone CNN externo e o mistura ao gradiente ponderado pela atenção - cnn_checkpoint_path: caminho padrão do backbone CNN auxiliar (será carregado sob demanda) """ super().__init__() self.model = model self.eps = eps self.steps = steps self.eps_step = self.eps / max(1, steps) self.discard_ratio = discard_ratio self.head_fusion = head_fusion self.use_resnet = use_resnet # Pode ser um caminho local ou uma referência ao Hugging Face Hub: # - Local: "models/resnet.pth" # - HF Hub: "hf://usuario/repo/path/no/repo/resnet.pth" (opcionalmente com @revision) self.cnn_checkpoint_spec = cnn_checkpoint_path self.cnn_model: Optional[torch.nn.Module] = None self.vit_weight = vit_weight self.device = next(model.parameters()).device self.iteration_images: List[Image.Image] = [] self.iteration_tensors: List[torch.Tensor] = [] self.attention_masks_cache: List[np.ndarray] = [] # Cache opcional: atenções por camada/head em cada iteração # Formato: lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada self.attentions_per_iter: List[List[torch.Tensor]] = [] self.loss_fn = torch.nn.CrossEntropyLoss() @staticmethod def _resolve_checkpoint_path(spec: object) -> Optional[Path]: """Resolve um checkpoint local ou no Hugging Face Hub para um Path local. Formato suportado (HF): - hf://owner/repo/path/to/file.pth - hf://owner/repo@revision/path/to/file.pth Retorna None se não conseguir resolver. """ if spec is None: return None if isinstance(spec, Path): return spec if not isinstance(spec, str): return None s = spec.strip() if s.startswith("hf://") or s.startswith("hf:"): if hf_hub_download is None: warnings.warn("huggingface_hub não está instalado; não é possível baixar checkpoint via HF Hub.") return None rest = s[len("hf://"):] if s.startswith("hf://") else s[len("hf:"):] rest = rest.lstrip("/") parts = [p for p in rest.split("/") if p] if len(parts) < 3: raise ValueError( "Formato inválido para checkpoint HF. Use hf://owner/repo/path/to/file.pth" ) repo_part = "/".join(parts[:2]) filename = "/".join(parts[2:]) revision = None if "@" in repo_part: repo_id, revision = repo_part.split("@", 1) else: repo_id = repo_part cache_dir = os.getenv("HF_HOME") or None local_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir) return Path(local_path) return Path(s) def _attention_map(self, images_norm: torch.Tensor, save: bool = False) -> torch.Tensor: """Extrai mapa de atenção (rollout) e retorna tensor expandido [B,3,H,W] em [0,1]. images_norm: imagens já normalizadas para o forward do modelo. """ # Esta função agora assume que as atenções foram capturadas no mesmo forward # e serão passadas externamente; mantida para compatibilidade se necessário. raise RuntimeError("_attention_map should not be called directly; use integrated forward attention capture.") def _capture_outputs_and_attentions(self, x_norm: torch.Tensor): """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT. Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada. """ attentions: List[torch.Tensor] = [] def make_attention_hook(): def hook(module, input, output): # input[0] é o embedding antes de atenção (B, N, C) x = input[0] B, N, C = x.shape if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): return qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) scale = (C // module.num_heads) ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) attentions.append(attn.detach()) return hook hooks = [] if not hasattr(self.model, 'blocks'): outputs = self.model(x_norm) return outputs, [] for block in self.model.blocks: if hasattr(block, 'attn'): hooks.append(block.attn.register_forward_hook(make_attention_hook())) self.model.eval() outputs = self.model(x_norm) for h in hooks: h.remove() # mover atenções para CPU para cache leve attentions = [a.cpu() for a in attentions] return outputs, attentions def _load_cnn_backbone(self) -> Optional[torch.nn.Module]: """Carrega (lazy) o backbone CNN auxiliar usado quando use_resnet=True.""" if not self.use_resnet: return None if self.cnn_model is not None: return self.cnn_model if tv_models is None: warnings.warn("torchvision não disponível; desabilitando modo CNN do SAGA.") return None model: Optional[torch.nn.Module] = None checkpoint_model_name = "resnetv2_101x1_bit.goog_in21k_ft_in1k" resolved_ckpt_path = None try: resolved_ckpt_path = self._resolve_checkpoint_path(self.cnn_checkpoint_spec) except Exception as exc: warnings.warn(f"Falha ao resolver cnn_checkpoint_path='{self.cnn_checkpoint_spec}': {exc}") if resolved_ckpt_path and resolved_ckpt_path.exists(): try: checkpoint = torch.load(resolved_ckpt_path, map_location=self.device) if isinstance(checkpoint, torch.nn.Module): model = checkpoint elif isinstance(checkpoint, dict): state_dict = checkpoint.get('model_state_dict') or checkpoint.get('state_dict') or checkpoint if timm is not None and any(key.startswith("stem.") for key in state_dict.keys()): num_classes = None head_bias = state_dict.get('head.fc.bias') if isinstance(head_bias, torch.Tensor): num_classes = head_bias.shape[0] model = timm.create_model( checkpoint.get("model_name", checkpoint_model_name), pretrained=False, num_classes=num_classes or 1000 ) load_result = model.load_state_dict(state_dict, strict=False) else: model = tv_models.resnet101(weights=None) load_result = model.load_state_dict(state_dict, strict=False) missing = load_result.missing_keys unexpected = load_result.unexpected_keys if missing or unexpected: warn_msg = "[SAGA] ResNet checkpoint keys mismatch." if missing: warn_msg += f" Missing: {missing[:5]}{'...' if len(missing) > 5 else ''}." if unexpected: warn_msg += f" Unexpected: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}." warnings.warn(warn_msg + " Using available weights (strict=False).") else: warnings.warn(f"Formato de checkpoint desconhecido em {resolved_ckpt_path}; utilizando pesos padrão.") except Exception as exc: # pragma: no cover - fallback resiliente warnings.warn(f"Falha ao carregar {resolved_ckpt_path}: {exc}. Usando ResNet padrão.") if model is None: if timm is not None: try: model = timm.create_model(checkpoint_model_name, pretrained=True) except Exception: model = None if model is None and tv_models is not None: try: model = tv_models.resnet101(weights="IMAGENET1K_V2") except Exception: model = tv_models.resnet101(pretrained=True) model = model.to(self.device) model.eval() self.cnn_model = model return self.cnn_model def _compute_cnn_gradient(self, images_norm: torch.Tensor, labels: torch.Tensor) -> Optional[torch.Tensor]: """Obtém gradientes do backbone CNN auxiliar para a mesma imagem normalizada.""" cnn_model = self._load_cnn_backbone() if cnn_model is None: return None cnn_input = images_norm.detach().clone().requires_grad_(True) outputs = cnn_model(cnn_input) loss = self.loss_fn(outputs, labels) grad = torch.autograd.grad(loss, cnn_input, retain_graph=False, create_graph=False)[0] return grad def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: """Executa o ataque SAGA (FGSM iterativo com ponderação por atenção). Fluxo por iteração: 1. Normaliza a imagem adversarial atual. 2. Calcula loss e gradiente. 3. Extrai mapa de atenção da imagem atual e pondera gradiente. 4. Aplica passo FGSM (sign) em pixel space [0,1]. 5. Projeta em L_inf (clamp delta) e clip final para [0,1]. 6. Salva imagem e tensor normalizado. """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) # Mean/std ImageNet para conversão entre espaços mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) # Pixel space [0,1] images_denorm = images * std + mean adv_denorm = images_denorm.clone().detach() # Reset buffers self.iteration_images = [] self.iteration_tensors = [] self.attention_masks_cache = [] self.attentions_per_iter = [] # Iteração 0 (imagem original) self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False)) self.iteration_tensors.append(images.clone().detach()) # Atenção da imagem original: captura integrada outputs0, attentions0 = self._capture_outputs_and_attentions(images) # Guardar atenções brutas self.attentions_per_iter.append([att for att in attentions0]) # Gerar máscara de rollout para cache visual from utils.visualization import attention_rollout import cv2 b, _, h, w = images.shape mask0 = attention_rollout(attentions0, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion) mask0_resized = cv2.resize(mask0, (w, h)) self.attention_masks_cache.append(mask0.copy()) for step_idx in range(self.steps): # Normalizar para forward adv_norm = (adv_denorm - mean) / std adv_norm.requires_grad = True outputs, attentions = self._capture_outputs_and_attentions(adv_norm) if isinstance(outputs, tuple): # compatibilidade com modelos que retornam extras outputs = outputs[0] loss = self.loss_fn(outputs, labels) grad = torch.autograd.grad(loss, adv_norm, retain_graph=False, create_graph=False)[0] # Atenção da imagem adversarial atual (já capturada) # Cache de atenções por camada/head self.attentions_per_iter.append([att for att in attentions]) # Rollout para gerar mapa usado na ponderação mask = attention_rollout(attentions, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion) mask_resized = cv2.resize(mask, (adv_norm.shape[-1], adv_norm.shape[-2])) mmax = mask_resized.max() if mask_resized.max() > 0 else 1.0 mask_resized = (mask_resized / mmax).astype('float32') att_map = torch.from_numpy(mask_resized).to(self.device).unsqueeze(0).unsqueeze(0).repeat(adv_norm.size(0), 3, 1, 1) # Cache visual self.attention_masks_cache.append(mask.copy()) grad_weighted = grad * att_map grad_final = grad_weighted if self.use_resnet: cnn_grad = self._compute_cnn_gradient(adv_norm, labels) if cnn_grad is not None: vit_contrib = grad_weighted.detach().abs().mean().item() cnn_contrib = cnn_grad.detach().abs().mean().item() grad_final = self.vit_weight * grad_weighted + (1 - self.vit_weight) * cnn_grad blended_contrib = grad_final.detach().abs().mean().item() # FGSM step em pixel space (sign do gradiente normalizado equivale ao do desnormalizado) adv_denorm = adv_denorm.detach() + self.eps_step * grad_final.sign() # Projeção na bola L_inf de raio eps em relação à imagem original delta = torch.clamp(adv_denorm - images_denorm, min=-self.eps, max=self.eps) adv_denorm = torch.clamp(images_denorm + delta, 0.0, 1.0).detach() # Salvar artefatos desta iteração self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False)) self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach()) # Retorna tensor normalizado final adv_final = (adv_denorm - mean) / std return adv_final, self.iteration_images class AttentionWeightedPGD(torch.nn.Module): """ [Deprecated] Implementação errada do ataque SAGA, mas que consegue fazer ataques adversariais eficazes em ViTs usando mapas de atenção para pesar o gradiente. """ def __init__(self, model, eps=0.03, steps=10): super().__init__() self.model = model self.eps = eps self.steps = steps self.eps_step = self.eps / self.steps self.device = next(model.parameters()).device self.iteration_images: List[Image.Image] = [] self.iteration_tensors: List[torch.Tensor] = [] self.attention_masks_cache: List[np.ndarray] = [] # Cache das máscaras de atenção def get_attention_map(self, images: torch.Tensor, save_for_viz: bool = False) -> tuple: """ Extrai mapa de atenção do ViT usando attention rollout. Retorna: - mask_tensor: [B, C, H, W] para uso no ataque - mask_np: [H, W] numpy array para visualização (se save_for_viz=True) """ from utils.visualization import extract_attention_maps, attention_rollout import cv2 batch_size = images.shape[0] img_size = images.shape[2] # Extrair attention maps attentions = extract_attention_maps(self.model, images) # Aplicar attention rollout mask = attention_rollout(attentions, discard_ratio=0.9, head_fusion='max') # Salvar para visualização se necessário if save_for_viz: self.attention_masks_cache.append(mask.copy()) # Redimensionar para tamanho da imagem (14x14 -> 224x224) mask_resized = cv2.resize(mask, (img_size, img_size)) # Expandir para 3 canais e batch: [H, W] -> [B, C, H, W] mask_tensor = torch.from_numpy(mask_resized).float().to(self.device) mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] mask_tensor = mask_tensor.repeat(batch_size, 3, 1, 1) # [B, 3, H, W] return mask_tensor, mask if save_for_viz else None def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: """ Executa ataque SAGA e retorna: - adv_images: tensor adversarial final - iteration_images: lista de PIL Images de cada iteração """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) loss_fn = torch.nn.CrossEntropyLoss() # Desnormalizar para trabalhar no espaço [0,1] mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) images_denorm = images * std + mean adv_images_denorm = images_denorm.clone().detach() self.iteration_images = [] self.iteration_tensors = [] self.attention_masks_cache = [] # Salvar imagem original (iteração 0) pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) self.iteration_images.append(pil_img_orig) self.iteration_tensors.append(images.clone().detach()) # Calcular atenção para imagem original e salvar attention_map, _ = self.get_attention_map(images, save_for_viz=True) for step in range(self.steps): # Normalizar para passar pelo modelo adv_images = (adv_images_denorm - mean) / std adv_images.requires_grad = True # Forward pass outputs = self.model(adv_images) # Calcular loss cost = loss_fn(outputs, labels) # Calcular gradiente grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] # RECALCULAR atenção para a imagem adversarial ATUAL (chave do SAGA!) attention_map, _ = self.get_attention_map(adv_images.detach(), save_for_viz=True) # SAGA: Multiplicar gradiente pelo mapa de atenção grad_weighted = grad * attention_map # Aplicar perturbação no espaço desnormalizado [0,1] adv_images_denorm = adv_images_denorm.detach() + self.eps_step * grad_weighted.sign() delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() # Normalizar para salvar tensor adv_images_normalized = (adv_images_denorm - mean) / std # Salvar iteração pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False) self.iteration_images.append(pil_img) self.iteration_tensors.append(adv_images_normalized.clone().detach()) # Retornar imagem normalizada adv_images = (adv_images_denorm - mean) / std return adv_images, self.iteration_images class MIFGSM(torchattacks.MIFGSM): """ MI-FGSM: Momentum Iterative Fast Gradient Sign Method Extensão do ataque MIFGSM que captura imagens e atenção de cada iteração. Usa momentum para estabilizar direção do gradiente e melhorar transferabilidade. Paper: "Boosting Adversarial Attacks with Momentum" (2017) https://arxiv.org/abs/1710.06081 """ def __init__(self, model, eps=8/255, alpha=2/255, steps=10, decay=1.0): super().__init__(model, eps=eps, alpha=alpha, steps=steps, decay=decay) self.iteration_images: List[Image.Image] = [] self.iteration_tensors: List[torch.Tensor] = [] self.attentions_per_iter: List[List[torch.Tensor]] = [] def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: """ Executa o ataque MI-FGSM e retorna: - adv_images: tensor adversarial final - iteration_images: lista de PIL Images (uma por iteração) Implementação adaptada para trabalhar com imagens normalizadas ImageNet e capturar todas as iterações. """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) if self.targeted: target_labels = self.get_target_label(images, labels) loss = torch.nn.CrossEntropyLoss() # Desnormalizar para aplicar eps e clipping no espaço correto [0,1] mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) images_denorm = images * std + mean adv_images_denorm = images_denorm.clone().detach() # Inicializar momentum no espaço desnormalizado momentum = torch.zeros_like(images_denorm).detach().to(self.device) self.iteration_images = [] self.iteration_tensors = [] self.attentions_per_iter = [] # Salvar imagem original (iteração 0) pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) self.iteration_images.append(pil_img_orig) self.iteration_tensors.append(images.clone().detach()) # Atenções da imagem original outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) self.attentions_per_iter.append([att for att in attentions0]) for step in range(self.steps): # Normalizar para passar pelo modelo com gradiente adv_images = (adv_images_denorm - mean) / std adv_images.requires_grad = True outputs, attentions = capture_outputs_and_attentions(self.model, adv_images) # Calcular loss if self.targeted: cost = -loss(outputs, target_labels) else: cost = loss(outputs, labels) # Calcular gradiente no espaço normalizado grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] # Cache de atenções desta iteração self.attentions_per_iter.append([att for att in attentions]) # Converter gradiente para espaço desnormalizado grad_denorm = grad * std # Normalizar gradiente (chave do MI-FGSM!) grad_denorm = grad_denorm / torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True) # Aplicar momentum no espaço desnormalizado grad_denorm = grad_denorm + momentum * self.decay momentum = grad_denorm # Aplicar perturbação no espaço desnormalizado adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad_denorm.sign() delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() # Normalizar e armazenar artefatos desta iteração adv_images_normalized = (adv_images_denorm - mean) / std self.iteration_tensors.append(adv_images_normalized.clone().detach()) pil_iter = tensor_to_pil(adv_images_denorm[0], denormalize=False) self.iteration_images.append(pil_iter) adv_images = (adv_images_denorm - mean) / std return adv_images, self.iteration_images class TGR(torch.nn.Module): """TGR: Token Gradient Regularization attack. Ataque iterativo untargeted, white-box, no estilo MI-FGSM, que aplica regularização de gradiente em módulos internos do transformer via backward hooks (Attention map, QKV, MLP). Diferenças-chave vs. MI-FGSM: - Attention: zera LINHAS e COLUNAS inteiras do mapa N×N (pares extremos) - QKV/MLP: zera TOKENS INTEIROS (todas as features de tokens extremos) - Escala por componente (código oficial): s_attn=0.25, s_qkv=0.75, s_mlp=0.5 O ataque trabalha em pixel space [0,1], respeitando orçamento L_inf. """ def __init__( self, model: torch.nn.Module, eps: float = 16 / 255, steps: int = 10, decay: float = 1.0, k: int = 1, gamma_attn: float = 0.25, gamma_qkv: float = 0.75, gamma_mlp: float = 0.5, debug_shapes: bool = False, enable_attn_hook: bool = True, enable_qkv_hook: bool = True, enable_mlp_hook: bool = True, debug_stats: bool = False, protect_cls_token: bool = True, debug_progress: bool = False, ) -> None: super().__init__() self.model = model self.eps = float(eps) self.steps = int(steps) self.decay = float(decay) self.k = int(k) # número de extremos (paper usa k=1) self.eps_step = self.eps / max(1, self.steps) self.gamma_attn = float(gamma_attn) self.gamma_qkv = float(gamma_qkv) self.gamma_mlp = float(gamma_mlp) self.debug_shapes = bool(debug_shapes) self.enable_attn_hook = bool(enable_attn_hook) self.enable_qkv_hook = bool(enable_qkv_hook) self.enable_mlp_hook = bool(enable_mlp_hook) self.debug_stats = bool(debug_stats) self.protect_cls_token = bool(protect_cls_token) self.debug_progress = bool(debug_progress) self.device = next(model.parameters()).device self.loss_fn = torch.nn.CrossEntropyLoss() self.iteration_images: List[Image.Image] = [] self.iteration_tensors: List[torch.Tensor] = [] self.attentions_per_iter: List[List[torch.Tensor]] = [] self.debug_last: dict = {} self.debug_progress_log: List[dict] = [] self._patched_attn_forwards: dict = {} # ---------------------- hooks & grad processing ---------------------- def _patch_attention_forward(self, attn_module: torch.nn.Module) -> None: """Monkeypatch do forward do Attention para anexar hook no mapa de atenção. Isso permite aplicar o Algoritmo 1 de forma paper-faithful em timm ViTs, porque o tensor de atenção [B,H,N,N] não é exposto diretamente como saída de um submódulo. """ if attn_module in self._patched_attn_forwards: return orig_forward = attn_module.forward self._patched_attn_forwards[attn_module] = orig_forward def forward_patched(this, x, attn_mask=None, **kwargs): B, N, C = x.shape num_heads = getattr(this, "num_heads") qkv = this.qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) scale = getattr(this, "scale", (C // num_heads) ** -0.5) attn = (q @ k.transpose(-2, -1)) * scale # `attn_mask` pode existir em variantes do timm (p.ex. atenção com máscara). Aqui, # para ViT-B/16 padrão, costuma ser None. if attn_mask is not None: # Espera-se broadcastável para [B, H, N, N] attn = attn + attn_mask attn = attn.softmax(dim=-1) if self.debug_shapes and not getattr(self, "_debug_attn_map_printed", False): print(f"[TGR DEBUG] attn_map tensor shape (patched): {attn.shape}") print(f"[TGR DEBUG] attn_map tensor ndim (patched): {attn.ndim}") self._debug_attn_map_printed = True def grad_hook(grad): return self._tgr_process_grad_attention(grad, self.gamma_attn) attn.register_hook(grad_hook) attn = this.attn_drop(attn) x_out = (attn @ v).transpose(1, 2).reshape(B, N, C) x_out = this.proj(x_out) proj_drop = getattr(this, "proj_drop", None) if proj_drop is not None: x_out = proj_drop(x_out) return x_out attn_module.forward = types.MethodType(forward_patched, attn_module) def _tgr_process_grad_attention(self, grad: torch.Tensor, gamma: float) -> torch.Tensor: """Regularização TGR para componente Attention. Paper-faithful (Algoritmo 1): atua no gradiente do mapa de atenção com shape [B, H, N, N] (H=heads). Para cada head (canal de saída), seleciona 2k posições extremas e zera a linha e a coluna correspondentes. Mantemos também suportes legados: - [B, N, C] (tokens): fallback para arquiteturas onde só há gradiente token-wise. - [B, C, H, W] (CNN): fallback histórico. Args: grad: gradiente [B,H,N,N] (atenção) ou [B,N,C] ou [B,C,H,W] gamma: fator de escala (paper usa 0.25). Se gamma=1.0, retorna sem modificação. """ if grad is None: return grad # Se gamma=1.0, não há regularização TGR - retorna gradiente original if abs(gamma - 1.0) < 1e-6: return grad g = grad * gamma # Caso 0: [B, H, N, N] - gradiente do mapa de atenção (paper) if g.ndim == 4 and g.shape[-1] == g.shape[-2] and g.shape[1] <= 64: try: B, Hh, N, _ = g.shape k_actual = min(self.k, N * N) if k_actual <= 0: return g for b in range(B): for h in range(Hh): gh = g[b, h] # [N, N] flat = gh.reshape(-1) _, idx_max = torch.topk(flat, k_actual, largest=True) _, idx_min = torch.topk(flat, k_actual, largest=False) idxs = torch.cat([idx_max, idx_min], dim=0) removed_cls = False for idx in idxs.tolist(): r = idx // N c = idx % N if self.protect_cls_token and (r == 0 or c == 0): removed_cls = True continue g[b, h, r, :] = 0.0 g[b, h, :, c] = 0.0 if self.debug_shapes and b == 0 and h == 0: extra = " (CLS protegido)" if removed_cls else "" print( f"[TGR DEBUG] AttentionMap: head0 zerou linhas/cols por 2k={2*k_actual} entradas{extra}" ) return g except Exception as e: warnings.warn(f"[TGR] AttentionMap ([B,H,N,N]): fallback ({e})") return g # Caso 1: [B, N, C] - tokens (fallback) if g.ndim == 3: # Usar mesma lógica de _tgr_process_grad_tokens try: B, N, C = g.shape for b in range(B): # Paper: rank by channel independently (Seção 3.2) token_ids = set() for c in range(C): v = g[b, :, c] # [N] valores do canal c k_actual = min(self.k, N) if k_actual > 0: _, idx_max = torch.topk(v, k_actual, largest=True) _, idx_min = torch.topk(v, k_actual, largest=False) token_ids.update(idx_max.tolist()) token_ids.update(idx_min.tolist()) removed_cls = False if self.protect_cls_token and 0 in token_ids: token_ids.discard(0) removed_cls = True # Debug: mostrar quantos tokens serão zerados if self.debug_shapes and b == 0: extra = " (CLS protegido)" if removed_cls else "" print(f"[TGR DEBUG] AttentionTokens: zerando {len(token_ids)}/{N} tokens (k={self.k}, C={C}){extra}") # Zera todas as features dos tokens extremos for t in token_ids: g[b, t, :] = 0.0 return g except Exception as e: warnings.warn(f"[TGR] Atenção ([B,N,C]): fallback ({e})") return g # Caso 2: [B, C, H, W] - feature maps espaciais (implementação original TGR) elif g.ndim == 4: B, C, H, W = g.shape # Verifica se é formato espacial (não formato de atenção N×N) if H * W >= C: try: g_flat = g[0].reshape(C, H * W) max_idx = g_flat.argmax(dim=1) min_idx = g_flat.argmin(dim=1) max_h = max_idx // W max_w = max_idx % W min_h = min_idx // W min_w = min_idx % W c_range = torch.arange(C, device=g.device) g[:, c_range, max_h, :] = 0.0 g[:, c_range, :, max_w] = 0.0 g[:, c_range, min_h, :] = 0.0 g[:, c_range, :, min_w] = 0.0 return g except Exception as e: warnings.warn(f"[TGR] Atenção ([B,C,H,W]): fallback ({e})") return g # Fallback: apenas escala return g def _tgr_process_grad_tokens(self, grad: torch.Tensor, gamma: float) -> torch.Tensor: """Regularização TGR para componentes QKV/MLP (conforme implementação original do paper). Para gradiente shape [B, N, C] (entrada do QKV/MLP): - Escala por gamma - Para cada canal c, encontra top-k e bottom-k tokens (por valor) - Zera as ENTRADAS extremas (token, canal), isto é: g[b, token, c] = 0 Observação: isso difere de "zerar token inteiro". É o que o código oficial faz quando executa: out_grad[:, max_all, range(c)] = 0.0. """ if grad is None: return grad # Se gamma=1.0, não há regularização TGR - retorna gradiente original if abs(gamma - 1.0) < 1e-6: return grad g = grad * gamma try: if g.ndim == 3: # [B, N, C] B, N, C = g.shape for b in range(B): # Seleção por canal, como no código oficial k_actual = min(self.k, N) zeroed = 0 for c in range(C): v = g[b, :, c] # [N] if k_actual <= 0: continue _, idx_max = torch.topk(v, k_actual, largest=True) _, idx_min = torch.topk(v, k_actual, largest=False) for t in idx_max.tolist() + idx_min.tolist(): if self.protect_cls_token and t == 0: continue g[b, t, c] = 0.0 zeroed += 1 if self.debug_shapes and b == 0: if not hasattr(self, "_debug_token_zero_counts"): self._debug_token_zero_counts = {} key = f"gamma={gamma:.3f}" count = self._debug_token_zero_counts.get(key, 0) if count < 3: # total de entradas potencialmente zeradas = 2*k*C print( f"[TGR DEBUG] Tokens ({key}): zerando ~{zeroed} entradas (2*k*C={2*k_actual*C}, ataque em [token,canal])" ) self._debug_token_zero_counts[key] = count + 1 except Exception as e: warnings.warn(f"[TGR] Tokens: fallback no processo de QKV/MLP ({e})") g = grad * gamma return g def _make_attention_hook(self): raise RuntimeError("_make_attention_hook não é mais usado; use _patch_attention_forward") def _make_qkv_hook(self): """Hook para componente QKV.""" def hook(module, grad_input, grad_output): if not grad_input or grad_input[0] is None: return grad_input g0_new = self._tgr_process_grad_tokens(grad_input[0], self.gamma_qkv) return (g0_new,) + tuple(grad_input[1:]) return hook def _make_mlp_hook(self): """Hook para componente MLP.""" def hook(module, grad_input, grad_output): if not grad_input or grad_input[0] is None: return grad_input g0_new = self._tgr_process_grad_tokens(grad_input[0], self.gamma_mlp) return (g0_new,) + tuple(grad_input[1:]) return hook def _register_tgr_hooks(self) -> List[torch.utils.hooks.RemovableHandle]: """Registra hooks conforme Algoritmo 1 do paper TGR. Implementação alinhada ao código oficial: - Attention: aplica TGR no gradiente do mapa de atenção [B,H,N,N] (monkeypatch no forward do módulo de atenção para anexar hook no tensor `attn`) - QKV: hook em `attn.qkv` para regularizar grad_input[0] ([B,N,C]) - MLP: hook no `mlp` para regularizar grad_input[0] ([B,N,C]) Se não encontrar nenhum módulo compatível, não registra nada; o ataque ainda funciona (equivale a um MI-FGSM), apenas sem regularização TGR. """ handles: List[torch.utils.hooks.RemovableHandle] = [] warned_attn = False # ViTs estilo timm normalmente expõem model.blocks[*].attn e .mlp if hasattr(self.model, "blocks"): for block in self.model.blocks: attn_module = getattr(block, "attn", None) if attn_module is not None: # Hook 1: Attention component (paper-faithful) # - Monkeypatch do forward para anexar hook no tensor `attn` (softmax) [B,H,N,N]. if self.enable_attn_hook: if hasattr(attn_module, "qkv") and hasattr(attn_module, "num_heads") and hasattr(attn_module, "proj"): self._patch_attention_forward(attn_module) elif not warned_attn: warnings.warn( "[TGR] Nenhum módulo de atenção compatível encontrado (qkv/num_heads/proj); " "pulando regularização TGR-Attention. Apenas QKV/MLP serão regularizados." ) warned_attn = True # Hook 2: QKV component if self.enable_qkv_hook and hasattr(attn_module, "qkv"): handles.append( attn_module.qkv.register_full_backward_hook(self._make_qkv_hook()) ) # Hook 3: MLP component mlp = getattr(block, "mlp", None) if self.enable_mlp_hook and mlp is not None: handles.append(mlp.register_full_backward_hook(self._make_mlp_hook())) if not handles: warnings.warn( "[TGR] Nenhum módulo compatível encontrado para hooks; " "executando como MI-FGSM (sem regularização interna)." ) elif self.debug_shapes: print(f"[TGR DEBUG] Registrados {len(handles)} hooks") return handles # ------------------------------ forward ------------------------------ def forward(self, images: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, List[Image.Image]]: """Executa o ataque TGR. Retorna: - adv_images: tensor adversarial final (normalizado) - iteration_images: lista de PIL Images (uma por iteração, incluindo original) """ images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) # Mean/std ImageNet para conversão entre espaços mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) # Pixel space [0,1] images_denorm = images * std + mean unnorm_inps = images_denorm.clone().detach() # Perturbação em pixel-space, como no código oficial perts = torch.zeros_like(unnorm_inps).detach() # Reset buffers self.iteration_images = [] self.iteration_tensors = [] self.attentions_per_iter = [] self.debug_progress_log = [] # Iteração 0 (imagem original) self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False)) self.iteration_tensors.append(images.clone().detach()) # Garantir eval mode (evita dropout/ruído durante ataque) was_training = self.model.training self.model.eval() # Atenções da imagem original (detach para evitar vazamento de memória) outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) self.attentions_per_iter.append([att.detach().cpu() for att in attentions0]) momentum = torch.zeros_like(perts).detach().to(self.device) handles: List[torch.utils.hooks.RemovableHandle] = [] try: handles = self._register_tgr_hooks() self.debug_last = {} for step_idx in range(self.steps): # Forward do modelo com (imagem + perturbação) em pixel space perts = perts.detach().requires_grad_(True) adv_norm = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - mean) / std outputs, attentions = capture_outputs_and_attentions(self.model, adv_norm) if isinstance(outputs, tuple): outputs = outputs[0] loss = self.loss_fn(outputs, labels) if self.debug_progress: with torch.no_grad(): probs = torch.softmax(outputs, dim=1) pred = probs.argmax(dim=1) conf_pred = probs.gather(1, pred.view(-1, 1)).squeeze(1) conf_label = probs.gather(1, labels.view(-1, 1)).squeeze(1) delta_now = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps).detach() dmax = float(delta_now.abs().max().item()) dmean = float(delta_now.abs().mean().item()) changed = float((delta_now.abs() > 1e-6).float().mean().item()) self.debug_progress_log.append( { "iter": int(step_idx), "loss": float(loss.detach().item()), "pred": pred.detach().cpu().tolist(), "label": labels.detach().cpu().tolist(), "conf_pred": conf_pred.detach().cpu().tolist(), "conf_label": conf_label.detach().cpu().tolist(), "delta_linf": dmax, "delta_mean": dmean, "pixels_changed_ratio": changed, } ) # Mostra só o batch 0 para não poluir print( f"[TGR PROGRESS] it={step_idx} loss={loss.item():.4f} " f"pred={int(pred[0])} conf_pred={conf_pred[0].item():.4f} " f"label={int(labels[0])} conf_label={conf_label[0].item():.4f} " f"dLinf={dmax:.6f} dMean={dmean:.6f} changed={changed*100:.1f}%" ) grad_norm = torch.autograd.grad( loss, perts, retain_graph=False, create_graph=False, )[0] # Cache de atenções desta iteração (detach para evitar vazamento) self.attentions_per_iter.append([att.detach().cpu() for att in attentions]) # Aqui grad_norm já é dL/d(perts) no pixel-space (depois de normalização interna do modelo) grad_denorm = grad_norm if self.debug_stats: # estatísticas no espaço normalizado e no pixel-space self.debug_last[f"iter_{step_idx}"] = { "loss": float(loss.detach().item()), "grad_norm_abs_mean": float(grad_norm.detach().abs().mean().item()), "grad_norm_abs_max": float(grad_norm.detach().abs().max().item()), "grad_denorm_abs_mean_pre_norm": float(grad_denorm.detach().abs().mean().item()), "grad_denorm_abs_max_pre_norm": float(grad_denorm.detach().abs().max().item()), } # Normalizar gradiente (como MI-FGSM) denom = torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True) + 1e-12 grad_denorm = grad_denorm / denom # Momentum grad_denorm = grad_denorm + momentum * self.decay momentum = grad_denorm # Atualiza perturbação (igual ao código oficial) perts = perts.detach() + self.eps_step * grad_denorm.sign() perts = torch.clamp(perts, -self.eps, self.eps) # clamp final em pixel space e volta para delta perts = torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps if self.debug_shapes: step_size = (self.eps_step * grad_denorm.sign()).abs().max().item() grad_sign_nonzero = (grad_denorm.sign().abs() > 0).float().mean().item() print(f"[TGR DEBUG] Step size: {step_size:.6f}, grad_sign non-zero: {grad_sign_nonzero*100:.1f}%") if self.debug_stats: # completa estatísticas após normalização + momentum iter_stats = self.debug_last.get(f"iter_{step_idx}", {}) iter_stats.update( { "denom_abs_mean": float(denom.detach().mean().item()), "grad_denorm_abs_mean_post_norm": float(grad_denorm.detach().abs().mean().item()), "grad_denorm_abs_max_post_norm": float(grad_denorm.detach().abs().max().item()), "grad_sign_nonzero_ratio": float( (grad_denorm.detach().sign().abs() > 0).float().mean().item() ), "step_size": float((self.eps_step * grad_denorm.detach().sign()).abs().max().item()), } ) self.debug_last[f"iter_{step_idx}"] = iter_stats if self.debug_shapes: actual_delta = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps).abs().max().item() print(f"[TGR DEBUG] Iteration delta: {actual_delta:.6f} (eps={self.eps:.6f}, eps_step={self.eps_step:.6f})") # Salvar artefatos desta iteração adv_denorm = torch.clamp(unnorm_inps + perts, 0.0, 1.0).detach() self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False)) self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach()) finally: for h in handles: h.remove() # Restaurar forwards originais de atenção if self._patched_attn_forwards: for attn_module, orig_forward in list(self._patched_attn_forwards.items()): try: attn_module.forward = orig_forward except Exception: pass self._patched_attn_forwards.clear() if hasattr(self, "_debug_attn_map_printed"): delattr(self, "_debug_attn_map_printed") # Restaurar modo de treinamento original if was_training: self.model.train() adv_final = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - mean) / std return adv_final, self.iteration_images