|
|
from __future__ import print_function
|
|
|
|
|
|
import argparse
|
|
|
import os
|
|
|
from tqdm import tqdm
|
|
|
import time
|
|
|
import random
|
|
|
import warnings
|
|
|
|
|
|
import copy
|
|
|
|
|
|
import torch
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
|
from torch.utils.data import DataLoader, Sampler
|
|
|
|
|
|
from torchvision.datasets import StanfordCars, Food101, SUN397, EuroSAT, \
|
|
|
Caltech256, Country211, Flowers102, PCAM, FGVCAircraft
|
|
|
|
|
|
from torchvision.datasets import *
|
|
|
|
|
|
import torchvision.transforms as transforms
|
|
|
import torchvision
|
|
|
|
|
|
from modified_clip import clip
|
|
|
|
|
|
from models.model import *
|
|
|
from models.prompters import TokenPrompter, NullPrompter, PromptLearner
|
|
|
from attacks import *
|
|
|
|
|
|
from utils import accuracy, AverageMeter, ProgressMeter, save_checkpoint, str2bool
|
|
|
from utils import cosine_lr, convert_models_to_fp32, refine_classname
|
|
|
|
|
|
from data_utils.autoaugment import ImageNetPolicy
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import functools
|
|
|
from autoattack import AutoAttack
|
|
|
|
|
|
import ssl
|
|
|
ssl._create_default_https_context = ssl._create_unverified_context
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib import rc
|
|
|
rc('font',family='Arial')
|
|
|
from sklearn import manifold,datasets
|
|
|
from sklearn.manifold import TSNE
|
|
|
|
|
|
"""
|
|
|
Tuning Text Prompts (Embeddings) to generate adversarial examples.
|
|
|
|
|
|
Default Training Setting: Batch_size=256, Dataset=ImageNet, train_stepsize=1
|
|
|
|
|
|
Default Evaluation Setting: 20-step PGD test_stepsize==1
|
|
|
|
|
|
(img epsilon=1) == (text tototal perturbation=0.01)
|
|
|
---------------------------------
|
|
|
eval_type: fast_motivation|motivation
|
|
|
|
|
|
|
|
|
|
|
|
CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 1 --save_path TeCoA_eps1
|
|
|
CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 2 --save_path TeCoA_eps2
|
|
|
CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 3 --save_path TeCoA_eps3
|
|
|
CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 4 --save_path TeCoA_eps4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
def parse_option():
|
|
|
parser = argparse.ArgumentParser('Visual Prompting for CLIP')
|
|
|
|
|
|
parser.add_argument('--print_freq', type=int, default=2000,
|
|
|
help='print frequency')
|
|
|
parser.add_argument('--save_freq', type=int, default=50,
|
|
|
help='save frequency')
|
|
|
parser.add_argument('--validate_freq', type=int, default=1,
|
|
|
help='validate frequency')
|
|
|
parser.add_argument('--batch_size', type=int, default=256,
|
|
|
help='batch_size')
|
|
|
parser.add_argument('--num_workers', type=int, default=32,
|
|
|
help='num of workers to use')
|
|
|
parser.add_argument('--epochs', type=int, default=10,
|
|
|
help='number of training epoch5s')
|
|
|
parser.add_argument("--mix_alpha", type=float, default=-1,
|
|
|
help="interpolation")
|
|
|
|
|
|
|
|
|
parser.add_argument('--optim', type=str, default='sgd',
|
|
|
help='optimizer to use')
|
|
|
parser.add_argument('--learning_rate', type=float, default=1e-5,
|
|
|
help='learning rate')
|
|
|
parser.add_argument("--weight_decay", type=float, default=0,
|
|
|
help="weight decay")
|
|
|
parser.add_argument("--warmup", type=int, default=1000,
|
|
|
help="number of steps to warmup for")
|
|
|
parser.add_argument('--momentum', type=float, default=0.9,
|
|
|
help='momentum')
|
|
|
parser.add_argument('--train_eps', type=float, default=2,
|
|
|
help='momentum')
|
|
|
parser.add_argument('--train_numsteps', type=int, default=5)
|
|
|
parser.add_argument('--train_stepsize', type=int, default=1)
|
|
|
parser.add_argument('--test_eps', type=float, default=2,
|
|
|
help='momentum')
|
|
|
parser.add_argument('--test_numsteps', type=int, default=20)
|
|
|
parser.add_argument('--test_stepsize', type=int, default=1)
|
|
|
parser.add_argument('--patience', type=int, default=1000)
|
|
|
|
|
|
|
|
|
parser.add_argument('--model', type=str, default='clip')
|
|
|
parser.add_argument('--imagenet_root', type=str, default='temp')
|
|
|
parser.add_argument('--arch', type=str, default='vit_b32')
|
|
|
parser.add_argument('--method', type=str, default='null_patch',
|
|
|
choices=['null_patch'],
|
|
|
help='choose visual prompting method')
|
|
|
parser.add_argument('--name', type=str, default='')
|
|
|
parser.add_argument('--prompt_size', type=int, default=30,
|
|
|
help='size for visual prompts')
|
|
|
parser.add_argument('--add_prompt_size', type=int, default=0,
|
|
|
help='size for additional visual prompts')
|
|
|
|
|
|
|
|
|
parser.add_argument('--root', type=str, default='/home/data1/junhao/datasets/',
|
|
|
help='dataset')
|
|
|
parser.add_argument('--dataset', type=str, default='ImageNet',
|
|
|
help='Pre-training Dataset: cifar10|cifar100|ImageNet')
|
|
|
parser.add_argument('--image_size', type=int, default=224,
|
|
|
help='image size')
|
|
|
|
|
|
|
|
|
parser.add_argument('--seed', type=int, default=None,
|
|
|
help='seed for initializing training')
|
|
|
parser.add_argument('--model_dir', type=str, default='../save_ckpts',
|
|
|
help='path to save models')
|
|
|
parser.add_argument('--image_dir', type=str, default='./save/images',
|
|
|
help='path to save images')
|
|
|
parser.add_argument('--filename', type=str, default=None,
|
|
|
help='filename to save')
|
|
|
parser.add_argument('--trial', type=int, default=1,
|
|
|
help='number of trials')
|
|
|
parser.add_argument('--gpu', type=int, default=None,
|
|
|
help='gpu to use')
|
|
|
parser.add_argument('--debug', action='store_true')
|
|
|
parser.add_argument('--VPbaseline', action='store_true')
|
|
|
parser.add_argument('--CW', action='store_true')
|
|
|
parser.add_argument('--autoattack', action='store_true')
|
|
|
parser.add_argument('--train_class_count', type=int, default=90)
|
|
|
parser.add_argument('--last_num_ft', type=int, default=-1)
|
|
|
parser.add_argument('--noimginprop', action='store_true')
|
|
|
parser.add_argument('--exp_name', type=str, default=None)
|
|
|
|
|
|
|
|
|
parser.add_argument('--aug_type', type=str, default='Vanilla',
|
|
|
help='Vanilla|Vanilla_Flip|Resizecrop|Resizecrop_Flip|Autoaug')
|
|
|
|
|
|
|
|
|
parser.add_argument('--resume', type=str, default=None,
|
|
|
help='path to resume from checkpoint')
|
|
|
parser.add_argument('--evaluate', default=False,
|
|
|
action="store_true",
|
|
|
help='evaluate model test set')
|
|
|
parser.add_argument('--eval_type', type=str, default="motivation",
|
|
|
help='fast|full')
|
|
|
|
|
|
|
|
|
parser.add_argument('--adv_prompt_gen', type=str2bool, default="False",
|
|
|
help='Whether to conduct adversarial prompt generation')
|
|
|
parser.add_argument('--ctx', type=int, default=16,
|
|
|
help='number of context vector')
|
|
|
parser.add_argument('--ctx_init', type=str, default='This is a photo of a',
|
|
|
help='Initialization for context prompt (e.g., (This is a photo of a)|(a photo of a))')
|
|
|
parser.add_argument('--position', type=str, default='end',
|
|
|
help='CLS prompt position: end|middle|front')
|
|
|
parser.add_argument('--text_perb_stepsize', type=float, default=0.001,
|
|
|
help='perturbation step size for texts, the perturbation share the same step for adv images')
|
|
|
|
|
|
|
|
|
parser.add_argument('--W_Pred_Align', type=float, default=0.0,
|
|
|
help='Prediction alignment between clean and adv logits')
|
|
|
parser.add_argument('--W_Nat_CE', type=float, default=0.0,
|
|
|
help='Natural classification of clean logit')
|
|
|
parser.add_argument('--W_Pred_Align_Ori', type=float, default=0.0,
|
|
|
help='Prediction alignment between adv logits to the original clip-clean logits')
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--adv_type', type=str, default="Img_Only",
|
|
|
help='Img_Only|Text_Only|Joint')
|
|
|
|
|
|
|
|
|
parser.add_argument('--save_path', type=str, default="temp",
|
|
|
help='save path for TSNE results')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.exp_name is not None:
|
|
|
args.filename = args.exp_name
|
|
|
else:
|
|
|
args.filename = '{}_{}_{}_{}_{}_{}_{}_lr_{}_decay_{}_bsz_{}_warmup_{}_trial_{}_addp_{}'. \
|
|
|
format(args.name, args.method, args.prompt_size, args.dataset, args.model, args.arch,
|
|
|
args.optim, args.learning_rate, args.weight_decay, args.batch_size, args.warmup, args.trial,
|
|
|
args.add_prompt_size)
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
class BalancedBatchSampler(Sampler):
|
|
|
def __init__(self, dataset):
|
|
|
self.dataset = dataset
|
|
|
self.labels = np.array([sample[1] for sample in dataset.imgs])
|
|
|
self.labels_set = np.unique(self.labels)
|
|
|
self.label_to_indices = {label: np.where(self.labels == label)[0]
|
|
|
for label in self.labels_set}
|
|
|
self.used_labels_indices = {label: 0 for label in self.labels_set}
|
|
|
self.count = len(self.labels) // len(self.labels_set)
|
|
|
|
|
|
def __iter__(self):
|
|
|
count = self.count
|
|
|
for _ in range(count):
|
|
|
indices = []
|
|
|
for label in self.labels_set:
|
|
|
start = self.used_labels_indices[label]
|
|
|
end = (start + 1) % len(self.label_to_indices[label])
|
|
|
indices.append(self.label_to_indices[label][start])
|
|
|
self.used_labels_indices[label] = end
|
|
|
|
|
|
for index in indices:
|
|
|
yield index
|
|
|
|
|
|
def __len__(self):
|
|
|
return self.count * len(self.labels_set)
|
|
|
|
|
|
def on_epoch_end(self):
|
|
|
self.used_labels_indices = {label: 0 for label in self.labels_set}
|
|
|
|
|
|
|
|
|
|
|
|
best_acc1 = 0
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
def train(train_loader, texts, model, original_model, prompter, add_prompter,
|
|
|
optimizer, scheduler, criterion, scaler, epoch, prompt_learner, args):
|
|
|
batch_time = AverageMeter('Time', ':6.3f')
|
|
|
data_time = AverageMeter('Data', ':6.3f')
|
|
|
losses = AverageMeter('Loss', ':.4e')
|
|
|
top1 = AverageMeter('Acc@1', ':6.2f')
|
|
|
progress = ProgressMeter(
|
|
|
len(train_loader),
|
|
|
[batch_time, data_time, losses, top1],
|
|
|
prefix="Epoch: [{}]".format(epoch))
|
|
|
|
|
|
|
|
|
model.module.visual.train()
|
|
|
|
|
|
num_batches_per_epoch = len(train_loader)
|
|
|
|
|
|
alpha = args.train_stepsize
|
|
|
attack_iters = args.train_numsteps
|
|
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
original_prompter_state = copy.deepcopy(prompt_learner.state_dict())
|
|
|
|
|
|
for i, (images, target) in enumerate(tqdm(train_loader, ncols = 80)):
|
|
|
|
|
|
|
|
|
data_time.update(time.time() - end)
|
|
|
|
|
|
BATCH_SIZE = images.size(0)
|
|
|
|
|
|
|
|
|
|
|
|
step = num_batches_per_epoch * epoch + i
|
|
|
scheduler(step)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
images = images.to(device)
|
|
|
target = target.to(device)
|
|
|
text_tokens = clip.tokenize(texts).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with autocast():
|
|
|
loss_Pred_Align = 0.0
|
|
|
loss_Nat_CE = 0.0
|
|
|
loss_Pred_Align_Ori = 0.0
|
|
|
output_Inat_Tnat = None
|
|
|
if not args.VPbaseline:
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
|
|
|
prompt_learner.load_state_dict(original_prompter_state)
|
|
|
delta = attack_pgd_adv_prompt(prompter, model, add_prompter, criterion, images,
|
|
|
target, text_tokens, alpha, attack_iters, 'l_inf', prompt_learner, args.text_perb_stepsize, epsilon=args.train_eps)
|
|
|
else:
|
|
|
delta = attack_pgd(prompter, model, add_prompter, criterion, images,
|
|
|
target, text_tokens, alpha, attack_iters, 'l_inf', epsilon=args.train_eps)
|
|
|
|
|
|
|
|
|
tmp = clip_img_preprocessing(images + delta)
|
|
|
else:
|
|
|
tmp = clip_img_preprocessing(images)
|
|
|
|
|
|
prompted_images = prompter(tmp)
|
|
|
prompt_token = None
|
|
|
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
output_Iadv_Tnat, _ = multiGPU_CLIP_Text_Prompt_Tuning(model, prompted_images, text_tokens, prompt_token, prompt_learner)
|
|
|
else:
|
|
|
|
|
|
output_Iadv_Tnat, _ = multiGPU_CLIP(model, prompted_images, text_tokens, prompt_token)
|
|
|
|
|
|
if args.W_Pred_Align > 0.0:
|
|
|
criterion_KL = nn.KLDivLoss(reduction='batchmean').to(device)
|
|
|
tmp_nat = clip_img_preprocessing(images)
|
|
|
prompted_nat_images = prompter(tmp_nat)
|
|
|
output_Inat_Tnat, _ = multiGPU_CLIP(model, prompted_nat_images, text_tokens, prompt_token)
|
|
|
loss_Pred_Align = criterion_KL(F.log_softmax(output_Iadv_Tnat, dim=1),
|
|
|
F.softmax(output_Inat_Tnat, dim=1))
|
|
|
if args.W_Pred_Align_Ori > 0.0:
|
|
|
criterion_KL = nn.KLDivLoss(reduction='batchmean').to(device)
|
|
|
tmp_nat = clip_img_preprocessing(images)
|
|
|
prompted_nat_images = prompter(tmp_nat)
|
|
|
with torch.no_grad():
|
|
|
Ori_output_Inat_Tnat, _ = multiGPU_CLIP(original_model, prompted_nat_images, text_tokens, prompt_token)
|
|
|
loss_Pred_Align_Ori = criterion_KL(F.log_softmax(output_Iadv_Tnat, dim=1),
|
|
|
F.softmax(Ori_output_Inat_Tnat, dim=1))
|
|
|
|
|
|
loss = criterion(output_Iadv_Tnat, target) + args.W_Pred_Align * loss_Pred_Align + args.W_Nat_CE * loss_Nat_CE + args.W_Pred_Align_Ori * loss_Pred_Align_Ori
|
|
|
scaler.scale(loss).backward()
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
|
|
|
|
|
|
model.module.logit_scale.data = torch.clamp(model.module.logit_scale.data, 0, 4.6052)
|
|
|
|
|
|
|
|
|
acc1 = accuracy(output_Iadv_Tnat, target, topk=(1,))
|
|
|
losses.update(loss.item(), images.size(0))
|
|
|
top1.update(acc1[0].item(), images.size(0))
|
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
end = time.time()
|
|
|
|
|
|
if i % args.print_freq == 0 and i != 0:
|
|
|
progress.display(i)
|
|
|
if args.debug:
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i % args.save_freq == 0:
|
|
|
save_checkpoint({
|
|
|
'epoch': epoch + 1,
|
|
|
'state_dict': prompter.state_dict(),
|
|
|
'add_prompter': add_prompter.state_dict(),
|
|
|
'vision_encoder_state_dict': model.module.visual.state_dict(),
|
|
|
'best_acc1': best_acc1,
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
}, args)
|
|
|
|
|
|
return losses.avg, top1.avg
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
global best_acc1, device
|
|
|
|
|
|
args = parse_option()
|
|
|
args.train_eps = args.train_eps / 255.
|
|
|
args.test_eps = args.test_eps / 255.
|
|
|
args.train_stepsize = args.train_stepsize / 255.
|
|
|
args.test_stepsize = args.test_stepsize / 255.
|
|
|
|
|
|
if args.resume is not None:
|
|
|
args.resume = os.path.join("../save_ckpts", args.resume)
|
|
|
|
|
|
|
|
|
args.save_path = os.path.join("../save_TSNE_Vis", args.save_path)
|
|
|
os.makedirs(args.save_path, exist_ok=True)
|
|
|
|
|
|
print(args)
|
|
|
|
|
|
if args.seed is not None:
|
|
|
random.seed(args.seed)
|
|
|
torch.manual_seed(args.seed)
|
|
|
cudnn.deterministic = True
|
|
|
else:
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
import socket
|
|
|
if socket.gethostname() == 'junhao':
|
|
|
args.root = '/home/data1/junhao/datasets/'
|
|
|
elif socket.gethostname() == 'ai-planning-p4de-02':
|
|
|
args.root = '/data_3/teddy_research/datasets_jh/'
|
|
|
|
|
|
imagenet_root = os.path.join(args.root, "ImageNet")
|
|
|
|
|
|
imgnet_full = imagenet_root
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add_prompt_len = 0
|
|
|
|
|
|
model, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len)
|
|
|
|
|
|
|
|
|
convert_models_to_fp32(model)
|
|
|
model = torch.nn.DataParallel(model)
|
|
|
model.eval()
|
|
|
|
|
|
original_model = None
|
|
|
if args.W_Pred_Align_Ori > 0.0:
|
|
|
original_model, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len)
|
|
|
convert_models_to_fp32(original_model)
|
|
|
original_model = torch.nn.DataParallel(original_model)
|
|
|
original_model.eval()
|
|
|
|
|
|
|
|
|
prompter = NullPrompter()
|
|
|
add_prompter = TokenPrompter(add_prompt_len)
|
|
|
|
|
|
prompter = torch.nn.DataParallel(prompter).cuda()
|
|
|
add_prompter = torch.nn.DataParallel(add_prompter).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
if args.last_num_ft == -1:
|
|
|
optimizer = torch.optim.SGD(model.module.visual.parameters(),
|
|
|
lr=args.learning_rate,
|
|
|
momentum=args.momentum,
|
|
|
weight_decay=args.weight_decay)
|
|
|
else:
|
|
|
optimizer = torch.optim.SGD(list(model.module.visual.parameters())[-args.last_num_ft:],
|
|
|
lr=args.learning_rate,
|
|
|
momentum=args.momentum,
|
|
|
weight_decay=args.weight_decay)
|
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
|
args.start_epoch = 0
|
|
|
|
|
|
|
|
|
if args.resume:
|
|
|
if os.path.isfile(args.resume):
|
|
|
print("=> loading checkpoint '{}'".format(args.resume))
|
|
|
if args.gpu is None:
|
|
|
checkpoint = torch.load(args.resume)
|
|
|
else:
|
|
|
|
|
|
loc = 'cuda:{}'.format(args.gpu)
|
|
|
checkpoint = torch.load(args.resume, map_location=loc)
|
|
|
args.start_epoch = checkpoint['epoch']
|
|
|
best_acc1 = checkpoint['best_acc1']
|
|
|
if args.gpu is not None:
|
|
|
|
|
|
best_acc1 = best_acc1.to(args.gpu)
|
|
|
|
|
|
if args.mix_alpha > 0:
|
|
|
alpha = args.mix_alpha
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_ori = torch.load('original_clip.pth.tar')
|
|
|
theta_ori = checkpoint_ori['vision_encoder_state_dict']
|
|
|
theta_rob = checkpoint['vision_encoder_state_dict']
|
|
|
|
|
|
theta = {
|
|
|
key: (1 - alpha) * theta_ori[key] + alpha * theta_rob[key]
|
|
|
for key in theta_ori.keys()
|
|
|
}
|
|
|
model.module.visual.load_state_dict(theta)
|
|
|
|
|
|
else:
|
|
|
|
|
|
model.module.visual.load_state_dict(checkpoint['vision_encoder_state_dict'])
|
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
|
|
|
|
|
print("=> loaded checkpoint '{}' (epoch {})"
|
|
|
.format(args.resume, checkpoint['epoch']))
|
|
|
else:
|
|
|
print("=> no checkpoint found at '{}'".format(args.resume))
|
|
|
|
|
|
|
|
|
template = 'This is a photo of a {}'
|
|
|
print(f'template: {template}')
|
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([
|
|
|
|
|
|
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
preprocess224 = transforms.Compose([
|
|
|
transforms.Resize(256),
|
|
|
transforms.CenterCrop(224),
|
|
|
|
|
|
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
preprocess224_interpolate = transforms.Compose([
|
|
|
transforms.Resize((224, 224)),
|
|
|
|
|
|
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
|
|
|
preprocess224_vanilla_flip = transforms.Compose([
|
|
|
transforms.Resize(256),
|
|
|
transforms.CenterCrop(224),
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
preprocess224_resizecrop = transforms.Compose([
|
|
|
transforms.RandomResizedCrop(224),
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
preprocess224_resizecrop_flip = transforms.Compose([
|
|
|
transforms.RandomResizedCrop(224),
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
preprocess_autoaug = transforms.Compose([
|
|
|
transforms.RandomResizedCrop(224),
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
ImageNetPolicy(),
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
|
|
|
if args.aug_type == 'Vanilla':
|
|
|
IN_aug_type = preprocess224
|
|
|
elif args.aug_type == 'Vanilla_Flip':
|
|
|
IN_aug_type = preprocess224_vanilla_flip
|
|
|
elif args.aug_type == 'Resizecrop':
|
|
|
IN_aug_type = preprocess224_resizecrop
|
|
|
elif args.aug_type == 'Resizecrop_Flip':
|
|
|
IN_aug_type = preprocess224_resizecrop_flip
|
|
|
elif args.aug_type == 'Autoaug':
|
|
|
IN_aug_type = preprocess_autoaug
|
|
|
|
|
|
|
|
|
|
|
|
if args.dataset == 'cifar100':
|
|
|
print('hi')
|
|
|
train_dataset = CIFAR100(args.root, transform=preprocess,
|
|
|
download=True, train=True)
|
|
|
|
|
|
val_dataset = CIFAR100(args.root, transform=preprocess,
|
|
|
download=True, train=False)
|
|
|
elif args.dataset == 'cifar10':
|
|
|
train_dataset = CIFAR10(args.root, transform=preprocess,
|
|
|
download=True, train=True)
|
|
|
|
|
|
val_dataset = CIFAR10(args.root, transform=preprocess,
|
|
|
download=True, train=False)
|
|
|
|
|
|
elif args.dataset == 'ImageNet':
|
|
|
train_dataset = torchvision.datasets.ImageFolder(
|
|
|
os.path.join(imagenet_root, 'train'),
|
|
|
transform=IN_aug_type
|
|
|
)
|
|
|
|
|
|
val_dataset_list = []
|
|
|
val_dataset_name = ['StanfordCars', 'Food101', 'PCAM', 'cifar100', 'oxfordpet', 'flowers102',
|
|
|
'Country211', 'dtd', 'EuroSAT', 'fgvc_aircraft', 'ImageNet', 'cifar10', 'SUN397']
|
|
|
|
|
|
if args.evaluate:
|
|
|
if args.eval_type == 'fast':
|
|
|
val_dataset_name = ['ImageNet', 'SUN397', 'Food101', 'flowers102', 'Caltech101', 'Caltech256']
|
|
|
elif args.eval_type == 'full':
|
|
|
val_dataset_name = ['ImageNet', 'cifar10', 'STL10', 'cifar100',
|
|
|
'SUN397', 'StanfordCars', 'Food101', 'oxfordpet',
|
|
|
'flowers102', 'dtd', 'EuroSAT', 'fgvc_aircraft', 'PCAM', 'Caltech101', 'Caltech256']
|
|
|
elif args.eval_type == 'motivation':
|
|
|
val_dataset_name = ['ImageNet']
|
|
|
elif args.eval_type == 'fast_motivation':
|
|
|
val_dataset_name = ['Caltech101']
|
|
|
else:
|
|
|
val_dataset_name = ['cifar10', 'cifar100', 'dtd', 'EuroSAT']
|
|
|
|
|
|
|
|
|
for each in val_dataset_name:
|
|
|
if each == 'cifar10':
|
|
|
val_dataset_list.append(CIFAR10(args.root, transform=preprocess,
|
|
|
download=True, train=False))
|
|
|
elif each == 'cifar100':
|
|
|
val_dataset_list.append(CIFAR100(args.root, transform=preprocess,
|
|
|
download=True, train=False))
|
|
|
elif each == 'Caltech101':
|
|
|
val_dataset_list.append(Caltech101(args.root, target_type='category', transform=preprocess224,
|
|
|
download=True))
|
|
|
elif each == 'PCAM':
|
|
|
val_dataset_list.append(PCAM(args.root, split='test', transform=preprocess224,
|
|
|
download=True))
|
|
|
elif each == 'STL10':
|
|
|
val_dataset_list.append(STL10(args.root, split='test',
|
|
|
transform=preprocess, download=True))
|
|
|
elif each == 'SUN397':
|
|
|
val_dataset_list.append(SUN397(args.root,
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'StanfordCars':
|
|
|
val_dataset_list.append(StanfordCars(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'Food101':
|
|
|
val_dataset_list.append(Food101(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'oxfordpet':
|
|
|
val_dataset_list.append(OxfordIIITPet(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'EuroSAT':
|
|
|
val_dataset_list.append(EuroSAT(args.root,
|
|
|
transform=preprocess224, download=True))
|
|
|
|
|
|
elif each == 'Caltech256':
|
|
|
val_dataset_list.append(Caltech256(args.root, transform=preprocess224,
|
|
|
download=True))
|
|
|
|
|
|
|
|
|
|
|
|
elif each == 'flowers102':
|
|
|
val_dataset_list.append(Flowers102(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'Country211':
|
|
|
val_dataset_list.append(Country211(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'dtd':
|
|
|
val_dataset_list.append(DTD(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
|
|
|
elif each == 'fgvc_aircraft':
|
|
|
val_dataset_list.append(FGVCAircraft(args.root, split='test',
|
|
|
transform=preprocess224, download=True))
|
|
|
elif each == 'ImageNet':
|
|
|
val_dataset_list.append(torchvision.datasets.ImageFolder(
|
|
|
os.path.join(imgnet_full, 'val'),
|
|
|
transform=preprocess224))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_sampler = None
|
|
|
val_sampler = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OneImagePerClassSampler(Sampler):
|
|
|
def __init__(self, data_source):
|
|
|
self.data_source = data_source
|
|
|
self.indices_map = {label: np.where(np.array(data_source.targets) == label)[0] for label in set(data_source.targets)}
|
|
|
|
|
|
def __iter__(self):
|
|
|
batch = []
|
|
|
for indices in self.indices_map.values():
|
|
|
index = np.random.choice(indices)
|
|
|
batch.append(index)
|
|
|
np.random.shuffle(batch)
|
|
|
return iter(batch)
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.indices_map)
|
|
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset,
|
|
|
batch_size=args.batch_size, pin_memory=True,
|
|
|
num_workers=args.num_workers, shuffle=True, sampler=train_sampler)
|
|
|
|
|
|
val_loader_list = [DataLoader(each,
|
|
|
batch_size=args.batch_size, pin_memory=True,
|
|
|
num_workers=args.num_workers, shuffle=False, sampler=BalancedBatchSampler(each)) for each in
|
|
|
val_dataset_list]
|
|
|
|
|
|
|
|
|
class_names = train_dataset.classes
|
|
|
|
|
|
if args.dataset == 'ImageNet':
|
|
|
from utils import load_imagenet_folder2name
|
|
|
folder2name = load_imagenet_folder2name('imagenet_classes_names.txt')
|
|
|
new_class_names = []
|
|
|
for each in class_names:
|
|
|
new_class_names.append(folder2name[each])
|
|
|
|
|
|
class_names = new_class_names
|
|
|
|
|
|
|
|
|
class_names = refine_classname(class_names)
|
|
|
|
|
|
texts_train = [template.format(label) for label in class_names]
|
|
|
|
|
|
|
|
|
|
|
|
training_original_classnames = class_names
|
|
|
|
|
|
texts_list = []
|
|
|
val_original_classnames = []
|
|
|
for cnt, each in enumerate(val_dataset_list):
|
|
|
if hasattr(each, 'clip_prompts'):
|
|
|
texts_tmp = each.clip_prompts
|
|
|
else:
|
|
|
class_names = each.classes
|
|
|
if val_dataset_name[cnt] == 'ImageNet':
|
|
|
from utils import load_imagenet_folder2name
|
|
|
folder2name = load_imagenet_folder2name('imagenet_classes_names.txt')
|
|
|
new_class_names = []
|
|
|
for class_name in class_names:
|
|
|
new_class_names.append(folder2name[class_name])
|
|
|
class_names = new_class_names
|
|
|
|
|
|
val_original_classnames.append(class_names)
|
|
|
|
|
|
class_names = refine_classname(class_names)
|
|
|
texts_tmp = [template.format(label) for label in class_names]
|
|
|
texts_list.append(texts_tmp)
|
|
|
assert len(texts_list) == len(val_dataset_list)
|
|
|
|
|
|
scaler = GradScaler()
|
|
|
total_steps = len(train_loader) * args.epochs
|
|
|
scheduler = cosine_lr(optimizer, args.learning_rate, args.warmup, total_steps)
|
|
|
|
|
|
|
|
|
refined_template = template.lower().replace(' ', '_')
|
|
|
|
|
|
|
|
|
args.model_folder = os.path.join(args.model_dir, args.filename)
|
|
|
if not os.path.isdir(args.model_folder):
|
|
|
os.makedirs(args.model_folder)
|
|
|
|
|
|
|
|
|
if not args.evaluate:
|
|
|
prompt_learner = None
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
prompt_learner = PromptLearner(args, training_original_classnames, model)
|
|
|
prompt_learner = torch.nn.DataParallel(prompt_learner).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.evaluate:
|
|
|
acc1_mean = validate(val_loader_list, val_dataset_name, val_original_classnames, texts_list, model,
|
|
|
prompter, add_prompter, criterion, args)
|
|
|
return
|
|
|
|
|
|
|
|
|
epochs_since_improvement = 0
|
|
|
|
|
|
|
|
|
for epoch in range(args.start_epoch, args.epochs):
|
|
|
|
|
|
|
|
|
train(train_loader, texts_train, model, original_model, prompter, add_prompter,
|
|
|
optimizer, scheduler, criterion, scaler, epoch, prompt_learner, args)
|
|
|
|
|
|
|
|
|
if epoch % args.validate_freq == 0:
|
|
|
acc1_mean = validate(val_loader_list, val_dataset_name, val_original_classnames, texts_list, model,
|
|
|
prompter, add_prompter, criterion, args)
|
|
|
|
|
|
|
|
|
is_best = acc1_mean > best_acc1
|
|
|
best_acc1 = max(acc1_mean, best_acc1)
|
|
|
|
|
|
save_checkpoint({
|
|
|
'epoch': epoch + 1,
|
|
|
'state_dict': prompter.state_dict(),
|
|
|
'add_prompter': add_prompter.state_dict(),
|
|
|
'vision_encoder_state_dict': model.module.visual.state_dict(),
|
|
|
'best_acc1': best_acc1,
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
}, args, is_best=is_best)
|
|
|
|
|
|
if is_best:
|
|
|
epochs_since_improvement = 0
|
|
|
else:
|
|
|
epochs_since_improvement += 1
|
|
|
print(f"There's no improvement for {epochs_since_improvement} epochs.")
|
|
|
|
|
|
if epochs_since_improvement >= args.patience:
|
|
|
print("The training halted by early stopping criterion.")
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cov_matrix(x, y, epsilon=1e0):
|
|
|
B = x.shape[0]
|
|
|
|
|
|
x_mean = x.mean(dim=0, keepdim=True)
|
|
|
y_mean = y.mean(dim=0, keepdim=True)
|
|
|
x_centered = x - x_mean
|
|
|
y_centered = y - y_mean
|
|
|
|
|
|
cov_xy = x_centered.t().mm(y_centered) / (B - 1)
|
|
|
|
|
|
|
|
|
cov_xy += torch.eye(cov_xy.size(0)).to(cov_xy.device) * epsilon
|
|
|
|
|
|
return cov_xy
|
|
|
|
|
|
|
|
|
def logdet_divergence(cov_matrix):
|
|
|
|
|
|
trace = torch.trace(cov_matrix)
|
|
|
|
|
|
|
|
|
logdet = torch.logdet(cov_matrix)
|
|
|
|
|
|
dim = cov_matrix.size(0)
|
|
|
|
|
|
|
|
|
L = trace - logdet - dim
|
|
|
return L
|
|
|
|
|
|
def compute_mean_and_chol(embeddings):
|
|
|
"""Compute the mean and Cholesky decomposition of the covariance matrix of the embeddings."""
|
|
|
mean = embeddings.mean(dim=0)
|
|
|
xm = embeddings - mean
|
|
|
cov = xm.T @ xm / (embeddings.size(0) - 1)
|
|
|
|
|
|
cov += 1e-3 * torch.eye(cov.size(0), device=cov.device)
|
|
|
chol = torch.linalg.cholesky(cov)
|
|
|
return mean, chol
|
|
|
|
|
|
def kl_divergence_chol(embeddings0, embeddings1):
|
|
|
"""Compute KL divergence using Cholesky factors."""
|
|
|
mu0, L0 = compute_mean_and_chol(embeddings0)
|
|
|
mu1, L1 = compute_mean_and_chol(embeddings1)
|
|
|
|
|
|
L1_inv = torch.linalg.inv(L1)
|
|
|
Sigma1_inv = L1_inv.T @ L1_inv
|
|
|
|
|
|
|
|
|
M = Sigma1_inv @ (L0 @ L0.T)
|
|
|
trace_term = torch.trace(M)
|
|
|
|
|
|
diff = mu1 - mu0
|
|
|
quadratic_term = diff.T @ Sigma1_inv @ diff
|
|
|
|
|
|
logdet_Sigma0 = 2 * torch.log(torch.diagonal(L0)).sum()
|
|
|
logdet_Sigma1 = 2 * torch.log(torch.diagonal(L1)).sum()
|
|
|
logdet_term = logdet_Sigma1 - logdet_Sigma0
|
|
|
|
|
|
kl = 0.5 * (trace_term + quadratic_term - mu0.numel() + logdet_term)
|
|
|
return kl
|
|
|
|
|
|
|
|
|
def validate(val_loader_list, val_dataset_name, val_original_classnames, texts_list, model,
|
|
|
prompter, add_prompter, criterion, args):
|
|
|
dataset_num = len(val_loader_list)
|
|
|
acc_all_nat = []
|
|
|
acc_all_adv = []
|
|
|
|
|
|
test_stepsize = args.test_stepsize
|
|
|
|
|
|
for cnt in range(dataset_num):
|
|
|
|
|
|
val_loader = val_loader_list[cnt]
|
|
|
texts = texts_list[cnt]
|
|
|
dataset_name = val_dataset_name[cnt]
|
|
|
original_classname = val_original_classnames[cnt]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Cos_IT_sim_nat_all = torch.tensor([], device=device)
|
|
|
Cos_IT_sim_adv_all = torch.tensor([], device=device)
|
|
|
|
|
|
Cos_Img_sim_all = torch.tensor([], device=device)
|
|
|
Cos_Text_sim_all = torch.tensor([], device=device)
|
|
|
|
|
|
Nat_Img_emb_all = torch.tensor([], device=device)
|
|
|
Nat_Text_emb_all = torch.tensor([], device=device)
|
|
|
Adv_Img_emb_all = torch.tensor([], device=device)
|
|
|
Adv_Text_emb_all = torch.tensor([], device=device)
|
|
|
|
|
|
grad_norm_all_nat = torch.tensor([], device=device)
|
|
|
grad_norm_all_adv = torch.tensor([], device=device)
|
|
|
|
|
|
num_classes = len(val_original_classnames[cnt])
|
|
|
|
|
|
|
|
|
prompt_learner = None
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
prompt_learner = PromptLearner(args, original_classname, model)
|
|
|
prompt_learner = torch.nn.DataParallel(prompt_learner).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
binary = ['PCAM']
|
|
|
attacks_to_run=['apgd-ce', 'apgd-dlr']
|
|
|
if dataset_name in binary:
|
|
|
attacks_to_run=['apgd-ce']
|
|
|
|
|
|
batch_time = AverageMeter('Time', ':6.3f')
|
|
|
losses = AverageMeter('Loss', ':.4e')
|
|
|
top1_org = AverageMeter('Original Acc@1', ':6.2f')
|
|
|
top1_prompt = AverageMeter('Prompt Acc@1', ':6.2f')
|
|
|
top1_adv_org = AverageMeter('Adv Original Acc@1', ':6.2f')
|
|
|
top1_adv_prompt = AverageMeter('Adv Prompt Acc@1', ':6.2f')
|
|
|
|
|
|
progress = ProgressMeter(
|
|
|
len(val_loader),
|
|
|
[batch_time, losses, top1_org, top1_adv_org],
|
|
|
prefix=dataset_name + '_Validate: ')
|
|
|
|
|
|
|
|
|
prompter.eval()
|
|
|
add_prompter.eval()
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
original_prompter_state = copy.deepcopy(prompt_learner.state_dict())
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
|
print("len(val_loader)", len(val_loader))
|
|
|
|
|
|
Nat_Img_Emb_Full = torch.tensor([])
|
|
|
Adv_Img_Emb_Full = torch.tensor([])
|
|
|
Nat_Text_Emb_Full = torch.tensor([])
|
|
|
Adv_Text_Emb_Full = torch.tensor([])
|
|
|
|
|
|
|
|
|
for i, (images, target) in enumerate(tqdm(val_loader, ncols = 80)):
|
|
|
|
|
|
if i % 4 == 0:
|
|
|
Nat_Img_Emb_Full = torch.tensor([])
|
|
|
Adv_Img_Emb_Full = torch.tensor([])
|
|
|
Nat_Text_Emb_Full = torch.tensor([])
|
|
|
Adv_Text_Emb_Full = torch.tensor([])
|
|
|
|
|
|
|
|
|
if 'cifar' not in val_dataset_name:
|
|
|
if i % 20 != 0 and not args.evaluate:
|
|
|
continue
|
|
|
|
|
|
images = images.to(device)
|
|
|
target = target.to(device)
|
|
|
text_tokens = clip.tokenize(texts).to(device)
|
|
|
|
|
|
with autocast():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
prompt_token = None
|
|
|
|
|
|
images.requires_grad = True
|
|
|
output_prompt, _, nat_img_emb, nat_scaled_text_emb = multiGPU_CLIP(model, prompter(clip_img_preprocessing(images)), text_tokens, prompt_token, is_embedding=True)
|
|
|
nat_scaled_text_emb = nat_scaled_text_emb[target]/model.module.logit_scale.exp()
|
|
|
|
|
|
|
|
|
loss = criterion(output_prompt, target)
|
|
|
|
|
|
|
|
|
acc1 = accuracy(output_prompt, target, topk=(1,))
|
|
|
losses.update(loss.item(), images.size(0))
|
|
|
|
|
|
top1_org.update(acc1[0].item(), images.size(0))
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
if args.CW:
|
|
|
delta_prompt = attack_CW(prompter, model, add_prompter, criterion,
|
|
|
images, target, text_tokens,
|
|
|
test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps)
|
|
|
attacked_images = images + delta_prompt
|
|
|
elif args.autoattack:
|
|
|
attacked_images = attack_auto(model, images, target, text_tokens,
|
|
|
None, None, epsilon=args.test_eps, attacks_to_run=attacks_to_run)
|
|
|
else:
|
|
|
if args.adv_type == 'Joint':
|
|
|
prompt_learner.load_state_dict(original_prompter_state)
|
|
|
delta_prompt = attack_pgd_adv_prompt(prompter, model, add_prompter, criterion,
|
|
|
images, target, text_tokens,
|
|
|
test_stepsize, args.test_numsteps, 'l_inf', prompt_learner, args.text_perb_stepsize, epsilon=args.test_eps)
|
|
|
elif args.adv_type == 'Text_Only':
|
|
|
prompt_learner.load_state_dict(original_prompter_state)
|
|
|
delta_prompt = attack_pgd_adv_promptONLY(prompter, model, add_prompter, criterion,
|
|
|
images, target, text_tokens,
|
|
|
test_stepsize, args.test_numsteps, 'l_inf', prompt_learner, args.text_perb_stepsize, epsilon=args.test_eps)
|
|
|
else:
|
|
|
delta_prompt = attack_pgd_motivation(prompter, model, add_prompter, criterion,
|
|
|
images, target, text_tokens,
|
|
|
test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps)
|
|
|
attacked_images = images + delta_prompt
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
prompt_token = add_prompter()
|
|
|
|
|
|
if args.adv_type == 'Text_Only' or args.adv_type == 'Joint':
|
|
|
output_prompt_adv, _, adv_img_emb, adv_scaled_text_emb = multiGPU_CLIP_Text_Prompt_Tuning(model, prompter(clip_img_preprocessing(attacked_images)),
|
|
|
text_tokens, prompt_token, prompt_learner, is_embedding=True)
|
|
|
else:
|
|
|
output_prompt_adv, _, adv_img_emb, adv_scaled_text_emb = multiGPU_CLIP(model, prompter(clip_img_preprocessing(attacked_images)),
|
|
|
text_tokens, prompt_token, is_embedding=True)
|
|
|
|
|
|
adv_scaled_text_emb = adv_scaled_text_emb[target]/model.module.logit_scale.exp()
|
|
|
|
|
|
|
|
|
loss = criterion(output_prompt_adv, target)
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
acc1 = accuracy(output_prompt_adv, target, topk=(1,))
|
|
|
losses.update(loss.item(), images.size(0))
|
|
|
top1_adv_org.update(acc1[0].item(), images.size(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
nat_img_emb = nat_img_emb.detach().cpu()
|
|
|
adv_img_emb = adv_img_emb.detach().cpu()
|
|
|
nat_scaled_text_emb = nat_scaled_text_emb.detach().cpu()
|
|
|
adv_scaled_text_emb = adv_scaled_text_emb.detach().cpu()
|
|
|
|
|
|
Nat_Img_Emb_Full = torch.cat((Nat_Img_Emb_Full, nat_img_emb), dim=0)
|
|
|
Adv_Img_Emb_Full = torch.cat((Adv_Img_Emb_Full, adv_img_emb), dim=0)
|
|
|
Nat_Text_Emb_Full = torch.cat((Nat_Text_Emb_Full, nat_scaled_text_emb), dim=0)
|
|
|
Adv_Text_Emb_Full = torch.cat((Adv_Text_Emb_Full, adv_scaled_text_emb), dim=0)
|
|
|
|
|
|
|
|
|
if (i+1) % 4 == 0:
|
|
|
|
|
|
full_emb = torch.cat((Nat_Img_Emb_Full, Nat_Text_Emb_Full), dim=0)
|
|
|
full_emb = full_emb.numpy()
|
|
|
tsne = TSNE(n_components=2, perplexity=30.0, early_exaggeration=12.0, \
|
|
|
learning_rate=100, n_iter=2000, random_state=200, verbose=1)
|
|
|
result = tsne.fit_transform(full_emb)
|
|
|
x_min, x_max = result.min(0), result.max(0)
|
|
|
result = (result - x_min) / (x_max - x_min)
|
|
|
|
|
|
results_img = result[:1000,:]
|
|
|
results_text = result[1000:,:]
|
|
|
|
|
|
fig = plt.figure()
|
|
|
L_color = ['#b0c4de'] * 1000
|
|
|
plt.scatter(results_img[:,0], results_img[:,1], c=L_color, marker='.', alpha=0.7)
|
|
|
L_color = ['#f08080'] * 1000
|
|
|
plt.scatter(results_text[:,0], results_text[:,1], c=L_color, marker='.', alpha=0.7)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
temp_NO = (i+1) // 4
|
|
|
|
|
|
|
|
|
npy_path = os.path.join(args.save_path, 'tsne_{}_nat.npy'.format(str(temp_NO)))
|
|
|
np.save(npy_path, full_emb)
|
|
|
tsne_path = os.path.join(args.save_path, 'tsne_{}_nat.pdf'.format(str(temp_NO)))
|
|
|
plt.savefig(tsne_path, dpi=300)
|
|
|
tsne_path = os.path.join(args.save_path, 'tsne_{}_nat.png'.format(str(temp_NO)))
|
|
|
plt.savefig(tsne_path, dpi=300)
|
|
|
tsne_path = os.path.join(args.save_path, 'tsne_{}_nat.svg'.format(str(temp_NO)))
|
|
|
plt.savefig(tsne_path, dpi=300)
|
|
|
|
|
|
|
|
|
|
|
|
full_emb = torch.cat((Adv_Img_Emb_Full, Adv_Text_Emb_Full), dim=0)
|
|
|
full_emb = full_emb.numpy()
|
|
|
tsne = TSNE(n_components=2, perplexity=30.0, early_exaggeration=12.0, \
|
|
|
learning_rate=100, n_iter=2000, random_state=200, verbose=1)
|
|
|
result = tsne.fit_transform(full_emb)
|
|
|
x_min, x_max = result.min(0), result.max(0)
|
|
|
result = (result - x_min) / (x_max - x_min)
|
|
|
|
|
|
results_img = result[:1000,:]
|
|
|
results_text = result[1000:,:]
|
|
|
|
|
|
fig = plt.figure()
|
|
|
L_color = ['#b0c4de'] * 1000
|
|
|
plt.scatter(results_img[:,0], results_img[:,1], c=L_color, marker='.', alpha=0.7)
|
|
|
L_color = ['#f08080'] * 1000
|
|
|
plt.scatter(results_text[:,0], results_text[:,1], c=L_color, marker='.', alpha=0.7)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
temp_NO = (i+1) // 4
|
|
|
|
|
|
|
|
|
npy_path = os.path.join(args.save_path, 'tsne_{}_adv.npy'.format(str(temp_NO)))
|
|
|
np.save(npy_path, full_emb)
|
|
|
tsne_path = os.path.join(args.save_path, 'tsne_{}_adv.pdf'.format(str(temp_NO)))
|
|
|
plt.savefig(tsne_path, dpi=300)
|
|
|
tsne_path = os.path.join(args.save_path, 'tsne_{}_adv.png'.format(str(temp_NO)))
|
|
|
plt.savefig(tsne_path, dpi=300)
|
|
|
tsne_path = os.path.join(args.save_path, 'tsne_{}_adv.svg'.format(str(temp_NO)))
|
|
|
plt.savefig(tsne_path, dpi=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i % args.print_freq == 0 and i != 0:
|
|
|
progress.display(i)
|
|
|
if args.debug:
|
|
|
break
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
print("Eps: {} Step: {} Adversarial Type: {}".format(args.test_eps, args.test_numsteps, args.adv_type))
|
|
|
print(dataset_name + '--- Clean Acc.: {top1_org.avg:.2f} Adv Acc.: {top1_adv_org.avg:.2f}.'
|
|
|
.format(top1_org=top1_org, top1_adv_org=top1_adv_org))
|
|
|
|
|
|
acc_all_nat.append(top1_org.avg)
|
|
|
acc_all_adv.append(top1_adv_org.avg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('Average on all datasets --- Clean Acc.: {:.2f} Adv Acc.: {:.2f}.'
|
|
|
.format(np.mean(acc_all_nat), np.mean(acc_all_adv)))
|
|
|
|
|
|
return np.mean(acc_all_adv)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|