lucasddmc commited on
Commit
ceb2d70
·
1 Parent(s): b877d29

fix: saga attack is now correct

Browse files

feat: attention is now cached during attack instead of calculating it afterwards

Files changed (4) hide show
  1. app.py +88 -91
  2. utils/attacks.py +218 -4
  3. utils/model_loader.py +13 -7
  4. utils/visualization.py +64 -13
app.py CHANGED
@@ -7,7 +7,7 @@ from utils.model_loader import load_model_and_labels
7
  from utils.preprocessing import get_default_transform, preprocess_image
8
  from utils.inference import predict_topk
9
  from utils.attacks import PGDIterations, FGSM, SAGA, MIFGSM
10
- from utils.visualization import extract_attention_maps, attention_rollout, create_attention_overlay, extract_attention_for_iterations, create_iteration_attention_overlays
11
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
@@ -56,7 +56,7 @@ def classify_image(model_file, image):
56
  img_tensor = preprocess_image(image, transform=transform).to(DEVICE)
57
 
58
  # Inferência
59
- top_prob, top_idx, num_classes, probabilities = predict_topk(model, img_tensor, top_k=5, device=DEVICE)
60
 
61
  top_k = len(top_prob)
62
  result = f"**Top {top_k} Predictions:**\n\n"
@@ -177,7 +177,7 @@ def run_attack(
177
  discard_ratio: float,
178
  head_fusion: str,
179
  alpha_overlay: float
180
- ) -> Tuple[List[Image.Image], str, Image.Image, List[Image.Image]]:
181
  """
182
  Executa ataque adversarial (FGSM ou PGD) untargeted e extrai atenção.
183
 
@@ -193,13 +193,13 @@ def run_attack(
193
  alpha_overlay: transparência da sobreposição
194
 
195
  Returns:
196
- (iteration_images, result_text, final_adv_image, attention_overlays)
197
  """
198
  try:
199
  if model_file is None:
200
- return [], "Please upload a model file (.pth)", None, []
201
  if image is None:
202
- return [], "Please upload an image", None, []
203
 
204
  # Carregar modelo e labels
205
  model_path = _to_path(model_file)
@@ -230,18 +230,26 @@ def run_attack(
230
  adv_tensor, iteration_images = attack(img_tensor, original_label)
231
 
232
  # Extrair atenção para todas as iterações
233
- # SAGA calcula atenção internamente, podemos reutilizar!
234
- if attack_type == "SAGA" and hasattr(attack, 'attention_masks_cache') and len(attack.attention_masks_cache) > 0:
235
- # Usar atenção calculada pelo SAGA
236
- attention_masks = attack.attention_masks_cache
 
 
 
 
 
 
 
237
  else:
238
- # Para FGSM e PGD, calcular atenção normalmente
239
  attention_masks = extract_attention_for_iterations(
240
  model,
241
  attack.iteration_tensors,
242
  discard_ratio=discard_ratio,
243
  head_fusion=head_fusion
244
  )
 
245
 
246
  # Criar overlays de atenção
247
  attention_overlays = create_iteration_attention_overlays(
@@ -255,10 +263,6 @@ def run_attack(
255
  adv_class = top_idx_adv[0].item()
256
  adv_prob = top_prob_adv[0].item()
257
 
258
- # Converter imagem adversarial final para PIL
259
- from utils.attacks import tensor_to_pil
260
- final_adv_image = tensor_to_pil(adv_tensor[0])
261
-
262
  # Format result
263
  result = f"## {attack_type} Attack Result (Untargeted)\n\n"
264
  result += f"**Configuration:**\n"
@@ -305,12 +309,12 @@ def run_attack(
305
  result += f"- Iteration 0 = Original image\n"
306
  result += f"- Iteration {steps} = Final adversarial image\n"
307
 
308
- return iteration_images, result, final_adv_image, attention_overlays
309
 
310
  except Exception as e:
311
  import traceback
312
  error_msg = f"Error executing attack:\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
313
- return [], error_msg, None, []
314
 
315
 
316
  def create_app():
@@ -318,7 +322,7 @@ def create_app():
318
 
319
  with gr.Blocks(title="ViTViz - Classifier & Attacks", theme=gr.themes.Citrus()) as app:
320
  gr.Markdown("""
321
- # ViTViz: Vision Transformer Classifier & Adversarial Attacks
322
  """)
323
 
324
  with gr.Tabs():
@@ -406,12 +410,13 @@ def create_app():
406
  outputs=[attention_output, output_text_attention]
407
  )
408
 
409
- # Tab 3: AAdversarial Attack + Attention
410
  with gr.Tab("Adversarial Attack + Attention"):
411
  gr.Markdown("### Execute adversarial attacks and visualize how attention evolves")
412
 
413
- with gr.Row():
414
- with gr.Column(scale=1):
 
415
  model_upload_attack = gr.File(
416
  label="Upload Model (.pth/.pt)",
417
  file_types=[".pth", ".pt"]
@@ -420,45 +425,48 @@ def create_app():
420
  label="Upload Image",
421
  type="pil"
422
  )
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
- gr.Markdown("#### Attack Configuration")
425
-
426
- attack_type = gr.Dropdown(
427
- choices=["PGD", "FGSM", "MIFGSM", "SAGA"],
428
- value="PGD",
429
- label="Attack Type",
430
- info="PGD: iterative | FGSM: single-step | MIFGSM: momentum | SAGA: gradient × attention"
431
- )
432
-
433
- eps_input = gr.Slider(
434
- minimum=0.0,
435
- maximum=1.0,
436
- value=8/255,
437
- step=1/255,
438
- label="Epsilon (ε) - Maximum Perturbation"
439
- )
440
-
441
- # Parâmetros iterativos
442
- alpha_group = gr.Group(visible=True)
443
- with alpha_group:
444
- alpha_input = gr.Slider(
445
  minimum=0.0,
446
- maximum=0.1,
447
- value=2/255,
448
  step=1/255,
449
- label="Alpha (α) - Step Size"
450
- )
451
-
452
- steps_group = gr.Group(visible=True)
453
- with steps_group:
454
- steps_input = gr.Slider(
455
- minimum=1,
456
- maximum=100,
457
- value=10,
458
- step=1,
459
- label="Steps - NNumber of Iterations"
460
  )
 
 
 
 
 
 
 
 
 
 
 
461
 
 
 
 
 
 
 
 
 
 
 
462
  # Função para atualizar visibilidade dos parâmetros
463
  def update_attack_params(attack_type):
464
  if attack_type == "FGSM":
@@ -477,15 +485,15 @@ def create_app():
477
  outputs=[alpha_group, steps_group]
478
  )
479
 
480
- gr.Markdown("#### Attention Configuration")
 
481
 
482
- head_fusion_attack = gr.Radio(
483
- choices=["mean", "max", "min"],
484
- value="max",
485
- label="Head Fusion"
486
- )
487
-
488
- with gr.Row():
489
  discard_ratio_attack = gr.Slider(
490
  minimum=0.0,
491
  maximum=1.0,
@@ -501,22 +509,10 @@ def create_app():
501
  label="Alpha Overlay"
502
  )
503
 
504
- attack_btn = gr.Button("Execute Full Analysis", variant="primary", size="lg")
505
-
506
- with gr.Column(scale=2):
507
- output_text_attack = gr.Markdown(label="Result")
508
-
509
- with gr.Row():
510
- with gr.Column():
511
- gr.Markdown("**Final Adversarial Image**")
512
- final_adv_image = gr.Image(type="pil", show_label=False)
513
- with gr.Column():
514
- gr.Markdown("**All Iterations**")
515
- iteration_gallery = gr.Gallery(
516
- columns=5,
517
- height="auto",
518
- show_label=False
519
- )
520
 
521
  # Seção de Evolução da Atenção
522
  gr.Markdown("---")
@@ -538,10 +534,11 @@ def create_app():
538
  maximum=10,
539
  step=1,
540
  value=0,
541
- label="Attack Iteration"
 
542
  )
543
  show_attention_checkbox = gr.Checkbox(
544
- value=False,
545
  label="Show Attention Heatmap",
546
  info="Overlay attention map on images"
547
  )
@@ -584,11 +581,7 @@ def create_app():
584
  attack_type, eps_input, alpha_input, steps_input,
585
  discard_ratio_attack, head_fusion_attack, alpha_overlay_attack
586
  ],
