SAE / attacks /CGNC /get_text_feature.py
Ttius's picture
Upload 192 files
998bb30 verified
import argparse
import os
from attacks.UnivIntruder.utils_.text_templates import imagenet_templates
import open_clip
parser = argparse.ArgumentParser(description='Text Condition')
parser.add_argument('--gpu_id', type=str, default='0', help='GPU id')
parser.add_argument('--label_flag', type=str, default='CL', help='label nums: N8, D1,...,D20')
parser.add_argument('--save_path', type=str, default='text_feature_10.pth', help='save path')
args = parser.parse_args()
print(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
import clip
import torch
from utils_ import getImageNetClassIndex, getCIFAR100ClassIndex, get_classes
from utils.plot import load_cifar100_classes
use_gpu = torch.cuda.is_available()
if use_gpu:
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
clip_model = 'ViT-B-32'
model, _, clip_preprocess = open_clip.create_model_and_transforms(
clip_model,
pretrained= 'D:\\Sharing\\Programs\\3-Durable-Adv\\PyCIL-master\\CLIP-ViT-B-32-laion2B-s34B-b79K\\open_clip_pytorch_model.bin', # pretrained,
jit=True,
device=device,
)
model.to(device)
tokenizer = open_clip.get_tokenizer(clip_model)
label_set = get_classes(args.label_flag)
all_classes = load_cifar100_classes()
class_str = [all_classes[i] for i in label_set]
text_class_features = dict()
templates = imagenet_templates
for classes in class_str:
if all_classes.index(classes) not in text_class_features:
template_text = [template.format(classes) for template in templates]
tokens = tokenizer(template_text).to(device)
text_features = model.encode_text(tokens).detach()
text_features = text_features.mean(dim=0)
if True:
text_features /= text_features.norm(dim=-1, keepdim=True)
text_class_features[all_classes.index(classes)] = text_features
torch.save(text_class_features, f'text_feature_10.pth')
# clip_model, _ = clip.load('ViT-B/16')
# clip_model = clip_model.to(device)
# class_list = getCIFAR100ClassIndex()
# class_text = torch.cat([clip.tokenize(f"a photo of a {class_list[c][1]}") for c in range(len(class_list))]).to(device)
# text_embed = clip_model.encode_text(class_text)
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# label_set = get_classes(args.label_flag)
# text_cond_dict = dict()
# for label in label_set:
# text_cond_dict[label] = text_embed[label]
# print(text_cond_dict)
# torch.save(text_cond_dict, args.save_path)