EW-LoRA / overwrite_attack /attack_with_stegastamp.py
Donnyll's picture
first commit
658e26c verified
import importlib.util
import sys
import os
import torch
import numpy as np
import argparse
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import pandas as pd
from torchvision.utils import save_image
from accelerate.utils import set_seed
from utils_img import normalize_vqgan, unnormalize_vqgan, psnr
default_transform = transforms.Compose([
transforms.ToTensor(),
normalize_vqgan,
])
class CustomImageDataset(torch.utils.data.Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
# Sort file names to ensure consistent order
self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))])
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, 0
def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4):
"""
Custom dataloader that loads images from a directory without expecting class subfolders.
"""
# Create custom dataset
dataset = CustomImageDataset(data_dir, transform=transform)
# Create the dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return dataloader
def get_parser():
parser = argparse.ArgumentParser(description='StegaStamp Attack')
parser.add_argument('--file_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/models.py', help='Path to the stegastamp models.py file')
parser.add_argument('--encoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth', help='Path to the encoder weights')
parser.add_argument('--decoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth', help='Path to the decoder weights')
parser.add_argument('--data_dir', type=str, default='/pubdata/ldd/projects/sd-lora-wm/smattacks/from_222/smattacks/gen_with_prompt_lawa_test/019-exps/images', help='Path to the dataset')
parser.add_argument('--images_dir', type=str, default='', help='Path to save the images')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--seed', type=int, default=1337, help='Random seed')
return parser
def main(args):
# 指定文件路径
# args.file_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/models.py'
module_dir = os.path.dirname(args.file_path)
sys.path.append(module_dir)
# 加载模块
spec = importlib.util.spec_from_file_location("stagastamp_models", args.file_path)
stagastamp_models = importlib.util.module_from_spec(spec)
sys.modules["stagastamp_models"] = stagastamp_models
spec.loader.exec_module(stagastamp_models)
encoder = stagastamp_models.StegaStampEncoder(256, 3, 200, return_residual=False)
decoder = stagastamp_models.StegaStampDecoder(256, 3, 200)
# args.encoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth'
# args.decoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth'
# Load weights
encoder.load_state_dict(torch.load(args.encoder_path, map_location='cuda'))
decoder.load_state_dict(torch.load(args.decoder_path, map_location='cuda'))
encoder = encoder.to('cuda')
decoder = decoder.to('cuda')
# args.data_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
# args.images_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
# if not os.path.exists(args.images_dir):
# os.makedirs(args.images_dir)
args.batch_size = 1
default_transform = transforms.Compose([
transforms.ToTensor(),
# normalize_vqgan,
])
args.seed = 1337
set_seed(args.seed)
def generate_random_fingerprints(fingerprint_length, batch_size=4, size=(400, 400)):
z = torch.zeros((batch_size, fingerprint_length), dtype=torch.float).random_(0, 2)
return z
args.seed = 42
torch.manual_seed(args.seed)
fingerprints = generate_random_fingerprints(200, batch_size=1, size=(256, 3))
# 定义多个 checkpoint 前缀
ckpt_prefixes = [
"SS_fix_weights",
"SS_dlwt",
"WMA_fix_weights",
"WMA_dlwt",
"LaWa_fix_weights",
"LaWa_dlwt",
"EW-LoRA_fix_weights",
"EW-LoRA_dlwt"
]
for ckpt_prefix in ckpt_prefixes:
dataloader = get_dataloader(os.path.join(args.data_dir, f'save_imgs_' + ckpt_prefix), transform=default_transform, batch_size=args.batch_size)
df = pd.DataFrame(columns=["iteration", "bit_acc_avg"])
bit_accs_avg_list = []
psnr_avg_list = []
for i, (images, _) in enumerate(tqdm(dataloader)):
fingerprints = fingerprints.to('cuda')
images = images.to('cuda')
fingerprinted_images = encoder(fingerprints, images)
decoder_output = decoder(fingerprinted_images)
save_image_path = os.path.join(args.data_dir, f'overwrite_stegastamp_' + ckpt_prefix)
if not os.path.exists(save_image_path):
os.makedirs(save_image_path)
save_image(fingerprinted_images, os.path.join(save_image_path, f'overwrite_img_w_{i:07}.png'))
# msg stats
ori_msgs = torch.sign(fingerprints) > 0
decoded_msgs = torch.sign(decoder_output) > 0 # b k -> b k
diff = (~torch.logical_xor(ori_msgs, decoded_msgs)) # b k -> b k
bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b
bit_accs_avg = torch.mean(bit_accs).item()
psnr_avg = psnr(fingerprinted_images, images).mean().item()
psnr_avg_list.append(psnr_avg)
bit_accs_avg_list.append(bit_accs_avg)
df = df._append({"iteration": i, "bit_acc_avg": bit_accs_avg, "psnr_avg": psnr_avg}, ignore_index=True)
df.to_csv(os.path.join(args.data_dir, f'overwrite_att_' + ckpt_prefix, "bit_acc_stegastamp.csv"), index=False)
overall_avg_bit_accs = sum(bit_accs_avg_list) / len(bit_accs_avg_list)
overall_avg_psnr = sum(psnr_avg_list) / len(psnr_avg_list)
print(f"Model: {ckpt_prefix}, ACC: {overall_avg_bit_accs}, PSNR: {overall_avg_psnr}")
if __name__ == '__main__':
# generate parser / parse parameters
parser = get_parser()
args = parser.parse_args()
main(args)