587
- outputs=[iteration_images_state, output_text_attack, final_adv_image, attention_overlays_state]
588
- ).then(
589
- fn=lambda x: x,
590
- inputs=[iteration_images_state],
591
- outputs=[iteration_gallery]
592
  ).then(
593
  fn=lambda imgs: gr.update(maximum=len(imgs)-1 if imgs else 0, value=0),
594
  inputs=[iteration_images_state],
@@ -626,8 +619,12 @@ def create_app():
626
 
627
  if __name__ == "__main__":
628
  app = create_app()
629
- app.launch(
630
- server_name="0.0.0.0",
631
- server_port=7861,
632
- share=False
633
- )
 
 
 
 
 
7
  from utils.preprocessing import get_default_transform, preprocess_image
8
  from utils.inference import predict_topk
9
  from utils.attacks import PGDIterations, FGSM, SAGA, MIFGSM
10
+ from utils.visualization import extract_attention_maps, attention_rollout, create_attention_overlay, extract_attention_for_iterations, create_iteration_attention_overlays, rollout_from_cached_attentions
11
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
56
  img_tensor = preprocess_image(image, transform=transform).to(DEVICE)
57
 
58
  # Inferência
59
+ top_prob, top_idx, num_classes, probabilities = predict_topk(model, img_tensor, top_k=5)
60
 
61
  top_k = len(top_prob)
62
  result = f"**Top {top_k} Predictions:**\n\n"
 
177
  discard_ratio: float,
178
  head_fusion: str,
179
  alpha_overlay: float
180
+ ) -> Tuple[List[Image.Image], str, List[Image.Image]]:
181
  """
182
  Executa ataque adversarial (FGSM ou PGD) untargeted e extrai atenção.
183
 
 
193
  alpha_overlay: transparência da sobreposição
194
 
195
  Returns:
196
+ (iteration_images, result_text, attention_overlays)
197
  """
198
  try:
199
  if model_file is None:
200
+ return [], "Please upload a model file (.pth)", []
201
  if image is None:
202
+ return [], "Please upload an image", []
203
 
204
  # Carregar modelo e labels
205
  model_path = _to_path(model_file)
 
230
  adv_tensor, iteration_images = attack(img_tensor, original_label)
231
 
232
  # Extrair atenção para todas as iterações
