CGSCORE / examples /image /test_onepixel.py
Yaning1001's picture
Add files using upload-large-folder tool
4113c4d verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F #233
import torch.optim as optim
from torchvision import datasets,models,transforms
from PIL import Image
import argparse
from deeprobust.image.attack.onepixel import Onepixel
from deeprobust.image.netmodels import resnet
from deeprobust.image.config import attack_params
from deeprobust.image.utils import download_model
def parameter_parser():
parser = argparse.ArgumentParser(description = "Run attack algorithms.")
parser.add_argument("--destination",
default = './trained_models/',
help = "choose destination to load the pretrained models.")
parser.add_argument("--filename",
default = "MNIST_CNN_epoch_20.pt")
return parser.parse_args()
args = parameter_parser() # read argument and creat an argparse object
model = resnet.ResNet18().to('cuda')
print("Load network")
model.load_state_dict(torch.load("./trained_models/CIFAR10_ResNet18_epoch_20.pt"))
model.eval()
transform_val = transforms.Compose([
transforms.ToTensor(),
])
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('deeprobust/image/data', train = False, download=True,
transform = transform_val),
batch_size = 1, shuffle=True) #, **kwargs)
classes = np.array(('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'))
xx, yy = next(iter(test_loader))
xx = xx.to('cuda').float()
"""
Generate adversarial examples
"""
F1 = Onepixel(model, device = "cuda") ### or cuda
AdvExArray = F1.generate(xx, yy)
predict0 = model(xx)
predict0= predict0.argmax(dim=1, keepdim=True)
predict1 = model(AdvExArray)
predict1= predict1.argmax(dim=1, keepdim=True)
print("original prediction:")
print(predict0)
print("attack prediction:")
print(predict1)
xx = xx.cpu().detach().numpy()
AdvExArray = AdvExArray.cpu().detach().numpy()
import matplotlib.pyplot as plt
xx = xx[0].transpose(1, 2, 0)
AdvExArray = AdvExArray[0].transpose(1, 2, 0)
plt.imshow(xx, vmin=0, vmax=255)
plt.savefig('./adversary_examples/cifar10_advexample_ori.png')
plt.imshow(AdvExArray, vmin=0, vmax=255)
plt.savefig('./adversary_examples/cifar10_advexample_onepixel_adv.png')