import argparse import os from attacks.UnivIntruder.utils_.text_templates import imagenet_templates import open_clip parser = argparse.ArgumentParser(description='Text Condition') parser.add_argument('--gpu_id', type=str, default='0', help='GPU id') parser.add_argument('--label_flag', type=str, default='CL', help='label nums: N8, D1,...,D20') parser.add_argument('--save_path', type=str, default='text_feature_10.pth', help='save path') args = parser.parse_args() print(args) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id import clip import torch from utils_ import getImageNetClassIndex, getCIFAR100ClassIndex, get_classes from utils.plot import load_cifar100_classes use_gpu = torch.cuda.is_available() if use_gpu: torch.manual_seed(1) torch.cuda.manual_seed(1) torch.cuda.manual_seed_all(1) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) clip_model = 'ViT-B-32' model, _, clip_preprocess = open_clip.create_model_and_transforms( clip_model, pretrained= 'D:\\Sharing\\Programs\\3-Durable-Adv\\PyCIL-master\\CLIP-ViT-B-32-laion2B-s34B-b79K\\open_clip_pytorch_model.bin', # pretrained, jit=True, device=device, ) model.to(device) tokenizer = open_clip.get_tokenizer(clip_model) label_set = get_classes(args.label_flag) all_classes = load_cifar100_classes() class_str = [all_classes[i] for i in label_set] text_class_features = dict() templates = imagenet_templates for classes in class_str: if all_classes.index(classes) not in text_class_features: template_text = [template.format(classes) for template in templates] tokens = tokenizer(template_text).to(device) text_features = model.encode_text(tokens).detach() text_features = text_features.mean(dim=0) if True: text_features /= text_features.norm(dim=-1, keepdim=True) text_class_features[all_classes.index(classes)] = text_features torch.save(text_class_features, f'text_feature_10.pth') # clip_model, _ = clip.load('ViT-B/16') # clip_model = clip_model.to(device) # class_list = getCIFAR100ClassIndex() # class_text = torch.cat([clip.tokenize(f"a photo of a {class_list[c][1]}") for c in range(len(class_list))]).to(device) # text_embed = clip_model.encode_text(class_text) # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) # label_set = get_classes(args.label_flag) # text_cond_dict = dict() # for label in label_set: # text_cond_dict[label] = text_embed[label] # print(text_cond_dict) # torch.save(text_cond_dict, args.save_path)