233
+ # Importante: para visualização, respeitar parâmetros do UI (discard_ratio, head_fusion)
234
+ # Mesmo quando o ataque é SAGA, recomputar os mapas para permitir controle do usuário.
235
+ # Preferir atenções em cache capturadas durante o ataque
236
+ attention_masks = []
237
+ cached_attns = getattr(attack, 'attentions_per_iter', None)
238
+ if cached_attns and len(cached_attns) == len(iteration_images):
239
+ attention_masks = rollout_from_cached_attentions(
240
+ cached_attns,
241
+ discard_ratio=discard_ratio,
242
+ head_fusion=head_fusion
243
+ )
244
  else:
245
+ # Fallback: recomputar via hooks (depreciado)
246
  attention_masks = extract_attention_for_iterations(
247
  model,
248
  attack.iteration_tensors,
249
  discard_ratio=discard_ratio,
250
  head_fusion=head_fusion
251
  )
252
+ print("⚠️ Warning: recomputing attentions via hooks. Consider using an attack class that caches attentions.")
253
 
254
  # Criar overlays de atenção
255
  attention_overlays = create_iteration_attention_overlays(
 
263
  adv_class = top_idx_adv[0].item()
264
  adv_prob = top_prob_adv[0].item()
265
 
 
 
 
 
266
  # Format result
267
  result = f"## {attack_type} Attack Result (Untargeted)\n\n"
268
  result += f"**Configuration:**\n"
 
309
  result += f"- Iteration 0 = Original image\n"
310
  result += f"- Iteration {steps} = Final adversarial image\n"
311
 
312
+ return iteration_images, result, attention_overlays
313
 
314
  except Exception as e:
315
  import traceback
316
  error_msg = f"Error executing attack:\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
317
+ return [], error_msg, []
318
 
319
 
320
  def create_app():
 
322
 
323
  with gr.Blocks(title="ViTViz - Classifier & Attacks", theme=gr.themes.Citrus()) as app:
324
  gr.Markdown("""
325
+ # ViTViz: Vision Transformer Adversarial Attack & Attention Visualization Tool
326
  """)
327
 
328
  with gr.Tabs():
 
410
  outputs=[attention_output, output_text_attention]
411
  )
412
 
413
+ # Tab 3: Adversarial Attack + Attention
414
  with gr.Tab("Adversarial Attack + Attention"):
415
  gr.Markdown("### Execute adversarial attacks and visualize how attention evolves")
416
 
417
+ # with gr.Row():
418
+ with gr.Column(scale=1):
419
+ with gr.Row():
420
  model_upload_attack = gr.File(
421
  label="Upload Model (.pth/.pt)",
422
  file_types=[".pth", ".pt"]
 
425
  label="Upload Image",
426
  type="pil"
427
  )
428
+ gr.Markdown("---")
429
+
430
+ with gr.Row():
431
+ with gr.Column():
432
+ gr.Markdown("#### Attack Configuration")
433
+
434
+ attack_type = gr.Dropdown(
435
+ choices=["PGD", "FGSM", "MIFGSM", "SAGA"],
436
+ value="PGD",
437
+ label="Attack Type",
438
+ info="PGD: iterative | FGSM: single-step | MIFGSM: momentum | SAGA: gradient × attention"
439
+ )
440
 
441
+ eps_input = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  minimum=0.0,
443
+ maximum=1.0,
444
+ value=8/255,
445
  step=1/255,
446
+ label="Epsilon (ε) - Maximum Perturbation"
 
 
 
 
 
 
 
 
 
 
447
  )
448
+
449
+ # Parâmetros iterativos
450
+ alpha_group = gr.Group(visible=True)
451
+ with alpha_group:
452
+ alpha_input = gr.Slider(
453
+ minimum=0.0,
454
+ maximum=0.1,
455
+ value=2/255,
456
+ step=1/255,
457
+ label="Alpha (α) - Step Size"
458
+ )
459
 
460
+ steps_group = gr.Group(visible=True)
461
+ with steps_group:
462
+ steps_input = gr.Slider(
463
+ minimum=1,
464
+ maximum=100,
465
+ value=10,
466
+ step=1,
467
+ label="Steps - NNumber of Iterations"
468
+ )
469
+
470
  # Função para atualizar visibilidade dos parâmetros
471
  def update_attack_params(attack_type):
472
  if attack_type == "FGSM":
 
485
  outputs=[alpha_group, steps_group]
486
  )
487
 
488
+ with gr.Column():
489
+ gr.Markdown("#### Attention Rollout Configuration")
490
 
491
+ head_fusion_attack = gr.Radio(
492
+ choices=["mean", "max", "min"],
493
+ value="max",
494
+ label="Head Fusion"
495
+ )
496
+
 
497
  discard_ratio_attack = gr.Slider(
498
  minimum=0.0,
499
  maximum=1.0,
 
509
  label="Alpha Overlay"
510
  )
511
 
512
+ attack_btn = gr.Button("Execute Full Analysis", variant="primary", size="lg")
513
+
514
+ with gr.Column(scale=2):
515
+ output_text_attack = gr.Markdown(label="Result")
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
  # Seção de Evolução da Atenção
518
  gr.Markdown("---")
 
534
  maximum=10,
535
  step=1,
536
  value=0,
537
+ label="Attack Iteration",
538
+ info="0 for the original image, last is the final adversarial image"
539
  )
