|
|
import importlib.util
|
|
|
import sys
|
|
|
import os
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
|
|
|
import argparse
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torchvision import transforms
|
|
|
from PIL import Image
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
import pandas as pd
|
|
|
from torchvision.utils import save_image
|
|
|
from accelerate.utils import set_seed
|
|
|
|
|
|
from utils_img import normalize_vqgan, unnormalize_vqgan, psnr
|
|
|
|
|
|
default_transform = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
normalize_vqgan,
|
|
|
])
|
|
|
|
|
|
class CustomImageDataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, image_dir, transform=None):
|
|
|
self.image_dir = image_dir
|
|
|
|
|
|
self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))])
|
|
|
self.transform = transform
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.image_paths)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
img_path = self.image_paths[idx]
|
|
|
image = Image.open(img_path).convert('RGB')
|
|
|
|
|
|
if self.transform:
|
|
|
image = self.transform(image)
|
|
|
|
|
|
return image, 0
|
|
|
|
|
|
def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4):
|
|
|
"""
|
|
|
Custom dataloader that loads images from a directory without expecting class subfolders.
|
|
|
"""
|
|
|
|
|
|
|
|
|
dataset = CustomImageDataset(data_dir, transform=transform)
|
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
|
|
|
|
return dataloader
|
|
|
|
|
|
def get_parser():
|
|
|
parser = argparse.ArgumentParser(description='StegaStamp Attack')
|
|
|
parser.add_argument('--file_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/models.py', help='Path to the stegastamp models.py file')
|
|
|
parser.add_argument('--encoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth', help='Path to the encoder weights')
|
|
|
parser.add_argument('--decoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth', help='Path to the decoder weights')
|
|
|
parser.add_argument('--data_dir', type=str, default='/pubdata/ldd/projects/sd-lora-wm/smattacks/from_222/smattacks/gen_with_prompt_lawa_test/019-exps/images', help='Path to the dataset')
|
|
|
parser.add_argument('--images_dir', type=str, default='', help='Path to save the images')
|
|
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
|
|
parser.add_argument('--seed', type=int, default=1337, help='Random seed')
|
|
|
return parser
|
|
|
|
|
|
def main(args):
|
|
|
|
|
|
|
|
|
module_dir = os.path.dirname(args.file_path)
|
|
|
sys.path.append(module_dir)
|
|
|
|
|
|
|
|
|
spec = importlib.util.spec_from_file_location("stagastamp_models", args.file_path)
|
|
|
stagastamp_models = importlib.util.module_from_spec(spec)
|
|
|
sys.modules["stagastamp_models"] = stagastamp_models
|
|
|
spec.loader.exec_module(stagastamp_models)
|
|
|
|
|
|
encoder = stagastamp_models.StegaStampEncoder(256, 3, 200, return_residual=False)
|
|
|
decoder = stagastamp_models.StegaStampDecoder(256, 3, 200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoder.load_state_dict(torch.load(args.encoder_path, map_location='cuda'))
|
|
|
decoder.load_state_dict(torch.load(args.decoder_path, map_location='cuda'))
|
|
|
encoder = encoder.to('cuda')
|
|
|
decoder = decoder.to('cuda')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.batch_size = 1
|
|
|
|
|
|
default_transform = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
|
|
|
])
|
|
|
args.seed = 1337
|
|
|
set_seed(args.seed)
|
|
|
|
|
|
def generate_random_fingerprints(fingerprint_length, batch_size=4, size=(400, 400)):
|
|
|
z = torch.zeros((batch_size, fingerprint_length), dtype=torch.float).random_(0, 2)
|
|
|
return z
|
|
|
|
|
|
args.seed = 42
|
|
|
torch.manual_seed(args.seed)
|
|
|
fingerprints = generate_random_fingerprints(200, batch_size=1, size=(256, 3))
|
|
|
|
|
|
|
|
|
ckpt_prefixes = [
|
|
|
"SS_fix_weights",
|
|
|
"SS_dlwt",
|
|
|
"WMA_fix_weights",
|
|
|
"WMA_dlwt",
|
|
|
"LaWa_fix_weights",
|
|
|
"LaWa_dlwt",
|
|
|
"EW-LoRA_fix_weights",
|
|
|
"EW-LoRA_dlwt"
|
|
|
]
|
|
|
|
|
|
for ckpt_prefix in ckpt_prefixes:
|
|
|
dataloader = get_dataloader(os.path.join(args.data_dir, f'save_imgs_' + ckpt_prefix), transform=default_transform, batch_size=args.batch_size)
|
|
|
df = pd.DataFrame(columns=["iteration", "bit_acc_avg"])
|
|
|
|
|
|
bit_accs_avg_list = []
|
|
|
psnr_avg_list = []
|
|
|
|
|
|
for i, (images, _) in enumerate(tqdm(dataloader)):
|
|
|
fingerprints = fingerprints.to('cuda')
|
|
|
images = images.to('cuda')
|
|
|
fingerprinted_images = encoder(fingerprints, images)
|
|
|
decoder_output = decoder(fingerprinted_images)
|
|
|
|
|
|
save_image_path = os.path.join(args.data_dir, f'overwrite_stegastamp_' + ckpt_prefix)
|
|
|
if not os.path.exists(save_image_path):
|
|
|
os.makedirs(save_image_path)
|
|
|
save_image(fingerprinted_images, os.path.join(save_image_path, f'overwrite_img_w_{i:07}.png'))
|
|
|
|
|
|
|
|
|
ori_msgs = torch.sign(fingerprints) > 0
|
|
|
decoded_msgs = torch.sign(decoder_output) > 0
|
|
|
diff = (~torch.logical_xor(ori_msgs, decoded_msgs))
|
|
|
bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1]
|
|
|
bit_accs_avg = torch.mean(bit_accs).item()
|
|
|
|
|
|
psnr_avg = psnr(fingerprinted_images, images).mean().item()
|
|
|
psnr_avg_list.append(psnr_avg)
|
|
|
bit_accs_avg_list.append(bit_accs_avg)
|
|
|
|
|
|
df = df._append({"iteration": i, "bit_acc_avg": bit_accs_avg, "psnr_avg": psnr_avg}, ignore_index=True)
|
|
|
df.to_csv(os.path.join(args.data_dir, f'overwrite_att_' + ckpt_prefix, "bit_acc_stegastamp.csv"), index=False)
|
|
|
|
|
|
overall_avg_bit_accs = sum(bit_accs_avg_list) / len(bit_accs_avg_list)
|
|
|
overall_avg_psnr = sum(psnr_avg_list) / len(psnr_avg_list)
|
|
|
|
|
|
print(f"Model: {ckpt_prefix}, ACC: {overall_avg_bit_accs}, PSNR: {overall_avg_psnr}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
parser = get_parser()
|
|
|
args = parser.parse_args()
|
|
|
main(args) |