File size: 5,725 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils_ import *
from models.generator import CrossAttenGenerator
from image_transformer import rotation
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import argparse
Image.MAX_IMAGE_PIXELS = None
import torch.optim as optim
parser = argparse.ArgumentParser(description='Clip-based Generative Networks')
parser.add_argument('--train_dir', default='./dataset/ImageNet/train', help='imagenet')
parser.add_argument('--batch_size', type=int, default=20, help='Number of training samples/batch')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--lr', type=float, default=2e-4, help='Initial learning rate')
parser.add_argument('--eps', type=int, default=16, help='Perturbation budget')
parser.add_argument('--model_type', type=str, default='res152', help='Source model')
parser.add_argument('--start_epoch', type=int, default=0, help='Start epoch')
parser.add_argument('--label_flag', type=str, default='N8', help='Label nums: N8, C20,...,C200')
parser.add_argument('--nz', type=int, default=16, help='nz')
parser.add_argument('--save_dir', type=str, default='checkpoints', help='Dictionary to save the model')
parser.add_argument('--load_path', type=str, help='Path to checkpoint')
parser.add_argument('--finetune', action='store_true', help='Finetune for single class attack')
parser.add_argument('--finetune_class', type=int, help='Class id to be finetuned')
parser.add_argument('--mask_ratio', type=float, default='2e-1', help='Mask ratio in finetune stage')
args = parser.parse_args()
print(args)
# set class
n_class = 1000
# Normalize (0-1)
eps = args.eps / 255.
use_gpu = torch.cuda.is_available()
if use_gpu:
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(1111)
# GPU
device_ids = [i for i in range(0, torch.cuda.device_count())]
print(device_ids)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# Input dimension and generator
if args.model_type == 'incv3':
scale_size, img_size = 300, 299
netG = CrossAttenGenerator(inception=True, nz=args.nz, device=device)
else:
scale_size, img_size = 256, 224
netG = CrossAttenGenerator(nz=args.nz, device=device)
if args.start_epoch > 0:
netG.load_state_dict(torch.load(args.load_path, map_location=device))
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
netG = nn.DataParallel(netG, device_ids=device_ids)
netG = netG.to(device)
# Optimizer
optimG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.5, 0.999))
if torch.cuda.device_count() > 1:
optimG = nn.DataParallel(optimG, device_ids=device_ids)
optimG = optimG.module
# Data
train_set = get_data(args.train_dir, scale_size, img_size)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=8,
pin_memory=True)
# Surrogate model
if args.model_type == 'incv3':
model = torchvision.models.inception_v3(pretrained=True).to(device)
elif args.model_type == 'res152':
model = torchvision.models.resnet152(pretrained=True).to(device)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model, device_ids=device_ids)
model = model.module
model.eval()
# class
label_set = get_classes(args.label_flag)
# Loss
criterion = nn.CrossEntropyLoss()
# text condition
text_cond_dict = torch.load('text_feature.pth')
# save dir
save_dir = os.path.join(args.save_dir, args.model_type)
# Training
for epoch in range(args.start_epoch, args.epochs):
running_loss = 0
for i, (imgs, _) in enumerate(tqdm(train_loader)):
img = imgs[0].to(device)
img_rot = rotation(img)[0]
img_aug = imgs[1].to(device)
if args.finetune:
label = np.array([args.finetune_class] * img.size(0))
else:
np.random.shuffle(label_set)
label = np.random.choice(label_set, img.size(0))
cond = torch.stack([text_cond_dict[j] for j in label], dim=0)
label = torch.from_numpy(label).long().to(device)
netG.train()
optimG.zero_grad()
# generate img
noise = netG(input=img, cond=cond, eps=eps)
noise_rot = netG(input=img_rot, cond=cond, eps=eps)
noise_aug = netG(input=img_aug, cond=cond, eps=eps)
if args.finetune:
noise = get_mask(noise, args.mask_ratio, device)
noise_rot = get_mask(noise_rot, args.mask_ratio, device)
noise_aug = get_mask(noise_aug, args.mask_ratio, device)
adv = noise + img
adv = torch.clamp(adv, 0.0, 1.0)
adv_rot = noise_rot + img_rot
adv_rot = torch.clamp(adv_rot, 0.0, 1.0)
adv_aug = noise_aug + img_aug
adv_aug = torch.clamp(adv_aug, 0.0, 1.0)
adv_out = model(normalize(adv))
adv_rot_out = model(normalize(adv_rot))
adv_aug_out = model(normalize(adv_aug))
loss = criterion(adv_out, label) + criterion(adv_rot_out, label) + criterion(adv_aug_out, label)
loss.backward()
optimG.step()
if i % 10 == 9:
print('Epoch: {} \t Batch: {}/{} \t loss: {:.5f}'.format(epoch, i, len(train_loader), running_loss / 100))
running_loss = 0
running_loss += abs(loss.item())
if epoch >= args.start_epoch:
if torch.cuda.device_count() > 1:
torch.save(netG.module.state_dict(), '{}/model-{}.pth'.format(save_dir, epoch))
else:
torch.save(netG.state_dict(), '{}/model-{}.pth'.format(save_dir, epoch))
|