540
  show_attention_checkbox = gr.Checkbox(
541
+ value=True,
542
  label="Show Attention Heatmap",
543
  info="Overlay attention map on images"
544
  )
 
581
  attack_type, eps_input, alpha_input, steps_input,
582
  discard_ratio_attack, head_fusion_attack, alpha_overlay_attack
583
  ],
584
+ outputs=[iteration_images_state, output_text_attack, attention_overlays_state]
 
 
 
 
585
  ).then(
586
  fn=lambda imgs: gr.update(maximum=len(imgs)-1 if imgs else 0, value=0),
587
  inputs=[iteration_images_state],
 
619
 
620
  if __name__ == "__main__":
621
  app = create_app()
622
+ try:
623
+ app.launch(
624
+ server_name="0.0.0.0",
625
+ server_port=7861,
626
+ share=False
627
+ )
628
+
629
+ except KeyboardInterrupt:
630
+ print("\nShutting down gracefully...")
utils/attacks.py CHANGED
@@ -3,6 +3,41 @@ import torchattacks
3
  from PIL import Image
4
  from typing import List, Tuple
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def denormalize_imagenet(tensor: torch.Tensor) -> torch.Tensor:
@@ -57,6 +92,8 @@ class FGSM(torchattacks.FGSM):
57
  super().__init__(model, eps=eps)
58
  self.iteration_images: List[Image.Image] = []
59
  self.iteration_tensors: List[torch.Tensor] = []
 
 
60
 
61
  def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
62
  """
@@ -77,6 +114,7 @@ class FGSM(torchattacks.FGSM):
77
 
78
  self.iteration_images = []
79
  self.iteration_tensors = []
 
80
 
81
  # Salvar imagem original
82
  pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False)
@@ -85,7 +123,9 @@ class FGSM(torchattacks.FGSM):
85
 
86
  # Calcular gradiente
87
  images.requires_grad = True
88
- outputs = self.get_logits(images)
 
 
89
 
90
  if self.targeted:
91
  target_labels = self.get_target_label(images, labels)
@@ -106,6 +146,10 @@ class FGSM(torchattacks.FGSM):
106
  pil_img_adv = tensor_to_pil(adv_images_denorm[0], denormalize=False)
107
  self.iteration_images.append(pil_img_adv)
108
  self.iteration_tensors.append(adv_images.clone().detach())
 
 
 
 
109
 
110
  return adv_images, self.iteration_images
111
 
@@ -120,6 +164,7 @@ class PGDIterations(torchattacks.PGD):
120
  super().__init__(model, eps=eps, alpha=alpha, steps=steps, random_start=random_start)
121
  self.iteration_images: List[Image.Image] = []
122
  self.iteration_tensors: List[torch.Tensor] = []
 
123
 
124
  def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
125
  """
@@ -154,17 +199,21 @@ class PGDIterations(torchattacks.PGD):
154
 
155
  self.iteration_images = []
156
  self.iteration_tensors = []
 
157
 
158
  # Salvar iteração 0 (imagem original)
159
  pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False)
160
  self.iteration_images.append(pil_img_orig)
161
  self.iteration_tensors.append(images.clone().detach())
 
 
 
162
 
163
  for _ in range(self.steps):
164
  # Normalizar para passar pelo modelo
165
  adv_images = (adv_images_denorm - mean) / std
166
  adv_images.requires_grad = True
167
- outputs = self.get_logits(adv_images)
168
 
169
  # Calculate loss
170
  if self.targeted:
@@ -188,12 +237,13 @@ class PGDIterations(torchattacks.PGD):
188
  pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False)
189
  self.iteration_images.append(pil_img)
190
  self.iteration_tensors.append(adv_images_normalized.clone().detach())
 
 
191
 
192
  # Retornar imagem normalizada para o modelo
193
  adv_images = (adv_images_denorm - mean) / std
194
  return adv_images, self.iteration_images
195
-
196
-
197
  class SAGA(torch.nn.Module):
198
  """
199
  SAGA: Self-Attention Gradient Attack
@@ -205,6 +255,170 @@ class SAGA(torch.nn.Module):
205
  Baseado em: https://github.com/MetaMain/ViTRobust
206
  Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021)
207
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def __init__(self, model, eps=0.03, steps=10):
209
  super().__init__()
210
  self.model = model
 
3
  from PIL import Image
4
  from typing import List, Tuple
5
  import numpy as np
6
+ def capture_outputs_and_attentions(model, x_norm: torch.Tensor):
7
+ """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT.
8
+ Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada.
9
+ Funciona para modelos do timm com atributo 'blocks' e submódulo 'attn'.
10
+ """
11
+ attentions: List[torch.Tensor] = []
12
+
13
+ def make_attention_hook():
14
+ def hook(module, input, output):
15
+ x = input[0]
16
+ B, N, C = x.shape
17
+ if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')):
18
+ return
19
+ qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
20
+ q, k, v = qkv.unbind(0)
21
+ scale = (C // module.num_heads) ** -0.5
22
+ attn = (q @ k.transpose(-2, -1)) * scale
23
+ attn = attn.softmax(dim=-1)
24
+ attentions.append(attn.detach())
25
+ return hook
26
+
27
+ hooks = []
28
+ if hasattr(model, 'blocks'):
29
+ for block in model.blocks:
30
+ if hasattr(block, 'attn'):
31
+ hooks.append(block.attn.register_forward_hook(make_attention_hook()))
32
+
33
+ model.eval()
34
+ outputs = model(x_norm)
35
+
36
+ for h in hooks:
37
+ h.remove()
38
+
39
+ attentions = [a.cpu() for a in attentions]
40
+ return outputs, attentions
41
 
42
 
43
  def denormalize_imagenet(tensor: torch.Tensor) -> torch.Tensor:
 
92
  super().__init__(model, eps=eps)
93
  self.iteration_images: List[Image.Image] = []
94
  self.iteration_tensors: List[torch.Tensor] = []
95
+ # Atenções por iteração (iteração 0: original, iteração 1: adversarial)
96
+ self.attentions_per_iter: List[List[torch.Tensor]] = []
97
 
98
  def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
99
  """
 
