lucasddmc commited on
Commit
8780912
·
1 Parent(s): e21e78f

feat: adds support to resnet gradients in SAGA attack

Browse files
Files changed (2) hide show
  1. app.py +80 -15
  2. utils/attacks.py +116 -8
app.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import torch
4
  from PIL import Image
5
  from typing import Optional, List, Tuple
 
6
 
7
  from utils.model_loader import load_model_and_labels
8
  from utils.preprocessing import get_default_transform, preprocess_image
@@ -24,6 +25,7 @@ from utils.visualization import (
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  transform = get_default_transform()
 
27
 
28
  def _to_path(file_like: Optional[object]) -> Optional[str]:
29
  """Extrai caminho de um objeto vindo do Gradio File (string, dict com 'name' ou objeto com atributo .name)."""
@@ -204,6 +206,8 @@ def run_attack(
204
  eps: float,
205
  alpha: float,
206
  steps: int,
 
 
207
  ) -> Tuple[List[Image.Image], str, List[List[torch.Tensor]]]:
208
  """
209
  Executa ataque adversarial (FGSM ou PGD) untargeted e extrai atenção.
@@ -213,8 +217,9 @@ def run_attack(
213
  image: imagem PIL
214
  attack_type: "FGSM" (single-step) ou "PGD" (iterativo)
215
  eps: epsilon (perturbação máxima)
216
- alpha: step size (apenas PGD)
217
- steps: número de iterações (apenas PGD)
 
218
  discard_ratio: proporção de atenções fracas a descartar
219
  head_fusion: como agregar heads ('mean', 'max', 'min')
220
  alpha_overlay: transparência da sobreposição
@@ -245,7 +250,16 @@ def run_attack(
245
  if attack_type == "FGSM":
246
  attack = FGSM(model, eps=eps)
247
  elif attack_type == "MIM":
248
- attack = MIFGSM(model, eps=eps, alpha=alpha, steps=steps, decay=1.0)
 
 
 
 
 
 
 
 
 
249
  elif attack_type == "SAGA":
250
  attack = SAGA(model, eps=eps, steps=steps)
251
  else: # PGD
@@ -278,9 +292,9 @@ def run_attack(
278
  elif attack_type == "MIM":
279
  result += f"- Alpha (α): {alpha:.4f}\n"
280
  result += f"- Steps: {steps}\n"
281
- result += f"- Momentum decay: 1.0\n"
282
  result += f"- Normalized gradient with momentum accumulation\n"
283
- elif attack_type == "SAGA":
284
  result += f"- Steps: {steps}\n"
285
  result += f"- Attention-weighted gradient (ViT-specific)\n"
286
  else: # FGSM
@@ -433,10 +447,10 @@ def create_app():
433
  gr.Markdown("#### Attack Configuration")
434
 
435
  attack_type = gr.Dropdown(
436
- choices=["PGD", "FGSM", "MIM", "SAGA"],
437
  value="PGD",
438
  label="Attack Type",
439
- info="PGD/MIM: iterative | FGSM: single-step | SAGA: gradient × attention"
440
  )
441
 
442
  eps_input = gr.Slider(
@@ -467,6 +481,26 @@ def create_app():
467
  step=1,
468
  label="Steps - NNumber of Iterations"
469
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  with gr.Column():
472
  output_text_attack = gr.Markdown(label="Result")
@@ -475,18 +509,49 @@ def create_app():
475
  def update_attack_params(attack_type):
476
  if attack_type == "FGSM":
477
  # FGSM: não usa alpha nem steps
478
- return gr.update(visible=False), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  elif attack_type == "SAGA":
480
- # SAGA: usa steps mas não usa alpha
481
- return gr.update(visible=False), gr.update(visible=True)
482
- else: # PGD ou MIM
483
- # PGD e MIM: usam alpha e steps
484
- return gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  attack_type.change(
487
  fn=update_attack_params,
488
  inputs=[attack_type],
489
- outputs=[alpha_group, steps_group]
490
  )
491
 
492
  # Removido: configuração de rollout da área de ataque
@@ -647,7 +712,7 @@ def create_app():
647
  fn=run_attack,
648
  inputs=[
649
  model_upload_attack, image_upload_attack,
650
- attack_type, eps_input, alpha_input, steps_input
651
  ],
652
  outputs=[iteration_images_state, output_text_attack, cached_attentions_state]
653
  ).then(
 
3
  import torch
4
  from PIL import Image
5
  from typing import Optional, List, Tuple
6
+ from pathlib import Path
7
 
8
  from utils.model_loader import load_model_and_labels
9
  from utils.preprocessing import get_default_transform, preprocess_image
 
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
  transform = get_default_transform()
28
+ RESNET_BACKBONE_PATH = Path("models/resnet.pth")
29
 
30
  def _to_path(file_like: Optional[object]) -> Optional[str]:
31
  """Extrai caminho de um objeto vindo do Gradio File (string, dict com 'name' ou objeto com atributo .name)."""
 
206
  eps: float,
207
  alpha: float,
208
  steps: int,
209
+ decay: float,
210
+ vit_weight: float,
211
  ) -> Tuple[List[Image.Image], str, List[List[torch.Tensor]]]:
212
  """
213
  Executa ataque adversarial (FGSM ou PGD) untargeted e extrai atenção.
 
217
  image: imagem PIL
218
  attack_type: "FGSM" (single-step) ou "PGD" (iterativo)
219
  eps: epsilon (perturbação máxima)
220
+ alpha: step size (apenas PGD/MIM)
221
+ steps: número de iterações (iterativos)
222
+ decay: momentum decay (apenas MIM)
223
  discard_ratio: proporção de atenções fracas a descartar
224
  head_fusion: como agregar heads ('mean', 'max', 'min')
225
  alpha_overlay: transparência da sobreposição
 
250
  if attack_type == "FGSM":
251
  attack = FGSM(model, eps=eps)
252
  elif attack_type == "MIM":
253
+ attack = MIFGSM(model, eps=eps, alpha=alpha, steps=steps, decay=decay)
254
+ elif attack_type == "SAGA (with CNN gradient)":
255
+ attack = SAGA(
256
+ model,
257
+ eps=eps,
258
+ steps=steps,
259
+ use_resnet=True,
260
+ vit_weight=vit_weight,
261
+ cnn_checkpoint_path=str(RESNET_BACKBONE_PATH)
262
+ )
263
  elif attack_type == "SAGA":
264
  attack = SAGA(model, eps=eps, steps=steps)
265
  else: # PGD
 
292
  elif attack_type == "MIM":
293
  result += f"- Alpha (α): {alpha:.4f}\n"
294
  result += f"- Steps: {steps}\n"
295
+ result += f"- Momentum decay: {decay:.2f}\n"
296
  result += f"- Normalized gradient with momentum accumulation\n"
297
+ elif attack_type in ("SAGA", "SAGA (with CNN gradient)"):
298
  result += f"- Steps: {steps}\n"
299
  result += f"- Attention-weighted gradient (ViT-specific)\n"
300
  else: # FGSM
 
447
  gr.Markdown("#### Attack Configuration")
448
 
449
  attack_type = gr.Dropdown(
450
+ choices=["PGD", "FGSM", "MIM", "SAGA", "SAGA (with CNN gradient)"],
451
  value="PGD",
452
  label="Attack Type",
453
+ info="PGD/MIM: iterative | FGSM: single-step | SAGA variants: gradient × attention"
454
  )
455
 
456
  eps_input = gr.Slider(
 
481
  step=1,
482
  label="Steps - NNumber of Iterations"
483
  )
484
+
485
+ decay_group = gr.Group(visible=False)
486
+ with decay_group:
487
+ decay_input = gr.Slider(
488
+ minimum=0.0,
489
+ maximum=1.0,
490
+ value=1.0,
491
+ step=0.05,
492
+ label="Momentum Decay (MIM only)"
493
+ )
494
+
495
+ vit_weight_slider = gr.Slider(
496
+ minimum=0.0,
497
+ maximum=1.0,
498
+ value=0.5,
499
+ step=0.05,
500
+ label="ViT Gradient Weight",
501
+ info="Blend between ViT attention gradient (1.0) and CNN gradient (0.0)",
502
+ visible=False
503
+ )
504
 
505
  with gr.Column():
506
  output_text_attack = gr.Markdown(label="Result")
 
509
  def update_attack_params(attack_type):
510
  if attack_type == "FGSM":
511
  # FGSM: não usa alpha nem steps
512
+ return (
513
+ gr.update(visible=False),
514
+ gr.update(visible=False),
515
+ gr.update(visible=False),
516
+ gr.update(visible=False)
517
+ )
518
+ elif attack_type == "SAGA (with CNN gradient)":
519
+ # SAGA+CNN: usa steps e slider de blend, sem alpha
520
+ return (
521
+ gr.update(visible=False),
522
+ gr.update(visible=True),
523
+ gr.update(visible=False),
524
+ gr.update(visible=True)
525
+ )
526
  elif attack_type == "SAGA":
527
+ # SAGA: usa steps mas não usa alpha ou slider
528
+ return (
529
+ gr.update(visible=False),
530
+ gr.update(visible=True),
531
+ gr.update(visible=False),
532
+ gr.update(visible=False)
533
+ )
534
+ elif attack_type == "MIM":
535
+ # MIM: usa alpha, steps e decay
536
+ return (
537
+ gr.update(visible=True),
538
+ gr.update(visible=True),
539
+ gr.update(visible=True),
540
+ gr.update(visible=False)
541
+ )
542
+ else: # PGD
543
+ # PGD: usa alpha e steps, sem decay
544
+ return (
545
+ gr.update(visible=True),
546
+ gr.update(visible=True),
547
+ gr.update(visible=False),
548
+ gr.update(visible=False)
549
+ )
550
 
551
  attack_type.change(
552
  fn=update_attack_params,
553
  inputs=[attack_type],
554
+ outputs=[alpha_group, steps_group, decay_group, vit_weight_slider]
555
  )
556
 
557
  # Removido: configuração de rollout da área de ataque
 
712
  fn=run_attack,
713
  inputs=[
714
  model_upload_attack, image_upload_attack,
715
+ attack_type, eps_input, alpha_input, steps_input, decay_input, vit_weight_slider
716
  ],
717
  outputs=[iteration_images_state, output_text_attack, cached_attentions_state]
718
  ).then(
utils/attacks.py CHANGED
@@ -1,8 +1,20 @@
1
  import torch
2
  import torchattacks
3
  from PIL import Image
4
- from typing import List, Tuple
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def capture_outputs_and_attentions(model, x_norm: torch.Tensor):
@@ -212,7 +224,7 @@ class PGDIterations(torchattacks.PGD):
212
  outputs0, attentions0 = capture_outputs_and_attentions(self.model, images)
213
  self.attentions_per_iter.append([att for att in attentions0])
214
 
215
- for _ in range(self.steps):
216
  # Normalizar para passar pelo modelo
217
  adv_images = (adv_images_denorm - mean) / std
218
  adv_images.requires_grad = True
@@ -259,7 +271,9 @@ class SAGA(torch.nn.Module):
259
  Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021)
