lucasddmc commited on
Commit
cfe9e41
·
1 Parent(s): 7aad02c

fix: fixes MIM attack to init momentum on denormalized space

Browse files
Files changed (1) hide show
  1. utils/attacks.py +14 -9
utils/attacks.py CHANGED
@@ -359,8 +359,8 @@ class MIFGSM(torchattacks.MIFGSM):
359
  images_denorm = images * std + mean
360
  adv_images_denorm = images_denorm.clone().detach()
361
 
362
- # Inicializar momentum
363
- momentum = torch.zeros_like(images).detach().to(self.device)
364
 
365
  self.iteration_images = []
366
  self.iteration_tensors = []
@@ -382,19 +382,24 @@ class MIFGSM(torchattacks.MIFGSM):
382
  else:
383
  cost = loss(outputs, labels)
384
 
385
- # Calcular gradiente
386
  grad = torch.autograd.grad(cost, adv_images,
387
  retain_graph=False, create_graph=False)[0]
388
 
 
 
 
 
389
  # Normalizar gradiente (chave do MI-FGSM!)
390
- grad = grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)
 
391
 
392
- # Aplicar momentum
393
- grad = grad + momentum * self.decay
394
- momentum = grad
395
 
396
- # Voltar para espaço desnormalizado para aplicar perturbação
397
- adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad.sign() * std
398
  delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps)
399
  adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach()
400
 
 
359
  images_denorm = images * std + mean
360
  adv_images_denorm = images_denorm.clone().detach()
361
 
362
+ # Inicializar momentum no espaço desnormalizado
363
+ momentum = torch.zeros_like(images_denorm).detach().to(self.device)
364
 
365
  self.iteration_images = []
366
  self.iteration_tensors = []
 
382
  else:
383
  cost = loss(outputs, labels)
384
 
385
+ # Calcular gradiente no espaço normalizado
386
  grad = torch.autograd.grad(cost, adv_images,
387
  retain_graph=False, create_graph=False)[0]
388
 
389
+ # Converter gradiente para espaço desnormalizado
390
+ # Isso é necessário porque o momentum precisa estar no mesmo espaço que as perturbações
391
+ grad_denorm = grad * std
392
+
393
  # Normalizar gradiente (chave do MI-FGSM!)
394
+ # Normalização acontece NO ESPAÇO DESNORMALIZADO para manter consistência
395
+ grad_denorm = grad_denorm / torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True)
396
 
397
+ # Aplicar momentum no espaço desnormalizado
398
+ grad_denorm = grad_denorm + momentum * self.decay
399
+ momentum = grad_denorm
400
 
401
+ # Aplicar perturbação no espaço desnormalizado
402
+ adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad_denorm.sign()
403
  delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps)
404
  adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach()
405