114
 
115
  self.iteration_images = []
116
  self.iteration_tensors = []
117
+ self.attentions_per_iter = []
118
 
119
  # Salvar imagem original
120
  pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False)
 
123
 
124
  # Calcular gradiente
125
  images.requires_grad = True
126
+ # Capturar atenções e logits para imagem original
127
+ outputs, attentions0 = capture_outputs_and_attentions(self.model, images)
128
+ self.attentions_per_iter.append([att for att in attentions0])
129
 
130
  if self.targeted:
131
  target_labels = self.get_target_label(images, labels)
 
146
  pil_img_adv = tensor_to_pil(adv_images_denorm[0], denormalize=False)
147
  self.iteration_images.append(pil_img_adv)
148
  self.iteration_tensors.append(adv_images.clone().detach())
149
+
150
+ # Capturar atenções para imagem adversarial final
151
+ outputs_adv, attentions1 = capture_outputs_and_attentions(self.model, adv_images)
152
+ self.attentions_per_iter.append([att for att in attentions1])
153
 
154
  return adv_images, self.iteration_images
155
 
 
164
  super().__init__(model, eps=eps, alpha=alpha, steps=steps, random_start=random_start)
165
  self.iteration_images: List[Image.Image] = []
166
  self.iteration_tensors: List[torch.Tensor] = []
167
+ self.attentions_per_iter: List[List[torch.Tensor]] = []
168
 
169
  def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
170
  """
 
199
 
200
  self.iteration_images = []
201
  self.iteration_tensors = []
202
+ self.attentions_per_iter = []
203
 
204
  # Salvar iteração 0 (imagem original)
205
  pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False)
206
  self.iteration_images.append(pil_img_orig)
207
  self.iteration_tensors.append(images.clone().detach())
208
+ # Atenções da imagem original
209
+ outputs0, attentions0 = capture_outputs_and_attentions(self.model, images)
210
+ self.attentions_per_iter.append([att for att in attentions0])
211
 
212
  for _ in range(self.steps):
213
  # Normalizar para passar pelo modelo
214
  adv_images = (adv_images_denorm - mean) / std
215
  adv_images.requires_grad = True
216
+ outputs, attentions = capture_outputs_and_attentions(self.model, adv_images)
217
 
218
  # Calculate loss
219
  if self.targeted:
 
237
  pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False)
238
  self.iteration_images.append(pil_img)
239
  self.iteration_tensors.append(adv_images_normalized.clone().detach())
240
+ # Atenções desta iteração
241
+ self.attentions_per_iter.append([att for att in attentions])
242
 
243
  # Retornar imagem normalizada para o modelo
244
  adv_images = (adv_images_denorm - mean) / std
245
  return adv_images, self.iteration_images
246
+
 
247
  class SAGA(torch.nn.Module):
248
  """
249
  SAGA: Self-Attention Gradient Attack
 
255
  Baseado em: https://github.com/MetaMain/ViTRobust
256
  Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021)