260
  """
261
 
262
- def __init__(self, model, eps=8/255, steps=10, discard_ratio: float = 0.0, head_fusion: str = "mean"):
 
 
263
  """Implementação correta do SAGA baseada no código original (SelfAttentionGradientAttack).
264
 
265
  Parâmetros:
@@ -268,6 +282,8 @@ class SAGA(torch.nn.Module):
268
  - steps: número de iterações (FGSM iterativo)
269
  - discard_ratio: razão de descarte usada no attention rollout
270
  - head_fusion: estratégia de fusão de heads ('mean','max','min')
 
 
271
  """
272
  super().__init__()
273
  self.model = model
@@ -276,6 +292,10 @@ class SAGA(torch.nn.Module):
276
  self.eps_step = self.eps / max(1, steps)
277
  self.discard_ratio = discard_ratio
278
  self.head_fusion = head_fusion
 
 
 
 
279
  self.device = next(model.parameters()).device
280
  self.iteration_images: List[Image.Image] = []
281
  self.iteration_tensors: List[torch.Tensor] = []
@@ -283,6 +303,7 @@ class SAGA(torch.nn.Module):
283
  # Cache opcional: atenções por camada/head em cada iteração
284
  # Formato: lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
285
  self.attentions_per_iter: List[List[torch.Tensor]] = []
 
286
 
287
  def _attention_map(self, images_norm: torch.Tensor, save: bool = False) -> torch.Tensor:
288
  """Extrai mapa de atenção (rollout) e retorna tensor expandido [B,3,H,W] em [0,1].
