import importlib.util import sys import os import torch import numpy as np import argparse from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image from tqdm import tqdm import pandas as pd from torchvision.utils import save_image from accelerate.utils import set_seed from utils_img import normalize_vqgan, unnormalize_vqgan, psnr default_transform = transforms.Compose([ transforms.ToTensor(), normalize_vqgan, ]) class CustomImageDataset(torch.utils.data.Dataset): def __init__(self, image_dir, transform=None): self.image_dir = image_dir # Sort file names to ensure consistent order self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]) self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, 0 def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4): """ Custom dataloader that loads images from a directory without expecting class subfolders. """ # Create custom dataset dataset = CustomImageDataset(data_dir, transform=transform) # Create the dataloader dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) return dataloader def get_parser(): parser = argparse.ArgumentParser(description='StegaStamp Attack') parser.add_argument('--file_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/models.py', help='Path to the stegastamp models.py file') parser.add_argument('--encoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth', help='Path to the encoder weights') parser.add_argument('--decoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth', help='Path to the decoder weights') parser.add_argument('--data_dir', type=str, default='/pubdata/ldd/projects/sd-lora-wm/smattacks/from_222/smattacks/gen_with_prompt_lawa_test/019-exps/images', help='Path to the dataset') parser.add_argument('--images_dir', type=str, default='', help='Path to save the images') parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument('--seed', type=int, default=1337, help='Random seed') return parser def main(args): # 指定文件路径 # args.file_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/models.py' module_dir = os.path.dirname(args.file_path) sys.path.append(module_dir) # 加载模块 spec = importlib.util.spec_from_file_location("stagastamp_models", args.file_path) stagastamp_models = importlib.util.module_from_spec(spec) sys.modules["stagastamp_models"] = stagastamp_models spec.loader.exec_module(stagastamp_models) encoder = stagastamp_models.StegaStampEncoder(256, 3, 200, return_residual=False) decoder = stagastamp_models.StegaStampDecoder(256, 3, 200) # args.encoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth' # args.decoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth' # Load weights encoder.load_state_dict(torch.load(args.encoder_path, map_location='cuda')) decoder.load_state_dict(torch.load(args.decoder_path, map_location='cuda')) encoder = encoder.to('cuda') decoder = decoder.to('cuda') # args.data_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals' # args.images_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals' # if not os.path.exists(args.images_dir): # os.makedirs(args.images_dir) args.batch_size = 1 default_transform = transforms.Compose([ transforms.ToTensor(), # normalize_vqgan, ]) args.seed = 1337 set_seed(args.seed) def generate_random_fingerprints(fingerprint_length, batch_size=4, size=(400, 400)): z = torch.zeros((batch_size, fingerprint_length), dtype=torch.float).random_(0, 2) return z args.seed = 42 torch.manual_seed(args.seed) fingerprints = generate_random_fingerprints(200, batch_size=1, size=(256, 3)) # 定义多个 checkpoint 前缀 ckpt_prefixes = [ "SS_fix_weights", "SS_dlwt", "WMA_fix_weights", "WMA_dlwt", "LaWa_fix_weights", "LaWa_dlwt", "EW-LoRA_fix_weights", "EW-LoRA_dlwt" ] for ckpt_prefix in ckpt_prefixes: dataloader = get_dataloader(os.path.join(args.data_dir, f'save_imgs_' + ckpt_prefix), transform=default_transform, batch_size=args.batch_size) df = pd.DataFrame(columns=["iteration", "bit_acc_avg"]) bit_accs_avg_list = [] psnr_avg_list = [] for i, (images, _) in enumerate(tqdm(dataloader)): fingerprints = fingerprints.to('cuda') images = images.to('cuda') fingerprinted_images = encoder(fingerprints, images) decoder_output = decoder(fingerprinted_images) save_image_path = os.path.join(args.data_dir, f'overwrite_stegastamp_' + ckpt_prefix) if not os.path.exists(save_image_path): os.makedirs(save_image_path) save_image(fingerprinted_images, os.path.join(save_image_path, f'overwrite_img_w_{i:07}.png')) # msg stats ori_msgs = torch.sign(fingerprints) > 0 decoded_msgs = torch.sign(decoder_output) > 0 # b k -> b k diff = (~torch.logical_xor(ori_msgs, decoded_msgs)) # b k -> b k bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b bit_accs_avg = torch.mean(bit_accs).item() psnr_avg = psnr(fingerprinted_images, images).mean().item() psnr_avg_list.append(psnr_avg) bit_accs_avg_list.append(bit_accs_avg) df = df._append({"iteration": i, "bit_acc_avg": bit_accs_avg, "psnr_avg": psnr_avg}, ignore_index=True) df.to_csv(os.path.join(args.data_dir, f'overwrite_att_' + ckpt_prefix, "bit_acc_stegastamp.csv"), index=False) overall_avg_bit_accs = sum(bit_accs_avg_list) / len(bit_accs_avg_list) overall_avg_psnr = sum(psnr_avg_list) / len(psnr_avg_list) print(f"Model: {ckpt_prefix}, ACC: {overall_avg_bit_accs}, PSNR: {overall_avg_psnr}") if __name__ == '__main__': # generate parser / parse parameters parser = get_parser() args = parser.parse_args() main(args)