|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
from torchvision import datasets,models,transforms |
|
|
from PIL import Image |
|
|
|
|
|
from deeprobust.image.attack.pgd import PGD |
|
|
import deeprobust.image.netmodels.resnet as resnet |
|
|
import deeprobust.image.netmodels.CNN as CNN |
|
|
from deeprobust.image.config import attack_params |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
model = resnet.ResNet18().to('cuda') |
|
|
print("Load network") |
|
|
|
|
|
import ipdb |
|
|
ipdb.set_trace() |
|
|
|
|
|
model.load_state_dict(torch.load("./trained_models/CIFAR10_ResNet18_epoch_20.pt")) |
|
|
model.eval() |
|
|
|
|
|
transform_val = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
datasets.CIFAR10('deeprobust/image/data', train = False, download=True, |
|
|
transform = transform_val), |
|
|
batch_size = 1, shuffle=True) |
|
|
|
|
|
|
|
|
classes = np.array(('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')) |
|
|
|
|
|
xx, yy = next(iter(test_loader)) |
|
|
xx = xx.to('cuda').float() |
|
|
|
|
|
predict0 = model(xx) |
|
|
predict0= predict0.argmax(dim=1, keepdim=True) |
|
|
|
|
|
adversary = PGD(model) |
|
|
AdvExArray = adversary.generate(xx, yy, **attack_params['PGD_CIFAR10']).float() |
|
|
|
|
|
predict1 = model(AdvExArray) |
|
|
predict1= predict1.argmax(dim=1, keepdim=True) |
|
|
|
|
|
print('====== RESULT =====') |
|
|
print('true label',classes[yy.cpu()]) |
|
|
print('predict_orig',classes[predict0.cpu()]) |
|
|
print('predict_adv',classes[predict1.cpu()]) |
|
|
|
|
|
x_show = xx.cpu().numpy().swapaxes(1,3).swapaxes(1,2)[0] |
|
|
|
|
|
plt.imshow(x_show, vmin = 0, vmax = 255) |
|
|
plt.savefig('./adversary_examples/cifar_advexample_orig.png') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AdvExArray = AdvExArray.cpu().detach().numpy() |
|
|
AdvExArray = AdvExArray.swapaxes(1,3).swapaxes(1,2)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.imshow(AdvExArray, vmin = 0, vmax = 255) |
|
|
plt.savefig('./adversary_examples/cifar_advexample_pgd.png') |
|
|
|
|
|
|
|
|
|