@@ -331,6 +352,82 @@ class SAGA(torch.nn.Module):
331
  attentions = [a.cpu() for a in attentions]
332
  return outputs, attentions
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
335
  """Executa o ataque SAGA (FGSM iterativo com ponderação por atenção).
336
 
@@ -374,16 +471,14 @@ class SAGA(torch.nn.Module):
374
  mask0_resized = cv2.resize(mask0, (w, h))
375
  self.attention_masks_cache.append(mask0.copy())
376
 
377
- loss_fn = torch.nn.CrossEntropyLoss()
378
-
379
- for _ in range(self.steps):
380
  # Normalizar para forward
381
  adv_norm = (adv_denorm - mean) / std
382
  adv_norm.requires_grad = True
383
  outputs, attentions = self._capture_outputs_and_attentions(adv_norm)
384
  if isinstance(outputs, tuple): # compatibilidade com modelos que retornam extras
385
  outputs = outputs[0]
386
- loss = loss_fn(outputs, labels)
387
  grad = torch.autograd.grad(loss, adv_norm, retain_graph=False, create_graph=False)[0]
388
 
389
  # Atenção da imagem adversarial atual (já capturada)
@@ -399,8 +494,21 @@ class SAGA(torch.nn.Module):
399
  self.attention_masks_cache.append(mask.copy())
400
  grad_weighted = grad * att_map
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  # FGSM step em pixel space (sign do gradiente normalizado equivale ao do desnormalizado)
403
- adv_denorm = adv_denorm.detach() + self.eps_step * grad_weighted.sign()
404
 
405
  # Projeção na bola L_inf de raio eps em relação à imagem original
406
  delta = torch.clamp(adv_denorm - images_denorm, min=-self.eps, max=self.eps)
 
1
  import torch
2
  import torchattacks
3
  from PIL import Image
4
+ from typing import List, Tuple, Optional
5
  import numpy as np
6
+ import warnings
7
+ from pathlib import Path
8
+
9
+ try:
10
+ import torchvision.models as tv_models
11
+ except Exception: # pragma: no cover - torchvision is optional for ViT-only mode
12
+ tv_models = None
13
+
14
+ try:
15
+ import timm
16
+ except Exception: # pragma: no cover - timm is optional for CNN blending
17
+ timm = None
18
 
19
 
20
  def capture_outputs_and_attentions(model, x_norm: torch.Tensor):
 
224
  outputs0, attentions0 = capture_outputs_and_attentions(self.model, images)
225
  self.attentions_per_iter.append([att for att in attentions0])
226
 
227
+ for step_idx in range(self.steps):
228
  # Normalizar para passar pelo modelo
229
  adv_images = (adv_images_denorm - mean) / std
230
  adv_images.requires_grad = True
 
271
  Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021)
272
  """
