from __future__ import print_function import argparse import os from tqdm import tqdm import time import random import warnings # import wandb import copy import torch import torch.backends.cudnn as cudnn from torch.cuda.amp import GradScaler, autocast from torch.utils.data import DataLoader, Sampler from torchvision.datasets import StanfordCars, Food101, SUN397, EuroSAT, \ Caltech256, Country211, Flowers102, PCAM, FGVCAircraft from torchvision.datasets import * import torchvision.transforms as transforms import torchvision from modified_clip import clip from models.model import * from models.prompters import TokenPrompter, NullPrompter, PromptLearner from attacks import * from utils import accuracy, AverageMeter, ProgressMeter, save_checkpoint, str2bool from utils import cosine_lr, convert_models_to_fp32, refine_classname from data_utils.autoaugment import ImageNetPolicy import torch.nn.functional as F import numpy as np import torch.nn as nn import functools from autoattack import AutoAttack import ssl ssl._create_default_https_context = ssl._create_unverified_context import matplotlib.pyplot as plt from matplotlib import rc rc('font',family='Arial') from sklearn import manifold,datasets from sklearn.manifold import TSNE """ Tuning Text Prompts (Embeddings) to generate adversarial examples. Default Training Setting: Batch_size=256, Dataset=ImageNet, train_stepsize=1 Default Evaluation Setting: 20-step PGD test_stepsize==1 (img epsilon=1) == (text tototal perturbation=0.01) --------------------------------- eval_type: fast_motivation|motivation CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 1 --save_path TeCoA_eps1 CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 2 --save_path TeCoA_eps2 CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 3 --save_path TeCoA_eps3 CUDA_VISIBLE_DEVICES=0,1 python Visulization_TSNE.py --batch_size 250 --evaluate --resume Source_PT/TeCoAmodel_best.pth.tar --test_eps 4 --save_path TeCoA_eps4 """ def parse_option(): parser = argparse.ArgumentParser('Visual Prompting for CLIP') parser.add_argument('--print_freq', type=int, default=2000, help='print frequency') parser.add_argument('--save_freq', type=int, default=50, help='save frequency') parser.add_argument('--validate_freq', type=int, default=1, help='validate frequency') parser.add_argument('--batch_size', type=int, default=256, help='batch_size') parser.add_argument('--num_workers', type=int, default=32, help='num of workers to use') parser.add_argument('--epochs', type=int, default=10, help='number of training epoch5s') parser.add_argument("--mix_alpha", type=float, default=-1, help="interpolation") # optimization parser.add_argument('--optim', type=str, default='sgd', help='optimizer to use') parser.add_argument('--learning_rate', type=float, default=1e-5, ## Change from 1e-7 to 1e-5 help='learning rate') parser.add_argument("--weight_decay", type=float, default=0, help="weight decay") parser.add_argument("--warmup", type=int, default=1000, help="number of steps to warmup for") parser.add_argument('--momentum', type=float, default=0.9, help='momentum') parser.add_argument('--train_eps', type=float, default=2, help='momentum') parser.add_argument('--train_numsteps', type=int, default=5) parser.add_argument('--train_stepsize', type=int, default=1) parser.add_argument('--test_eps', type=float, default=2, help='momentum') parser.add_argument('--test_numsteps', type=int, default=20) parser.add_argument('--test_stepsize', type=int, default=1) parser.add_argument('--patience', type=int, default=1000) # model parser.add_argument('--model', type=str, default='clip') parser.add_argument('--imagenet_root', type=str, default='temp') parser.add_argument('--arch', type=str, default='vit_b32') parser.add_argument('--method', type=str, default='null_patch', choices=['null_patch'], help='choose visual prompting method') parser.add_argument('--name', type=str, default='') parser.add_argument('--prompt_size', type=int, default=30, help='size for visual prompts') parser.add_argument('--add_prompt_size', type=int, default=0, help='size for additional visual prompts') # dataset parser.add_argument('--root', type=str, default='/home/data1/junhao/datasets/', help='dataset') parser.add_argument('--dataset', type=str, default='ImageNet', help='Pre-training Dataset: cifar10|cifar100|ImageNet') parser.add_argument('--image_size', type=int, default=224, help='image size') # other parser.add_argument('--seed', type=int, default=None, help='seed for initializing training') parser.add_argument('--model_dir', type=str, default='../save_ckpts', help='path to save models') parser.add_argument('--image_dir', type=str, default='./save/images', help='path to save images') parser.add_argument('--filename', type=str, default=None, help='filename to save') parser.add_argument('--trial', type=int, default=1, help='number of trials') parser.add_argument('--gpu', type=int, default=None, help='gpu to use') parser.add_argument('--debug', action='store_true') parser.add_argument('--VPbaseline', action='store_true') parser.add_argument('--CW', action='store_true') parser.add_argument('--autoattack', action='store_true') parser.add_argument('--train_class_count', type=int, default=90) parser.add_argument('--last_num_ft', type=int, default=-1) parser.add_argument('--noimginprop', action='store_true') parser.add_argument('--exp_name', type=str, default=None) # Augmentation parser.add_argument('--aug_type', type=str, default='Vanilla', help='Vanilla|Vanilla_Flip|Resizecrop|Resizecrop_Flip|Autoaug') # Evaluation parser.add_argument('--resume', type=str, default=None, help='path to resume from checkpoint') parser.add_argument('--evaluate', default=False, action="store_true", help='evaluate model test set') parser.add_argument('--eval_type', type=str, default="motivation", help='fast|full') # Text Prompt Tuning parser.add_argument('--adv_prompt_gen', type=str2bool, default="False", help='Whether to conduct adversarial prompt generation') parser.add_argument('--ctx', type=int, default=16, help='number of context vector') parser.add_argument('--ctx_init', type=str, default='This is a photo of a', help='Initialization for context prompt (e.g., (This is a photo of a)|(a photo of a))') parser.add_argument('--position', type=str, default='end', help='CLS prompt position: end|middle|front') parser.add_argument('--text_perb_stepsize', type=float, default=0.001, help='perturbation step size for texts, the perturbation share the same step for adv images') # Extra modules parser.add_argument('--W_Pred_Align', type=float, default=0.0, help='Prediction alignment between clean and adv logits') parser.add_argument('--W_Nat_CE', type=float, default=0.0, help='Natural classification of clean logit') parser.add_argument('--W_Pred_Align_Ori', type=float, default=0.0, help='Prediction alignment between adv logits to the original clip-clean logits') # Motivation modules parser.add_argument('--adv_type', type=str, default="Img_Only", help='Img_Only|Text_Only|Joint') # Visualization parser.add_argument('--save_path', type=str, default="temp", help='save path for TSNE results') args = parser.parse_args() if args.exp_name is not None: args.filename = args.exp_name else: args.filename = '{}_{}_{}_{}_{}_{}_{}_lr_{}_decay_{}_bsz_{}_warmup_{}_trial_{}_addp_{}'. \ format(args.name, args.method, args.prompt_size, args.dataset, args.model, args.arch, args.optim, args.learning_rate, args.weight_decay, args.batch_size, args.warmup, args.trial, args.add_prompt_size) return args class BalancedBatchSampler(Sampler): def __init__(self, dataset): self.dataset = dataset self.labels = np.array([sample[1] for sample in dataset.imgs]) self.labels_set = np.unique(self.labels) self.label_to_indices = {label: np.where(self.labels == label)[0] for label in self.labels_set} self.used_labels_indices = {label: 0 for label in self.labels_set} self.count = len(self.labels) // len(self.labels_set) def __iter__(self): count = self.count for _ in range(count): indices = [] for label in self.labels_set: start = self.used_labels_indices[label] end = (start + 1) % len(self.label_to_indices[label]) indices.append(self.label_to_indices[label][start]) self.used_labels_indices[label] = end #np.random.shuffle(indices) for index in indices: yield index def __len__(self): return self.count * len(self.labels_set) def on_epoch_end(self): self.used_labels_indices = {label: 0 for label in self.labels_set} best_acc1 = 0 device = "cuda" if torch.cuda.is_available() else "cpu" def train(train_loader, texts, model, original_model, prompter, add_prompter, optimizer, scheduler, criterion, scaler, epoch, prompt_learner, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.module.visual.train() num_batches_per_epoch = len(train_loader) alpha = args.train_stepsize attack_iters = args.train_numsteps # print('text token', texts) end = time.time() # original prompter state if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': original_prompter_state = copy.deepcopy(prompt_learner.state_dict()) for i, (images, target) in enumerate(tqdm(train_loader, ncols = 80)): # measure data loading time data_time.update(time.time() - end) BATCH_SIZE = images.size(0) # print('bs', BATCH_SIZE) # adjust learning rate step = num_batches_per_epoch * epoch + i scheduler(step) optimizer.zero_grad() images = images.to(device) target = target.to(device) text_tokens = clip.tokenize(texts).to(device) # print(images.min(), images.max()) # with automatic mixed precision with autocast(): loss_Pred_Align = 0.0 loss_Nat_CE = 0.0 loss_Pred_Align_Ori = 0.0 output_Inat_Tnat = None if not args.VPbaseline: if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': # Reset prompt first prompt_learner.load_state_dict(original_prompter_state) delta = attack_pgd_adv_prompt(prompter, model, add_prompter, criterion, images, target, text_tokens, alpha, attack_iters, 'l_inf', prompt_learner, args.text_perb_stepsize, epsilon=args.train_eps) else: delta = attack_pgd(prompter, model, add_prompter, criterion, images, target, text_tokens, alpha, attack_iters, 'l_inf', epsilon=args.train_eps) # print('delta', delta.min(), delta.max()) tmp = clip_img_preprocessing(images + delta) else: tmp = clip_img_preprocessing(images) prompted_images = prompter(tmp) prompt_token = None if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': output_Iadv_Tnat, _ = multiGPU_CLIP_Text_Prompt_Tuning(model, prompted_images, text_tokens, prompt_token, prompt_learner) else: # Compute logits_image(256, 1000), logits_text(1000, 256) (Image-Text Alignment) output_Iadv_Tnat, _ = multiGPU_CLIP(model, prompted_images, text_tokens, prompt_token) if args.W_Pred_Align > 0.0: criterion_KL = nn.KLDivLoss(reduction='batchmean').to(device) tmp_nat = clip_img_preprocessing(images) prompted_nat_images = prompter(tmp_nat) output_Inat_Tnat, _ = multiGPU_CLIP(model, prompted_nat_images, text_tokens, prompt_token) loss_Pred_Align = criterion_KL(F.log_softmax(output_Iadv_Tnat, dim=1), F.softmax(output_Inat_Tnat, dim=1)) if args.W_Pred_Align_Ori > 0.0: criterion_KL = nn.KLDivLoss(reduction='batchmean').to(device) tmp_nat = clip_img_preprocessing(images) prompted_nat_images = prompter(tmp_nat) with torch.no_grad(): Ori_output_Inat_Tnat, _ = multiGPU_CLIP(original_model, prompted_nat_images, text_tokens, prompt_token) loss_Pred_Align_Ori = criterion_KL(F.log_softmax(output_Iadv_Tnat, dim=1), F.softmax(Ori_output_Inat_Tnat, dim=1)) loss = criterion(output_Iadv_Tnat, target) + args.W_Pred_Align * loss_Pred_Align + args.W_Nat_CE * loss_Nat_CE + args.W_Pred_Align_Ori * loss_Pred_Align_Ori scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # Note: we clamp to 4.6052 = ln(100), as in the original paper. model.module.logit_scale.data = torch.clamp(model.module.logit_scale.data, 0, 4.6052) # measure accuracy acc1 = accuracy(output_Iadv_Tnat, target, topk=(1,)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0].item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and i != 0: progress.display(i) if args.debug: break # break # if args.use_wandb: # wandb.log({ # 'training_loss': losses.avg, # 'training_acc': top1.avg # }) if i % args.save_freq == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': prompter.state_dict(), 'add_prompter': add_prompter.state_dict(), 'vision_encoder_state_dict': model.module.visual.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, args) return losses.avg, top1.avg def main(): global best_acc1, device args = parse_option() args.train_eps = args.train_eps / 255. args.test_eps = args.test_eps / 255. args.train_stepsize = args.train_stepsize / 255. args.test_stepsize = args.test_stepsize / 255. if args.resume is not None: args.resume = os.path.join("../save_ckpts", args.resume) args.save_path = os.path.join("../save_TSNE_Vis", args.save_path) os.makedirs(args.save_path, exist_ok=True) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True else: cudnn.benchmark = True import socket if socket.gethostname() == 'junhao': args.root = '/home/data1/junhao/datasets/' elif socket.gethostname() == 'ai-planning-p4de-02': args.root = '/data_3/teddy_research/datasets_jh/' # if args.imagenet_root is not None: imagenet_root = os.path.join(args.root, "ImageNet") imgnet_full = imagenet_root # create model # add_prompt_len = args.add_prompt_size # No prompts during the inference statge add_prompt_len = 0 model, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len) # model_text, model_image = None, None convert_models_to_fp32(model) model = torch.nn.DataParallel(model) # .to(device) model.eval() original_model = None if args.W_Pred_Align_Ori > 0.0: original_model, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len) convert_models_to_fp32(original_model) original_model = torch.nn.DataParallel(original_model) # .to(device) original_model.eval() ### !!! These two are prompters for the images prompter = NullPrompter() # .to(device) add_prompter = TokenPrompter(add_prompt_len) # .to(device) prompter = torch.nn.DataParallel(prompter).cuda() add_prompter = torch.nn.DataParallel(add_prompter).cuda() # define criterion and optimizer # we finetune the image module parameters only if args.last_num_ft == -1: optimizer = torch.optim.SGD(model.module.visual.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(list(model.module.visual.parameters())[-args.last_num_ft:], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion = torch.nn.CrossEntropyLoss().to(device) args.start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(args.gpu) if args.mix_alpha > 0: alpha = args.mix_alpha # model1, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len) # model2, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len) # model1 = torch.nn.DataParallel(model1) # model2 = torch.nn.DataParallel(model2) checkpoint_ori = torch.load('original_clip.pth.tar') theta_ori = checkpoint_ori['vision_encoder_state_dict'] theta_rob = checkpoint['vision_encoder_state_dict'] theta = { key: (1 - alpha) * theta_ori[key] + alpha * theta_rob[key] for key in theta_ori.keys() } model.module.visual.load_state_dict(theta) else: model.module.visual.load_state_dict(checkpoint['vision_encoder_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) # prompter.load_state_dict(checkpoint['state_dict']) # add_prompter.load_state_dict(checkpoint['add_prompter']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # create data template = 'This is a photo of a {}' print(f'template: {template}') # TODO: we can train on cifar10 and test on cifar10, 100 in zero shot way, to see if generalize. preprocess = transforms.Compose([ # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(15), # TODO: may use later transforms.ToTensor() ]) preprocess224 = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(15), # TODO: may use later transforms.ToTensor() ]) preprocess224_interpolate = transforms.Compose([ transforms.Resize((224, 224)), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(15), # TODO: may use later transforms.ToTensor() ]) ############################ Augmentation ############################ preprocess224_vanilla_flip = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) preprocess224_resizecrop = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.ToTensor() ]) preprocess224_resizecrop_flip = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) preprocess_autoaug = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), ImageNetPolicy(), transforms.ToTensor() ]) # Vanilla|Vanilla_Flip|Resizecrop|Resizecrop_Flip|Autoaug if args.aug_type == 'Vanilla': IN_aug_type = preprocess224 elif args.aug_type == 'Vanilla_Flip': IN_aug_type = preprocess224_vanilla_flip elif args.aug_type == 'Resizecrop': IN_aug_type = preprocess224_resizecrop elif args.aug_type == 'Resizecrop_Flip': IN_aug_type = preprocess224_resizecrop_flip elif args.aug_type == 'Autoaug': IN_aug_type = preprocess_autoaug ############################ Augmentation ############################ if args.dataset == 'cifar100': print('hi') train_dataset = CIFAR100(args.root, transform=preprocess, download=True, train=True) val_dataset = CIFAR100(args.root, transform=preprocess, download=True, train=False) elif args.dataset == 'cifar10': train_dataset = CIFAR10(args.root, transform=preprocess, download=True, train=True) val_dataset = CIFAR10(args.root, transform=preprocess, download=True, train=False) elif args.dataset == 'ImageNet': train_dataset = torchvision.datasets.ImageFolder( os.path.join(imagenet_root, 'train'), transform=IN_aug_type ) val_dataset_list = [] val_dataset_name = ['StanfordCars', 'Food101', 'PCAM', 'cifar100', 'oxfordpet', 'flowers102', 'Country211', 'dtd', 'EuroSAT', 'fgvc_aircraft', 'ImageNet', 'cifar10', 'SUN397'] if args.evaluate: if args.eval_type == 'fast': val_dataset_name = ['ImageNet', 'SUN397', 'Food101', 'flowers102', 'Caltech101', 'Caltech256'] elif args.eval_type == 'full': val_dataset_name = ['ImageNet', 'cifar10', 'STL10', 'cifar100', 'SUN397', 'StanfordCars', 'Food101', 'oxfordpet', 'flowers102', 'dtd', 'EuroSAT', 'fgvc_aircraft', 'PCAM', 'Caltech101', 'Caltech256'] elif args.eval_type == 'motivation': val_dataset_name = ['ImageNet'] elif args.eval_type == 'fast_motivation': val_dataset_name = ['Caltech101'] else: val_dataset_name = ['cifar10', 'cifar100', 'dtd', 'EuroSAT'] for each in val_dataset_name: if each == 'cifar10': val_dataset_list.append(CIFAR10(args.root, transform=preprocess, download=True, train=False)) elif each == 'cifar100': val_dataset_list.append(CIFAR100(args.root, transform=preprocess, download=True, train=False)) elif each == 'Caltech101': val_dataset_list.append(Caltech101(args.root, target_type='category', transform=preprocess224, download=True)) elif each == 'PCAM': val_dataset_list.append(PCAM(args.root, split='test', transform=preprocess224, download=True)) elif each == 'STL10': val_dataset_list.append(STL10(args.root, split='test', transform=preprocess, download=True)) elif each == 'SUN397': val_dataset_list.append(SUN397(args.root, transform=preprocess224, download=True)) elif each == 'StanfordCars': val_dataset_list.append(StanfordCars(args.root, split='test', transform=preprocess224, download=True)) elif each == 'Food101': val_dataset_list.append(Food101(args.root, split='test', transform=preprocess224, download=True)) elif each == 'oxfordpet': val_dataset_list.append(OxfordIIITPet(args.root, split='test', transform=preprocess224, download=True)) elif each == 'EuroSAT': val_dataset_list.append(EuroSAT(args.root, transform=preprocess224, download=True)) elif each == 'Caltech256': val_dataset_list.append(Caltech256(args.root, transform=preprocess224, download=True)) # elif each == 'FER2013': # val_dataset_list.append(OxfordIIITPet(args.root, split='test', # transform=preprocess224, download=True)) elif each == 'flowers102': val_dataset_list.append(Flowers102(args.root, split='test', transform=preprocess224, download=True)) elif each == 'Country211': val_dataset_list.append(Country211(args.root, split='test', transform=preprocess224, download=True)) elif each == 'dtd': val_dataset_list.append(DTD(args.root, split='test', transform=preprocess224, download=True)) elif each == 'fgvc_aircraft': val_dataset_list.append(FGVCAircraft(args.root, split='test', transform=preprocess224, download=True)) elif each == 'ImageNet': val_dataset_list.append(torchvision.datasets.ImageFolder( os.path.join(imgnet_full, 'val'), transform=preprocess224)) # val_dataset_list.append(torchvision.datasets.ImageNet( # root=imagenet_root, # split='val', # transform=preprocess224)) train_sampler = None val_sampler = None ############################ Subset to simulate the last batch (For test only) ############################ # from torch.utils.data import Subset # class_names = train_dataset.classes # subset_indices = torch.randperm(len(train_dataset))[:143] # temp_train_dataset = Subset(train_dataset, subset_indices) # temp_train_dataset.classes = train_dataset.classes # train_dataset = temp_train_dataset ############################ Subset to simulate the last batch ############################ # Sampler definition class OneImagePerClassSampler(Sampler): def __init__(self, data_source): self.data_source = data_source self.indices_map = {label: np.where(np.array(data_source.targets) == label)[0] for label in set(data_source.targets)} def __iter__(self): batch = [] for indices in self.indices_map.values(): index = np.random.choice(indices) batch.append(index) np.random.shuffle(batch) return iter(batch) def __len__(self): return len(self.indices_map) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=True, sampler=train_sampler) # shuffle need to be True val_loader_list = [DataLoader(each, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False, sampler=BalancedBatchSampler(each)) for each in val_dataset_list] ### serial number (not semantic classes) class_names = train_dataset.classes if args.dataset == 'ImageNet': from utils import load_imagenet_folder2name folder2name = load_imagenet_folder2name('imagenet_classes_names.txt') new_class_names = [] for each in class_names: new_class_names.append(folder2name[each]) class_names = new_class_names # Original class name class_names = refine_classname(class_names) # Context + Class name texts_train = [template.format(label) for label in class_names] ###### Save the original classnames for Text Prompt Tuning training_original_classnames = class_names texts_list = [] val_original_classnames = [] for cnt, each in enumerate(val_dataset_list): if hasattr(each, 'clip_prompts'): texts_tmp = each.clip_prompts else: class_names = each.classes if val_dataset_name[cnt] == 'ImageNet': from utils import load_imagenet_folder2name folder2name = load_imagenet_folder2name('imagenet_classes_names.txt') new_class_names = [] for class_name in class_names: new_class_names.append(folder2name[class_name]) class_names = new_class_names val_original_classnames.append(class_names) class_names = refine_classname(class_names) texts_tmp = [template.format(label) for label in class_names] texts_list.append(texts_tmp) assert len(texts_list) == len(val_dataset_list) scaler = GradScaler() total_steps = len(train_loader) * args.epochs scheduler = cosine_lr(optimizer, args.learning_rate, args.warmup, total_steps) # make dir refined_template = template.lower().replace(' ', '_') # args.filename = f'{args.filename}_template_{refined_template}' args.model_folder = os.path.join(args.model_dir, args.filename) if not os.path.isdir(args.model_folder): os.makedirs(args.model_folder) #################################### Constructing Text Prompter #################################### if not args.evaluate: prompt_learner = None if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': prompt_learner = PromptLearner(args, training_original_classnames, model) prompt_learner = torch.nn.DataParallel(prompt_learner).cuda() # prompter_optim = torch.optim.SGD(prompt_learner, # lr=args.text_perb_stepsize, # momentum=0, # weight_decay=0) #################################### Constructing Text Prompter #################################### ################################################## Evaluation ################################################## if args.evaluate: # !!! Need to create each prompt learner for each dataset acc1_mean = validate(val_loader_list, val_dataset_name, val_original_classnames, texts_list, model, prompter, add_prompter, criterion, args) return ################################################## Evaluation ################################################## epochs_since_improvement = 0 for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, texts_train, model, original_model, prompter, add_prompter, optimizer, scheduler, criterion, scaler, epoch, prompt_learner, args) # evaluate on validation set if epoch % args.validate_freq == 0: acc1_mean = validate(val_loader_list, val_dataset_name, val_original_classnames, texts_list, model, prompter, add_prompter, criterion, args) # remember best acc@1 and save checkpoint is_best = acc1_mean > best_acc1 best_acc1 = max(acc1_mean, best_acc1) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': prompter.state_dict(), 'add_prompter': add_prompter.state_dict(), 'vision_encoder_state_dict': model.module.visual.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, args, is_best=is_best) if is_best: epochs_since_improvement = 0 else: epochs_since_improvement += 1 print(f"There's no improvement for {epochs_since_improvement} epochs.") if epochs_since_improvement >= args.patience: print("The training halted by early stopping criterion.") break # wandb.run.finish() def cov_matrix(x, y, epsilon=1e0): B = x.shape[0] x_mean = x.mean(dim=0, keepdim=True) y_mean = y.mean(dim=0, keepdim=True) x_centered = x - x_mean y_centered = y - y_mean cov_xy = x_centered.t().mm(y_centered) / (B - 1) # 添加对角线扰动 cov_xy += torch.eye(cov_xy.size(0)).to(cov_xy.device) * epsilon return cov_xy def logdet_divergence(cov_matrix): # Compute the trace of the covariance matrix trace = torch.trace(cov_matrix) # Compute the log determinant of the covariance matrix logdet = torch.logdet(cov_matrix) dim = cov_matrix.size(0) # Compute the LogDet Divergence L = trace - logdet - dim return L def compute_mean_and_chol(embeddings): """Compute the mean and Cholesky decomposition of the covariance matrix of the embeddings.""" mean = embeddings.mean(dim=0) xm = embeddings - mean cov = xm.T @ xm / (embeddings.size(0) - 1) # Regularization for numerical stability in Cholesky decomposition cov += 1e-3 * torch.eye(cov.size(0), device=cov.device) chol = torch.linalg.cholesky(cov) return mean, chol def kl_divergence_chol(embeddings0, embeddings1): """Compute KL divergence using Cholesky factors.""" mu0, L0 = compute_mean_and_chol(embeddings0) mu1, L1 = compute_mean_and_chol(embeddings1) L1_inv = torch.linalg.inv(L1) Sigma1_inv = L1_inv.T @ L1_inv # Sigma1_inv * Sigma0 M = Sigma1_inv @ (L0 @ L0.T) trace_term = torch.trace(M) diff = mu1 - mu0 quadratic_term = diff.T @ Sigma1_inv @ diff logdet_Sigma0 = 2 * torch.log(torch.diagonal(L0)).sum() logdet_Sigma1 = 2 * torch.log(torch.diagonal(L1)).sum() logdet_term = logdet_Sigma1 - logdet_Sigma0 kl = 0.5 * (trace_term + quadratic_term - mu0.numel() + logdet_term) return kl # def validate(val_loader, texts, model, prompter, add_prompter, criterion, args): def validate(val_loader_list, val_dataset_name, val_original_classnames, texts_list, model, prompter, add_prompter, criterion, args): dataset_num = len(val_loader_list) acc_all_nat = [] acc_all_adv = [] test_stepsize = args.test_stepsize for cnt in range(dataset_num): val_loader = val_loader_list[cnt] texts = texts_list[cnt] dataset_name = val_dataset_name[cnt] original_classname = val_original_classnames[cnt] ## Results for Motivation # Image-Text cosine sim Cos_IT_sim_nat_all = torch.tensor([], device=device) Cos_IT_sim_adv_all = torch.tensor([], device=device) # Image/Text-level cosine sim Cos_Img_sim_all = torch.tensor([], device=device) Cos_Text_sim_all = torch.tensor([], device=device) # Image/Text Embeddings Nat_Img_emb_all = torch.tensor([], device=device) Nat_Text_emb_all = torch.tensor([], device=device) Adv_Img_emb_all = torch.tensor([], device=device) Adv_Text_emb_all = torch.tensor([], device=device) # Gradient Norm grad_norm_all_nat = torch.tensor([], device=device) grad_norm_all_adv = torch.tensor([], device=device) num_classes = len(val_original_classnames[cnt]) #################################### Constructing Text Prompter #################################### prompt_learner = None if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': prompt_learner = PromptLearner(args, original_classname, model) prompt_learner = torch.nn.DataParallel(prompt_learner).cuda() # prompter_optim = torch.optim.SGD(prompt_learner, # lr=args.text_perb_stepsize, # momentum=0, # weight_decay=0) #################################### Constructing Text Prompter #################################### binary = ['PCAM'] attacks_to_run=['apgd-ce', 'apgd-dlr'] if dataset_name in binary: attacks_to_run=['apgd-ce'] batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1_org = AverageMeter('Original Acc@1', ':6.2f') top1_prompt = AverageMeter('Prompt Acc@1', ':6.2f') top1_adv_org = AverageMeter('Adv Original Acc@1', ':6.2f') top1_adv_prompt = AverageMeter('Adv Prompt Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1_org, top1_adv_org], prefix=dataset_name + '_Validate: ') # switch to evaluation mode prompter.eval() add_prompter.eval() model.eval() # original prompter state if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': original_prompter_state = copy.deepcopy(prompt_learner.state_dict()) end = time.time() print("len(val_loader)", len(val_loader)) Nat_Img_Emb_Full = torch.tensor([]) Adv_Img_Emb_Full = torch.tensor([]) Nat_Text_Emb_Full = torch.tensor([]) Adv_Text_Emb_Full = torch.tensor([]) for i, (images, target) in enumerate(tqdm(val_loader, ncols = 80)): if i % 4 == 0: # batch_size=100 only Nat_Img_Emb_Full = torch.tensor([]) Adv_Img_Emb_Full = torch.tensor([]) Nat_Text_Emb_Full = torch.tensor([]) Adv_Text_Emb_Full = torch.tensor([]) if 'cifar' not in val_dataset_name: if i % 20 != 0 and not args.evaluate: continue images = images.to(device) target = target.to(device) text_tokens = clip.tokenize(texts).to(device) with autocast(): # compute output # with torch.no_grad(): # prompt_token = add_prompter() with torch.no_grad(): prompt_token = None # output_prompt, _ = model(prompter(clip_img_preprocessing(images)), text_tokens, prompt_token) images.requires_grad = True output_prompt, _, nat_img_emb, nat_scaled_text_emb = multiGPU_CLIP(model, prompter(clip_img_preprocessing(images)), text_tokens, prompt_token, is_embedding=True) nat_scaled_text_emb = nat_scaled_text_emb[target]/model.module.logit_scale.exp() # Gradient norm for natural samples loss = criterion(output_prompt, target) # measure accuracy and record loss acc1 = accuracy(output_prompt, target, topk=(1,)) losses.update(loss.item(), images.size(0)) # top1_prompt.update(acc1[0].item(), images.size(0)) top1_org.update(acc1[0].item(), images.size(0)) torch.cuda.empty_cache() # generate adv example if args.CW: delta_prompt = attack_CW(prompter, model, add_prompter, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) attacked_images = images + delta_prompt elif args.autoattack: attacked_images = attack_auto(model, images, target, text_tokens, None, None, epsilon=args.test_eps, attacks_to_run=attacks_to_run) else: if args.adv_type == 'Joint': prompt_learner.load_state_dict(original_prompter_state) delta_prompt = attack_pgd_adv_prompt(prompter, model, add_prompter, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', prompt_learner, args.text_perb_stepsize, epsilon=args.test_eps) elif args.adv_type == 'Text_Only': prompt_learner.load_state_dict(original_prompter_state) delta_prompt = attack_pgd_adv_promptONLY(prompter, model, add_prompter, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', prompt_learner, args.text_perb_stepsize, epsilon=args.test_eps) else: delta_prompt = attack_pgd_motivation(prompter, model, add_prompter, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) attacked_images = images + delta_prompt # compute output torch.cuda.empty_cache() # with torch.no_grad(): prompt_token = add_prompter() # output_prompt_adv, _ = model(prompter(clip_img_preprocessing(images + delta_prompt)), text_tokens, prompt_token) if args.adv_type == 'Text_Only' or args.adv_type == 'Joint': output_prompt_adv, _, adv_img_emb, adv_scaled_text_emb = multiGPU_CLIP_Text_Prompt_Tuning(model, prompter(clip_img_preprocessing(attacked_images)), text_tokens, prompt_token, prompt_learner, is_embedding=True) else: output_prompt_adv, _, adv_img_emb, adv_scaled_text_emb = multiGPU_CLIP(model, prompter(clip_img_preprocessing(attacked_images)), text_tokens, prompt_token, is_embedding=True) adv_scaled_text_emb = adv_scaled_text_emb[target]/model.module.logit_scale.exp() # Gradient norm for natural samples loss = criterion(output_prompt_adv, target) # bl attack torch.cuda.empty_cache() # measure accuracy and record loss acc1 = accuracy(output_prompt_adv, target, topk=(1,)) losses.update(loss.item(), images.size(0)) top1_adv_org.update(acc1[0].item(), images.size(0)) # top1_adv_prompt.update(acc1[0].item(), images.size(0)) # acc1 = accuracy(output_org_adv, target, topk=(1,)) # top1_adv_org.update(acc1[0].item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() ############ Ensemble 4 embeddings ############ nat_img_emb = nat_img_emb.detach().cpu() adv_img_emb = adv_img_emb.detach().cpu() nat_scaled_text_emb = nat_scaled_text_emb.detach().cpu() adv_scaled_text_emb = adv_scaled_text_emb.detach().cpu() Nat_Img_Emb_Full = torch.cat((Nat_Img_Emb_Full, nat_img_emb), dim=0) Adv_Img_Emb_Full = torch.cat((Adv_Img_Emb_Full, adv_img_emb), dim=0) Nat_Text_Emb_Full = torch.cat((Nat_Text_Emb_Full, nat_scaled_text_emb), dim=0) Adv_Text_Emb_Full = torch.cat((Adv_Text_Emb_Full, adv_scaled_text_emb), dim=0) # T-SNE Visulization if (i+1) % 4 == 0: ################################### Natural ################################### full_emb = torch.cat((Nat_Img_Emb_Full, Nat_Text_Emb_Full), dim=0) full_emb = full_emb.numpy() tsne = TSNE(n_components=2, perplexity=30.0, early_exaggeration=12.0, \ learning_rate=100, n_iter=2000, random_state=200, verbose=1) result = tsne.fit_transform(full_emb) x_min, x_max = result.min(0), result.max(0) result = (result - x_min) / (x_max - x_min) # Normalization results_img = result[:1000,:] results_text = result[1000:,:] fig = plt.figure() L_color = ['#b0c4de'] * 1000 plt.scatter(results_img[:,0], results_img[:,1], c=L_color, marker='.', alpha=0.7) L_color = ['#f08080'] * 1000 plt.scatter(results_text[:,0], results_text[:,1], c=L_color, marker='.', alpha=0.7) plt.tight_layout() temp_NO = (i+1) // 4 # full_emb_array = full_emb.numpy() npy_path = os.path.join(args.save_path, 'tsne_{}_nat.npy'.format(str(temp_NO))) np.save(npy_path, full_emb) tsne_path = os.path.join(args.save_path, 'tsne_{}_nat.pdf'.format(str(temp_NO))) plt.savefig(tsne_path, dpi=300) tsne_path = os.path.join(args.save_path, 'tsne_{}_nat.png'.format(str(temp_NO))) plt.savefig(tsne_path, dpi=300) tsne_path = os.path.join(args.save_path, 'tsne_{}_nat.svg'.format(str(temp_NO))) plt.savefig(tsne_path, dpi=300) ################################### Natural ################################### ################################### Adversarial ################################### full_emb = torch.cat((Adv_Img_Emb_Full, Adv_Text_Emb_Full), dim=0) full_emb = full_emb.numpy() tsne = TSNE(n_components=2, perplexity=30.0, early_exaggeration=12.0, \ learning_rate=100, n_iter=2000, random_state=200, verbose=1) result = tsne.fit_transform(full_emb) x_min, x_max = result.min(0), result.max(0) result = (result - x_min) / (x_max - x_min) # Normalization results_img = result[:1000,:] results_text = result[1000:,:] fig = plt.figure() L_color = ['#b0c4de'] * 1000 plt.scatter(results_img[:,0], results_img[:,1], c=L_color, marker='.', alpha=0.7) L_color = ['#f08080'] * 1000 plt.scatter(results_text[:,0], results_text[:,1], c=L_color, marker='.', alpha=0.7) plt.tight_layout() temp_NO = (i+1) // 4 # full_emb_array = full_emb.numpy() npy_path = os.path.join(args.save_path, 'tsne_{}_adv.npy'.format(str(temp_NO))) np.save(npy_path, full_emb) tsne_path = os.path.join(args.save_path, 'tsne_{}_adv.pdf'.format(str(temp_NO))) plt.savefig(tsne_path, dpi=300) tsne_path = os.path.join(args.save_path, 'tsne_{}_adv.png'.format(str(temp_NO))) plt.savefig(tsne_path, dpi=300) tsne_path = os.path.join(args.save_path, 'tsne_{}_adv.svg'.format(str(temp_NO))) plt.savefig(tsne_path, dpi=300) ################################### Adversarial ################################### ############ Ensemble 4 embeddings ############ if i % args.print_freq == 0 and i != 0: progress.display(i) if args.debug: break torch.cuda.empty_cache() print("Eps: {} Step: {} Adversarial Type: {}".format(args.test_eps, args.test_numsteps, args.adv_type)) print(dataset_name + '--- Clean Acc.: {top1_org.avg:.2f} Adv Acc.: {top1_adv_org.avg:.2f}.' .format(top1_org=top1_org, top1_adv_org=top1_adv_org)) acc_all_nat.append(top1_org.avg) acc_all_adv.append(top1_adv_org.avg) # if args.use_wandb: # wandb.log({ # 'val_loss': losses.avg, # 'val_acc_prompt': top1_prompt.avg, # 'val_acc_org': top1_org.avg, # }) print('Average on all datasets --- Clean Acc.: {:.2f} Adv Acc.: {:.2f}.' .format(np.mean(acc_all_nat), np.mean(acc_all_adv))) return np.mean(acc_all_adv) if __name__ == '__main__': main()