File size: 7,197 Bytes
658e26c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import importlib.util
import sys
import os
import torch
import numpy as np
import argparse
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import pandas as pd
from torchvision.utils import save_image
from accelerate.utils import set_seed
from utils_img import normalize_vqgan, unnormalize_vqgan, psnr
default_transform = transforms.Compose([
transforms.ToTensor(),
normalize_vqgan,
])
class CustomImageDataset(torch.utils.data.Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
# Sort file names to ensure consistent order
self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))])
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, 0
def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4):
"""
Custom dataloader that loads images from a directory without expecting class subfolders.
"""
# Create custom dataset
dataset = CustomImageDataset(data_dir, transform=transform)
# Create the dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return dataloader
def get_parser():
parser = argparse.ArgumentParser(description='StegaStamp Attack')
parser.add_argument('--file_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/models.py', help='Path to the stegastamp models.py file')
parser.add_argument('--encoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth', help='Path to the encoder weights')
parser.add_argument('--decoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth', help='Path to the decoder weights')
parser.add_argument('--data_dir', type=str, default='/pubdata/ldd/projects/sd-lora-wm/smattacks/from_222/smattacks/gen_with_prompt_lawa_test/019-exps/images', help='Path to the dataset')
parser.add_argument('--images_dir', type=str, default='', help='Path to save the images')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--seed', type=int, default=1337, help='Random seed')
return parser
def main(args):
# 指定文件路径
# args.file_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/models.py'
module_dir = os.path.dirname(args.file_path)
sys.path.append(module_dir)
# 加载模块
spec = importlib.util.spec_from_file_location("stagastamp_models", args.file_path)
stagastamp_models = importlib.util.module_from_spec(spec)
sys.modules["stagastamp_models"] = stagastamp_models
spec.loader.exec_module(stagastamp_models)
encoder = stagastamp_models.StegaStampEncoder(256, 3, 200, return_residual=False)
decoder = stagastamp_models.StegaStampDecoder(256, 3, 200)
# args.encoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth'
# args.decoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth'
# Load weights
encoder.load_state_dict(torch.load(args.encoder_path, map_location='cuda'))
decoder.load_state_dict(torch.load(args.decoder_path, map_location='cuda'))
encoder = encoder.to('cuda')
decoder = decoder.to('cuda')
# args.data_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
# args.images_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
# if not os.path.exists(args.images_dir):
# os.makedirs(args.images_dir)
args.batch_size = 1
default_transform = transforms.Compose([
transforms.ToTensor(),
# normalize_vqgan,
])
args.seed = 1337
set_seed(args.seed)
def generate_random_fingerprints(fingerprint_length, batch_size=4, size=(400, 400)):
z = torch.zeros((batch_size, fingerprint_length), dtype=torch.float).random_(0, 2)
return z
args.seed = 42
torch.manual_seed(args.seed)
fingerprints = generate_random_fingerprints(200, batch_size=1, size=(256, 3))
# 定义多个 checkpoint 前缀
ckpt_prefixes = [
"SS_fix_weights",
"SS_dlwt",
"WMA_fix_weights",
"WMA_dlwt",
"LaWa_fix_weights",
"LaWa_dlwt",
"EW-LoRA_fix_weights",
"EW-LoRA_dlwt"
]
for ckpt_prefix in ckpt_prefixes:
dataloader = get_dataloader(os.path.join(args.data_dir, f'save_imgs_' + ckpt_prefix), transform=default_transform, batch_size=args.batch_size)
df = pd.DataFrame(columns=["iteration", "bit_acc_avg"])
bit_accs_avg_list = []
psnr_avg_list = []
for i, (images, _) in enumerate(tqdm(dataloader)):
fingerprints = fingerprints.to('cuda')
images = images.to('cuda')
fingerprinted_images = encoder(fingerprints, images)
decoder_output = decoder(fingerprinted_images)
save_image_path = os.path.join(args.data_dir, f'overwrite_stegastamp_' + ckpt_prefix)
if not os.path.exists(save_image_path):
os.makedirs(save_image_path)
save_image(fingerprinted_images, os.path.join(save_image_path, f'overwrite_img_w_{i:07}.png'))
# msg stats
ori_msgs = torch.sign(fingerprints) > 0
decoded_msgs = torch.sign(decoder_output) > 0 # b k -> b k
diff = (~torch.logical_xor(ori_msgs, decoded_msgs)) # b k -> b k
bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b
bit_accs_avg = torch.mean(bit_accs).item()
psnr_avg = psnr(fingerprinted_images, images).mean().item()
psnr_avg_list.append(psnr_avg)
bit_accs_avg_list.append(bit_accs_avg)
df = df._append({"iteration": i, "bit_acc_avg": bit_accs_avg, "psnr_avg": psnr_avg}, ignore_index=True)
df.to_csv(os.path.join(args.data_dir, f'overwrite_att_' + ckpt_prefix, "bit_acc_stegastamp.csv"), index=False)
overall_avg_bit_accs = sum(bit_accs_avg_list) / len(bit_accs_avg_list)
overall_avg_psnr = sum(psnr_avg_list) / len(psnr_avg_list)
print(f"Model: {ckpt_prefix}, ACC: {overall_avg_bit_accs}, PSNR: {overall_avg_psnr}")
if __name__ == '__main__':
# generate parser / parse parameters
parser = get_parser()
args = parser.parse_args()
main(args) |