273
 
274
+ def __init__(self, model, eps=8/255, steps=10, discard_ratio: float = 0.0,
275
+ head_fusion: str = "mean", use_resnet: bool = False,
276
+ cnn_checkpoint_path: str = "resnet.pth", vit_weight=0.5):
277
  """Implementação correta do SAGA baseada no código original (SelfAttentionGradientAttack).
278
 
279
  Parâmetros:
 
282
  - steps: número de iterações (FGSM iterativo)
283
  - discard_ratio: razão de descarte usada no attention rollout
284
  - head_fusion: estratégia de fusão de heads ('mean','max','min')
285
+ - use_resnet: se True, acumula gradiente de um backbone CNN externo e o mistura ao gradiente ponderado pela atenção
286
+ - cnn_checkpoint_path: caminho padrão do backbone CNN auxiliar (será carregado sob demanda)
287
  """
288
  super().__init__()
289
  self.model = model
 
292
  self.eps_step = self.eps / max(1, steps)
293
  self.discard_ratio = discard_ratio
294
  self.head_fusion = head_fusion
295
+ self.use_resnet = use_resnet
296
+ self.cnn_checkpoint_path = Path(cnn_checkpoint_path)
297
+ self.cnn_model: Optional[torch.nn.Module] = None
298
+ self.vit_weight = vit_weight
299
  self.device = next(model.parameters()).device
