|
|
import argparse |
|
|
import os |
|
|
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates |
|
|
import open_clip |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Text Condition') |
|
|
parser.add_argument('--gpu_id', type=str, default='0', help='GPU id') |
|
|
parser.add_argument('--label_flag', type=str, default='CL', help='label nums: N8, D1,...,D20') |
|
|
parser.add_argument('--save_path', type=str, default='text_feature_10.pth', help='save path') |
|
|
args = parser.parse_args() |
|
|
print(args) |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id |
|
|
|
|
|
import clip |
|
|
import torch |
|
|
from utils_ import getImageNetClassIndex, getCIFAR100ClassIndex, get_classes |
|
|
from utils.plot import load_cifar100_classes |
|
|
|
|
|
use_gpu = torch.cuda.is_available() |
|
|
if use_gpu: |
|
|
torch.manual_seed(1) |
|
|
torch.cuda.manual_seed(1) |
|
|
torch.cuda.manual_seed_all(1) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
print(device) |
|
|
|
|
|
clip_model = 'ViT-B-32' |
|
|
model, _, clip_preprocess = open_clip.create_model_and_transforms( |
|
|
clip_model, |
|
|
pretrained= 'D:\\Sharing\\Programs\\3-Durable-Adv\\PyCIL-master\\CLIP-ViT-B-32-laion2B-s34B-b79K\\open_clip_pytorch_model.bin', |
|
|
jit=True, |
|
|
device=device, |
|
|
) |
|
|
model.to(device) |
|
|
tokenizer = open_clip.get_tokenizer(clip_model) |
|
|
label_set = get_classes(args.label_flag) |
|
|
all_classes = load_cifar100_classes() |
|
|
class_str = [all_classes[i] for i in label_set] |
|
|
text_class_features = dict() |
|
|
templates = imagenet_templates |
|
|
for classes in class_str: |
|
|
if all_classes.index(classes) not in text_class_features: |
|
|
template_text = [template.format(classes) for template in templates] |
|
|
tokens = tokenizer(template_text).to(device) |
|
|
text_features = model.encode_text(tokens).detach() |
|
|
text_features = text_features.mean(dim=0) |
|
|
if True: |
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
text_class_features[all_classes.index(classes)] = text_features |
|
|
torch.save(text_class_features, f'text_feature_10.pth') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|