from __future__ import print_function import argparse, os, time, random from tqdm import tqdm import torch, torchvision import torch.backends.cudnn as cudnn from torch.cuda.amp import GradScaler, autocast from torch.utils.data import DataLoader from torchvision.datasets import * from modified_clip import clip from models import prompters from models.prompters import TokenPrompter from models.model import * from attacks import * from utils import accuracy, AverageMeter, ProgressMeter, save_checkpoint from utils import cosine_lr, convert_models_to_fp32, refine_classname from utils import load_train_dataset, load_val_datasets, get_text_prompts_train, \ get_text_prompts_val import numpy as np import torch.nn as nn import torch.nn.functional as F """ CUDA_VISIBLE_DEVICES=0,1 python visual_prompt.py --batch_size 256 --dataset ImageNet --add_prompt_size 100 --learning_rate 40 --exp_name VPT_TeCoA --train_eps 1 --train_numsteps 2 --train_stepsize 1 """ def parse_option(): parser = argparse.ArgumentParser('Adapting CLIP for zero-shot adv robustness') parser.add_argument('--print_freq', type=int, default=2000, help='print frequency') parser.add_argument('--save_freq', type=int, default=50, help='save frequency') parser.add_argument('--validate_freq', type=int, default=1, help='validate frequency') parser.add_argument('--batch_size', type=int, default=16, help='batch_size') parser.add_argument('--num_workers', type=int, default=64, help='num of workers to use') parser.add_argument('--epochs', type=int, default=10, help='number of training epochs') # optimization parser.add_argument('--optim', type=str, default='sgd', help='optimizer to use') parser.add_argument('--learning_rate', type=float, default=40, ## Why so large help='learning rate') parser.add_argument("--weight_decay", type=float, default=0, help="weight decay") parser.add_argument("--warmup", type=int, default=1000, help="number of steps to warmup for") parser.add_argument('--momentum', type=float, default=0.9, help='momentum') parser.add_argument('--train_eps', type=float, default=2, help='momentum') parser.add_argument('--train_numsteps', type=int, default=5) parser.add_argument('--train_stepsize', type=int, default=1) parser.add_argument('--test_eps', type=float, default=2, help='momentum') parser.add_argument('--test_numsteps', type=int, default=5) parser.add_argument('--test_stepsize', type=int, default=1) parser.add_argument('--patience', type=int, default=1000) # model parser.add_argument('--model', type=str, default='clip') parser.add_argument('--arch', type=str, default='vit_b32') parser.add_argument('--method', type=str, default='padding', choices=['padding', 'random_patch', 'fixed_patch', 'null_patch'], help='choose visual prompting method') parser.add_argument('--name', type=str, default='') parser.add_argument('--prompt_size', type=int, default=30, help='size for visual prompts --- padding the original image') parser.add_argument('--add_prompt_size', type=int, default=100, help='size for additional visual prompts --- token level prompt') # dataset parser.add_argument('--root', type=str, default='/home/data1/junhao/datasets/', help='dataset') parser.add_argument('--dataset', type=str, default='ImageNet', help='dataset') parser.add_argument('--image_size', type=int, default=224, help='image size') parser.add_argument('--imagenet_root', type=str, default='/home/data1/junhao/datasets/ImageNet/') # other parser.add_argument('--seed', type=int, default=0, help='seed for initializing training') parser.add_argument('--model_dir', type=str, default='../save_ckpts', help='path to save models') parser.add_argument('--image_dir', type=str, default='./save/images', help='path to save images') parser.add_argument('--filename', type=str, default=None, help='filename to save') parser.add_argument('--trial', type=int, default=1, help='number of trials') parser.add_argument('--resume', type=str, default=None, help='path to resume from checkpoint') parser.add_argument('--evaluate', default=False, action="store_true", help='evaluate model test set') parser.add_argument('--gpu', type=int, default=None, help='gpu to use') parser.add_argument('--debug', action='store_true') parser.add_argument('--attack', choices=['pgd', 'CW'], default='pgd') parser.add_argument('--train_class_count', type=int, default=90) parser.add_argument('--noimginprop', action='store_true') parser.add_argument('--exp_name', type=str, default=None) # Extra modules parser.add_argument('--W_inner_CE', type=float, default=0.0, help='Weighting for inner CE for adv gen') parser.add_argument('--W_outer_CE', type=float, default=1.0, help='Weighting for outer CE for network optimization') parser.add_argument('--W_Pred_Align', type=float, default=0.0, help='Prediction alignment between clean and adv logits') parser.add_argument('--W_Nat_CE', type=float, default=0.0, help='Natural classification of clean logit') parser.add_argument('--W_Pred_Align_Ori', type=float, default=0.0, help='Prediction alignment between adv logits to the original clip-clean logits') parser.add_argument('--align_type', type=str, default='KL', help='KL|Nuc|KL_NuAT') args = parser.parse_args() if args.exp_name is not None: args.filename = args.exp_name else: args.filename = '{}_{}_{}_{}_{}_{}_{}_lr_{}_decay_{}_bsz_{}_warmup_{}_trial_{}_addp_{}'. \ format(args.name, args.method, args.prompt_size, args.dataset, args.model, args.arch, args.optim, args.learning_rate, args.weight_decay, args.batch_size, args.warmup, args.trial, args.add_prompt_size) return args best_acc1 = 0 device = "cuda" if torch.cuda.is_available() else "cpu" def train(train_loader, texts, model, prompter, add_prompter, optimizer, scheduler, criterion, scaler, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1], prefix="Epoch: [{}]".format(epoch)) # switch to train mode prompter.train() add_prompter.train() num_batches_per_epoch = len(train_loader) alpha = args.train_stepsize attack_iters = args.train_numsteps end = time.time() for i, (images, target) in enumerate(tqdm(train_loader, ncols = 80)): # measure data loading time data_time.update(time.time() - end) BATCH_SIZE = images.size(0) # print('bs', BATCH_SIZE) # adjust learning rate step = num_batches_per_epoch * epoch + i scheduler(step) optimizer.zero_grad() images = images.to(device) target = target.to(device) text_tokens = clip.tokenize(texts).to(device) # print(images.min(), images.max()) # with automatic mixed precision with autocast(): delta = attack_pgd(prompter, model, add_prompter, criterion, images, target, text_tokens, alpha, attack_iters, 'l_inf', epsilon=args.train_eps) tmp = clip_img_preprocessing(images + delta) prompted_images = prompter(tmp) prompt_token = add_prompter() # for multiple GPU output, _ = multiGPU_CLIP(model, prompted_images, text_tokens, prompt_token) loss = criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # Note: we clamp to 4.6052 = ln(100), as in the original paper. model.module.logit_scale.data = torch.clamp(model.module.logit_scale.data, 0, 4.6052) # measure accuracy acc1 = accuracy(output, target, topk=(1,)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0].item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if args.debug: break if i % args.save_freq == 0: save_checkpoint({ 'epoch': epoch + 1, 'state_dict': prompter.state_dict(), 'add_prompter': add_prompter.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, args) return losses.avg, top1.avg def main(): global best_acc1, device args = parse_option() args.train_eps = args.train_eps / 255. args.test_eps = args.test_eps / 255. args.train_stepsize = args.train_stepsize / 255. args.test_stepsize = args.test_stepsize / 255. if args.resume is not None: args.resume = os.path.join("../save_ckpts", args.resume) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True import socket if socket.gethostname() == 'junhao': args.root = '/home/data1/junhao/datasets/' elif socket.gethostname() == 'ai-planning-p4de-02': args.root = '/data_3/teddy_research/datasets_jh/' # if args.imagenet_root is not None: imagenet_root = os.path.join(args.root, "ImageNet") # create model add_prompt_len = args.add_prompt_size model, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len) convert_models_to_fp32(model) model = torch.nn.DataParallel(model) # .to(device) model.eval() prompter = prompters.__dict__[args.method](args) add_prompter = TokenPrompter(add_prompt_len) prompter = torch.nn.DataParallel(prompter).to(device) add_prompter = torch.nn.DataParallel(add_prompter).to(device) # optionally resume from a checkpoint args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(args.gpu) if 'vision_encoder_state_dict' in checkpoint.keys(): # Load backbone for complementary experiment, # this assume that the finetuned model does not have the following prompts model.module.visual.load_state_dict(checkpoint['vision_encoder_state_dict'], strict=False) else: # load only prompts, not backbone prompter.load_state_dict(checkpoint['state_dict']) add_prompter.load_state_dict(checkpoint['add_prompter']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # create data template = 'This is a photo of a {}' print(f'template: {template}') # load training dataset train_dataset = load_train_dataset(args) # load val dataset(s) if args.evaluate: val_dataset_name = ['cifar10', 'cifar100', 'STL10', 'SUN397', 'StanfordCars', 'Food101', 'oxfordpet', 'flowers102', 'Country211', 'dtd', 'EuroSAT', 'fgvc_aircraft', 'PCAM', 'hateful_memes', 'ImageNet', 'Caltech101', 'Caltech256'] else: val_dataset_name = ['cifar10', 'cifar100', 'dtd', 'EuroSAT',] val_dataset_list = load_val_datasets(args, val_dataset_name) # create dataloaders train_sampler = None val_sampler = None train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=True, sampler=train_sampler) val_loader_list = [DataLoader(each, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers, shuffle=False, sampler=val_sampler) for each in val_dataset_list] # get text prompts for training/val texts_train = get_text_prompts_train(args, train_dataset, template=template) texts_list = get_text_prompts_val(val_dataset_list, val_dataset_name, template=template) # define criterion and optimizer optimizer = torch.optim.SGD(list(prompter.parameters()) + list(add_prompter.parameters()), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion = torch.nn.CrossEntropyLoss().to(device) scaler = GradScaler() total_steps = len(train_loader) * args.epochs scheduler = cosine_lr(optimizer, args.learning_rate, args.warmup, total_steps) cudnn.benchmark = True # make dir refined_template = template.lower().replace(' ', '_') args.filename = f'{args.filename}_template_{refined_template}' args.model_folder = os.path.join(args.model_dir, args.filename) if not os.path.isdir(args.model_folder): os.makedirs(args.model_folder) if args.evaluate: acc1_mean = validate(val_loader_list, val_dataset_name, texts_list, model, prompter, add_prompter, criterion, args) return epochs_since_improvement = 0 for epoch in range(args.epochs): # train for one epoch train(train_loader, texts_train, model, prompter, add_prompter, optimizer, scheduler, criterion, scaler, epoch, args) # evaluate on validation set if epoch % args.validate_freq == 0: acc1_mean = validate(val_loader_list, val_dataset_name, texts_list, model, prompter, add_prompter, criterion, args) # remember best acc@1 and save checkpoint is_best = acc1_mean > best_acc1 best_acc1 = max(acc1_mean, best_acc1) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': prompter.state_dict(), 'add_prompter': add_prompter.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, args, is_best=is_best) if is_best: epochs_since_improvement = 0 else: epochs_since_improvement += 1 print(f"There's no improvement for {epochs_since_improvement} epochs.") if epochs_since_improvement >= args.patience: print("The training halted by early stopping criterion.") break # def validate(val_loader, texts, model, prompter, add_prompter, criterion, args): def validate(val_loader_list, val_dataset_name, texts_list, model, prompter, add_prompter, criterion, args): dataset_num = len(val_loader_list) acc_all = [] test_stepsize = args.test_stepsize for cnt in range(dataset_num): val_loader = val_loader_list[cnt] texts = texts_list[cnt] dataset_name = val_dataset_name[cnt] batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1_org = AverageMeter('Original Acc@1', ':6.2f') top1_prompt = AverageMeter('Prompt Acc@1', ':6.2f') top1_adv_org = AverageMeter('Adv Original Acc@1', ':6.2f') top1_adv_prompt = AverageMeter('Adv Prompt Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1_org, top1_prompt, top1_adv_org, top1_adv_prompt], prefix=dataset_name + '_Validate: ') # switch to evaluation mode prompter.eval() add_prompter.eval() end = time.time() for i, (images, target) in enumerate(tqdm(val_loader, ncols = 80)): if 'cifar' not in val_dataset_name: if i % 20 != 0 and not args.evaluate: continue images = images.to(device) target = target.to(device) text_tokens = clip.tokenize(texts).to(device) with autocast(): # clean images, with prompt and without prompt # compute output with torch.no_grad(): prompt_token = add_prompter() output_prompt, _ = multiGPU_CLIP(model, prompter(clip_img_preprocessing(images)), text_tokens, prompt_token) output_org, _ = multiGPU_CLIP(model, clip_img_preprocessing(images), text_tokens, None) loss = criterion(output_prompt, target) # measure accuracy and record loss acc1 = accuracy(output_prompt, target, topk=(1,)) losses.update(loss.item(), images.size(0)) top1_prompt.update(acc1[0].item(), images.size(0)) acc1 = accuracy(output_org, target, topk=(1,)) top1_org.update(acc1[0].item(), images.size(0)) torch.cuda.empty_cache() # generate adv example if args.attack == 'CW': delta_prompt = attack_CW(prompter, model, add_prompter, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) else: delta_prompt = attack_pgd(prompter, model, add_prompter, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) # compute output torch.cuda.empty_cache() with torch.no_grad(): prompt_token = add_prompter() output_prompt_adv, _ = multiGPU_CLIP(model, prompter(clip_img_preprocessing(images + delta_prompt)), text_tokens, prompt_token) loss = criterion(output_prompt_adv, target) # bl attack torch.cuda.empty_cache() if args.attack == 'CW': delta_noprompt = attack_CW(None, model, None, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) else: delta_noprompt = attack_pgd(None, model, None, criterion, images, target, text_tokens, test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) torch.cuda.empty_cache() with torch.no_grad(): output_org_adv, _ = multiGPU_CLIP(model, clip_img_preprocessing(images + delta_noprompt), text_tokens, None) torch.cuda.empty_cache() # measure accuracy and record loss acc1 = accuracy(output_prompt_adv, target, topk=(1,)) losses.update(loss.item(), images.size(0)) top1_adv_prompt.update(acc1[0].item(), images.size(0)) acc1 = accuracy(output_org_adv, target, topk=(1,)) top1_adv_org.update(acc1[0].item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if args.debug: break torch.cuda.empty_cache() print(dataset_name + ' * Adv Prompt Acc@1 {top1_adv_prompt.avg:.3f} Adv Original Acc@1 {top1_adv_org.avg:.3f} ' '* Prompt Acc@1 {top1_prompt.avg:.3f} Original Acc@1 {top1_org.avg:.3f}' .format(top1_adv_prompt=top1_adv_prompt, top1_adv_org=top1_adv_org, top1_prompt=top1_prompt, top1_org=top1_org)) acc_all.append(top1_adv_prompt.avg) return np.mean(acc_all) if __name__ == '__main__': main()