300
  self.iteration_images: List[Image.Image] = []
301
  self.iteration_tensors: List[torch.Tensor] = []
 
303
  # Cache opcional: atenções por camada/head em cada iteração
304
  # Formato: lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
305
  self.attentions_per_iter: List[List[torch.Tensor]] = []
306
+ self.loss_fn = torch.nn.CrossEntropyLoss()
307
 
308
  def _attention_map(self, images_norm: torch.Tensor, save: bool = False) -> torch.Tensor:
309
  """Extrai mapa de atenção (rollout) e retorna tensor expandido [B,3,H,W] em [0,1].
 
352
  attentions = [a.cpu() for a in attentions]
353
  return outputs, attentions
354
 
355
+ def _load_cnn_backbone(self) -> Optional[torch.nn.Module]:
356
+ """Carrega (lazy) o backbone CNN auxiliar usado quando use_resnet=True."""
357
+ if not self.use_resnet:
358
+ return None
359
+ if self.cnn_model is not None:
360
+ return self.cnn_model
361
+ if tv_models is None:
362
+ warnings.warn("torchvision não disponível; desabilitando modo CNN do SAGA.")
363
+ return None
364
+
365
+ model: Optional[torch.nn.Module] = None
366
+ checkpoint_model_name = "resnetv2_101x1_bit.goog_in21k_ft_in1k"
367
+ if self.cnn_checkpoint_path and self.cnn_checkpoint_path.exists():
368
+ try:
369
+ checkpoint = torch.load(self.cnn_checkpoint_path, map_location=self.device)
370
+ if isinstance(checkpoint, torch.nn.Module):
371
+ model = checkpoint
372
+ elif isinstance(checkpoint, dict):
373
+ state_dict = checkpoint.get('model_state_dict') or checkpoint.get('state_dict') or checkpoint
374
+ if timm is not None and any(key.startswith("stem.") for key in state_dict.keys()):
375
+ num_classes = None
376
+ head_bias = state_dict.get('head.fc.bias')
377
+ if isinstance(head_bias, torch.Tensor):
378
+ num_classes = head_bias.shape[0]
379
+ model = timm.create_model(
380
+ checkpoint.get("model_name", checkpoint_model_name),
381
+ pretrained=False,
382
+ num_classes=num_classes or 1000
383
+ )
384
+ load_result = model.load_state_dict(state_dict, strict=False)
385
+ else:
386
+ model = tv_models.resnet101(weights=None)
387
+ load_result = model.load_state_dict(state_dict, strict=False)
388
+ missing = load_result.missing_keys
389
+ unexpected = load_result.unexpected_keys
390
+ if missing or unexpected:
391
+ warn_msg = "[SAGA] ResNet checkpoint keys mismatch."
392
+ if missing:
393
+ warn_msg += f" Missing: {missing[:5]}{'...' if len(missing) > 5 else ''}."
394
+ if unexpected:
395
+ warn_msg += f" Unexpected: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}."
396
+ warnings.warn(warn_msg + " Using available weights (strict=False).")
397
+ else:
398
+ warnings.warn(f"Formato de checkpoint desconhecido em {self.cnn_checkpoint_path}; utilizando pesos padrão.")
399
+ except Exception as exc: # pragma: no cover - fallback resiliente
400
+ warnings.warn(f"Falha ao carregar {self.cnn_checkpoint_path}: {exc}. Usando ResNet padrão.")
401
+
402
+ if model is None:
403
+ if timm is not None:
404
+ try:
405
+ model = timm.create_model(checkpoint_model_name, pretrained=True)
406
+ except Exception:
407
+ model = None
408
+ if model is None and tv_models is not None:
409
+ try:
410
+ model = tv_models.resnet101(weights="IMAGENET1K_V2")
411
+ except Exception:
412
+ model = tv_models.resnet101(pretrained=True)
413
+
414
+ model = model.to(self.device)
415
+ model.eval()
416
+ self.cnn_model = model
417
+ return self.cnn_model
418
+
419
+ def _compute_cnn_gradient(self, images_norm: torch.Tensor, labels: torch.Tensor) -> Optional[torch.Tensor]:
420
+ """Obtém gradientes do backbone CNN auxiliar para a mesma imagem normalizada."""
421
+ cnn_model = self._load_cnn_backbone()
422
+ if cnn_model is None:
423
+ return None
424
+
425
+ cnn_input = images_norm.detach().clone().requires_grad_(True)
426
+ outputs = cnn_model(cnn_input)
427
+ loss = self.loss_fn(outputs, labels)
428
+ grad = torch.autograd.grad(loss, cnn_input, retain_graph=False, create_graph=False)[0]
429
+ return grad
430
+
431
  def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
432
  """Executa o ataque SAGA (FGSM iterativo com ponderação por atenção).
