SAE / attacks /Gaker /Generator /craftadv.py
Ttius's picture
Upload 192 files
998bb30 verified
import os
from typing import Dict
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
from .Generator import Generator
from .gaussian_smoothing import get_gaussian_kernel
import numpy as np
import torch.nn as nn
import random
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
def normalize(t):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
t[:, 0, :, :] = (t[:, 0, :, :] - mean[0]) / std[0]
t[:, 1, :, :] = (t[:, 1, :, :] - mean[1]) / std[1]
t[:, 2, :, :] = (t[:, 2, :, :] - mean[2]) / std[2]
return t
class CustomDataset(Dataset):
def __init__(self, img_dir, target_label,transform=None):
self.img_paths = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)]
self.transform = transform
self.target_label = target_label
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
label = self.target_label
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
img_name = img_path.split('/')[-1]
return img, label ,img_name
class CustomDenseNet121(nn.Module):
def __init__(self, original_model):
super().__init__()
self.features = original_model.features
self.classifier = original_model.classifier
def forward(self, x):
x = self.features(x)
x = F.relu(x, inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
# x = self.classifier(x)
return x
class CustomResnet50(nn.Module):
def __init__(self, original_model):
super().__init__()
self.conv1 = original_model.conv1
self.bn1 = original_model.bn1
self.relu = original_model.relu
self.maxpool = original_model.maxpool
self.layer1 = original_model.layer1
self.layer2 = original_model.layer2
self.layer3 = original_model.layer3
self.layer4 = original_model.layer4
self.avgpool = original_model.avgpool
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def craftadv(modelConfig: Dict):
device = torch.device(modelConfig["device"])
scale_size = 256
img_size = 224
val_transform = transforms.Compose([transforms.Resize(scale_size),transforms.CenterCrop(img_size),transforms.ToTensor()])
target_transform = transforms.Compose([transforms.Resize(scale_size),transforms.CenterCrop(img_size),transforms.ToTensor()])
target_select = modelConfig["target_select"]
set_targets = modelConfig["set_targets"]
unknown = modelConfig["unknown"]
if unknown=='False':
src = "source samples for known classes test"
val_set = torchvision.datasets.ImageFolder(src, transform=val_transform)
elif unknown=='True':
src = "source samples for unknown classes test"
val_set = torchvision.datasets.ImageFolder(src, transform=val_transform)
if set_targets=='targets_200':
targets = [22, 30, 43, 51, 53, 67, 76, 84, 107, 111, 116, 139, 156, 163, 174, 191, 194, 199, 228, 241, 251, 288, 301, 310, 313, 323, 324, 354,393, 398, 399, 401, 405, 418, 419, 420, 422, 428, 429, 439, 441, 451, 455, 457, 465, 467, 478, 480, 481, 488, 489, 490, 493, 496, 498, 499, 500, 507, 508, 514, 515, 519, 523, 530, 532, 533, 539, 540, 550, 552, 553, 557, 565, 566, 575, 576, 579, 583, 588, 592, 593, 594, 599, 601, 604, 605, 606, 607, 608, 611, 614, 622, 627, 640, 644, 646, 647, 659, 660, 666, 668, 674, 678, 683, 684, 687, 688, 691, 694, 700, 704, 712, 714, 715, 722, 726, 729, 738, 739, 740, 741, 749, 751, 761, 766, 769, 772, 773, 783, 785, 789, 790, 793, 794, 796, 798, 800, 807, 815, 822, 825, 826, 831, 843, 844, 851, 853, 854, 855, 858, 860, 862, 863, 869, 876, 877, 879, 880, 884, 888, 891, 897, 898, 901, 903, 904, 908, 910, 912, 914, 916, 918, 919, 924, 925, 927, 931, 932, 933, 934, 937, 938, 943, 946, 950, 952, 954, 958, 959, 961, 963, 971, 974, 977, 979, 980, 984, 985, 995, 996]
elif set_targets=='targets_1000':
targets = list(range(1000))
else:
print('please choose target')
quit()
print(targets)
if unknown=='True':
targets_num = [i for i in range(1000)]
for i in targets:
targets_num.remove(i)
targets = targets_num
elif unknown=='False':
targets = targets
else:
print('please choose True or False')
quit()
print(targets)
if modelConfig["Source_Model"] == "ResNet50":
original_model = torchvision.models.resnet50(pretrained=True)#######固定了模型res50
feature_extraction = CustomResnet50(original_model)
feature_extraction = feature_extraction.eval().to(device)
feature_channel = 2048
source_name = 'resnet50'
elif modelConfig["Source_Model"] == "DenseNet121":
original_model = torchvision.models.densenet121(pretrained=True)
feature_extraction = CustomDenseNet121(original_model)
feature_extraction = feature_extraction.eval().to(device)
feature_channel = 1024
source_name = 'densenet121'
elif modelConfig["Source_Model"] == "vgg19bn":
vgg19bn = torchvision.models.vgg19_bn(pretrained=True).eval().to(device)
feature_channel = 4096
global hook_output
hook_output = None
def hook(module, input, output):
global hook_output
hook_output = output
handle = vgg19bn.classifier[5].register_forward_hook(hook)
source_name = 'vgg19_bn'
generator = Generator( num_target=len(targets), ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],num_res_blocks=modelConfig["num_res_blocks"],feature_channel_num=feature_channel)
ckpt = torch.load(os.path.join(modelConfig["Generator_save_dir"], modelConfig["test_load_weight"]), map_location=device)
generator.load_state_dict(ckpt,strict=False)
print("model load weight done.")
ran_best = modelConfig["ran_best"]
generator.eval().to(device)
eps = 16.0/255
print('eps:',eps*255)
for target in targets:
print('##########################')
target_label = target
numtotargetname = []
with open('imagenet_numtotarget.txt', 'r') as f:
lines = f.readlines()
for line in lines:
parts = line.split()
numtotargetname.append(parts[1])
target_name = numtotargetname[target_label]
source_img_dir = src + target_name
source_set = CustomDataset(source_img_dir, target_label, val_transform)
print('target==>',target,' num:',len(source_set))
source_set = DataLoader(source_set, batch_size=10, shuffle=True, num_workers=12,pin_memory=True)
####################################################################################################
if target_select=='10':
target_img_dir = 'target images folder'+target_name
target_set = DataLoader(CustomDataset(target_img_dir,target_label, target_transform), batch_size=10, shuffle=True, num_workers=12, pin_memory=True)
elif target_select=='1':
target_img_dir = 'target images folder'+target_name
top1datasets = CustomDataset(target_img_dir, target_label, target_transform)
target_samples = []
if unknown=='False':
target_num= 10
elif unknown=='True':
target_num= 10
while len(target_samples) < target_num:
sample = random.choice(top1datasets)
img_name, label,_ = sample
target_samples.append(sample)
print(len(target_samples))
target_set = DataLoader(target_samples, batch_size=10, shuffle=True, num_workers=12, pin_memory=True)
target_iter = iter(target_set)
####################################################################################################
with torch.no_grad():
for imgs,labels,file_name in source_set:
imgs = imgs.to(device)
kernel = get_gaussian_kernel(kernel_size=3, pad=2, sigma=1).cuda()
try:
imgs_target, labels_target, _ = next(target_iter)
except StopIteration:
target_iter = iter(target_set)
imgs_target, labels_target , _ = next(target_iter)
imgs_target=imgs_target.to(device)
if ran_best== 'random':
if modelConfig["Source_Model"] == "vgg19bn":
output = vgg19bn(normalize(imgs_target.clone().detach().to(device)))
target_feature = hook_output
else:
target_feature=feature_extraction(normalize(imgs_target.clone().detach().to(device)))
output_to_mix = target_feature.squeeze()
elif ran_best== 'best':
print('not used')
else:
print('please choose random or best')
perturbated_imgs = generator(imgs, mix=output_to_mix)
perturbated_imgs = kernel(perturbated_imgs)
adv = torch.min(torch.max(perturbated_imgs, imgs-eps), imgs + eps)
adv = torch.clamp(adv, 0, 1.0)
save_adv_path = "..."
os.makedirs(save_adv_path, exist_ok=True)
adv_img = (adv.detach().permute((0,2,3,1)).cpu().numpy() * 255).astype(np.uint8)
for j in range(adv_img.shape[0]):
Image.fromarray(adv_img[j]).save(save_adv_path + "/" + file_name[j])