File size: 2,684 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import os
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates
import open_clip

parser = argparse.ArgumentParser(description='Text Condition')
parser.add_argument('--gpu_id', type=str, default='0', help='GPU id')
parser.add_argument('--label_flag', type=str, default='CL', help='label nums: N8, D1,...,D20')
parser.add_argument('--save_path', type=str, default='text_feature_10.pth', help='save path')
args = parser.parse_args()
print(args)

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

import clip
import torch
from utils_ import getImageNetClassIndex, getCIFAR100ClassIndex, get_classes
from utils.plot import load_cifar100_classes

use_gpu = torch.cuda.is_available()
if use_gpu:
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

clip_model = 'ViT-B-32'
model, _, clip_preprocess = open_clip.create_model_and_transforms(
            clip_model,
            pretrained= 'D:\\Sharing\\Programs\\3-Durable-Adv\\PyCIL-master\\CLIP-ViT-B-32-laion2B-s34B-b79K\\open_clip_pytorch_model.bin', #  pretrained,
            jit=True,
            device=device,
            )
model.to(device)
tokenizer = open_clip.get_tokenizer(clip_model)
label_set = get_classes(args.label_flag)
all_classes = load_cifar100_classes()
class_str = [all_classes[i] for i in label_set]
text_class_features = dict()
templates = imagenet_templates
for classes in class_str:
    if all_classes.index(classes) not in text_class_features:
        template_text = [template.format(classes) for template in templates]
        tokens = tokenizer(template_text).to(device)
        text_features = model.encode_text(tokens).detach()
        text_features = text_features.mean(dim=0)
        if True:
            text_features /= text_features.norm(dim=-1, keepdim=True)
        text_class_features[all_classes.index(classes)] = text_features
torch.save(text_class_features, f'text_feature_10.pth')

# clip_model, _ = clip.load('ViT-B/16')
# clip_model = clip_model.to(device)
# class_list = getCIFAR100ClassIndex()
# class_text = torch.cat([clip.tokenize(f"a photo of a {class_list[c][1]}") for c in range(len(class_list))]).to(device)
# text_embed = clip_model.encode_text(class_text)
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# label_set = get_classes(args.label_flag)
# text_cond_dict = dict()
# for label in label_set:
#     text_cond_dict[label] = text_embed[label]
# print(text_cond_dict)
# torch.save(text_cond_dict, args.save_path)