433
 
 
471
  mask0_resized = cv2.resize(mask0, (w, h))
472
  self.attention_masks_cache.append(mask0.copy())
473
 
474
+ for step_idx in range(self.steps):
 
 
475
  # Normalizar para forward
476
  adv_norm = (adv_denorm - mean) / std
477
  adv_norm.requires_grad = True
478
  outputs, attentions = self._capture_outputs_and_attentions(adv_norm)
479
  if isinstance(outputs, tuple): # compatibilidade com modelos que retornam extras
480
  outputs = outputs[0]
481
+ loss = self.loss_fn(outputs, labels)
482
  grad = torch.autograd.grad(loss, adv_norm, retain_graph=False, create_graph=False)[0]
483
 
484
  # Atenção da imagem adversarial atual (já capturada)
 
494
  self.attention_masks_cache.append(mask.copy())
495
  grad_weighted = grad * att_map
496
 
497
+ grad_final = grad_weighted
498
+ if self.use_resnet:
499
+ cnn_grad = self._compute_cnn_gradient(adv_norm, labels)
500
+ if cnn_grad is not None:
501
+ vit_contrib = grad_weighted.detach().abs().mean().item()
502
+ cnn_contrib = cnn_grad.detach().abs().mean().item()
503
+ grad_final = self.vit_weight * grad_weighted + (1 - self.vit_weight) * cnn_grad
504
+ blended_contrib = grad_final.detach().abs().mean().item()
505
+ print(
506
+ f"[SAGA][step {step_idx+1}/{self.steps}] vit_weight={self.vit_weight:.2f} "
507
+ f"|ViT|={vit_contrib:.4e} |CNN|={cnn_contrib:.4e} |blend|={blended_contrib:.4e}"
508
+ )
509
+
510
  # FGSM step em pixel space (sign do gradiente normalizado equivale ao do desnormalizado)
511
+ adv_denorm = adv_denorm.detach() + self.eps_step * grad_final.sign()
512
 
513
  # Projeção na bola L_inf de raio eps em relação à imagem original
514
  delta = torch.clamp(adv_denorm - images_denorm, min=-self.eps, max=self.eps)