fix: fixes MIM attack to init momentum on denormalized space
Browse files- utils/attacks.py +14 -9
utils/attacks.py
CHANGED
|
@@ -359,8 +359,8 @@ class MIFGSM(torchattacks.MIFGSM):
|
|
| 359 |
images_denorm = images * std + mean
|
| 360 |
adv_images_denorm = images_denorm.clone().detach()
|
| 361 |
|
| 362 |
-
# Inicializar momentum
|
| 363 |
-
momentum = torch.zeros_like(
|
| 364 |
|
| 365 |
self.iteration_images = []
|
| 366 |
self.iteration_tensors = []
|
|
@@ -382,19 +382,24 @@ class MIFGSM(torchattacks.MIFGSM):
|
|
| 382 |
else:
|
| 383 |
cost = loss(outputs, labels)
|
| 384 |
|
| 385 |
-
# Calcular gradiente
|
| 386 |
grad = torch.autograd.grad(cost, adv_images,
|
| 387 |
retain_graph=False, create_graph=False)[0]
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
# Normalizar gradiente (chave do MI-FGSM!)
|
| 390 |
-
|
|
|
|
| 391 |
|
| 392 |
-
# Aplicar momentum
|
| 393 |
-
|
| 394 |
-
momentum =
|
| 395 |
|
| 396 |
-
#
|
| 397 |
-
adv_images_denorm = adv_images_denorm.detach() + self.alpha *
|
| 398 |
delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps)
|
| 399 |
adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach()
|
| 400 |
|
|
|
|
| 359 |
images_denorm = images * std + mean
|
| 360 |
adv_images_denorm = images_denorm.clone().detach()
|
| 361 |
|
| 362 |
+
# Inicializar momentum no espaço desnormalizado
|
| 363 |
+
momentum = torch.zeros_like(images_denorm).detach().to(self.device)
|
| 364 |
|
| 365 |
self.iteration_images = []
|
| 366 |
self.iteration_tensors = []
|
|
|
|
| 382 |
else:
|
| 383 |
cost = loss(outputs, labels)
|
| 384 |
|
| 385 |
+
# Calcular gradiente no espaço normalizado
|
| 386 |
grad = torch.autograd.grad(cost, adv_images,
|
| 387 |
retain_graph=False, create_graph=False)[0]
|
| 388 |
|
| 389 |
+
# Converter gradiente para espaço desnormalizado
|
| 390 |
+
# Isso é necessário porque o momentum precisa estar no mesmo espaço que as perturbações
|
| 391 |
+
grad_denorm = grad * std
|
| 392 |
+
|
| 393 |
# Normalizar gradiente (chave do MI-FGSM!)
|
| 394 |
+
# Normalização acontece NO ESPAÇO DESNORMALIZADO para manter consistência
|
| 395 |
+
grad_denorm = grad_denorm / torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True)
|
| 396 |
|
| 397 |
+
# Aplicar momentum no espaço desnormalizado
|
| 398 |
+
grad_denorm = grad_denorm + momentum * self.decay
|
| 399 |
+
momentum = grad_denorm
|
| 400 |
|
| 401 |
+
# Aplicar perturbação no espaço desnormalizado
|
| 402 |
+
adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad_denorm.sign()
|
| 403 |
delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps)
|
| 404 |
adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach()
|
| 405 |
|