257
  """
258
+
259
+ def __init__(self, model, eps=8/255, steps=10, discard_ratio: float = 0.0, head_fusion: str = "mean"):
260
+ """Implementação correta do SAGA baseada no código original (SelfAttentionGradientAttack).
261
+
262
+ Parâmetros:
263
+ - model: Vision Transformer (deve expor atenções via forward ou função auxiliar em visualization utils)
264
+ - eps: orçamento L_inf máximo (em pixel space [0,1])
265
+ - steps: número de iterações (FGSM iterativo)
266
+ - discard_ratio: razão de descarte usada no attention rollout
267
+ - head_fusion: estratégia de fusão de heads ('mean','max','min')
268
+ """
269
+ super().__init__()
270
+ self.model = model
271
+ self.eps = eps
272
+ self.steps = steps
273
+ self.eps_step = self.eps / max(1, steps)
274
+ self.discard_ratio = discard_ratio
275
+ self.head_fusion = head_fusion
276
+ self.device = next(model.parameters()).device
277
+ self.iteration_images: List[Image.Image] = []
278
+ self.iteration_tensors: List[torch.Tensor] = []
279
+ self.attention_masks_cache: List[np.ndarray] = []
280
+ # Cache opcional: atenções por camada/head em cada iteração
281
+ # Formato: lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
282
+ self.attentions_per_iter: List[List[torch.Tensor]] = []
283
+
284
+ def _attention_map(self, images_norm: torch.Tensor, save: bool = False) -> torch.Tensor:
285
+ """Extrai mapa de atenção (rollout) e retorna tensor expandido [B,3,H,W] em [0,1].
286
+ images_norm: imagens já normalizadas para o forward do modelo.
287
+ """
288
+ # Esta função agora assume que as atenções foram capturadas no mesmo forward
289
+ # e serão passadas externamente; mantida para compatibilidade se necessário.
290
+ raise RuntimeError("_attention_map should not be called directly; use integrated forward attention capture.")
291
+
292
+ def _capture_outputs_and_attentions(self, x_norm: torch.Tensor):
293
+ """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT.
294
+ Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada.
295
+ """
296
+ attentions: List[torch.Tensor] = []
297
+
298
+ def make_attention_hook():
299
+ def hook(module, input, output):
300
+ # input[0] é o embedding antes de atenção (B, N, C)
301
+ x = input[0]
302
+ B, N, C = x.shape
303
+ if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')):
304
+ return
305
+ qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
306
+ q, k, v = qkv.unbind(0)
307
+ scale = (C // module.num_heads) ** -0.5
308
+ attn = (q @ k.transpose(-2, -1)) * scale
309
+ attn = attn.softmax(dim=-1)
310
+ attentions.append(attn.detach())
311
+ return hook
312
+
313
+ hooks = []
314
+ if not hasattr(self.model, 'blocks'):
315
+ outputs = self.model(x_norm)
316
+ return outputs, []
317
+ for block in self.model.blocks:
318
+ if hasattr(block, 'attn'):
319
+ hooks.append(block.attn.register_forward_hook(make_attention_hook()))
320
+
321
+ self.model.eval()
322
+ outputs = self.model(x_norm)
323
+
324
+ for h in hooks:
325
+ h.remove()
326
+
327
+ # mover atenções para CPU para cache leve
328
+ attentions = [a.cpu() for a in attentions]
329
+ return outputs, attentions
330
+
331
+ def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
332
+ """Executa o ataque SAGA (FGSM iterativo com ponderação por atenção).
333
+
334
+ Fluxo por iteração:
335
+ 1. Normaliza a imagem adversarial atual.
336
+ 2. Calcula loss e gradiente.
337
+ 3. Extrai mapa de atenção da imagem atual e pondera gradiente.
338
+ 4. Aplica passo FGSM (sign) em pixel space [0,1].
339
+ 5. Projeta em L_inf (clamp delta) e clip final para [0,1].
340
+ 6. Salva imagem e tensor normalizado.
341
+ """
342
+ images = images.clone().detach().to(self.device)
343
+ labels = labels.clone().detach().to(self.device)
344
+
345
+ # Mean/std ImageNet para conversão entre espaços
346
+ mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
347
+ std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
348
+
349
+ # Pixel space [0,1]
350
+ images_denorm = images * std + mean
351
+ adv_denorm = images_denorm.clone().detach()
352
+
353
+ # Reset buffers
354
+ self.iteration_images = []
355
+ self.iteration_tensors = []
356
+ self.attention_masks_cache = []
357
+ self.attentions_per_iter = []
358
+
359
+ # Iteração 0 (imagem original)
360
+ self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False))
361
+ self.iteration_tensors.append(images.clone().detach())
362
+ # Atenção da imagem original: captura integrada
363
+ outputs0, attentions0 = self._capture_outputs_and_attentions(images)
364
+ # Guardar atenções brutas
365
+ self.attentions_per_iter.append([att for att in attentions0])
366
+ # Gerar máscara de rollout para cache visual
367
+ from utils.visualization import attention_rollout
368
+ import cv2
369
+ b, _, h, w = images.shape
370
+ mask0 = attention_rollout(attentions0, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion)
371
+ mask0_resized = cv2.resize(mask0, (w, h))
372
+ self.attention_masks_cache.append(mask0.copy())
373
+
374
+ loss_fn = torch.nn.CrossEntropyLoss()
375
+
376
+ for _ in range(self.steps):
377
+ # Normalizar para forward
378
+ adv_norm = (adv_denorm - mean) / std
379
+ adv_norm.requires_grad = True
380
+ outputs, attentions = self._capture_outputs_and_attentions(adv_norm)
381
+ if isinstance(outputs, tuple): # compatibilidade com modelos que retornam extras
382
+ outputs = outputs[0]
383
+ loss = loss_fn(outputs, labels)
384
+ grad = torch.autograd.grad(loss, adv_norm, retain_graph=False, create_graph=False)[0]
385
+
386
+ # Atenção da imagem adversarial atual (já capturada)
387
+ # Cache de atenções por camada/head
388
+ self.attentions_per_iter.append([att for att in attentions])
389
+ # Rollout para gerar mapa usado na ponderação
390
+ mask = attention_rollout(attentions, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion)
391
+ mask_resized = cv2.resize(mask, (adv_norm.shape[-1], adv_norm.shape[-2]))
392
+ mmax = mask_resized.max() if mask_resized.max() > 0 else 1.0
393
+ mask_resized = (mask_resized / mmax).astype('float32')
394
+ att_map = torch.from_numpy(mask_resized).to(self.device).unsqueeze(0).unsqueeze(0).repeat(adv_norm.size(0), 3, 1, 1)
395
+ # Cache visual
396
+ self.attention_masks_cache.append(mask.copy())
397
+ grad_weighted = grad * att_map
398
+
399
+ # FGSM step em pixel space (sign do gradiente normalizado equivale ao do desnormalizado)
400
+ adv_denorm = adv_denorm.detach() + self.eps_step * grad_weighted.sign()
401
+
402
+ # Projeção na bola L_inf de raio eps em relação à imagem original
403
+ delta = torch.clamp(adv_denorm - images_denorm, min=-self.eps, max=self.eps)
404
+ adv_denorm = torch.clamp(images_denorm + delta, 0.0, 1.0).detach()
405
+
406
+ # Salvar artefatos desta iteração
407
+ self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False))
408
+ self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach())
409
+
410
+ # Retorna tensor normalizado final
411
+ adv_final = (adv_denorm - mean) / std
412
+ return adv_final, self.iteration_images
413
+
414
+
415
+
416
+ class AttentionWeightedPGD(torch.nn.Module):
417
+ """
418
+ [Deprecated]
419
+ Implementação errada do ataque SAGA, mas que consegue fazer ataques
420
+ adversariais eficazes em ViTs usando mapas de atenção para pesar o gradiente.
421
+ """
422
  def __init__(self, model, eps=0.03, steps=10):
423
  super().__init__()
424
  self.model = model
utils/model_loader.py CHANGED
@@ -110,17 +110,20 @@ def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device]
110
  elif 'state_dict' in checkpoint:
111
  state_dict = checkpoint['state_dict']
112
  num_classes = infer_num_classes(state_dict)
 
113
  model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
114
  model.load_state_dict(state_dict)
115
  elif 'model_state_dict' in checkpoint:
116
  # Novo formato com class_names embutidas
117
  state_dict = checkpoint['model_state_dict']
118
  num_classes = infer_num_classes(state_dict)
 
119
  model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
120
  model.load_state_dict(state_dict)
121
  else:
122
  # assume dict é um state_dict
123
  num_classes = infer_num_classes(checkpoint)
 
124
  model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
125
  model.load_state_dict(checkpoint)
126
  else:
@@ -147,13 +150,16 @@ def load_model_and_labels(
147
  device = device or DEVICE_DEFAULT
148
  checkpoint = load_checkpoint(model_path, device=device)
149
  class_names_ckpt = extract_class_names(checkpoint)
150
- class_names_file = load_class_names_from_file(labels_file)
151
- class_names = class_names_file or class_names_ckpt
152
- source: Optional[str] = None
153
- if class_names_file:
154
- source = 'file'
155
- elif class_names_ckpt:
156
- source = 'checkpoint'
 
 
 
157
 
158
  model = build_model_from_checkpoint(checkpoint, device=device)
159
  return model, class_names, source
 
110
  elif 'state_dict' in checkpoint:
111
  state_dict = checkpoint['state_dict']
112
  num_classes = infer_num_classes(state_dict)
113
+ # TODO: fazer tratamento dinâmico para timm, pytorch, etc.
114
  model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
115
  model.load_state_dict(state_dict)
116
  elif 'model_state_dict' in checkpoint:
117
  # Novo formato com class_names embutidas
118
  state_dict = checkpoint['model_state_dict']
119
  num_classes = infer_num_classes(state_dict)
120
+ # TODO: fazer tratamento dinâmico para timm, pytorch, etc.
121
  model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
122
  model.load_state_dict(state_dict)
123
  else:
124
  # assume dict é um state_dict
125
  num_classes = infer_num_classes(checkpoint)
126
+ # TODO: fazer tratamento dinâmico para timm, pytorch, etc.
127
  model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
128
  model.load_state_dict(checkpoint)
129
  else:
 
150
  device = device or DEVICE_DEFAULT
151
  checkpoint = load_checkpoint(model_path, device=device)
152
  class_names_ckpt = extract_class_names(checkpoint)
153
+ # class_names_file = load_class_names_from_file(labels_file)
154
+ # class_names = class_names_file or class_names_ckpt
155
+ # source: Optional[str] = None
156
+ # if class_names_file:
157
+ # source = 'file'
158
+ # elif class_names_ckpt:
159
+ # source = 'checkpoint'
160
+
161
+ class_names = class_names_ckpt
162
+ source = 'checkpoint' if class_names_ckpt else None
163
 
164
  model = build_model_from_checkpoint(checkpoint, device=device)
165
  return model, class_names, source
utils/visualization.py CHANGED
@@ -77,9 +77,9 @@ def extract_attention_maps(model, image: torch.Tensor) -> list:
77
  return attentions
78
 
79
 
80
- def attention_rollout(attentions: list,
81
- discard_ratio: float = 0.9,
82
- head_fusion: str = 'max') -> np.ndarray:
83
  """
