|
|
import os |
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
from torch.autograd import Variable |
|
|
import torch.optim as optim |
|
|
from torchvision import datasets, transforms |
|
|
import numpy as np |
|
|
|
|
|
def PGD_Revise(model, img, gt, epsilon, step_size, num_steps, device): |
|
|
model.eval() |
|
|
x_adv = img.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, img.shape)).float().to(device) |
|
|
x_adv = torch.clamp(x_adv, 0.0, 1.0) |
|
|
for k in range(num_steps): |
|
|
x_adv.requires_grad_() |
|
|
output = model(x_adv) |
|
|
model.zero_grad() |
|
|
with torch.enable_grad(): |
|
|
loss_adv = nn.CrossEntropyLoss()(output, gt) |
|
|
loss_adv.backward() |
|
|
eta = step_size * x_adv.grad.sign() |
|
|
|
|
|
x_adv = x_adv.detach() + eta |
|
|
x_adv = torch.min(torch.max(x_adv, img - epsilon), img + epsilon) |
|
|
x_adv = torch.clamp(x_adv, 0.0, 1.0) |
|
|
if k == (num_steps//2) : |
|
|
x_adv_mid = x_adv.detach().clone() |
|
|
x_adv = Variable(x_adv, requires_grad=False) |
|
|
x_adv_mid = Variable(x_adv_mid, requires_grad = False) |
|
|
return x_adv, x_adv_mid |
|
|
|
|
|
|
|
|
def PGD_Revise2(model, img, gt, epsilon, step_size, num_steps, device): |
|
|
model.eval() |
|
|
x_adv = img.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, img.shape)).float().to(device) |
|
|
x_adv = torch.clamp(x_adv, 0.0, 1.0) |
|
|
for k in range(num_steps): |
|
|
x_adv.requires_grad_() |
|
|
output = model(x_adv) |
|
|
model.zero_grad() |
|
|
with torch.enable_grad(): |
|
|
loss_adv = nn.CrossEntropyLoss()(output, gt) |
|
|
loss_adv.backward() |
|
|
eta = step_size * x_adv.grad.sign() |
|
|
|
|
|
x_adv = x_adv.detach() + eta |
|
|
x_adv = torch.min(torch.max(x_adv, img - epsilon), img + epsilon) |
|
|
x_adv = torch.clamp(x_adv, 0.0, 1.0) |
|
|
if k == (num_steps//2 -1) : |
|
|
x_adv_now = x_adv.detach().clone() |
|
|
x_adv = Variable(x_adv, requires_grad=False) |
|
|
x_adv_now = Variable(x_adv_now, requires_grad = False) |
|
|
return x_adv_now, x_adv |