lucasddmc commited on
Commit
641929c
·
1 Parent(s): 98d71cd

feat: add new SAGA attack with attention capture

Browse files
Files changed (2) hide show
  1. app.py +23 -12
  2. utils/attacks.py +126 -0
app.py CHANGED
@@ -6,7 +6,7 @@ from typing import Optional, List, Tuple
6
  from utils.model_loader import load_model_and_labels
7
  from utils.preprocessing import get_default_transform, preprocess_image
8
  from utils.inference import predict_topk
9
- from utils.attacks import PGDIterations, FGSM
10
  from utils.visualization import extract_attention_maps, attention_rollout, create_attention_overlay, extract_attention_for_iterations, create_iteration_attention_overlays
11
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -224,6 +224,8 @@ def run_attack(
224
  # Configurar ataque baseado no tipo selecionado
225
  if attack_type == "FGSM":
226
  attack = FGSM(model, eps=eps)
 
 
227
  else: # PGD
228
  attack = PGDIterations(model, eps=eps, alpha=alpha, steps=steps)
229
 
@@ -233,13 +235,19 @@ def run_attack(
233
  # Executar ataque
234
  adv_tensor, iteration_images = attack(img_tensor, original_label)
235
 
236
- # Extrair atenção para todas as iterações (incluindo original)
237
- attention_masks = extract_attention_for_iterations(
238
- model,
239
- attack.iteration_tensors,
240
- discard_ratio=discard_ratio,
241
- head_fusion=head_fusion
242
- )
 
 
 
 
 
 
243
 
244
  # Criar overlays de atenção
245
  attention_overlays = create_iteration_attention_overlays(
@@ -267,7 +275,10 @@ def run_attack(
267
  if attack_type == "PGD":
268
  result += f"- Alpha (α): {alpha:.4f}\n"
269
  result += f"- Steps: {steps}\n"
270
- else:
 
 
 
271
  result += f"- Single-step (sem iterações)\n"
272
 
273
  result += f"\n**Predição Original:**\n"
@@ -364,10 +375,10 @@ def create_app():
364
  gr.Markdown("#### ⚔️ Configuração do Ataque")
365
 
366
  attack_type = gr.Dropdown(
367
- choices=["PGD", "FGSM"],
368
  value="PGD",
369
  label="Tipo de Ataque",
370
- info="PGD: iterativo (múltiplos steps) | FGSM: single-step (mais rápido)"
371
  )
372
 
373
  eps_input = gr.Slider(
@@ -400,7 +411,7 @@ def create_app():
400
  def update_attack_params(attack_type):
401
  if attack_type == "FGSM":
402
  return gr.update(visible=False)
403
- else: # PGD
404
  return gr.update(visible=True)
405
 
406
  attack_type.change(
 
6
  from utils.model_loader import load_model_and_labels
7
  from utils.preprocessing import get_default_transform, preprocess_image
8
  from utils.inference import predict_topk
9
+ from utils.attacks import PGDIterations, FGSM, SAGA
10
  from utils.visualization import extract_attention_maps, attention_rollout, create_attention_overlay, extract_attention_for_iterations, create_iteration_attention_overlays
11
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
224
  # Configurar ataque baseado no tipo selecionado
225
  if attack_type == "FGSM":
226
  attack = FGSM(model, eps=eps)
227
+ elif attack_type == "SAGA":
228
+ attack = SAGA(model, eps=eps, steps=steps)
229
  else: # PGD
230
  attack = PGDIterations(model, eps=eps, alpha=alpha, steps=steps)
231
 
 
235
  # Executar ataque
236
  adv_tensor, iteration_images = attack(img_tensor, original_label)
237
 
238
+ # Extrair atenção para todas as iterações
239
+ # SAGA já calcula atenção internamente, podemos reutilizar!
240
+ if attack_type == "SAGA" and hasattr(attack, 'attention_masks_cache') and len(attack.attention_masks_cache) > 0:
241
+ # Usar atenção já calculada pelo SAGA
242
+ attention_masks = attack.attention_masks_cache
243
+ else:
244
+ # Para FGSM e PGD, calcular atenção normalmente
245
+ attention_masks = extract_attention_for_iterations(
246
+ model,
247
+ attack.iteration_tensors,
248
+ discard_ratio=discard_ratio,
249
+ head_fusion=head_fusion
250
+ )
251
 
252
  # Criar overlays de atenção
253
  attention_overlays = create_iteration_attention_overlays(
 
275
  if attack_type == "PGD":
276
  result += f"- Alpha (α): {alpha:.4f}\n"
277
  result += f"- Steps: {steps}\n"
278
+ elif attack_type == "SAGA":
279
+ result += f"- Steps: {steps}\n"
280
+ result += f"- Gradiente ponderado por atenção (ViT-specific)\n"
281
+ else: # FGSM
282
  result += f"- Single-step (sem iterações)\n"
283
 
284
  result += f"\n**Predição Original:**\n"
 
375
  gr.Markdown("#### ⚔️ Configuração do Ataque")
376
 
377
  attack_type = gr.Dropdown(
378
+ choices=["PGD", "FGSM", "SAGA"],
379
  value="PGD",
380
  label="Tipo de Ataque",
381
+ info="PGD: iterativo | FGSM: single-step | SAGA: gradient × attention (ViT-specific)"
382
  )
383
 
384
  eps_input = gr.Slider(
 
411
  def update_attack_params(attack_type):
412
  if attack_type == "FGSM":
413
  return gr.update(visible=False)
414
+ else: # PGD ou SAGA
415
  return gr.update(visible=True)
416
 
417
  attack_type.change(
utils/attacks.py CHANGED
@@ -191,4 +191,130 @@ class PGDIterations(torchattacks.PGD):
191
 
192
  # Retornar imagem normalizada para o modelo
193
  adv_images = (adv_images_denorm - mean) / std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  return adv_images, self.iteration_images
 
191
 
192
  # Retornar imagem normalizada para o modelo
193
  adv_images = (adv_images_denorm - mean) / std
194
+ return adv_images, self.iteration_images
195
+
196
+
197
+ class SAGA(torch.nn.Module):
198
+ """
199
+ SAGA: Self-Attention Gradient Attack
200
+
201
+ Ataque adversarial específico para Vision Transformers que multiplica
202
+ o gradiente FGSM pelo mapa de atenção do modelo, focando perturbações
203
+ nas regiões que o modelo considera importantes.
204
+
205
+ Baseado em: https://github.com/MetaMain/ViTRobust
206
+ Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021)
207
+ """
208
+ def __init__(self, model, eps=0.03, steps=10):
209
+ super().__init__()
210
+ self.model = model
211
+ self.eps = eps
212
+ self.steps = steps
213
+ self.device = next(model.parameters()).device
214
+ self.iteration_images: List[Image.Image] = []
215
+ self.iteration_tensors: List[torch.Tensor] = []
216
+ self.attention_masks_cache: List[np.ndarray] = [] # Cache das máscaras de atenção
217
+
218
+ def get_attention_map(self, images: torch.Tensor, save_for_viz: bool = False) -> tuple:
219
+ """
220
+ Extrai mapa de atenção do ViT usando attention rollout.
221
+ Retorna:
222
+ - mask_tensor: [B, C, H, W] para uso no ataque
223
+ - mask_np: [H, W] numpy array para visualização (se save_for_viz=True)
224
+ """
225
+ from utils.visualization import extract_attention_maps, attention_rollout
226
+ import cv2
227
+
228
+ batch_size = images.shape[0]
229
+ img_size = images.shape[2]
230
+
231
+ # Extrair attention maps
232
+ attentions = extract_attention_maps(self.model, images)
233
+
234
+ # Aplicar attention rollout
235
+ mask = attention_rollout(attentions, discard_ratio=0.9, head_fusion='max')
236
+
237
+ # Salvar para visualização se necessário
238
+ if save_for_viz:
239
+ self.attention_masks_cache.append(mask.copy())
240
+
241
+ # Redimensionar para tamanho da imagem (14x14 -> 224x224)
242
+ mask_resized = cv2.resize(mask, (img_size, img_size))
243
+
244
+ # Expandir para 3 canais e batch: [H, W] -> [B, C, H, W]
245
+ mask_tensor = torch.from_numpy(mask_resized).float().to(self.device)
246
+ mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
247
+ mask_tensor = mask_tensor.repeat(batch_size, 3, 1, 1) # [B, 3, H, W]
248
+
249
+ return mask_tensor, mask if save_for_viz else None
250
+
251
+ def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]:
252
+ """
253
+ Executa ataque SAGA e retorna:
254
+ - adv_images: tensor adversarial final
255
+ - iteration_images: lista de PIL Images de cada iteração
256
+ """
257
+ images = images.clone().detach().to(self.device)
258
+ labels = labels.clone().detach().to(self.device)
259
+
260
+ loss_fn = torch.nn.CrossEntropyLoss()
261
+
262
+ # Desnormalizar
263
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
264
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
265
+
266
+ images_denorm = images * std + mean
267
+ adv_images_denorm = images_denorm.clone().detach()
268
+
269
+ self.iteration_images = []
270
+ self.iteration_tensors = []
271
+ self.attention_masks_cache = []
272
+
273
+ # Salvar imagem original (iteração 0)
274
+ pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False)
275
+ self.iteration_images.append(pil_img_orig)
276
+ self.iteration_tensors.append(images.clone().detach())
277
+
278
+ # Calcular atenção para imagem original e salvar
279
+ _, _ = self.get_attention_map(images, save_for_viz=True)
280
+
281
+ # Calcular eps_step
282
+ eps_step = self.eps / self.steps
283
+
284
+ for step in range(self.steps):
285
+ # Normalizar para passar pelo modelo
286
+ adv_images = (adv_images_denorm - mean) / std
287
+ adv_images.requires_grad = True
288
+
289
+ # Forward pass
290
+ outputs = self.model(adv_images)
291
+
292
+ # Calcular loss
293
+ cost = loss_fn(outputs, labels)
294
+
295
+ # Calcular gradiente
296
+ self.model.zero_grad()
297
+ cost.backward()
298
+ grad = adv_images.grad.data
299
+
300
+ # SAGA: Multiplicar gradiente pelo mapa de atenção (e salvar máscara)
301
+ attention_map, _ = self.get_attention_map(adv_images.detach(), save_for_viz=True)
302
+ grad_weighted = grad * attention_map
303
+
304
+ # Aplicar perturbação no espaço desnormalizado
305
+ adv_images_denorm = adv_images_denorm + eps_step * grad_weighted.sign() * std
306
+
307
+ # Clip para [0, 1]
308
+ adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach()
309
+
310
+ # Normalizar para salvar tensor
311
+ adv_images_normalized = (adv_images_denorm - mean) / std
312
+
313
+ # Salvar iteração
314
+ pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False)
315
+ self.iteration_images.append(pil_img)
316
+ self.iteration_tensors.append(adv_images_normalized.clone().detach())
317
+
318
+ # Retornar normalizado
319
+ adv_images = (adv_images_denorm - mean) / std
320
  return adv_images, self.iteration_images