File size: 5,725 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils_ import *
from models.generator import CrossAttenGenerator
from image_transformer import rotation
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import argparse
Image.MAX_IMAGE_PIXELS = None
import torch.optim as optim

parser = argparse.ArgumentParser(description='Clip-based Generative Networks')
parser.add_argument('--train_dir', default='./dataset/ImageNet/train', help='imagenet')
parser.add_argument('--batch_size', type=int, default=20, help='Number of training samples/batch')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--lr', type=float, default=2e-4, help='Initial learning rate')
parser.add_argument('--eps', type=int, default=16, help='Perturbation budget')
parser.add_argument('--model_type', type=str, default='res152', help='Source model')
parser.add_argument('--start_epoch', type=int, default=0, help='Start epoch')
parser.add_argument('--label_flag', type=str, default='N8', help='Label nums: N8, C20,...,C200')
parser.add_argument('--nz', type=int, default=16, help='nz')
parser.add_argument('--save_dir', type=str, default='checkpoints', help='Dictionary to save the model')
parser.add_argument('--load_path', type=str, help='Path to checkpoint')
parser.add_argument('--finetune', action='store_true', help='Finetune for single class attack')
parser.add_argument('--finetune_class', type=int, help='Class id to be finetuned')
parser.add_argument('--mask_ratio', type=float, default='2e-1', help='Mask ratio in finetune stage')
args = parser.parse_args()
print(args)

# set class
n_class = 1000

# Normalize (0-1)
eps = args.eps / 255.
use_gpu = torch.cuda.is_available()
if use_gpu:
    torch.backends.cudnn.benchmark = True
    torch.cuda.manual_seed_all(1111)

# GPU
device_ids = [i for i in range(0, torch.cuda.device_count())]
print(device_ids)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

# Input dimension and generator
if args.model_type == 'incv3':
    scale_size, img_size = 300, 299
    netG = CrossAttenGenerator(inception=True, nz=args.nz, device=device)
else:
    scale_size, img_size = 256, 224
    netG = CrossAttenGenerator(nz=args.nz, device=device)
if args.start_epoch > 0:
    netG.load_state_dict(torch.load(args.load_path, map_location=device))
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    netG = nn.DataParallel(netG, device_ids=device_ids)
netG = netG.to(device)

# Optimizer
optimG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.5, 0.999))

if torch.cuda.device_count() > 1:
    optimG = nn.DataParallel(optimG, device_ids=device_ids)
    optimG = optimG.module

# Data
train_set = get_data(args.train_dir, scale_size, img_size)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=8,
                                           pin_memory=True)

# Surrogate model
if args.model_type == 'incv3':
    model = torchvision.models.inception_v3(pretrained=True).to(device)
elif args.model_type == 'res152':
    model = torchvision.models.resnet152(pretrained=True).to(device)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.module
model.eval()

# class
label_set = get_classes(args.label_flag)

# Loss
criterion = nn.CrossEntropyLoss()

# text condition
text_cond_dict = torch.load('text_feature.pth')

# save dir
save_dir = os.path.join(args.save_dir, args.model_type)

# Training
for epoch in range(args.start_epoch, args.epochs):
    running_loss = 0
    for i, (imgs, _) in enumerate(tqdm(train_loader)):
        img = imgs[0].to(device)
        img_rot = rotation(img)[0]

        img_aug = imgs[1].to(device)
        if args.finetune:
            label = np.array([args.finetune_class] * img.size(0))
        else:
            np.random.shuffle(label_set)
            label = np.random.choice(label_set, img.size(0))
        cond = torch.stack([text_cond_dict[j] for j in label], dim=0)
        label = torch.from_numpy(label).long().to(device)
        netG.train()
        optimG.zero_grad()

        # generate img
        noise = netG(input=img, cond=cond, eps=eps)
        noise_rot = netG(input=img_rot, cond=cond, eps=eps)
        noise_aug = netG(input=img_aug, cond=cond, eps=eps)

        if args.finetune:
            noise = get_mask(noise, args.mask_ratio, device)
            noise_rot = get_mask(noise_rot, args.mask_ratio, device)
            noise_aug = get_mask(noise_aug, args.mask_ratio, device)

        adv = noise + img
        adv = torch.clamp(adv, 0.0, 1.0)

        adv_rot = noise_rot + img_rot
        adv_rot = torch.clamp(adv_rot, 0.0, 1.0)

        adv_aug = noise_aug + img_aug
        adv_aug = torch.clamp(adv_aug, 0.0, 1.0)

        adv_out = model(normalize(adv))
        adv_rot_out = model(normalize(adv_rot))
        adv_aug_out = model(normalize(adv_aug))

        loss = criterion(adv_out, label) + criterion(adv_rot_out, label) + criterion(adv_aug_out, label)
        loss.backward()
        optimG.step()

        if i % 10 == 9:
            print('Epoch: {} \t Batch: {}/{} \t loss: {:.5f}'.format(epoch, i, len(train_loader), running_loss / 100))
            running_loss = 0
        running_loss += abs(loss.item())
    if epoch >= args.start_epoch:
        if torch.cuda.device_count() > 1:
            torch.save(netG.module.state_dict(), '{}/model-{}.pth'.format(save_dir, epoch))
        else:
            torch.save(netG.state_dict(), '{}/model-{}.pth'.format(save_dir, epoch))