84
  Implementa Attention Rollout seguindo a implementação original.
85
 
@@ -93,12 +93,12 @@ def attention_rollout(attentions: list,
93
  Returns:
94
  mask: Array numpy [grid_size, grid_size] com valores normalizados [0, 1]
95
  """
96
- # Inicializar com matriz identidade (CORREÇÃO 1)
97
  result = torch.eye(attentions[0].size(-1))
98
 
99
  with torch.no_grad():
100
  for attention in attentions:
101
- # Agregar heads (CORREÇÃO 2: usar axis=1, não dim=0)
102
  if head_fusion == 'mean':
103
  attention_heads_fused = attention.mean(axis=1)
104
  elif head_fusion == 'max':
@@ -107,12 +107,20 @@ def attention_rollout(attentions: list,
107
  attention_heads_fused = attention.min(axis=1)[0]
108
  else:
109
  raise ValueError(f"head_fusion deve ser 'mean', 'max' ou 'min'")
110
-
111
- # Descartar atenções fracas, mas proteger CLS token
112
- flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
113
- _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
114
- indices = indices[indices != 0] # Proteger CLS token
115
- flat[0, indices] = 0
 
 
 
 
 
 
 
 
116
 
117
  # Adicionar identidade e normalizar
118
  I = torch.eye(attention_heads_fused.size(-1))
@@ -124,7 +132,6 @@ def attention_rollout(attentions: list,
124
  # Rollout recursivo
125
  result = torch.matmul(a, result)
126
 
127
- # CORREÇÃO 4: Extrair atenção do CLS token (batch, CLS, patches)
128
  # Look at the total attention between the class token and the image patches
129
  mask = result[0, 0, 1:]
130
 
@@ -188,7 +195,8 @@ def extract_attention_for_iterations(
188
  head_fusion: str = 'max'
189
  ) -> list:
190
  """
191
- Extrai mapas de atenção para cada iteração do ataque PGD.
 
192
 
193
  Args:
194
  model: Modelo ViT
@@ -217,6 +225,49 @@ def extract_attention_for_iterations(
217
  return attention_masks
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def create_iteration_attention_overlays(
221
  iteration_images: list,
222
  attention_masks: list,
 
77
  return attentions
78
 
79
 
80
+ def attention_rollout(attentions: list,
81
+ discard_ratio: float = 0.9,
82
+ head_fusion: str = 'max') -> np.ndarray:
83
  """
84
  Implementa Attention Rollout seguindo a implementação original.
85
 
 
93
  Returns:
94
  mask: Array numpy [grid_size, grid_size] com valores normalizados [0, 1]
95
  """
96
+ # Inicializar com matriz identidade
97
  result = torch.eye(attentions[0].size(-1))
98
 
99
  with torch.no_grad():
100
  for attention in attentions:
101
+ # Agregar heads
102
  if head_fusion == 'mean':
103
  attention_heads_fused = attention.mean(axis=1)
104
  elif head_fusion == 'max':
 
107
  attention_heads_fused = attention.min(axis=1)[0]
108
  else:
109
  raise ValueError(f"head_fusion deve ser 'mean', 'max' ou 'min'")
110
+ # Aplicar descarte condicional das atenções fracas por amostra
111
+ if discard_ratio > 0.0:
112
+ bsz, tokens, _ = attention_heads_fused.shape
113
+ flat = attention_heads_fused.view(bsz, -1)
114
+ k = int(flat.size(-1) * discard_ratio)
115
+ if k > 0:
116
+ # Menores valores (largest=False)
117
+ vals, idxs = torch.topk(flat, k, dim=-1, largest=False)
118
+ for b in range(bsz):
119
+ idxs_b = idxs[b]
120
+ # proteger CLS (posição 0 nas matrizes quadradas)
121
+ idxs_b = idxs_b[idxs_b != 0]
122
+ flat[b, idxs_b] = 0
123
+ attention_heads_fused = flat.view(bsz, tokens, tokens)
124
 
125
  # Adicionar identidade e normalizar
126
  I = torch.eye(attention_heads_fused.size(-1))
 
132
  # Rollout recursivo
133
  result = torch.matmul(a, result)
134
 
 
135
  # Look at the total attention between the class token and the image patches
136
  mask = result[0, 0, 1:]
137
 
 
195
  head_fusion: str = 'max'
196
  ) -> list:
197
  """
198
+ [Deprecated when cached attentions are present]
199
+ Extrai mapas de atenção para cada iteração do ataque usando hooks.
200
 
201
  Args:
202
  model: Modelo ViT
 
225
  return attention_masks
226
 
227
 
228
+ def rollout_from_cached_attentions(
229
+ attentions_per_iter: list,
230
+ discard_ratio: float = 0.9,
231
+ head_fusion: str = 'max'
232
+ ) -> list:
233
+ """
234
+ Gera máscaras de atenção por iteração a partir de atenções já capturadas no ataque.
235
+
236
+ Args:
237
+ attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
238
+ discard_ratio: Proporção de atenções fracas a descartar
239
+ head_fusion: Como agregar heads ('mean', 'max', 'min')
240
+
241
+ Returns:
242
+ Lista de máscaras de atenção [grid, grid] normalizadas [0, 1]
243
+ """
244
+ attention_masks = []
245
+
246
+ if attentions_per_iter is None or len(attentions_per_iter) == 0:
247
+ return attention_masks
248
+
249
+ for layer_attns in attentions_per_iter:
250
+ # layer_attns: lista de tensores por camada [B, H, T, T]
251
+ # Garantir CPU e detach
252
+ attentions_cpu = []
253
+ for att in layer_attns:
254
+ if isinstance(att, torch.Tensor):
255
+ attentions_cpu.append(att.detach().cpu())
256
+ else:
257
+ # já é CPU numpy/tensor? tentar converter via torch.as_tensor
258
+ attentions_cpu.append(torch.as_tensor(att))
259
+
260
+ # Aplicar rollout padrão sobre a lista de camadas
261
+ mask = attention_rollout(
262
+ attentions_cpu,
263
+ discard_ratio=discard_ratio,
264
+ head_fusion=head_fusion
265
+ )
266
+ attention_masks.append(mask)
267
+
268
+ return attention_masks
269
+
270
+
271
  def create_iteration_attention_overlays(
272
  iteration_images: list,
273
  attention_masks: list,