File size: 7,197 Bytes
658e26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import importlib.util
import sys
import os
import torch
import numpy as np

import argparse
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

from tqdm import tqdm
import pandas as pd
from torchvision.utils import save_image
from accelerate.utils import set_seed

from utils_img import normalize_vqgan, unnormalize_vqgan, psnr

default_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize_vqgan,
])

class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        # Sort file names to ensure consistent order
        self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))])
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, 0

def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4):
    """

    Custom dataloader that loads images from a directory without expecting class subfolders.

    """
    
    # Create custom dataset
    dataset = CustomImageDataset(data_dir, transform=transform)
    
    # Create the dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    return dataloader

def get_parser():
    parser = argparse.ArgumentParser(description='StegaStamp Attack')
    parser.add_argument('--file_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/models.py', help='Path to the stegastamp models.py file')
    parser.add_argument('--encoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth', help='Path to the encoder weights')
    parser.add_argument('--decoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth', help='Path to the decoder weights')
    parser.add_argument('--data_dir', type=str, default='/pubdata/ldd/projects/sd-lora-wm/smattacks/from_222/smattacks/gen_with_prompt_lawa_test/019-exps/images', help='Path to the dataset')
    parser.add_argument('--images_dir', type=str, default='', help='Path to save the images')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('--seed', type=int, default=1337, help='Random seed')
    return parser

def main(args):
    # 指定文件路径
    # args.file_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/models.py'
    module_dir = os.path.dirname(args.file_path)
    sys.path.append(module_dir)

    # 加载模块
    spec = importlib.util.spec_from_file_location("stagastamp_models", args.file_path)
    stagastamp_models = importlib.util.module_from_spec(spec)
    sys.modules["stagastamp_models"] = stagastamp_models
    spec.loader.exec_module(stagastamp_models)

    encoder = stagastamp_models.StegaStampEncoder(256, 3, 200, return_residual=False)
    decoder = stagastamp_models.StegaStampDecoder(256, 3, 200)

    # args.encoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth'
    # args.decoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth'

    # Load weights
    encoder.load_state_dict(torch.load(args.encoder_path, map_location='cuda'))
    decoder.load_state_dict(torch.load(args.decoder_path, map_location='cuda'))
    encoder = encoder.to('cuda')
    decoder = decoder.to('cuda')

    # args.data_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
    # args.images_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
    # if not os.path.exists(args.images_dir):
    #     os.makedirs(args.images_dir)
    args.batch_size = 1

    default_transform = transforms.Compose([
        transforms.ToTensor(),
        # normalize_vqgan,
    ])
    args.seed = 1337
    set_seed(args.seed)

    def generate_random_fingerprints(fingerprint_length, batch_size=4, size=(400, 400)):
        z = torch.zeros((batch_size, fingerprint_length), dtype=torch.float).random_(0, 2)
        return z

    args.seed = 42
    torch.manual_seed(args.seed)
    fingerprints = generate_random_fingerprints(200, batch_size=1, size=(256, 3))

    # 定义多个 checkpoint 前缀
    ckpt_prefixes = [
        "SS_fix_weights", 
        "SS_dlwt", 
        "WMA_fix_weights", 
        "WMA_dlwt", 
        "LaWa_fix_weights", 
        "LaWa_dlwt", 
        "EW-LoRA_fix_weights",
        "EW-LoRA_dlwt"
    ]

    for ckpt_prefix in ckpt_prefixes:
        dataloader = get_dataloader(os.path.join(args.data_dir, f'save_imgs_' + ckpt_prefix), transform=default_transform, batch_size=args.batch_size)
        df = pd.DataFrame(columns=["iteration", "bit_acc_avg"])

        bit_accs_avg_list = []
        psnr_avg_list = []

        for i, (images, _) in enumerate(tqdm(dataloader)):
            fingerprints = fingerprints.to('cuda')
            images = images.to('cuda')
            fingerprinted_images = encoder(fingerprints, images)
            decoder_output = decoder(fingerprinted_images)

            save_image_path = os.path.join(args.data_dir, f'overwrite_stegastamp_' + ckpt_prefix)
            if not os.path.exists(save_image_path):
                os.makedirs(save_image_path)
            save_image(fingerprinted_images, os.path.join(save_image_path, f'overwrite_img_w_{i:07}.png'))

            # msg stats
            ori_msgs = torch.sign(fingerprints) > 0
            decoded_msgs = torch.sign(decoder_output) > 0  # b k -> b k
            diff = (~torch.logical_xor(ori_msgs, decoded_msgs))  # b k -> b k
            bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1]  # b k -> b
            bit_accs_avg = torch.mean(bit_accs).item()

            psnr_avg = psnr(fingerprinted_images, images).mean().item()
            psnr_avg_list.append(psnr_avg)
            bit_accs_avg_list.append(bit_accs_avg)

            df = df._append({"iteration": i, "bit_acc_avg": bit_accs_avg, "psnr_avg": psnr_avg}, ignore_index=True)
            df.to_csv(os.path.join(args.data_dir, f'overwrite_att_' + ckpt_prefix, "bit_acc_stegastamp.csv"), index=False)

        overall_avg_bit_accs = sum(bit_accs_avg_list) / len(bit_accs_avg_list)
        overall_avg_psnr = sum(psnr_avg_list) / len(psnr_avg_list)

        print(f"Model: {ckpt_prefix}, ACC: {overall_avg_bit_accs}, PSNR: {overall_avg_psnr}")

if __name__ == '__main__':
    # generate parser / parse parameters
    parser = get_parser()
    args = parser.parse_args()
    main(args)