Donnyll commited on
Commit
658e26c
·
verified ·
1 Parent(s): bbd3c90

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +7 -0
  3. README.md +92 -3
  4. diagram.png +3 -0
  5. overwrite_attack/attack_with_stegastamp.py +165 -0
  6. overwrite_attack/utils_img.py +176 -0
  7. requirements.txt +119 -0
  8. watermarker/LaWa/configs/SD14_LaWa.yaml +108 -0
  9. watermarker/LaWa/configs/SD14_LaWa_dlwt.yaml +108 -0
  10. watermarker/LaWa/configs/SD14_LaWa_inference.yaml +57 -0
  11. watermarker/LaWa/configs/SD14_LaWa_inference_dlwt.yaml +58 -0
  12. watermarker/LaWa/configs/SD14_LaWa_ldm.yaml +107 -0
  13. watermarker/LaWa/configs/SD14_LaWa_modified.yaml +103 -0
  14. watermarker/LaWa/dlwt.py +251 -0
  15. watermarker/LaWa/ecc.py +281 -0
  16. watermarker/LaWa/examples/gen_wmimgs_EW-LoRA_dlwt.ipynb +267 -0
  17. watermarker/LaWa/examples/gen_wmimgs_EW-LoRA_fix_weights.ipynb +275 -0
  18. watermarker/LaWa/examples/gen_wmimgs_SS_dlwt.ipynb +225 -0
  19. watermarker/LaWa/examples/gen_wmimgs_SS_fix_weights.ipynb +236 -0
  20. watermarker/LaWa/examples/gen_wmimgs_WMA_dlwt.ipynb +233 -0
  21. watermarker/LaWa/examples/gen_wmimgs_WMA_fix_weights.ipynb +225 -0
  22. watermarker/LaWa/gen_wm_imgs.py +177 -0
  23. watermarker/LaWa/lawa_dataset/train_100k.csv +0 -0
  24. watermarker/LaWa/lawa_dataset/train_200k.csv +0 -0
  25. watermarker/LaWa/lawa_dataset/val_10k.csv +0 -0
  26. watermarker/LaWa/lawa_dataset/val_1k.csv +1001 -0
  27. watermarker/LaWa/ldm/__pycache__/util.cpython-38.pyc +0 -0
  28. watermarker/LaWa/ldm/data/__init__.py +0 -0
  29. watermarker/LaWa/ldm/data/util.py +24 -0
  30. watermarker/LaWa/ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  31. watermarker/LaWa/ldm/models/autoencoder.py +492 -0
  32. watermarker/LaWa/ldm/models/diffusion/__init__.py +0 -0
  33. watermarker/LaWa/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  34. watermarker/LaWa/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  35. watermarker/LaWa/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  36. watermarker/LaWa/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc +0 -0
  37. watermarker/LaWa/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc +0 -0
  38. watermarker/LaWa/ldm/models/diffusion/ddim.py +339 -0
  39. watermarker/LaWa/ldm/models/diffusion/ddpm.py +1798 -0
  40. watermarker/LaWa/ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  41. watermarker/LaWa/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc +0 -0
  42. watermarker/LaWa/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc +0 -0
  43. watermarker/LaWa/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc +0 -0
  44. watermarker/LaWa/ldm/models/diffusion/dpm_solver/dpm_solver.py +1154 -0
  45. watermarker/LaWa/ldm/models/diffusion/dpm_solver/sampler.py +87 -0
  46. watermarker/LaWa/ldm/models/diffusion/plms.py +244 -0
  47. watermarker/LaWa/ldm/models/diffusion/sampling_util.py +22 -0
  48. watermarker/LaWa/ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
  49. watermarker/LaWa/ldm/modules/__pycache__/ema.cpython-38.pyc +0 -0
  50. watermarker/LaWa/ldm/modules/__pycache__/x_transformer.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ diagram.png filter=lfs diff=lfs merge=lfs -text
37
+ watermarker/LaWa/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
38
+ watermarker/LaWa/stable-diffusion/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
39
+ watermarker/stable_signature/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ watermark_attacker
2
+ watermarker/stable_signature/outputs
3
+ watermarker/LaWa/outputs
4
+ experiments
5
+ scripts
6
+ optimizers
7
+ plots
README.md CHANGED
@@ -1,3 +1,92 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # An efficient watermarking method for latent diffusion models via low-rank adaptation
3
+
4
+ Code for our paper "An efficient watermarking method for latent diffusion models via low-rank adaptation".
5
+
6
+ You can download the paper via: [[ArXiv]](https://arxiv.org/abs/2410.20202)
7
+
8
+
9
+ ## 😀Summary
10
+
11
+ A lightweight parameter fine-tuning strategy with low-rank adaptation and dynamic loss weight adjustment enables efficient watermark embedding in large-scale models while minimizing impact on image quality and maintaining robustness.
12
+
13
+ ![image](diagram.png)
14
+
15
+ ## 🍉Requirement
16
+
17
+ ```shell
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ ## 🐬Preparation
22
+
23
+ ### Clone
24
+
25
+ ```shell
26
+ git clone https://github.com/MrDongdongLin/EW-LoRA
27
+ ```
28
+
29
+ ### Create an anaconda environment [Optional]:
30
+
31
+ ```shell
32
+ conda create -n ewlora python==3.8.18
33
+ conda activate ewlora
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ### Prepare the training data:
38
+
39
+ * Download the dataset files [here](https://cocodataset.org/).
40
+ * Extract them to the `data` folder.
41
+ * The directory structure will be as follows:
42
+
43
+ ```shell
44
+ coco2017
45
+ └── train
46
+ ├── img1.jpg
47
+ ├── img2.jpg
48
+ └── img3.jpg
49
+ └── test
50
+ ├── img4.jpg
51
+ ├── img5.jpg
52
+ └── img6.jpg
53
+ ```
54
+
55
+ ### Usage
56
+
57
+ #### Training
58
+
59
+ ```shell
60
+ cd ./watermarker/stable_signature
61
+ CUDA_VISIBLE_DEVICES=0 python train_SS.py --num_keys 1 \
62
+ --train_dir ./Datasets/coco2017/train2017 \
63
+ --val_dir ./Datasets/coco2017/val2017 \
64
+ --ldm_config ./watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml \
65
+ --ldm_ckpt ../models/ldm_ckpts/sd-v1-4-full-ema.ckpt \
66
+ --msg_decoder_path ../models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt \
67
+ --output_dir ./watermarker/stable_signature/outputs/ \
68
+ --task_name train_SS_fix_weights \
69
+ --do_validation \
70
+ --val_frep 50 \
71
+ --batch_size 4 \
72
+ --lambda_i 1.0 --lambda_w 0.2 \
73
+ --steps 20000 --val_size 100 \
74
+ --warmup_steps 20 \
75
+ --save_img_freq 100 \
76
+ --log_freq 1 --debug
77
+ ```
78
+
79
+ ## Citation
80
+
81
+ If this work is helpful, please cite as:
82
+
83
+ ```latex
84
+ @article{linEfficientWatermarkingMethod2024,
85
+ title = {An Efficient Watermarking Method for Latent Diffusion Models via Low-Rank Adaptation},
86
+ author = {Lin, Dongdong and Li, Yue and Tondi, Benedetta and Li, Bin and Barni, Mauro},
87
+ year = {2024},
88
+ month = oct,
89
+ number = {arXiv:2410.20202},
90
+ eprint = {2410.20202},
91
+ }
92
+ ```
diagram.png ADDED

Git LFS Details

  • SHA256: 5238de8ba3840e411e5e3abf475213bbfa19d78dc8df766f2ebf6a8b8645722f
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
overwrite_attack/attack_with_stegastamp.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import sys
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+
7
+ import argparse
8
+ from torch.utils.data import DataLoader
9
+ from torchvision import transforms
10
+ from PIL import Image
11
+
12
+ from tqdm import tqdm
13
+ import pandas as pd
14
+ from torchvision.utils import save_image
15
+ from accelerate.utils import set_seed
16
+
17
+ from utils_img import normalize_vqgan, unnormalize_vqgan, psnr
18
+
19
+ default_transform = transforms.Compose([
20
+ transforms.ToTensor(),
21
+ normalize_vqgan,
22
+ ])
23
+
24
+ class CustomImageDataset(torch.utils.data.Dataset):
25
+ def __init__(self, image_dir, transform=None):
26
+ self.image_dir = image_dir
27
+ # Sort file names to ensure consistent order
28
+ self.image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))])
29
+ self.transform = transform
30
+
31
+ def __len__(self):
32
+ return len(self.image_paths)
33
+
34
+ def __getitem__(self, idx):
35
+ img_path = self.image_paths[idx]
36
+ image = Image.open(img_path).convert('RGB')
37
+
38
+ if self.transform:
39
+ image = self.transform(image)
40
+
41
+ return image, 0
42
+
43
+ def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4):
44
+ """
45
+ Custom dataloader that loads images from a directory without expecting class subfolders.
46
+ """
47
+
48
+ # Create custom dataset
49
+ dataset = CustomImageDataset(data_dir, transform=transform)
50
+
51
+ # Create the dataloader
52
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
53
+
54
+ return dataloader
55
+
56
+ def get_parser():
57
+ parser = argparse.ArgumentParser(description='StegaStamp Attack')
58
+ parser.add_argument('--file_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/models.py', help='Path to the stegastamp models.py file')
59
+ parser.add_argument('--encoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth', help='Path to the encoder weights')
60
+ parser.add_argument('--decoder_path', type=str, default='/pubdata/ldd/models/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth', help='Path to the decoder weights')
61
+ parser.add_argument('--data_dir', type=str, default='/pubdata/ldd/projects/sd-lora-wm/smattacks/from_222/smattacks/gen_with_prompt_lawa_test/019-exps/images', help='Path to the dataset')
62
+ parser.add_argument('--images_dir', type=str, default='', help='Path to save the images')
63
+ parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
64
+ parser.add_argument('--seed', type=int, default=1337, help='Random seed')
65
+ return parser
66
+
67
+ def main(args):
68
+ # 指定文件路径
69
+ # args.file_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/models.py'
70
+ module_dir = os.path.dirname(args.file_path)
71
+ sys.path.append(module_dir)
72
+
73
+ # 加载模块
74
+ spec = importlib.util.spec_from_file_location("stagastamp_models", args.file_path)
75
+ stagastamp_models = importlib.util.module_from_spec(spec)
76
+ sys.modules["stagastamp_models"] = stagastamp_models
77
+ spec.loader.exec_module(stagastamp_models)
78
+
79
+ encoder = stagastamp_models.StegaStampEncoder(256, 3, 200, return_residual=False)
80
+ decoder = stagastamp_models.StegaStampDecoder(256, 3, 200)
81
+
82
+ # args.encoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_encoder_099.pth'
83
+ # args.decoder_path = '/data4/ldd/projects/RobustWM/wm_encdec/stegastamp/ckpts/stegastamp_coco_256_onefactor/stegastamp/checkpoints/stegastamp_decoder_099.pth'
84
+
85
+ # Load weights
86
+ encoder.load_state_dict(torch.load(args.encoder_path, map_location='cuda'))
87
+ decoder.load_state_dict(torch.load(args.decoder_path, map_location='cuda'))
88
+ encoder = encoder.to('cuda')
89
+ decoder = decoder.to('cuda')
90
+
91
+ # args.data_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
92
+ # args.images_dir = '/pubdata/ldd/projects/EW-LoRA/experiments/evals'
93
+ # if not os.path.exists(args.images_dir):
94
+ # os.makedirs(args.images_dir)
95
+ args.batch_size = 1
96
+
97
+ default_transform = transforms.Compose([
98
+ transforms.ToTensor(),
99
+ # normalize_vqgan,
100
+ ])
101
+ args.seed = 1337
102
+ set_seed(args.seed)
103
+
104
+ def generate_random_fingerprints(fingerprint_length, batch_size=4, size=(400, 400)):
105
+ z = torch.zeros((batch_size, fingerprint_length), dtype=torch.float).random_(0, 2)
106
+ return z
107
+
108
+ args.seed = 42
109
+ torch.manual_seed(args.seed)
110
+ fingerprints = generate_random_fingerprints(200, batch_size=1, size=(256, 3))
111
+
112
+ # 定义多个 checkpoint 前缀
113
+ ckpt_prefixes = [
114
+ "SS_fix_weights",
115
+ "SS_dlwt",
116
+ "WMA_fix_weights",
117
+ "WMA_dlwt",
118
+ "LaWa_fix_weights",
119
+ "LaWa_dlwt",
120
+ "EW-LoRA_fix_weights",
121
+ "EW-LoRA_dlwt"
122
+ ]
123
+
124
+ for ckpt_prefix in ckpt_prefixes:
125
+ dataloader = get_dataloader(os.path.join(args.data_dir, f'save_imgs_' + ckpt_prefix), transform=default_transform, batch_size=args.batch_size)
126
+ df = pd.DataFrame(columns=["iteration", "bit_acc_avg"])
127
+
128
+ bit_accs_avg_list = []
129
+ psnr_avg_list = []
130
+
131
+ for i, (images, _) in enumerate(tqdm(dataloader)):
132
+ fingerprints = fingerprints.to('cuda')
133
+ images = images.to('cuda')
134
+ fingerprinted_images = encoder(fingerprints, images)
135
+ decoder_output = decoder(fingerprinted_images)
136
+
137
+ save_image_path = os.path.join(args.data_dir, f'overwrite_stegastamp_' + ckpt_prefix)
138
+ if not os.path.exists(save_image_path):
139
+ os.makedirs(save_image_path)
140
+ save_image(fingerprinted_images, os.path.join(save_image_path, f'overwrite_img_w_{i:07}.png'))
141
+
142
+ # msg stats
143
+ ori_msgs = torch.sign(fingerprints) > 0
144
+ decoded_msgs = torch.sign(decoder_output) > 0 # b k -> b k
145
+ diff = (~torch.logical_xor(ori_msgs, decoded_msgs)) # b k -> b k
146
+ bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b
147
+ bit_accs_avg = torch.mean(bit_accs).item()
148
+
149
+ psnr_avg = psnr(fingerprinted_images, images).mean().item()
150
+ psnr_avg_list.append(psnr_avg)
151
+ bit_accs_avg_list.append(bit_accs_avg)
152
+
153
+ df = df._append({"iteration": i, "bit_acc_avg": bit_accs_avg, "psnr_avg": psnr_avg}, ignore_index=True)
154
+ df.to_csv(os.path.join(args.data_dir, f'overwrite_att_' + ckpt_prefix, "bit_acc_stegastamp.csv"), index=False)
155
+
156
+ overall_avg_bit_accs = sum(bit_accs_avg_list) / len(bit_accs_avg_list)
157
+ overall_avg_psnr = sum(psnr_avg_list) / len(psnr_avg_list)
158
+
159
+ print(f"Model: {ckpt_prefix}, ACC: {overall_avg_bit_accs}, PSNR: {overall_avg_psnr}")
160
+
161
+ if __name__ == '__main__':
162
+ # generate parser / parse parameters
163
+ parser = get_parser()
164
+ args = parser.parse_args()
165
+ main(args)
overwrite_attack/utils_img.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyright: reportMissingModuleSource=false
8
+
9
+ import numpy as np
10
+ from augly.image import functional as aug_functional
11
+ import torch
12
+ from torchvision import transforms
13
+ from torchvision.transforms import functional
14
+ from torch.autograd.variable import Variable
15
+ import torch.nn.functional as F
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ default_transform = transforms.Compose([
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
+ ])
23
+
24
+ normalize_vqgan = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize (x - 0.5) / 0.5
25
+ unnormalize_vqgan = transforms.Normalize(mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5]) # Unnormalize (x * 0.5) + 0.5
26
+ normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize (x - mean) / std
27
+ unnormalize_img = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) # Unnormalize (x * std) + mean
28
+
29
+ def psnr(x, y, img_space='vqgan'):
30
+ """
31
+ Return PSNR
32
+ Args:
33
+ x: Image tensor with values approx. between [-1,1]
34
+ y: Image tensor with values approx. between [-1,1], ex: original image
35
+ """
36
+ if img_space == 'vqgan':
37
+ delta = torch.clamp(unnormalize_vqgan(x), 0, 1) - torch.clamp(unnormalize_vqgan(y), 0, 1)
38
+ elif img_space == 'img':
39
+ delta = torch.clamp(unnormalize_img(x), 0, 1) - torch.clamp(unnormalize_img(y), 0, 1)
40
+ else:
41
+ delta = x - y
42
+ delta = 255 * delta
43
+ delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW
44
+ psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B
45
+ return psnr
46
+
47
+ def center_crop(x, scale):
48
+ """ Perform center crop such that the target area of the crop is at a given scale
49
+ Args:
50
+ x: PIL image
51
+ scale: target area scale
52
+ """
53
+ scale = np.sqrt(scale)
54
+ new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1]
55
+ return functional.center_crop(x, new_edges_size)
56
+
57
+ def resize(x, scale):
58
+ """ Perform center crop such that the target area of the crop is at a given scale
59
+ Args:
60
+ x: PIL image
61
+ scale: target area scale
62
+ """
63
+ scale = np.sqrt(scale)
64
+ new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1]
65
+ return functional.resize(x, new_edges_size)
66
+
67
+ def rotate(x, angle):
68
+ """ Rotate image by angle
69
+ Args:
70
+ x: image (PIl or tensor)
71
+ angle: angle in degrees
72
+ """
73
+ return functional.rotate(x, angle)
74
+
75
+ def flip(x, direction='horizontal'):
76
+ """ Rotate image by angle
77
+ Args:
78
+ x: image (PIl or tensor)
79
+ angle: angle in degrees
80
+ """
81
+ if direction == 'horizontal':
82
+ return functional.hflip(x)
83
+ elif direction == 'vertical':
84
+ return functional.vflip(x)
85
+
86
+ def adjust_brightness(x, brightness_factor):
87
+ """ Adjust brightness of an image
88
+ Args:
89
+ x: PIL image
90
+ brightness_factor: brightness factor
91
+ """
92
+ return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor))
93
+
94
+ def adjust_contrast(x, contrast_factor):
95
+ """ Adjust contrast of an image
96
+ Args:
97
+ x: PIL image
98
+ contrast_factor: contrast factor
99
+ """
100
+ return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor))
101
+
102
+ def adjust_saturation(x, saturation_factor):
103
+ """ Adjust saturation of an image
104
+ Args:
105
+ x: PIL image
106
+ saturation_factor: saturation factor
107
+ """
108
+ return normalize_img(functional.adjust_saturation(unnormalize_img(x), saturation_factor))
109
+
110
+ def adjust_hue(x, hue_factor):
111
+ """ Adjust hue of an image
112
+ Args:
113
+ x: PIL image
114
+ hue_factor: hue factor
115
+ """
116
+ return normalize_img(functional.adjust_hue(unnormalize_img(x), hue_factor))
117
+
118
+ def adjust_gamma(x, gamma, gain=1):
119
+ """ Adjust gamma of an image
120
+ Args:
121
+ x: PIL image
122
+ gamma: gamma factor
123
+ gain: gain factor
124
+ """
125
+ return normalize_img(functional.adjust_gamma(unnormalize_img(x), gamma, gain))
126
+
127
+ def adjust_sharpness(x, sharpness_factor):
128
+ """ Adjust sharpness of an image
129
+ Args:
130
+ x: PIL image
131
+ sharpness_factor: sharpness factor
132
+ """
133
+ return normalize_img(functional.adjust_sharpness(unnormalize_img(x), sharpness_factor))
134
+
135
+ def overlay_text(x, text='Lorem Ipsum'):
136
+ """ Overlay text on image
137
+ Args:
138
+ x: PIL image
139
+ text: text to overlay
140
+ font_path: path to font
141
+ font_size: font size
142
+ color: text color
143
+ position: text position
144
+ """
145
+ to_pil = transforms.ToPILImage()
146
+ to_tensor = transforms.ToTensor()
147
+ img_aug = torch.zeros_like(x, device=x.device)
148
+ for ii,img in enumerate(x):
149
+ pil_img = to_pil(unnormalize_img(img))
150
+ img_aug[ii] = to_tensor(aug_functional.overlay_text(pil_img, text=text))
151
+ return normalize_img(img_aug)
152
+
153
+ def jpeg_compress(x, quality_factor):
154
+ """ Apply jpeg compression to image
155
+ Args:
156
+ x: PIL image
157
+ quality_factor: quality factor
158
+ """
159
+ to_pil = transforms.ToPILImage()
160
+ to_tensor = transforms.ToTensor()
161
+ img_aug = torch.zeros_like(x, device=x.device)
162
+ for ii,img in enumerate(x):
163
+ pil_img = to_pil(unnormalize_img(img))
164
+ img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor))
165
+ return normalize_img(img_aug)
166
+
167
+ def gaussian_noise(input, stddev):
168
+ # noise = Variable(input.data.new(input.size()).normal_(0, stddev))
169
+ # output = torch.clamp(input + noise, -1, 1)
170
+ # output = A.GaussNoise(var_limit=stddev, p=1)
171
+ output = torch.clamp(unnormalize_img(input).clone() + (torch.randn(
172
+ [input.shape[0], input.shape[1], input.shape[2], input.shape[3]]) * (stddev**0.5)).to(input.device), -1, 1)
173
+ return normalize_img(output)
174
+
175
+ def adjust_gaussian_blur(img, ks):
176
+ return normalize_img(functional.gaussian_blur(unnormalize_img(img), kernel_size=ks))
requirements.txt ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.0.1
3
+ aiohappyeyeballs==2.4.0
4
+ aiohttp==3.10.5
5
+ aiosignal==1.3.1
6
+ antlr4-python3-runtime==4.9.3
7
+ async-timeout==4.0.3
8
+ attrs==24.2.0
9
+ augly==1.0.0
10
+ bm3d==4.0.1
11
+ bm4d==4.2.3
12
+ cachetools==5.5.0
13
+ certifi==2024.7.4
14
+ charset-normalizer==3.3.2
15
+ clip==0.2.0
16
+ cmake==3.30.2
17
+ coloredlogs==15.0.1
18
+ compressai==1.2.6
19
+ contourpy==1.1.1
20
+ cycler==0.12.1
21
+ datasets==3.0.1
22
+ diffusers==0.30.3
23
+ dill==0.3.8
24
+ einops==0.6.1
25
+ filelock==3.15.4
26
+ flatbuffers==24.3.25
27
+ fonttools==4.54.1
28
+ frozenlist==1.4.1
29
+ fsspec==2024.6.1
30
+ ftfy==6.2.3
31
+ future==1.0.0
32
+ google-auth==2.34.0
33
+ google-auth-oauthlib==1.0.0
34
+ grpcio==1.65.5
35
+ huggingface-hub==0.25.2
36
+ humanfriendly==10.0
37
+ idna==3.7
38
+ imageio==2.35.1
39
+ imhist==0.0.4
40
+ importlib_metadata==8.4.0
41
+ importlib_resources==6.4.5
42
+ invisible-watermark==0.2.0
43
+ iopath==0.1.10
44
+ Jinja2==3.1.4
45
+ kiwisolver==1.4.7
46
+ kornia==0.7.3
47
+ kornia_rs==0.1.5
48
+ lazy_loader==0.4
49
+ lightning-utilities==0.11.6
50
+ lit==18.1.8
51
+ lpips==0.1.4
52
+ Markdown==3.7
53
+ MarkupSafe==2.1.5
54
+ matplotlib==3.7.5
55
+ mpmath==1.3.0
56
+ multidict==6.0.5
57
+ multiprocess==0.70.16
58
+ networkx==3.1
59
+ numpy==1.24.4
60
+ nvidia-cublas-cu11==11.10.3.66
61
+ nvidia-cuda-cupti-cu11==11.7.101
62
+ nvidia-cuda-nvrtc-cu11==11.7.99
63
+ nvidia-cuda-runtime-cu11==11.7.99
64
+ nvidia-cudnn-cu11==8.5.0.96
65
+ nvidia-cufft-cu11==10.9.0.58
66
+ nvidia-curand-cu11==10.2.10.91
67
+ nvidia-cusolver-cu11==11.4.0.1
68
+ nvidia-cusparse-cu11==11.7.4.91
69
+ nvidia-nccl-cu11==2.14.3
70
+ nvidia-nvtx-cu11==11.7.91
71
+ oauthlib==3.2.2
72
+ omegaconf==2.3.0
73
+ onnxruntime==1.19.2
74
+ open_clip_torch==2.26.1
75
+ opencv-python==4.8.1.78
76
+ pandas==1.5.3
77
+ peft==0.13.2
78
+ pillow==10.4.0
79
+ portalocker==2.10.1
80
+ protobuf==5.27.3
81
+ pyarrow==17.0.0
82
+ pyasn1==0.6.0
83
+ pyasn1_modules==0.4.0
84
+ pyDeprecate==0.3.1
85
+ pyparsing==3.1.4
86
+ python-magic==0.4.27
87
+ pytorch-lightning==1.5.0
88
+ pytorch-msssim==1.0.0
89
+ pytz==2024.1
90
+ PyWavelets==1.4.1
91
+ PyYAML==6.0.2
92
+ regex==2024.7.24
93
+ requests==2.32.3
94
+ requests-oauthlib==2.0.0
95
+ rsa==4.9
96
+ safetensors==0.4.5
97
+ scikit-image==0.21.0
98
+ scipy==1.10.1
99
+ sympy==1.13.2
100
+ taming-transformers-rom1504==0.0.6
101
+ tensorboard==2.14.0
102
+ tensorboard-data-server==0.7.2
103
+ test_tube==0.7.5
104
+ tifffile==2023.7.10
105
+ timm==1.0.9
106
+ tokenizers==0.20.1
107
+ torch==2.0.1
108
+ torch-fidelity==0.3.0
109
+ torch-geometric==2.6.1
110
+ torchmetrics==1.4.1
111
+ torchvision==0.15.2
112
+ tqdm==4.66.5
113
+ transformers==4.45.2
114
+ triton==2.0.0
115
+ urllib3==2.2.2
116
+ Werkzeug==3.0.4
117
+ xxhash==3.5.0
118
+ yarl==1.9.4
119
+ zipp==3.20.0
watermarker/LaWa/configs/SD14_LaWa.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: models.modifiedAEDecoder.LaWa
3
+ params:
4
+ scale_factor: 1.0 #0.18215
5
+ extraction_resize: False
6
+ start_attack_acc_thresh: 0.5
7
+ watermark_addition_weight: 0.1
8
+ learning_rate: 0.00008
9
+ epoch_num: 100
10
+ dis_update_freq: 0
11
+ noise_block_size: 8
12
+ first_stage_config:
13
+ target: stable-diffusion.ldm.models.autoencoder.AutoencoderKL
14
+ params:
15
+ ckpt_path: /pubdata/ldd/models/first_stage_models/first_stage_KL-f8.ckpt
16
+ embed_dim: 4
17
+ monitor: val/rec_loss
18
+ ddconfig:
19
+ double_z: true
20
+ z_channels: 4
21
+ resolution: 256
22
+ in_channels: 3
23
+ out_ch: 3
24
+ ch: 128
25
+ ch_mult:
26
+ - 1
27
+ - 2
28
+ - 4
29
+ - 4
30
+ num_res_blocks: 2
31
+ attn_resolutions: []
32
+ dropout: 0.0
33
+ lossconfig:
34
+ target: torch.nn.Identity
35
+
36
+ decoder_config:
37
+ target: models.messageDecoder.MessageDecoder
38
+ params:
39
+ message_len: 48
40
+
41
+
42
+ discriminator_config:
43
+ target: models.modifiedAEDecoder.Discriminator1
44
+
45
+ # dlwt configs:
46
+ apply_dlwt: False
47
+ psnr_threshold: 30.0
48
+ bitacc_target: 0.95
49
+ delta: 1.0
50
+ # loss config: (set message_absolute_loss_weight=0 if dlwt is applied)
51
+ recon_type: rgb
52
+ recon_loss_weight: 0.1
53
+ adversarial_loss_weight: 1.0
54
+ perceptual_loss_weight: 1.0
55
+ message_absolute_loss_weight: 2.0
56
+
57
+ noise_config:
58
+ target: models.transformations.TransformNet
59
+ params:
60
+ ramp: 10000
61
+ apply_many_crops: False
62
+ apply_required_attacks: True
63
+ required_attack_list: ['none'] #['rotation', 'resize','random_crop', 'center_crop', 'blur', 'noise','contrast','brightness', 'jpeg']
64
+
65
+ data:
66
+ target: tools.dataset.DataModule
67
+ params:
68
+ batch_size: 8
69
+ num_workers: 8
70
+ use_worker_init_fn: true
71
+ train:
72
+ target: tools.dataset.dataset
73
+ params:
74
+ data_dir: /pubdata/ldd/Datasets/Flicker
75
+ data_list: /pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/lawa_dataset/train_100k.csv
76
+ resize: 256
77
+ validation:
78
+ target: tools.dataset.dataset
79
+ params:
80
+ data_dir: /pubdata/ldd/Datasets/Flicker
81
+ data_list: /pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/lawa_dataset/val_1k.csv
82
+ resize: 256
83
+ limit_samples: 100
84
+
85
+ lightning:
86
+ callbacks:
87
+ image_logger:
88
+ target: models.logger.ImageLogger
89
+ params:
90
+ batch_frequency: 1
91
+ max_images: 0
92
+ increase_log_steps: False
93
+ fixed_input: True
94
+ progress_bar:
95
+ target: pytorch_lightning.callbacks.ProgressBar
96
+ params:
97
+ refresh_rate: 4
98
+ checkpoint:
99
+ target: pytorch_lightning.callbacks.ModelCheckpoint
100
+ params:
101
+ verbose: true
102
+ filename: '{epoch:06}-{step:09}'
103
+ every_n_train_steps: 5000
104
+
105
+ trainer:
106
+ benchmark: True
107
+ base_learning_rate: 2e-5
108
+ accumulate_grad_batches: 1
watermarker/LaWa/configs/SD14_LaWa_dlwt.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: models.modifiedAEDecoder.LaWa
3
+ params:
4
+ scale_factor: 1.0 #0.18215
5
+ extraction_resize: False
6
+ start_attack_acc_thresh: 0.5
7
+ watermark_addition_weight: 0.1
8
+ learning_rate: 0.00008
9
+ epoch_num: 100
10
+ dis_update_freq: 0
11
+ noise_block_size: 8
12
+ first_stage_config:
13
+ target: stable-diffusion.ldm.models.autoencoder.AutoencoderKL
14
+ params:
15
+ ckpt_path: /pubdata/ldd/models/first_stage_models/first_stage_KL-f8.ckpt
16
+ embed_dim: 4
17
+ monitor: val/rec_loss
18
+ ddconfig:
19
+ double_z: true
20
+ z_channels: 4
21
+ resolution: 256
22
+ in_channels: 3
23
+ out_ch: 3
24
+ ch: 128
25
+ ch_mult:
26
+ - 1
27
+ - 2
28
+ - 4
29
+ - 4
30
+ num_res_blocks: 2
31
+ attn_resolutions: []
32
+ dropout: 0.0
33
+ lossconfig:
34
+ target: torch.nn.Identity
35
+
36
+ decoder_config:
37
+ target: models.messageDecoder.MessageDecoder
38
+ params:
39
+ message_len: 48
40
+
41
+
42
+ discriminator_config:
43
+ target: models.modifiedAEDecoder.Discriminator1
44
+
45
+ # dlwt configs:
46
+ apply_dlwt: True
47
+ psnr_threshold: 30.0
48
+ bitacc_target: 0.95
49
+ delta: 1.0
50
+ # loss config: (set message_absolute_loss_weight=0 if dlwt is applied)
51
+ recon_type: rgb
52
+ recon_loss_weight: 1.0
53
+ adversarial_loss_weight: 1.0
54
+ perceptual_loss_weight: 1.0
55
+ message_absolute_loss_weight: 0.0
56
+
57
+ noise_config:
58
+ target: models.transformations.TransformNet
59
+ params:
60
+ ramp: 10000
61
+ apply_many_crops: False
62
+ apply_required_attacks: True
63
+ required_attack_list: ['none'] #['rotation', 'resize','random_crop', 'center_crop', 'blur', 'noise','contrast','brightness', 'jpeg']
64
+
65
+ data:
66
+ target: tools.dataset.DataModule
67
+ params:
68
+ batch_size: 8
69
+ num_workers: 8
70
+ use_worker_init_fn: true
71
+ train:
72
+ target: tools.dataset.dataset
73
+ params:
74
+ data_dir: /pubdata/ldd/Datasets/Flicker
75
+ data_list: /pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/lawa_dataset/train_100k.csv
76
+ resize: 256
77
+ validation:
78
+ target: tools.dataset.dataset
79
+ params:
80
+ data_dir: /pubdata/ldd/Datasets/Flicker
81
+ data_list: /pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/lawa_dataset/val_1k.csv
82
+ resize: 256
83
+ limit_samples: 100
84
+
85
+ lightning:
86
+ callbacks:
87
+ image_logger:
88
+ target: models.logger.ImageLogger
89
+ params:
90
+ batch_frequency: 1
91
+ max_images: -1
92
+ increase_log_steps: False
93
+ fixed_input: True
94
+ progress_bar:
95
+ target: pytorch_lightning.callbacks.ProgressBar
96
+ params:
97
+ refresh_rate: 4
98
+ checkpoint:
99
+ target: pytorch_lightning.callbacks.ModelCheckpoint
100
+ params:
101
+ verbose: true
102
+ filename: '{epoch:06}-{step:09}'
103
+ every_n_train_steps: 5000
104
+
105
+ trainer:
106
+ benchmark: True
107
+ base_learning_rate: 2e-5
108
+ accumulate_grad_batches: 1
watermarker/LaWa/configs/SD14_LaWa_inference.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: models.modifiedAEDecoder.LaWa
3
+ params:
4
+ scale_factor: 0.18215 # 1.0
5
+ extraction_resize: False
6
+ start_attack_acc_thresh: 0.75
7
+ watermark_addition_weight: 0.1
8
+ learning_rate: 0.00008
9
+ epoch_num: 200
10
+ dis_update_freq: 0
11
+ noise_block_size: 8
12
+ first_stage_config:
13
+ target: stable-diffusion.ldm.models.autoencoder.AutoencoderKL
14
+ params:
15
+ ckpt_path: weights/first_stage_models/first_stage_KL-f8.ckpt
16
+ embed_dim: 4
17
+ monitor: val/rec_loss
18
+ ddconfig:
19
+ double_z: true
20
+ z_channels: 4
21
+ resolution: 256
22
+ in_channels: 3
23
+ out_ch: 3
24
+ ch: 128
25
+ ch_mult:
26
+ - 1
27
+ - 2
28
+ - 4
29
+ - 4
30
+ num_res_blocks: 2
31
+ attn_resolutions: []
32
+ dropout: 0.0
33
+ lossconfig:
34
+ target: torch.nn.Identity
35
+
36
+ decoder_config:
37
+ target: models.messageDecoder.MessageDecoder
38
+ params:
39
+ message_len: 48
40
+
41
+ discriminator_config:
42
+ target: models.modifiedAEDecoder.Discriminator1
43
+
44
+ # loss config:
45
+ recon_type: rgb
46
+ recon_loss_weight: 0.1
47
+ adversarial_loss_weight: 1.0
48
+ perceptual_loss_weight: 1.0
49
+ message_absolute_loss_weight: 2.0
50
+
51
+ noise_config:
52
+ target: models.transformations.TransformNet
53
+ params:
54
+ ramp: 10000
55
+ apply_many_crops: False
56
+ apply_required_attacks: True
57
+ required_attack_list: ['rotation', 'resize','random_crop', 'center_crop', 'blur', 'noise','contrast','brightness', 'jpeg']
watermarker/LaWa/configs/SD14_LaWa_inference_dlwt.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: models.modifiedAEDecoder.LaWa
3
+ params:
4
+ scale_factor: 0.18215 # 1.0
5
+ extraction_resize: False
6
+ start_attack_acc_thresh: 0.75
7
+ watermark_addition_weight: 0.1
8
+ learning_rate: 0.00008
9
+ epoch_num: 200
10
+ dis_update_freq: 0
11
+ noise_block_size: 8
12
+ first_stage_config:
13
+ target: stable-diffusion.ldm.models.autoencoder.AutoencoderKL
14
+ params:
15
+ ckpt_path: weights/first_stage_models/first_stage_KL-f8.ckpt
16
+ embed_dim: 4
17
+ monitor: val/rec_loss
18
+ ddconfig:
19
+ double_z: true
20
+ z_channels: 4
21
+ resolution: 256
22
+ in_channels: 3
23
+ out_ch: 3
24
+ ch: 128
25
+ ch_mult:
26
+ - 1
27
+ - 2
28
+ - 4
29
+ - 4
30
+ num_res_blocks: 2
31
+ attn_resolutions: []
32
+ dropout: 0.0
33
+ lossconfig:
34
+ target: torch.nn.Identity
35
+
36
+ decoder_config:
37
+ target: models.messageDecoder.HiDDeNDecoder #models.messageDecoder.MessageDecoder
38
+ params:
39
+ msg_decoder_dir: /pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt
40
+ message_len: 48
41
+
42
+ discriminator_config:
43
+ target: models.modifiedAEDecoder.Discriminator1
44
+
45
+ # loss config:
46
+ recon_type: rgb
47
+ recon_loss_weight: 0.1
48
+ adversarial_loss_weight: 1.0
49
+ perceptual_loss_weight: 1.0
50
+ message_absolute_loss_weight: 2.0
51
+
52
+ noise_config:
53
+ target: models.transformations.TransformNet
54
+ params:
55
+ ramp: 10000
56
+ apply_many_crops: False
57
+ apply_required_attacks: True
58
+ required_attack_list: ['rotation', 'resize','random_crop', 'center_crop', 'blur', 'noise','contrast','brightness', 'jpeg']
watermarker/LaWa/configs/SD14_LaWa_ldm.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: models.LaWaAEDecoder.LaWa
3
+ params:
4
+ scale_factor: 1.0 #0.18215
5
+ extraction_resize: False
6
+ start_attack_acc_thresh: 0.5
7
+ watermark_addition_weight: 0.1
8
+ learning_rate: 0.00008
9
+ epoch_num: 100
10
+ dis_update_freq: 0
11
+ noise_block_size: 8
12
+ first_stage_config:
13
+ target: stable-diffusion.ldm.models.autoencoder.AutoencoderKL
14
+ params:
15
+ ckpt_path: /pubdata/ldd/models/first_stage_models/first_stage_KL-f8.ckpt
16
+ embed_dim: 4
17
+ monitor: val/rec_loss
18
+ ddconfig:
19
+ double_z: true
20
+ z_channels: 4
21
+ resolution: 256
22
+ in_channels: 3
23
+ out_ch: 3
24
+ ch: 128
25
+ ch_mult:
26
+ - 1
27
+ - 2
28
+ - 4
29
+ - 4
30
+ num_res_blocks: 2
31
+ attn_resolutions: []
32
+ dropout: 0.0
33
+ lossconfig:
34
+ target: torch.nn.Identity
35
+
36
+ decoder_config:
37
+ target: models.messageDecoder.MessageDecoder
38
+ params:
39
+ message_len: 48
40
+
41
+
42
+ discriminator_config:
43
+ target: models.modifiedAEDecoder.Discriminator1
44
+
45
+ # dlwt configs:
46
+ apply_dlwt: False
47
+ psnr_threshold: 30
48
+ bitacc_target: 0.95
49
+ delta: 1.0
50
+ # loss config: (set message_absolute_loss_weight=0 if dlwt is applied)
51
+ recon_type: rgb
52
+ recon_loss_weight: 0.1
53
+ adversarial_loss_weight: 1.0
54
+ perceptual_loss_weight: 1.0
55
+ message_absolute_loss_weight: 2.0
56
+
57
+ noise_config:
58
+ target: models.transformations.TransformNet
59
+ params:
60
+ ramp: 10000
61
+ apply_many_crops: False
62
+ apply_required_attacks: True
63
+ required_attack_list: ['none'] #['rotation', 'resize','random_crop', 'center_crop', 'blur', 'noise','contrast','brightness', 'jpeg']
64
+
65
+ data:
66
+ target: tools.dataset.DataModule
67
+ params:
68
+ batch_size: 8
69
+ num_workers: 8
70
+ use_worker_init_fn: true
71
+ train:
72
+ target: tools.dataset.dataset
73
+ params:
74
+ data_dir: /pubdata/ldd/Datasets/Flicker
75
+ data_list: /pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/lawa_dataset/train_100k.csv
76
+ resize: 256
77
+ validation:
78
+ target: tools.dataset.dataset
79
+ params:
80
+ data_dir: /pubdata/ldd/Datasets/Flicker
81
+ data_list: /pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/lawa_dataset/val_1k.csv
82
+ resize: 256
83
+
84
+ lightning:
85
+ callbacks:
86
+ image_logger:
87
+ target: models.logger.ImageLogger
88
+ params:
89
+ batch_frequency: 1
90
+ max_images: 0
91
+ increase_log_steps: False
92
+ fixed_input: True
93
+ progress_bar:
94
+ target: pytorch_lightning.callbacks.ProgressBar
95
+ params:
96
+ refresh_rate: 4
97
+ checkpoint:
98
+ target: pytorch_lightning.callbacks.ModelCheckpoint
99
+ params:
100
+ verbose: true
101
+ filename: '{epoch:06}-{step:09}'
102
+ every_n_train_steps: 5000
103
+
104
+ trainer:
105
+ benchmark: True
106
+ base_learning_rate: 2e-5
107
+ accumulate_grad_batches: 1
watermarker/LaWa/configs/SD14_LaWa_modified.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: models.modifiedAEDecoder.LaWa
3
+ params:
4
+ scale_factor: 1.0 #0.18215
5
+ extraction_resize: False
6
+ start_attack_acc_thresh: 0.5
7
+ watermark_addition_weight: 0.1
8
+ learning_rate: 0.00008
9
+ epoch_num: 100
10
+ dis_update_freq: 0
11
+ noise_block_size: 8
12
+ first_stage_config:
13
+ target: stable-diffusion.ldm.models.autoencoder.AutoencoderKL
14
+ params:
15
+ ckpt_path: weights/first_stage_models/first_stage_KL-f8.ckpt
16
+ embed_dim: 4
17
+ monitor: val/rec_loss
18
+ ddconfig:
19
+ double_z: true
20
+ z_channels: 4
21
+ resolution: 256
22
+ in_channels: 3
23
+ out_ch: 3
24
+ ch: 128
25
+ ch_mult:
26
+ - 1
27
+ - 2
28
+ - 4
29
+ - 4
30
+ num_res_blocks: 2
31
+ attn_resolutions: []
32
+ dropout: 0.0
33
+ lossconfig:
34
+ target: torch.nn.Identity
35
+
36
+ decoder_config:
37
+ target: models.messageDecoder.HiDDeNDecoder #models.messageDecoder.MessageDecoder
38
+ params:
39
+ msg_decoder_dir: /pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt
40
+ message_len: 48
41
+
42
+
43
+ discriminator_config:
44
+ target: models.modifiedAEDecoder.Discriminator1
45
+
46
+ # loss config: (set message_absolute_loss_weight=0 if dlwt is applied)
47
+ recon_type: rgb
48
+ recon_loss_weight: 0.1
49
+ adversarial_loss_weight: 1.0
50
+ perceptual_loss_weight: 1.0
51
+ message_absolute_loss_weight: 0
52
+
53
+ noise_config:
54
+ target: models.transformations.TransformNet
55
+ params:
56
+ ramp: 10000
57
+ apply_many_crops: False
58
+ apply_required_attacks: True
59
+ required_attack_list: ['none'] #['rotation', 'resize','random_crop', 'center_crop', 'blur', 'noise','contrast','brightness', 'jpeg']
60
+
61
+ data:
62
+ target: tools.dataset.DataModule
63
+ params:
64
+ batch_size: 8
65
+ num_workers: 8
66
+ use_worker_init_fn: true
67
+ train:
68
+ target: tools.dataset.dataset
69
+ params:
70
+ data_dir: /pubdata/ldd/Datasets/Flicker
71
+ data_list: lawa_dataset/train_100k.csv
72
+ resize: 256
73
+ validation:
74
+ target: tools.dataset.dataset
75
+ params:
76
+ data_dir: /pubdata/ldd/Datasets/Flicker
77
+ data_list: lawa_dataset/val_1k.csv
78
+ resize: 256
79
+
80
+ lightning:
81
+ callbacks:
82
+ image_logger:
83
+ target: models.logger.ImageLogger
84
+ params:
85
+ batch_frequency: 1
86
+ max_images: 4
87
+ increase_log_steps: False
88
+ fixed_input: True
89
+ progress_bar:
90
+ target: pytorch_lightning.callbacks.ProgressBar
91
+ params:
92
+ refresh_rate: 4
93
+ checkpoint:
94
+ target: pytorch_lightning.callbacks.ModelCheckpoint
95
+ params:
96
+ verbose: true
97
+ filename: '{epoch:06}-{step:09}'
98
+ every_n_train_steps: 5000
99
+
100
+ trainer:
101
+ benchmark: True
102
+ base_learning_rate: 2e-5
103
+ accumulate_grad_batches: 1
watermarker/LaWa/dlwt.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ def dynamic_lambda_scheduler(psnr, bitacc, lambda_i, lambda_w, psnr_threshold, bitacc_target, delta, min_increment=1e-6,
4
+ acc_increment=0.05, psnr_increment=1.0, max_psnr_threshold=50, max_bitacc_target=1.0, patience=5):
5
+ """
6
+ Dynamic Loss Weight Tuning with Adaptive Strategy for BitAcc and PSNR Adjustment.
7
+
8
+ Args:
9
+ psnr (float): Current PSNR value.
10
+ bitacc (float): Current BitAcc value.
11
+ lambda_i (float or list of float): Current weight(s) for image quality loss.
12
+ lambda_w (float): Current weight for watermark accuracy loss.
13
+ psnr_threshold (float): Current PSNR threshold.
14
+ bitacc_target (float): Current BitAcc target.
15
+ delta (float): Scaling factor for weight adjustment.
16
+ min_increment (float): Minimum increment to use when lambda_w is zero.
17
+ acc_increment (float): Increment for bitacc_target when targets are met.
18
+ psnr_increment (float): Increment for psnr_threshold when targets are met.
19
+ max_psnr_threshold (float): Maximum limit for psnr_threshold.
20
+ max_bitacc_target (float): Maximum limit for bitacc_target.
21
+ patience (int): Number of iterations to adjust based on unmet target.
22
+
23
+ Returns:
24
+ tuple: Updated values of lambda_i, lambda_w, psnr_threshold, bitacc_target.
25
+ """
26
+ # Static variables to track performance history
27
+ if not hasattr(dynamic_lambda_scheduler, "psnr_threshold_history"):
28
+ dynamic_lambda_scheduler.psnr_threshold_history = psnr_threshold
29
+ if not hasattr(dynamic_lambda_scheduler, "bitacc_target_history"):
30
+ dynamic_lambda_scheduler.bitacc_target_history = bitacc_target
31
+ if not hasattr(dynamic_lambda_scheduler, "success_i_counter"):
32
+ dynamic_lambda_scheduler.success_i_counter = 0
33
+ if not hasattr(dynamic_lambda_scheduler, "success_w_counter"):
34
+ dynamic_lambda_scheduler.success_w_counter = 0
35
+ if not hasattr(dynamic_lambda_scheduler, "success_counter"):
36
+ dynamic_lambda_scheduler.success_counter = 0
37
+
38
+ # Define proportional growth factors
39
+ bitacc_diff = bitacc_target - bitacc
40
+ bitacc_growth_factor = min(math.log(1 + abs(bitacc_diff) / bitacc_target), 1)
41
+ psnr_diff = psnr_threshold - psnr
42
+ psnr_growth_factor = min(math.log(1 + abs(psnr_diff) / psnr_threshold), 1)
43
+ # bitacc_growth_factor = 0.05 * 2
44
+ # psnr_growth_factor = 0.05
45
+
46
+ # Helper function to adjust lambda values
47
+ def adjust_lambda(lambda_value, growth_factor):
48
+ if isinstance(lambda_value, list):
49
+ return [lv + delta * growth_factor for lv in lambda_value]
50
+ else:
51
+ return lambda_value + delta * growth_factor
52
+
53
+ def decrease_lambda(lambda_value, decrease_step=0.001):
54
+ if isinstance(lambda_value, list):
55
+ return [max(lv - decrease_step, 0.0) for lv in lambda_value]
56
+ else:
57
+ return max(lambda_value - decrease_step, 0.0)
58
+
59
+ # Adjustment strategies
60
+ if bitacc < bitacc_target:
61
+ # Increase lambda_w when bitacc is below target
62
+ lambda_w = max(adjust_lambda(lambda_w, bitacc_growth_factor), min_increment)
63
+ # dynamic_lambda_scheduler.success_i_counter += 1
64
+ else: # Let bitacc meet target first, then do the rest of the adjustment
65
+ # If bitacc meets target, reset success counter
66
+ dynamic_lambda_scheduler.success_i_counter += 1
67
+
68
+ if psnr < psnr_threshold:
69
+ # Increase lambda_i when psnr is below target
70
+ lambda_i = adjust_lambda(lambda_i, psnr_growth_factor)
71
+ # dynamic_lambda_scheduler.success_w_counter += 1
72
+ else:
73
+ # If psnr meets target, reset success counter
74
+ dynamic_lambda_scheduler.success_w_counter += 1
75
+
76
+ # Increment targets if both bitacc and psnr meet their thresholds consistently
77
+ # if bitacc >= bitacc_target and psnr >= psnr_threshold:
78
+ # dynamic_lambda_scheduler.success_counter += 1
79
+ if dynamic_lambda_scheduler.success_i_counter >= patience:
80
+ # Increment targets and reset success counter
81
+ bitacc_target = min(bitacc_target + acc_increment, max_bitacc_target)
82
+ lambda_i = adjust_lambda(lambda_i, psnr_growth_factor)
83
+ dynamic_lambda_scheduler.success_i_counter = 0
84
+ elif dynamic_lambda_scheduler.success_w_counter >= patience:
85
+ # Increment targets and reset success counter
86
+ psnr_threshold = min(psnr_threshold + psnr_increment, max_psnr_threshold)
87
+ lambda_w = adjust_lambda(lambda_w, bitacc_growth_factor)
88
+ dynamic_lambda_scheduler.success_w_counter = 0
89
+
90
+ # # Revert thresholds if `patience` limit reached without meeting targets
91
+ # if dynamic_lambda_scheduler.success_i_counter >= patience:
92
+ # # Increase lambda_i when bitacc continuously misses target
93
+ # lambda_i = adjust_lambda(lambda_i, bitacc_growth_factor)
94
+ # dynamic_lambda_scheduler.success_i_counter = 0
95
+
96
+ # if dynamic_lambda_scheduler.success_w_counter >= patience:
97
+ # # Increase lambda_w when psnr continuously misses target
98
+ # lambda_w = adjust_lambda(lambda_w, psnr_growth_factor)
99
+ # dynamic_lambda_scheduler.success_w_counter = 0
100
+
101
+ # # Apply a small, fixed decrease to lambda_i and lambda_w when both targets are met
102
+ # if bitacc >= bitacc_target and psnr >= psnr_threshold:
103
+ # lambda_i = decrease_lambda(lambda_i)
104
+ # lambda_w = decrease_lambda(lambda_w)
105
+
106
+ return lambda_i, lambda_w, psnr_threshold, bitacc_target
107
+
108
+
109
+
110
+ # def adjust_multi_lambda_i(psnr, bitacc, lambda_i, lambda_w, psnr_threshold, bitacc_target, delta, min_increment=1e-6,
111
+ # psnr_increment=0.5, max_psnr_threshold=50, patience=5):
112
+ # """
113
+ # Dynamic Loss Weight Tuning with Logarithmic Proportional Increase for Fast BitAcc Adjustment
114
+
115
+ # Adjusts the weights of two loss components based on the performance metrics (PSNR and BitACC).
116
+ # Initially prioritizes increasing lambda_w for fast BitAcc convergence. Once BitAcc reaches its target,
117
+ # adjusts lambda_i or lambda_w based on PSNR and BitAcc conditions. Also dynamically adjusts psnr_threshold.
118
+
119
+ # Args:
120
+ # psnr (float): Current PSNR value.
121
+ # bitacc (float): Current BitACC value.
122
+ # lambda_i (float or list of float): Current weight(s) for image quality loss.
123
+ # lambda_w (float): Current weight for watermark accuracy loss.
124
+ # psnr_threshold (float): Current PSNR threshold.
125
+ # bitacc_target (float): Target BitACC threshold.
126
+ # delta (float): Scaling factor for weight adjustment.
127
+ # min_increment (float): Minimum increment to use when lambda_w is zero.
128
+ # psnr_increment (float): Increment for psnr_threshold when bitacc_target is met.
129
+ # max_psnr_threshold (float): Maximum limit for psnr_threshold.
130
+ # patience (int): Number of iterations to allow bitacc below target before reverting psnr_threshold.
131
+
132
+ # Returns:
133
+ # tuple: Updated values of lambda_i, lambda_w, psnr_threshold, patience_counter.
134
+ # """
135
+ # # Static variables to hold dynamic adjustment states
136
+ # if not hasattr(adjust_multi_lambda_i, "psnr_threshold_history"):
137
+ # adjust_multi_lambda_i.psnr_threshold_history = psnr_threshold
138
+ # if not hasattr(adjust_multi_lambda_i, "patience_counter"):
139
+ # adjust_multi_lambda_i.patience_counter = 0
140
+
141
+ # # Define logarithmic growth factors based on differences
142
+ # bitacc_diff = bitacc_target - bitacc
143
+ # bitacc_growth_factor = math.log(1 + abs(bitacc_diff) / bitacc_target)
144
+ # psnr_diff = psnr_threshold - psnr
145
+ # psnr_growth_factor = math.log(1 + abs(psnr_diff) / psnr_threshold)
146
+
147
+ # # Helper function to handle single float or list for lambda_i
148
+ # def adjust_lambda(lambda_value, growth_factor):
149
+ # if isinstance(lambda_value, list):
150
+ # return [lv + delta * growth_factor for lv in lambda_value]
151
+ # else:
152
+ # return lambda_value + delta * growth_factor
153
+
154
+ # def decrease_lambda(lambda_value, decrease_step=0.001):
155
+ # if isinstance(lambda_value, list):
156
+ # return [max(lv - decrease_step, 0.0) for lv in lambda_value]
157
+ # else:
158
+ # return max(lambda_value - decrease_step, 0.0)
159
+
160
+ # # Adjusting strategy
161
+ # if bitacc < bitacc_target:
162
+ # # Stage 1: Prioritize increasing lambda_w to quickly improve bitacc
163
+ # lambda_w = max(adjust_lambda(lambda_w, bitacc_growth_factor), min_increment)
164
+ # adjust_multi_lambda_i.patience_counter += 1
165
+ # else:
166
+ # # Reset patience counter when bitacc meets the target
167
+ # adjust_multi_lambda_i.patience_counter = 0
168
+
169
+ # # Stage 2: After bitacc reaches target, increase psnr_threshold and adjust weights
170
+ # if psnr < psnr_threshold:
171
+ # lambda_i = adjust_lambda(lambda_i, psnr_growth_factor)
172
+ # else:
173
+ # lambda_w = adjust_lambda(lambda_w, bitacc_growth_factor)
174
+
175
+ # # Attempt to increase psnr_threshold, but respect max threshold
176
+ # if psnr_threshold < max_psnr_threshold:
177
+ # psnr_threshold += psnr_increment
178
+ # adjust_multi_lambda_i.psnr_threshold_history = psnr_threshold
179
+
180
+ # # Stage 3: When bitacc and psnr meet targets, apply a small, fixed decrease
181
+ # if bitacc >= bitacc_target and psnr >= psnr_threshold:
182
+ # lambda_i = decrease_lambda(lambda_i)
183
+ # lambda_w = decrease_lambda(lambda_w)
184
+
185
+ # # Revert psnr_threshold if patience is exceeded, 连续patience次bitacc未达标,降低psnr_threshold
186
+ # if adjust_multi_lambda_i.patience_counter >= patience:
187
+ # psnr_threshold = adjust_multi_lambda_i.psnr_threshold_history - psnr_increment
188
+ # adjust_multi_lambda_i.patience_counter = 0 # Reset patience
189
+
190
+ # return lambda_i, lambda_w, psnr_threshold
191
+
192
+
193
+ # def adjust_multi_lambda_i(psnr, bitacc, lambda_i, lambda_w, psnr_threshold, bitacc_target, delta, min_increment=1e-6):
194
+ # """
195
+ # Dynamic Loss Weight Tuning with Logarithmic Proportional Increase for Fast BitAcc Adjustment
196
+
197
+ # Adjusts the weights of two loss components based on the performance metrics (PSNR and BitACC).
198
+ # Initially prioritizes increasing lambda_w for fast BitAcc convergence. Once BitAcc reaches its target,
199
+ # adjusts lambda_i or lambda_w based on PSNR and BitAcc conditions.
200
+
201
+ # Args:
202
+ # psnr (float): Current PSNR value.
203
+ # bitacc (float): Current BitACC value.
204
+ # lambda_i (float or list of float): Current weight(s) for image quality loss.
205
+ # lambda_w (float): Current weight for watermark accuracy loss.
206
+ # psnr_threshold (float): Target PSNR threshold.
207
+ # bitacc_target (float): Target BitACC threshold.
208
+ # delta (float): Scaling factor for weight adjustment.
209
+ # min_increment (float): Minimum increment to use when lambda_w is zero.
210
+
211
+ # Returns:
212
+ # tuple: Updated values of lambda_i (float or list) and lambda_w.
213
+ # """
214
+ # # Define logarithmic growth factors based on differences
215
+ # bitacc_diff = bitacc_target - bitacc
216
+ # bitacc_growth_factor = math.log(1 + abs(bitacc_diff) / bitacc_target)
217
+ # psnr_diff = psnr_threshold - psnr
218
+ # psnr_growth_factor = math.log(1 + min(abs(psnr_diff), 100) / psnr_threshold)
219
+
220
+ # # Helper function to handle single float or list for lambda_i
221
+ # def adjust_lambda(lambda_value, growth_factor):
222
+ # if isinstance(lambda_value, list):
223
+ # return [lv + delta * growth_factor for lv in lambda_value]
224
+ # else:
225
+ # return lambda_value + delta * growth_factor
226
+
227
+ # def decrease_lambda(lambda_value, decrease_step=0.001):
228
+ # if isinstance(lambda_value, list):
229
+ # return [max(lv - decrease_step, 0.0) for lv in lambda_value]
230
+ # else:
231
+ # return max(lambda_value - decrease_step, 0.0)
232
+
233
+ # # Adjusting strategy
234
+ # if bitacc < bitacc_target:
235
+ # # Stage 1: Prioritize increasing lambda_w to quickly improve bitacc
236
+ # lambda_w = max(adjust_lambda(lambda_w, bitacc_growth_factor), min_increment)
237
+ # else:
238
+ # # Stage 2: After bitacc reaches target, adjust based on psnr
239
+ # if psnr < psnr_threshold:
240
+ # # If psnr is below threshold, increase lambda_i to improve image quality
241
+ # lambda_i = adjust_lambda(lambda_i, psnr_growth_factor)
242
+ # else:
243
+ # # If psnr is above threshold, continue increasing lambda_w for better embedding
244
+ # lambda_w = adjust_lambda(lambda_w, bitacc_growth_factor)
245
+
246
+ # # Stage 3: When both bitacc and psnr meet targets, apply a small, fixed decrease
247
+ # if bitacc >= bitacc_target and psnr >= psnr_threshold:
248
+ # lambda_i = decrease_lambda(lambda_i)
249
+ # lambda_w = decrease_lambda(lambda_w)
250
+
251
+ # return lambda_i, lambda_w
watermarker/LaWa/ecc.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bchlib
2
+ import numpy as np
3
+ from typing import List, Tuple
4
+ import random
5
+ from copy import deepcopy
6
+
7
+ class RSC(object):
8
+ def __init__(self, data_bytes=16, ecc_bytes=4, verbose=False, **kwargs):
9
+ from reedsolo import RSCodec
10
+ self.rs = RSCodec(ecc_bytes)
11
+ if verbose:
12
+ print(f'Reed-Solomon ECC len: {ecc_bytes*8} bits')
13
+ self.data_len = data_bytes
14
+ self.dlen = data_bytes * 8 # data length in bits
15
+ self.ecc_len = ecc_bytes * 8 # ecc length in bits
16
+
17
+ def get_total_len(self):
18
+ return self.dlen + self.ecc_len
19
+
20
+ def encode_text(self, text: List[str]):
21
+ return np.array([self._encode_text(t) for t in text])
22
+
23
+ def _encode_text(self, text: str):
24
+ text = text + ' ' * (self.dlen // 8 - len(text))
25
+ out = self.rs.encode(text.encode('utf-8')) # bytearray
26
+ out = ''.join(format(x, '08b') for x in out) # bit string
27
+ out = np.array([int(x) for x in out], dtype=np.float32)
28
+ return out
29
+
30
+ def decode_text(self, data: np.array):
31
+ assert len(data.shape)==2
32
+ return [self._decode_text(d) for d in data]
33
+
34
+ def _decode_text(self, data: np.array):
35
+ assert len(data.shape)==1
36
+ data = ''.join([str(int(bit)) for bit in data])
37
+ data = bytes(int(data[i: i + 8], 2) for i in range(0, len(data), 8))
38
+ data = bytearray(data)
39
+ try:
40
+ data = self.rs.decode(data)[0]
41
+ data = data.decode('utf-8').strip()
42
+ except:
43
+ print('Error: Decode failed')
44
+ data = get_random_unicode(self.get_total_len()//8)
45
+
46
+ return data
47
+
48
+ def get_random_unicode(length):
49
+ # Update this to include code point ranges to be sampled
50
+ include_ranges = [
51
+ ( 0x0021, 0x0021 ),
52
+ ( 0x0023, 0x0026 ),
53
+ ( 0x0028, 0x007E ),
54
+ ( 0x00A1, 0x00AC ),
55
+ ( 0x00AE, 0x00FF ),
56
+ ( 0x0100, 0x017F ),
57
+ ( 0x0180, 0x024F ),
58
+ ( 0x2C60, 0x2C7F ),
59
+ ( 0x16A0, 0x16F0 ),
60
+ ( 0x0370, 0x0377 ),
61
+ ( 0x037A, 0x037E ),
62
+ ( 0x0384, 0x038A ),
63
+ ( 0x038C, 0x038C ),
64
+ ]
65
+ alphabet = [
66
+ chr(code_point) for current_range in include_ranges
67
+ for code_point in range(current_range[0], current_range[1] + 1)
68
+ ]
69
+ return ''.join(random.choice(alphabet) for i in range(length))
70
+
71
+
72
+ class BCH(object):
73
+ def __init__(self, BCH_POLYNOMIAL = 137, BCH_BITS = 5, payload_len=100, verbose=True,**kwargs):
74
+ self.bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
75
+ self.payload_len = payload_len # in bits
76
+ self.data_len = (self.payload_len - self.bch.ecc_bytes*8)//7 # in ascii characters
77
+ assert self.data_len*7+self.bch.ecc_bytes*8 <= self.bch.n, f'Error! BCH with poly {BCH_POLYNOMIAL} and bits {BCH_BITS} can only encode max {self.bch.n//8} bytes of total payload'
78
+ if verbose:
79
+ print(f'BCH: POLYNOMIAL={BCH_POLYNOMIAL}, protected bits={BCH_BITS}, payload_len={payload_len} bits, data_len={self.data_len*7} bits ({self.data_len} ascii chars), ecc len={self.bch.ecc_bytes*8} bits')
80
+
81
+ def get_total_len(self):
82
+ return self.payload_len
83
+
84
+ def encode_text(self, text: List[str]):
85
+ return np.array([self._encode_text(t) for t in text])
86
+
87
+ def _encode_text(self, text: str):
88
+ text = text + ' ' * (self.data_len - len(text))
89
+ # data = text.encode('utf-8') # bytearray
90
+ data = encode_text_ascii(text) # bytearray
91
+ ecc = self.bch.encode(data) # bytearray
92
+ packet = data + ecc # payload in bytearray
93
+ packet = ''.join(format(x, '08b') for x in packet)
94
+ packet = [int(x) for x in packet]
95
+ packet.extend([0]*(self.payload_len - len(packet)))
96
+ packet = np.array(packet, dtype=np.float32)
97
+ return packet
98
+
99
+ def decode_text(self, data: np.array):
100
+ assert len(data.shape)==2
101
+ return [self._decode_text(d) for d in data]
102
+
103
+ def _decode_text(self, packet: np.array):
104
+ assert len(packet.shape)==1
105
+ packet = ''.join([str(int(bit)) for bit in packet]) # bit string
106
+ packet = packet[:(len(packet)//8*8)] # trim to multiple of 8 bits
107
+ packet = bytes(int(packet[i: i + 8], 2) for i in range(0, len(packet), 8))
108
+ packet = bytearray(packet)
109
+ # assert len(packet) == self.data_len + self.bch.ecc_bytes
110
+ data, ecc = packet[:-self.bch.ecc_bytes], packet[-self.bch.ecc_bytes:]
111
+ data0 = decode_text_ascii(deepcopy(data)).strip()
112
+ bitflips = self.bch.decode_inplace(data, ecc)
113
+ if bitflips == -1: # error, return random text
114
+ data = data0
115
+ else:
116
+ # data = data.decode('utf-8').strip()
117
+ data = decode_text_ascii(data).strip()
118
+ return data
119
+
120
+
121
+ def encode_text_ascii(text: str):
122
+ # encode text to 7-bit ascii
123
+ # input: text, str
124
+ # output: encoded text, bytearray
125
+ text_int7 = [ord(t) & 127 for t in text]
126
+ text_bitstr = ''.join(format(t,'07b') for t in text_int7)
127
+ if len(text_bitstr) % 8 != 0:
128
+ text_bitstr = '0'*(8-len(text_bitstr)%8) + text_bitstr # pad to multiple of 8
129
+ text_int8 = [int(text_bitstr[i:i+8], 2) for i in range(0, len(text_bitstr), 8)]
130
+ return bytearray(text_int8)
131
+
132
+
133
+ def decode_text_ascii(text: bytearray):
134
+ # decode text from 7-bit ascii
135
+ # input: text, bytearray
136
+ # output: decoded text, str
137
+ text_bitstr = ''.join(format(t,'08b') for t in text) # bit string
138
+ pad = len(text_bitstr) % 7
139
+ if pad != 0: # has padding, remove
140
+ text_bitstr = text_bitstr[pad:]
141
+ text_int7 = [int(text_bitstr[i:i+7], 2) for i in range(0, len(text_bitstr), 7)]
142
+ text_bytes = bytes(text_int7)
143
+ return text_bytes.decode('utf-8')
144
+
145
+
146
+ class ECC(object):
147
+ def __init__(self, BCH_POLYNOMIAL = 137, BCH_BITS = 5, **kwargs):
148
+ self.bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
149
+
150
+ def get_total_len(self):
151
+ return 100
152
+
153
+ def _encode(self, x):
154
+ # x: 56 bits, {0, 1}, np.array
155
+ # return: 100 bits, {0, 1}, np.array
156
+ dlen = len(x)
157
+ data_str = ''.join(str(x) for x in x.astype(int))
158
+ packet = bytes(int(data_str[i: i + 8], 2) for i in range(0, dlen, 8))
159
+ packet = bytearray(packet)
160
+ ecc = self.bch.encode(packet)
161
+ packet = packet + ecc # 96 bits
162
+ packet = ''.join(format(x, '08b') for x in packet)
163
+ packet = [int(x) for x in packet]
164
+ packet.extend([0, 0, 0, 0])
165
+ packet = np.array(packet, dtype=np.float32) # 100
166
+ return packet
167
+
168
+ def _decode(self, x):
169
+ # x: 100 bits, {0, 1}, np.array
170
+ # return: 56 bits, {0, 1}, np.array
171
+ packet_binary = "".join([str(int(bit)) for bit in x])
172
+ packet = bytes(int(packet_binary[i: i + 8], 2) for i in range(0, len(packet_binary), 8))
173
+ packet = bytearray(packet)
174
+
175
+ data, ecc = packet[:-self.bch.ecc_bytes], packet[-self.bch.ecc_bytes:]
176
+ bitflips = self.bch.decode_inplace(data, ecc)
177
+ if bitflips == -1: # error, return random data
178
+ data = np.random.binomial(1, .5, 56)
179
+ else:
180
+ data = ''.join(format(x, '08b') for x in data)
181
+ data = np.array([int(x) for x in data], dtype=np.float32)
182
+ return data # 56 bits
183
+
184
+ def _generate(self):
185
+ dlen = 56
186
+ data= np.random.binomial(1, .5, dlen)
187
+ packet = self._encode(data)
188
+ return packet, data
189
+
190
+ def generate(self, nsamples=1):
191
+ # generate random 56 bit secret
192
+ data = [self._generate() for _ in range(nsamples)]
193
+ data = (np.array([d[0] for d in data]), np.array([d[1] for d in data]))
194
+ return data # data with ecc, data org
195
+
196
+ def _to_text(self, data):
197
+ # data: {0, 1}, np.array
198
+ # return: str
199
+ data = ''.join([str(int(bit)) for bit in data])
200
+ all_bytes = [ data[i: i+8] for i in range(0, len(data), 8) ]
201
+ text = ''.join([chr(int(byte, 2)) for byte in all_bytes])
202
+ return text.strip()
203
+
204
+ def _to_binary(self, s):
205
+ if isinstance(s, str):
206
+ out = ''.join([ format(ord(i), "08b") for i in s ])
207
+ elif isinstance(s, bytes):
208
+ out = ''.join([ format(i, "08b") for i in s ])
209
+ elif isinstance(s, np.ndarray) and s.dtype is np.dtype(bool):
210
+ out = ''.join([chr(int(i)) for i in s])
211
+ elif isinstance(s, int) or isinstance(s, np.uint8):
212
+ out = format(s, "08b")
213
+ elif isinstance(s, np.ndarray):
214
+ out = [ format(i, "08b") for i in s ]
215
+ else:
216
+ raise TypeError("Type not supported.")
217
+
218
+ return np.array([float(i) for i in out], dtype=np.float32)
219
+
220
+ def _encode_text(self, s):
221
+ s = s + ' '*(7-len(s)) # 7 chars
222
+ s = self._to_binary(s) # 56 bits
223
+ packet = self._encode(s) # 100 bits
224
+ return packet, s
225
+
226
+ def encode_text(self, secret_list, return_pre_ecc=False):
227
+ """encode secret with BCH ECC.
228
+ Input: secret (list of strings)
229
+ Output: secret (np array) with shape (B, 100) type float23, val {0,1}"""
230
+ assert np.all(np.array([len(s) for s in secret_list]) <= 7), 'Error! all strings must be less than 7 characters'
231
+ secret_list = [self._encode_text(s) for s in secret_list]
232
+ ecc = np.array([s[0] for s in secret_list], dtype=np.float32)
233
+ if return_pre_ecc:
234
+ return ecc, np.array([s[1] for s in secret_list], dtype=np.float32)
235
+ return ecc
236
+
237
+ def decode_text(self, data):
238
+ """Decode secret with BCH ECC and convert to string.
239
+ Input: secret (torch.tensor) with shape (B, 100) type bool
240
+ Output: secret (B, 56)"""
241
+ data = self.decode(data)
242
+ data = [self._to_text(d) for d in data]
243
+ return data
244
+
245
+ def decode(self, data):
246
+ """Decode secret with BCH ECC and convert to string.
247
+ Input: secret (torch.tensor) with shape (B, 100) type bool
248
+ Output: secret (B, 56)"""
249
+ data = data[:, :96]
250
+ data = [self._decode(d) for d in data]
251
+ return np.array(data)
252
+
253
+ def test_ecc():
254
+ ecc = ECC()
255
+ batch_size = 10
256
+ secret_ecc, secret_org = ecc.generate(batch_size) # 10x100 ecc secret, 10x56 org secret
257
+ # modify secret_ecc
258
+ secret_pred = secret_ecc.copy()
259
+ secret_pred[:,3:6] = 1 - secret_pred[:,3:6]
260
+ # pass secret_ecc to model and get predicted as secret_pred
261
+ secret_pred_org = ecc.decode(secret_pred) # 10x56
262
+ assert np.all(secret_pred_org == secret_org) # 10
263
+
264
+
265
+ def test_bch():
266
+ # test 100 bit
267
+ def check(text, poly, k, l):
268
+ bch = BCH(poly, k, l)
269
+ # text = 'secrets'
270
+ encode = bch.encode_text([text])
271
+ for ind in np.random.choice(l, k):
272
+ encode[0, ind] = 1 - encode[0, ind]
273
+ text_recon = bch.decode_text(encode)[0]
274
+ assert text==text_recon
275
+
276
+ check('secrets', 137, 5, 100)
277
+ check('some secret', 285, 10, 160)
278
+
279
+ if __name__ == '__main__':
280
+ test_ecc()
281
+ test_bch()
watermarker/LaWa/examples/gen_wmimgs_EW-LoRA_dlwt.ipynb ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "No module 'xformers'. Proceeding without it.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\n",
28
+ " rank_zero_deprecation(\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ ">>> Building LDM model with config /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml and weights from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt...\n",
36
+ "Loading model from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\n",
37
+ "Global Step: 470000\n",
38
+ "LatentDiffusion: Running in eps-prediction mode\n",
39
+ "DiffusionWrapper has 859.52 M params.\n",
40
+ "making attention of type 'vanilla' with 512 in_channels\n",
41
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
42
+ "making attention of type 'vanilla' with 512 in_channels\n"
43
+ ]
44
+ },
45
+ {
46
+ "ename": "RuntimeError",
47
+ "evalue": "CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
48
+ "output_type": "error",
49
+ "traceback": [
50
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
51
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
52
+ "Cell \u001b[0;32mIn[1], line 40\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m>>> Building LDM model with config \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mldm_config\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and weights from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mldm_ckpt\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 39\u001b[0m config \u001b[38;5;241m=\u001b[39m OmegaConf\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mldm_config\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 40\u001b[0m ldm_ae: LatentDiffusion \u001b[38;5;241m=\u001b[39m \u001b[43mutils_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model_from_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mldm_ckpt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m ldm_ae: AutoencoderKL \u001b[38;5;241m=\u001b[39m ldm_ae\u001b[38;5;241m.\u001b[39mfirst_stage_model\n\u001b[1;32m 42\u001b[0m ldm_ae\u001b[38;5;241m.\u001b[39meval()\n",
53
+ "File \u001b[0;32m/pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/utils_model.py:149\u001b[0m, in \u001b[0;36mload_model_from_config\u001b[0;34m(config, ckpt, verbose)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124munexpected keys:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(u)\n\u001b[0;32m--> 149\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 150\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\n",
54
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/lightning_fabric/utilities/device_dtype_mixin.py:73\u001b[0m, in \u001b[0;36m_DeviceDtypeModuleMixin.cuda\u001b[0;34m(self, device)\u001b[0m\n\u001b[1;32m 71\u001b[0m device \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m, index\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__update_properties(device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m---> 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
55
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:911\u001b[0m, in \u001b[0;36mModule.cuda\u001b[0;34m(self, device)\u001b[0m\n\u001b[1;32m 894\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcuda\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, device: Optional[Union[\u001b[38;5;28mint\u001b[39m, device]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 895\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Move all model parameters and buffers to the GPU.\u001b[39;00m\n\u001b[1;32m 896\u001b[0m \n\u001b[1;32m 897\u001b[0m \u001b[38;5;124;03m This also makes associated parameters and buffers different objects. So\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 909\u001b[0m \u001b[38;5;124;03m Module: self\u001b[39;00m\n\u001b[1;32m 910\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 911\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
56
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:802\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 802\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
57
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:802\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 802\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
58
+ " \u001b[0;31m[... skipping similar frames: Module._apply at line 802 (1 times)]\u001b[0m\n",
59
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:802\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 802\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
60
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:825\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 822\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 823\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 824\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 825\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 826\u001b[0m should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 827\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_use_set_data:\n",
61
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:911\u001b[0m, in \u001b[0;36mModule.cuda.<locals>.<lambda>\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 894\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcuda\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, device: Optional[Union[\u001b[38;5;28mint\u001b[39m, device]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 895\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Move all model parameters and buffers to the GPU.\u001b[39;00m\n\u001b[1;32m 896\u001b[0m \n\u001b[1;32m 897\u001b[0m \u001b[38;5;124;03m This also makes associated parameters and buffers different objects. So\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 909\u001b[0m \u001b[38;5;124;03m Module: self\u001b[39;00m\n\u001b[1;32m 910\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 911\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_apply(\u001b[38;5;28;01mlambda\u001b[39;00m t: \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n",
62
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n"
63
+ ]
64
+ }
65
+ ],
66
+ "source": [
67
+ "from omegaconf import OmegaConf\n",
68
+ "from ldm.models.autoencoder import AutoencoderKL\n",
69
+ "from ldm.models.diffusion.ddpm import LatentDiffusion\n",
70
+ "\n",
71
+ "import os\n",
72
+ "import torch\n",
73
+ "import utils\n",
74
+ "import utils_model\n",
75
+ "import utils_img\n",
76
+ "import torch.nn as nn\n",
77
+ "import numpy as np\n",
78
+ "from copy import deepcopy\n",
79
+ "from torchvision import transforms\n",
80
+ "import os\n",
81
+ "import pandas as pd\n",
82
+ "from torchvision.utils import save_image\n",
83
+ "from accelerate import Accelerator\n",
84
+ "accelerator = Accelerator()\n",
85
+ "\n",
86
+ "\n",
87
+ "apply_dlwt = True\n",
88
+ "ckpt_prefix = \"EW-LoRA_dlwt\" if apply_dlwt else \"EW-LoRA_fix_weights\"\n",
89
+ "exps_num = \"003-exps\"\n",
90
+ "\n",
91
+ "img_size = 256\n",
92
+ "batch_size = 4\n",
93
+ "seed = 0\n",
94
+ "ldm_config = \"/pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml\"\n",
95
+ "ldm_ckpt = \"/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\"\n",
96
+ "msg_decoder_path = \"/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt\"\n",
97
+ "val_dir = \"/pubdata/ldd/Datasets/coco2017/val2017\"\n",
98
+ "\n",
99
+ "torch.manual_seed(seed)\n",
100
+ "torch.cuda.manual_seed_all(seed)\n",
101
+ "np.random.seed(seed)\n",
102
+ "\n",
103
+ "# Loads LDM auto-encoder models\n",
104
+ "print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')\n",
105
+ "config = OmegaConf.load(f\"{ldm_config}\")\n",
106
+ "ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, ldm_ckpt)\n",
107
+ "ldm_ae: AutoencoderKL = ldm_ae.first_stage_model\n",
108
+ "ldm_ae.eval()\n",
109
+ "ldm_ae.to(accelerator.device)\n",
110
+ "\n",
111
+ "# Loads hidden decoder\n",
112
+ "print(f'>>> Building hidden decoder with weights from {msg_decoder_path}...')\n",
113
+ "if 'torchscript' in msg_decoder_path: \n",
114
+ " msg_decoder = torch.jit.load(msg_decoder_path)\n",
115
+ "\n",
116
+ "msg_decoder.eval()\n",
117
+ "nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]\n",
118
+ "\n",
119
+ "# Freeze LDM and hidden decoder\n",
120
+ "for param in [*msg_decoder.parameters(), *ldm_ae.parameters()]:\n",
121
+ " param.requires_grad = False\n",
122
+ "\n",
123
+ "vqgan_transform = transforms.Compose([\n",
124
+ " transforms.Resize(img_size),\n",
125
+ " transforms.CenterCrop(img_size),\n",
126
+ " transforms.ToTensor(),\n",
127
+ " utils_img.normalize_vqgan,\n",
128
+ "])\n",
129
+ "val_loader = utils.get_dataloader(val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)\n",
130
+ "vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])\n",
131
+ "\n",
132
+ "key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)\n",
133
+ "key_str = \"\".join([ str(int(ii)) for ii in key.tolist()[0]])\n",
134
+ "print(f'Key: {key_str}')\n",
135
+ "\n",
136
+ "# Copy the LDM decoder and finetune the copy\n",
137
+ "ldm_decoder = deepcopy(ldm_ae)\n",
138
+ "ldm_decoder.encoder = nn.Identity()\n",
139
+ "ldm_decoder.quant_conv = nn.Identity()\n",
140
+ "# ldm_decoder.to(device)\n",
141
+ "for param in ldm_decoder.parameters():\n",
142
+ " param.requires_grad = False\n",
143
+ "\n",
144
+ "from peft import LoraConfig, get_peft_model\n",
145
+ "\n",
146
+ "wm_target = \"upsample.conv\"\n",
147
+ "rank = 4\n",
148
+ "lora_alpha = 4\n",
149
+ "\n",
150
+ "# Select the lora target model\n",
151
+ "def find_layers(model, wm_target=None):\n",
152
+ " layers = []\n",
153
+ " for name, layer in model.named_modules():\n",
154
+ " if any(wm_target in name.lower() for keyword in name):\n",
155
+ " layers.append(name)\n",
156
+ " all_layers = [name for name, _ in model.named_modules()]\n",
157
+ " return layers, all_layers\n",
158
+ "wm_target, _ = find_layers(ldm_decoder.decoder, wm_target)\n",
159
+ "\n",
160
+ "vae_lora_config = LoraConfig(\n",
161
+ " r=rank,\n",
162
+ " lora_alpha=lora_alpha,\n",
163
+ " init_lora_weights=\"gaussian\",\n",
164
+ " target_modules=wm_target,\n",
165
+ ")\n",
166
+ "vae_decoder_copy = get_peft_model(ldm_decoder.decoder, vae_lora_config)\n",
167
+ "trainable_params, all_param = vae_decoder_copy.get_nb_trainable_parameters()\n",
168
+ "print(f\"Parameters for PEFT watermarking: \"\n",
169
+ " f\"Trainable params: {trainable_params/1e6:.5f}M || \"\n",
170
+ " f\"PEFT Model size: {trainable_params*4/(1024*1024):.5f}M || \"\n",
171
+ " f\"All params: {all_param/1e6:.5f}M || \"\n",
172
+ " f\"Trainable%: {100 * trainable_params / all_param:.5f}\"\n",
173
+ ")\n",
174
+ "ldm_decoder.decoder = vae_decoder_copy\n",
175
+ "\n",
176
+ "saveimgs_dir = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{ckpt_prefix}'\n",
177
+ "os.makedirs(saveimgs_dir, exist_ok=True)\n",
178
+ "vae_decoder_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/outputs/train_{ckpt_prefix}/{exps_num}/checkpoints/checkpoint-latest'\n",
179
+ "\n",
180
+ "msg_decoder, ldm_decoder, val_loader, key = accelerator.prepare(\n",
181
+ " msg_decoder, ldm_decoder, val_loader, key\n",
182
+ ")\n",
183
+ "accelerator.load_state(os.path.join(vae_decoder_ckpt_dir)) # Load the LoRA watermark checkpoint\n",
184
+ "print(f\"Loaded the Stable Signature checkpoint from {vae_decoder_ckpt_dir}\")\n",
185
+ "for param in ldm_decoder.parameters():\n",
186
+ " param.requires_grad = False\n",
187
+ "\n",
188
+ "df_EWLoRA = pd.DataFrame(columns=[\n",
189
+ " \"iteration\",\n",
190
+ " \"psnr\",\n",
191
+ " \"bit_acc_avg\",\n",
192
+ "])\n",
193
+ "attacks = {\n",
194
+ " 'none': lambda x: x,\n",
195
+ " 'crop_01': lambda x: utils_img.center_crop(x, 0.1),\n",
196
+ " 'crop_05': lambda x: utils_img.center_crop(x, 0.5),\n",
197
+ " 'rot_25': lambda x: utils_img.rotate(x, 25),\n",
198
+ " 'rot_90': lambda x: utils_img.rotate(x, 90),\n",
199
+ " 'resize_03': lambda x: utils_img.resize(x, 0.3),\n",
200
+ " 'resize_07': lambda x: utils_img.resize(x, 0.7),\n",
201
+ " 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),\n",
202
+ " 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),\n",
203
+ " 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),\n",
204
+ " 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),\n",
205
+ " 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),\n",
206
+ " 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),\n",
207
+ " 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),\n",
208
+ " 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),\n",
209
+ "}\n",
210
+ "for ii, imgs in enumerate(val_loader):\n",
211
+ " imgs = imgs.to(accelerator.device)\n",
212
+ " keys = key.repeat(imgs.shape[0], 1)\n",
213
+ "\n",
214
+ " imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f\n",
215
+ " imgs_z = imgs_z.mode()\n",
216
+ "\n",
217
+ " # decode latents with original and finetuned decoder\n",
218
+ " imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w\n",
219
+ " imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w\n",
220
+ "\n",
221
+ " # extract watermark\n",
222
+ " decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k\n",
223
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
224
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
225
+ "\n",
226
+ " log_stats = {\n",
227
+ " \"iteration\": ii,\n",
228
+ " \"psnr\": utils_img.psnr(imgs_w, imgs_d0).mean().item(),\n",
229
+ " \"bit_acc_avg\": torch.mean(bit_accs).item(),\n",
230
+ " }\n",
231
+ " \n",
232
+ " for name, attack in attacks.items():\n",
233
+ " imgs_aug = attack(vqgan_to_imnet(imgs_w))\n",
234
+ " decoded = msg_decoder(imgs_aug) # b c h w -> b k\n",
235
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
236
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
237
+ " word_accs = (bit_accs == 1) # b\n",
238
+ " log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()\n",
239
+ "\n",
240
+ " df_EWLoRA = df_EWLoRA._append(log_stats, ignore_index=True)\n",
241
+ " save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir, f'{ii:03}_wm_orig.png'))\n",
242
+ "df_EWLoRA.to_csv(os.path.join(saveimgs_dir, 'bitacc.csv'), index=False)"
243
+ ]
244
+ }
245
+ ],
246
+ "metadata": {
247
+ "kernelspec": {
248
+ "display_name": "ldm",
249
+ "language": "python",
250
+ "name": "python3"
251
+ },
252
+ "language_info": {
253
+ "codemirror_mode": {
254
+ "name": "ipython",
255
+ "version": 3
256
+ },
257
+ "file_extension": ".py",
258
+ "mimetype": "text/x-python",
259
+ "name": "python",
260
+ "nbconvert_exporter": "python",
261
+ "pygments_lexer": "ipython3",
262
+ "version": "3.8.18"
263
+ }
264
+ },
265
+ "nbformat": 4,
266
+ "nbformat_minor": 2
267
+ }
watermarker/LaWa/examples/gen_wmimgs_EW-LoRA_fix_weights.ipynb ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "No module 'xformers'. Proceeding without it.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\n",
28
+ " rank_zero_deprecation(\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ ">>> Building LDM model with config /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml and weights from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt...\n",
36
+ "Loading model from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\n",
37
+ "Global Step: 470000\n",
38
+ "LatentDiffusion: Running in eps-prediction mode\n",
39
+ "DiffusionWrapper has 859.52 M params.\n",
40
+ "making attention of type 'vanilla' with 512 in_channels\n",
41
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
42
+ "making attention of type 'vanilla' with 512 in_channels\n"
43
+ ]
44
+ },
45
+ {
46
+ "ename": "OutOfMemoryError",
47
+ "evalue": "CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 45.44 MiB is free. Process 3764900 has 3.92 GiB memory in use. Process 3831541 has 582.00 MiB memory in use. Process 3843346 has 4.55 GiB memory in use. Process 3844737 has 4.55 GiB memory in use. Process 3849706 has 4.55 GiB memory in use. Process 3850599 has 4.55 GiB memory in use. Including non-PyTorch memory, this process has 892.00 MiB memory in use. Of the allocated memory 482.41 MiB is allocated by PyTorch, and 25.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)",
48
+ "output_type": "error",
49
+ "traceback": [
50
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
51
+ "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
52
+ "Cell \u001b[0;32mIn[1], line 40\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m>>> Building LDM model with config \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mldm_config\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and weights from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mldm_ckpt\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 39\u001b[0m config \u001b[38;5;241m=\u001b[39m OmegaConf\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mldm_config\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 40\u001b[0m ldm_ae: LatentDiffusion \u001b[38;5;241m=\u001b[39m \u001b[43mutils_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model_from_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mldm_ckpt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m ldm_ae: AutoencoderKL \u001b[38;5;241m=\u001b[39m ldm_ae\u001b[38;5;241m.\u001b[39mfirst_stage_model\n\u001b[1;32m 42\u001b[0m ldm_ae\u001b[38;5;241m.\u001b[39meval()\n",
53
+ "File \u001b[0;32m/pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/utils_model.py:149\u001b[0m, in \u001b[0;36mload_model_from_config\u001b[0;34m(config, ckpt, verbose)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124munexpected keys:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(u)\n\u001b[0;32m--> 149\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 150\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\n",
54
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/lightning_fabric/utilities/device_dtype_mixin.py:73\u001b[0m, in \u001b[0;36m_DeviceDtypeModuleMixin.cuda\u001b[0;34m(self, device)\u001b[0m\n\u001b[1;32m 71\u001b[0m device \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m, index\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__update_properties(device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m---> 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
55
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:911\u001b[0m, in \u001b[0;36mModule.cuda\u001b[0;34m(self, device)\u001b[0m\n\u001b[1;32m 894\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcuda\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, device: Optional[Union[\u001b[38;5;28mint\u001b[39m, device]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 895\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Move all model parameters and buffers to the GPU.\u001b[39;00m\n\u001b[1;32m 896\u001b[0m \n\u001b[1;32m 897\u001b[0m \u001b[38;5;124;03m This also makes associated parameters and buffers different objects. So\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 909\u001b[0m \u001b[38;5;124;03m Module: self\u001b[39;00m\n\u001b[1;32m 910\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 911\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
56
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:802\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 802\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
57
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:802\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 802\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
58
+ " \u001b[0;31m[... skipping similar frames: Module._apply at line 802 (4 times)]\u001b[0m\n",
59
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:802\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 802\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
60
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:825\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 822\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 823\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 824\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 825\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 826\u001b[0m should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 827\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_use_set_data:\n",
61
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:911\u001b[0m, in \u001b[0;36mModule.cuda.<locals>.<lambda>\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 894\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcuda\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, device: Optional[Union[\u001b[38;5;28mint\u001b[39m, device]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 895\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Move all model parameters and buffers to the GPU.\u001b[39;00m\n\u001b[1;32m 896\u001b[0m \n\u001b[1;32m 897\u001b[0m \u001b[38;5;124;03m This also makes associated parameters and buffers different objects. So\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 909\u001b[0m \u001b[38;5;124;03m Module: self\u001b[39;00m\n\u001b[1;32m 910\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 911\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_apply(\u001b[38;5;28;01mlambda\u001b[39;00m t: \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n",
62
+ "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 45.44 MiB is free. Process 3764900 has 3.92 GiB memory in use. Process 3831541 has 582.00 MiB memory in use. Process 3843346 has 4.55 GiB memory in use. Process 3844737 has 4.55 GiB memory in use. Process 3849706 has 4.55 GiB memory in use. Process 3850599 has 4.55 GiB memory in use. Including non-PyTorch memory, this process has 892.00 MiB memory in use. Of the allocated memory 482.41 MiB is allocated by PyTorch, and 25.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
63
+ ]
64
+ },
65
+ {
66
+ "ename": "",
67
+ "evalue": "",
68
+ "output_type": "error",
69
+ "traceback": [
70
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "from omegaconf import OmegaConf\n",
76
+ "from ldm.models.autoencoder import AutoencoderKL\n",
77
+ "from ldm.models.diffusion.ddpm import LatentDiffusion\n",
78
+ "\n",
79
+ "import os\n",
80
+ "import torch\n",
81
+ "import utils\n",
82
+ "import utils_model\n",
83
+ "import utils_img\n",
84
+ "import torch.nn as nn\n",
85
+ "import numpy as np\n",
86
+ "from copy import deepcopy\n",
87
+ "from torchvision import transforms\n",
88
+ "import os\n",
89
+ "import pandas as pd\n",
90
+ "from torchvision.utils import save_image\n",
91
+ "from accelerate import Accelerator\n",
92
+ "accelerator = Accelerator()\n",
93
+ "\n",
94
+ "\n",
95
+ "apply_dlwt = False\n",
96
+ "ckpt_prefix = \"EW-LoRA_dlwt\" if apply_dlwt else \"EW-LoRA_fix_weights\"\n",
97
+ "exps_num = \"002-exps\"\n",
98
+ "\n",
99
+ "img_size = 256\n",
100
+ "batch_size = 4\n",
101
+ "seed = 0\n",
102
+ "ldm_config = \"/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml\"\n",
103
+ "ldm_ckpt = \"/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\"\n",
104
+ "msg_decoder_path = \"/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt\"\n",
105
+ "val_dir = \"/pubdata/ldd/Datasets/coco2017/val2017\"\n",
106
+ "\n",
107
+ "torch.manual_seed(seed)\n",
108
+ "torch.cuda.manual_seed_all(seed)\n",
109
+ "np.random.seed(seed)\n",
110
+ "\n",
111
+ "# Loads LDM auto-encoder models\n",
112
+ "print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')\n",
113
+ "config = OmegaConf.load(f\"{ldm_config}\")\n",
114
+ "ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, ldm_ckpt)\n",
115
+ "ldm_ae: AutoencoderKL = ldm_ae.first_stage_model\n",
116
+ "ldm_ae.eval()\n",
117
+ "ldm_ae.to(accelerator.device)\n",
118
+ "\n",
119
+ "# Loads hidden decoder\n",
120
+ "print(f'>>> Building hidden decoder with weights from {msg_decoder_path}...')\n",
121
+ "if 'torchscript' in msg_decoder_path: \n",
122
+ " msg_decoder = torch.jit.load(msg_decoder_path)\n",
123
+ "\n",
124
+ "msg_decoder.eval()\n",
125
+ "nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]\n",
126
+ "\n",
127
+ "# Freeze LDM and hidden decoder\n",
128
+ "for param in [*msg_decoder.parameters(), *ldm_ae.parameters()]:\n",
129
+ " param.requires_grad = False\n",
130
+ "\n",
131
+ "vqgan_transform = transforms.Compose([\n",
132
+ " transforms.Resize(img_size),\n",
133
+ " transforms.CenterCrop(img_size),\n",
134
+ " transforms.ToTensor(),\n",
135
+ " utils_img.normalize_vqgan,\n",
136
+ "])\n",
137
+ "val_loader = utils.get_dataloader(val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)\n",
138
+ "vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])\n",
139
+ "\n",
140
+ "key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)\n",
141
+ "key_str = \"\".join([ str(int(ii)) for ii in key.tolist()[0]])\n",
142
+ "print(f'Key: {key_str}')\n",
143
+ "\n",
144
+ "# Copy the LDM decoder and finetune the copy\n",
145
+ "ldm_decoder = deepcopy(ldm_ae)\n",
146
+ "ldm_decoder.encoder = nn.Identity()\n",
147
+ "ldm_decoder.quant_conv = nn.Identity()\n",
148
+ "# ldm_decoder.to(device)\n",
149
+ "for param in ldm_decoder.parameters():\n",
150
+ " param.requires_grad = False\n",
151
+ "\n",
152
+ "from peft import LoraConfig, get_peft_model\n",
153
+ "\n",
154
+ "wm_target = \"upsample.conv\"\n",
155
+ "rank = 4\n",
156
+ "lora_alpha = 4\n",
157
+ "\n",
158
+ "# Select the lora target model\n",
159
+ "def find_layers(model, wm_target=None):\n",
160
+ " layers = []\n",
161
+ " for name, layer in model.named_modules():\n",
162
+ " if any(wm_target in name.lower() for keyword in name):\n",
163
+ " layers.append(name)\n",
164
+ " all_layers = [name for name, _ in model.named_modules()]\n",
165
+ " return layers, all_layers\n",
166
+ "wm_target, _ = find_layers(ldm_decoder.decoder, wm_target)\n",
167
+ "\n",
168
+ "vae_lora_config = LoraConfig(\n",
169
+ " r=rank,\n",
170
+ " lora_alpha=lora_alpha,\n",
171
+ " init_lora_weights=\"gaussian\",\n",
172
+ " target_modules=wm_target,\n",
173
+ ")\n",
174
+ "vae_decoder_copy = get_peft_model(ldm_decoder.decoder, vae_lora_config)\n",
175
+ "trainable_params, all_param = vae_decoder_copy.get_nb_trainable_parameters()\n",
176
+ "print(f\"Parameters for PEFT watermarking: \"\n",
177
+ " f\"Trainable params: {trainable_params/1e6:.5f}M || \"\n",
178
+ " f\"PEFT Model size: {trainable_params*4/(1024*1024):.5f}M || \"\n",
179
+ " f\"All params: {all_param/1e6:.5f}M || \"\n",
180
+ " f\"Trainable%: {100 * trainable_params / all_param:.5f}\"\n",
181
+ ")\n",
182
+ "ldm_decoder.decoder = vae_decoder_copy\n",
183
+ "\n",
184
+ "saveimgs_dir = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{ckpt_prefix}'\n",
185
+ "os.makedirs(saveimgs_dir, exist_ok=True)\n",
186
+ "vae_decoder_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/outputs/train_{ckpt_prefix}/{exps_num}/checkpoints/checkpoint-latest'\n",
187
+ "\n",
188
+ "msg_decoder, ldm_decoder, val_loader, key = accelerator.prepare(\n",
189
+ " msg_decoder, ldm_decoder, val_loader, key\n",
190
+ ")\n",
191
+ "accelerator.load_state(os.path.join(vae_decoder_ckpt_dir)) # Load the LoRA watermark checkpoint\n",
192
+ "print(f\"Loaded the Stable Signature checkpoint from {vae_decoder_ckpt_dir}\")\n",
193
+ "for param in ldm_decoder.parameters():\n",
194
+ " param.requires_grad = False\n",
195
+ "\n",
196
+ "df_EWLoRA = pd.DataFrame(columns=[\n",
197
+ " \"iteration\",\n",
198
+ " \"psnr\",\n",
199
+ " \"bit_acc_avg\",\n",
200
+ "])\n",
201
+ "attacks = {\n",
202
+ " 'none': lambda x: x,\n",
203
+ " 'crop_01': lambda x: utils_img.center_crop(x, 0.1),\n",
204
+ " 'crop_05': lambda x: utils_img.center_crop(x, 0.5),\n",
205
+ " 'rot_25': lambda x: utils_img.rotate(x, 25),\n",
206
+ " 'rot_90': lambda x: utils_img.rotate(x, 90),\n",
207
+ " 'resize_03': lambda x: utils_img.resize(x, 0.3),\n",
208
+ " 'resize_07': lambda x: utils_img.resize(x, 0.7),\n",
209
+ " 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),\n",
210
+ " 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),\n",
211
+ " 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),\n",
212
+ " 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),\n",
213
+ " 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),\n",
214
+ " 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),\n",
215
+ " 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),\n",
216
+ " 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),\n",
217
+ "}\n",
218
+ "for ii, imgs in enumerate(val_loader):\n",
219
+ " imgs = imgs.to(accelerator.device)\n",
220
+ " keys = key.repeat(imgs.shape[0], 1)\n",
221
+ "\n",
222
+ " imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f\n",
223
+ " imgs_z = imgs_z.mode()\n",
224
+ "\n",
225
+ " # decode latents with original and finetuned decoder\n",
226
+ " imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w\n",
227
+ " imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w\n",
228
+ "\n",
229
+ " # extract watermark\n",
230
+ " decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k\n",
231
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
232
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
233
+ "\n",
234
+ " log_stats = {\n",
235
+ " \"iteration\": ii,\n",
236
+ " \"psnr\": utils_img.psnr(imgs_w, imgs_d0).mean().item(),\n",
237
+ " \"bit_acc_avg\": torch.mean(bit_accs).item(),\n",
238
+ " }\n",
239
+ " \n",
240
+ " for name, attack in attacks.items():\n",
241
+ " imgs_aug = attack(vqgan_to_imnet(imgs_w))\n",
242
+ " decoded = msg_decoder(imgs_aug) # b c h w -> b k\n",
243
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
244
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
245
+ " word_accs = (bit_accs == 1) # b\n",
246
+ " log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()\n",
247
+ "\n",
248
+ " df_EWLoRA = df_EWLoRA._append(log_stats, ignore_index=True)\n",
249
+ " save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir, f'{ii:03}_wm_orig.png'))\n",
250
+ "df_EWLoRA.to_csv(os.path.join(saveimgs_dir, 'bitacc.csv'), index=False)"
251
+ ]
252
+ }
253
+ ],
254
+ "metadata": {
255
+ "kernelspec": {
256
+ "display_name": "ldm",
257
+ "language": "python",
258
+ "name": "python3"
259
+ },
260
+ "language_info": {
261
+ "codemirror_mode": {
262
+ "name": "ipython",
263
+ "version": 3
264
+ },
265
+ "file_extension": ".py",
266
+ "mimetype": "text/x-python",
267
+ "name": "python",
268
+ "nbconvert_exporter": "python",
269
+ "pygments_lexer": "ipython3",
270
+ "version": "3.8.18"
271
+ }
272
+ },
273
+ "nbformat": 4,
274
+ "nbformat_minor": 2
275
+ }
watermarker/LaWa/examples/gen_wmimgs_SS_dlwt.ipynb ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "No module 'xformers'. Proceeding without it.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\n",
28
+ " rank_zero_deprecation(\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ ">>> Building LDM model with config /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml and weights from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt...\n",
36
+ "Loading model from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\n",
37
+ "Global Step: 470000\n",
38
+ "LatentDiffusion: Running in eps-prediction mode\n",
39
+ "DiffusionWrapper has 859.52 M params.\n",
40
+ "making attention of type 'vanilla' with 512 in_channels\n",
41
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
42
+ "making attention of type 'vanilla' with 512 in_channels\n",
43
+ ">>> Building hidden decoder with weights from /pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt...\n",
44
+ "Key: 111010110101000001010111010011010100010000100111\n",
45
+ "Loaded the Stable Signature checkpoint from /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/outputs/train_SS_dlwt/005-exps/checkpoints/checkpoint-latest\n"
46
+ ]
47
+ },
48
+ {
49
+ "ename": "",
50
+ "evalue": "",
51
+ "output_type": "error",
52
+ "traceback": [
53
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "from omegaconf import OmegaConf\n",
59
+ "from ldm.models.autoencoder import AutoencoderKL\n",
60
+ "from ldm.models.diffusion.ddpm import LatentDiffusion\n",
61
+ "\n",
62
+ "import os\n",
63
+ "import torch\n",
64
+ "import utils\n",
65
+ "import utils_model\n",
66
+ "import utils_img\n",
67
+ "import torch.nn as nn\n",
68
+ "import numpy as np\n",
69
+ "from copy import deepcopy\n",
70
+ "from torchvision import transforms\n",
71
+ "import os\n",
72
+ "import pandas as pd\n",
73
+ "from torchvision.utils import save_image\n",
74
+ "from accelerate import Accelerator\n",
75
+ "accelerator = Accelerator()\n",
76
+ "\n",
77
+ "\n",
78
+ "apply_dlwt = True\n",
79
+ "ckpt_prefix = \"SS_dlwt\" if apply_dlwt else \"SS_fix_weights\"\n",
80
+ "exps_num = \"005-exps\"\n",
81
+ "\n",
82
+ "img_size = 256\n",
83
+ "batch_size = 4\n",
84
+ "seed = 0\n",
85
+ "\n",
86
+ "ldm_config = \"/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml\"\n",
87
+ "ldm_ckpt = \"/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\"\n",
88
+ "msg_decoder_path = \"/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt\"\n",
89
+ "val_dir = \"/pubdata/ldd/Datasets/coco2017/val2017\"\n",
90
+ "\n",
91
+ "torch.manual_seed(seed)\n",
92
+ "torch.cuda.manual_seed_all(seed)\n",
93
+ "np.random.seed(seed)\n",
94
+ "\n",
95
+ "# Loads LDM auto-encoder models\n",
96
+ "print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')\n",
97
+ "config = OmegaConf.load(f\"{ldm_config}\")\n",
98
+ "ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, ldm_ckpt)\n",
99
+ "ldm_ae: AutoencoderKL = ldm_ae.first_stage_model\n",
100
+ "ldm_ae.eval()\n",
101
+ "ldm_ae.to(accelerator.device)\n",
102
+ "\n",
103
+ "# Loads hidden decoder\n",
104
+ "print(f'>>> Building hidden decoder with weights from {msg_decoder_path}...')\n",
105
+ "if 'torchscript' in msg_decoder_path: \n",
106
+ " msg_decoder = torch.jit.load(msg_decoder_path)\n",
107
+ "\n",
108
+ "msg_decoder.eval()\n",
109
+ "nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]\n",
110
+ "\n",
111
+ "# Freeze LDM and hidden decoder\n",
112
+ "for param in [*msg_decoder.parameters(), *ldm_ae.parameters()]:\n",
113
+ " param.requires_grad = False\n",
114
+ "\n",
115
+ "vqgan_transform = transforms.Compose([\n",
116
+ " transforms.Resize(img_size),\n",
117
+ " transforms.CenterCrop(img_size),\n",
118
+ " transforms.ToTensor(),\n",
119
+ " utils_img.normalize_vqgan,\n",
120
+ "])\n",
121
+ "val_loader = utils.get_dataloader(val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)\n",
122
+ "vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])\n",
123
+ "\n",
124
+ "key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)\n",
125
+ "key_str = \"\".join([ str(int(ii)) for ii in key.tolist()[0]])\n",
126
+ "print(f'Key: {key_str}')\n",
127
+ "\n",
128
+ "# Copy the LDM decoder and finetune the copy\n",
129
+ "ldm_decoder = deepcopy(ldm_ae)\n",
130
+ "ldm_decoder.encoder = nn.Identity()\n",
131
+ "ldm_decoder.quant_conv = nn.Identity()\n",
132
+ "# ldm_decoder.to(device)\n",
133
+ "for param in ldm_decoder.parameters():\n",
134
+ " param.requires_grad = False\n",
135
+ "\n",
136
+ "saveimgs_dir_SS = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{ckpt_prefix}_fix_weights'\n",
137
+ "os.makedirs(saveimgs_dir_SS, exist_ok=True)\n",
138
+ "vae_decoder_ss_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/outputs/train_{ckpt_prefix}/{exps_num}/checkpoints/checkpoint-latest'\n",
139
+ "\n",
140
+ "msg_decoder, ldm_decoder, val_loader, key = accelerator.prepare(\n",
141
+ " msg_decoder, ldm_decoder, val_loader, key\n",
142
+ ")\n",
143
+ "accelerator.load_state(os.path.join(vae_decoder_ss_ckpt_dir)) # Load the LoRA watermark checkpoint\n",
144
+ "print(f\"Loaded the Stable Signature checkpoint from {vae_decoder_ss_ckpt_dir}\")\n",
145
+ "\n",
146
+ "df_SS = pd.DataFrame(columns=[\n",
147
+ " \"iteration\",\n",
148
+ " \"psnr\",\n",
149
+ " \"bit_acc_avg\",\n",
150
+ "])\n",
151
+ "attacks = {\n",
152
+ " 'none': lambda x: x,\n",
153
+ " 'crop_01': lambda x: utils_img.center_crop(x, 0.1),\n",
154
+ " 'crop_05': lambda x: utils_img.center_crop(x, 0.5),\n",
155
+ " 'rot_25': lambda x: utils_img.rotate(x, 25),\n",
156
+ " 'rot_90': lambda x: utils_img.rotate(x, 90),\n",
157
+ " 'resize_03': lambda x: utils_img.resize(x, 0.3),\n",
158
+ " 'resize_07': lambda x: utils_img.resize(x, 0.7),\n",
159
+ " 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),\n",
160
+ " 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),\n",
161
+ " 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),\n",
162
+ " 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),\n",
163
+ " 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),\n",
164
+ " 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),\n",
165
+ " 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),\n",
166
+ " 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),\n",
167
+ "}\n",
168
+ "\n",
169
+ "for ii, imgs in enumerate(val_loader):\n",
170
+ " imgs = imgs.to(accelerator.device)\n",
171
+ " keys = key.repeat(imgs.shape[0], 1)\n",
172
+ "\n",
173
+ " imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f\n",
174
+ " imgs_z = imgs_z.mode()\n",
175
+ "\n",
176
+ " # decode latents with original and finetuned decoder\n",
177
+ " imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w\n",
178
+ " imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w\n",
179
+ "\n",
180
+ " # extract watermark\n",
181
+ " decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k\n",
182
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
183
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
184
+ "\n",
185
+ " log_stats = {\n",
186
+ " \"iteration\": ii,\n",
187
+ " \"psnr\": utils_img.psnr(imgs_w, imgs_d0).mean().item(),\n",
188
+ " \"bit_acc_avg\": torch.mean(bit_accs).item(),\n",
189
+ " }\n",
190
+ " for name, attack in attacks.items():\n",
191
+ " imgs_aug = attack(vqgan_to_imnet(imgs_w))\n",
192
+ " decoded = msg_decoder(imgs_aug) # b c h w -> b k\n",
193
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
194
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
195
+ " word_accs = (bit_accs == 1) # b\n",
196
+ " log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()\n",
197
+ "\n",
198
+ " df_SS = df_SS._append(log_stats, ignore_index=True)\n",
199
+ " save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir_SS, f'{ii:03}_wm_orig.png'))\n",
200
+ "df_SS.to_csv(os.path.join(saveimgs_dir_SS, 'bitacc.csv'), index=False)"
201
+ ]
202
+ }
203
+ ],
204
+ "metadata": {
205
+ "kernelspec": {
206
+ "display_name": "ldm",
207
+ "language": "python",
208
+ "name": "python3"
209
+ },
210
+ "language_info": {
211
+ "codemirror_mode": {
212
+ "name": "ipython",
213
+ "version": 3
214
+ },
215
+ "file_extension": ".py",
216
+ "mimetype": "text/x-python",
217
+ "name": "python",
218
+ "nbconvert_exporter": "python",
219
+ "pygments_lexer": "ipython3",
220
+ "version": "3.8.18"
221
+ }
222
+ },
223
+ "nbformat": 4,
224
+ "nbformat_minor": 2
225
+ }
watermarker/LaWa/examples/gen_wmimgs_SS_fix_weights.ipynb ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "No module 'xformers'. Proceeding without it.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\n",
28
+ " rank_zero_deprecation(\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ ">>> Building LDM model with config /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml and weights from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt...\n",
36
+ "Loading model from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\n",
37
+ "Global Step: 470000\n",
38
+ "LatentDiffusion: Running in eps-prediction mode\n",
39
+ "DiffusionWrapper has 859.52 M params.\n",
40
+ "making attention of type 'vanilla' with 512 in_channels\n",
41
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
42
+ "making attention of type 'vanilla' with 512 in_channels\n",
43
+ ">>> Building hidden decoder with weights from /pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt...\n",
44
+ "Key: 111010110101000001010111010011010100010000100111\n"
45
+ ]
46
+ },
47
+ {
48
+ "ename": "ValueError",
49
+ "evalue": "Tried to find /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/outputs/train_SS_fix_weights/005-exps/checkpoints/checkpoint-latest but folder does not exist",
50
+ "output_type": "error",
51
+ "traceback": [
52
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
53
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
54
+ "Cell \u001b[0;32mIn[1], line 86\u001b[0m\n\u001b[1;32m 81\u001b[0m vae_decoder_ss_ckpt_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/outputs/train_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mckpt_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexps_num\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/checkpoints/checkpoint-latest\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 83\u001b[0m msg_decoder, ldm_decoder, val_loader, key \u001b[38;5;241m=\u001b[39m accelerator\u001b[38;5;241m.\u001b[39mprepare(\n\u001b[1;32m 84\u001b[0m msg_decoder, ldm_decoder, val_loader, key\n\u001b[1;32m 85\u001b[0m )\n\u001b[0;32m---> 86\u001b[0m \u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvae_decoder_ss_ckpt_dir\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Load the LoRA watermark checkpoint\u001b[39;00m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoaded the Stable Signature checkpoint from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvae_decoder_ss_ckpt_dir\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 89\u001b[0m df_SS \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(columns\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miteration\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpsnr\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 92\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbit_acc_avg\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 93\u001b[0m ])\n",
55
+ "File \u001b[0;32m~/miniconda3/envs/ldm/lib/python3.8/site-packages/accelerate/accelerator.py:2851\u001b[0m, in \u001b[0;36mAccelerator.load_state\u001b[0;34m(self, input_dir, **load_model_func_kwargs)\u001b[0m\n\u001b[1;32m 2849\u001b[0m input_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexpanduser(input_dir)\n\u001b[1;32m 2850\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(input_dir):\n\u001b[0;32m-> 2851\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTried to find \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minput_dir\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m but folder does not exist\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 2852\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproject_configuration\u001b[38;5;241m.\u001b[39mautomatic_checkpoint_naming:\n\u001b[1;32m 2853\u001b[0m \u001b[38;5;66;03m# Pick up from automatic checkpoint naming\u001b[39;00m\n\u001b[1;32m 2854\u001b[0m input_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproject_dir, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcheckpoints\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
56
+ "\u001b[0;31mValueError\u001b[0m: Tried to find /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/outputs/train_SS_fix_weights/005-exps/checkpoints/checkpoint-latest but folder does not exist"
57
+ ]
58
+ },
59
+ {
60
+ "ename": "",
61
+ "evalue": "",
62
+ "output_type": "error",
63
+ "traceback": [
64
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "from omegaconf import OmegaConf\n",
70
+ "from ldm.models.autoencoder import AutoencoderKL\n",
71
+ "from ldm.models.diffusion.ddpm import LatentDiffusion\n",
72
+ "\n",
73
+ "import os\n",
74
+ "import torch\n",
75
+ "import utils\n",
76
+ "import utils_model\n",
77
+ "import utils_img\n",
78
+ "import torch.nn as nn\n",
79
+ "import numpy as np\n",
80
+ "from copy import deepcopy\n",
81
+ "from torchvision import transforms\n",
82
+ "import os\n",
83
+ "import pandas as pd\n",
84
+ "from torchvision.utils import save_image\n",
85
+ "from accelerate import Accelerator\n",
86
+ "accelerator = Accelerator()\n",
87
+ "\n",
88
+ "\n",
89
+ "apply_dlwt = False\n",
90
+ "ckpt_prefix = \"SS_dlwt\" if apply_dlwt else \"SS_fix_weights\"\n",
91
+ "exps_num = \"002-exps\"\n",
92
+ "\n",
93
+ "img_size = 256\n",
94
+ "batch_size = 4\n",
95
+ "seed = 0\n",
96
+ "\n",
97
+ "ldm_config = \"/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml\"\n",
98
+ "ldm_ckpt = \"/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\"\n",
99
+ "msg_decoder_path = \"/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt\"\n",
100
+ "val_dir = \"/pubdata/ldd/Datasets/coco2017/val2017\"\n",
101
+ "\n",
102
+ "torch.manual_seed(seed)\n",
103
+ "torch.cuda.manual_seed_all(seed)\n",
104
+ "np.random.seed(seed)\n",
105
+ "\n",
106
+ "# Loads LDM auto-encoder models\n",
107
+ "print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')\n",
108
+ "config = OmegaConf.load(f\"{ldm_config}\")\n",
109
+ "ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, ldm_ckpt)\n",
110
+ "ldm_ae: AutoencoderKL = ldm_ae.first_stage_model\n",
111
+ "ldm_ae.eval()\n",
112
+ "ldm_ae.to(accelerator.device)\n",
113
+ "\n",
114
+ "# Loads hidden decoder\n",
115
+ "print(f'>>> Building hidden decoder with weights from {msg_decoder_path}...')\n",
116
+ "if 'torchscript' in msg_decoder_path: \n",
117
+ " msg_decoder = torch.jit.load(msg_decoder_path)\n",
118
+ "\n",
119
+ "msg_decoder.eval()\n",
120
+ "nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]\n",
121
+ "\n",
122
+ "# Freeze LDM and hidden decoder\n",
123
+ "for param in [*msg_decoder.parameters(), *ldm_ae.parameters()]:\n",
124
+ " param.requires_grad = False\n",
125
+ "\n",
126
+ "vqgan_transform = transforms.Compose([\n",
127
+ " transforms.Resize(img_size),\n",
128
+ " transforms.CenterCrop(img_size),\n",
129
+ " transforms.ToTensor(),\n",
130
+ " utils_img.normalize_vqgan,\n",
131
+ "])\n",
132
+ "val_loader = utils.get_dataloader(val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)\n",
133
+ "vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])\n",
134
+ "\n",
135
+ "key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)\n",
136
+ "key_str = \"\".join([ str(int(ii)) for ii in key.tolist()[0]])\n",
137
+ "print(f'Key: {key_str}')\n",
138
+ "\n",
139
+ "# Copy the LDM decoder and finetune the copy\n",
140
+ "ldm_decoder = deepcopy(ldm_ae)\n",
141
+ "ldm_decoder.encoder = nn.Identity()\n",
142
+ "ldm_decoder.quant_conv = nn.Identity()\n",
143
+ "# ldm_decoder.to(device)\n",
144
+ "for param in ldm_decoder.parameters():\n",
145
+ " param.requires_grad = False\n",
146
+ "\n",
147
+ "saveimgs_dir_SS = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{ckpt_prefix}_fix_weights'\n",
148
+ "os.makedirs(saveimgs_dir_SS, exist_ok=True)\n",
149
+ "vae_decoder_ss_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/outputs/train_{ckpt_prefix}/{exps_num}/checkpoints/checkpoint-latest'\n",
150
+ "\n",
151
+ "msg_decoder, ldm_decoder, val_loader, key = accelerator.prepare(\n",
152
+ " msg_decoder, ldm_decoder, val_loader, key\n",
153
+ ")\n",
154
+ "accelerator.load_state(os.path.join(vae_decoder_ss_ckpt_dir)) # Load the LoRA watermark checkpoint\n",
155
+ "print(f\"Loaded the Stable Signature checkpoint from {vae_decoder_ss_ckpt_dir}\")\n",
156
+ "\n",
157
+ "df_SS = pd.DataFrame(columns=[\n",
158
+ " \"iteration\",\n",
159
+ " \"psnr\",\n",
160
+ " \"bit_acc_avg\",\n",
161
+ "])\n",
162
+ "attacks = {\n",
163
+ " 'none': lambda x: x,\n",
164
+ " 'crop_01': lambda x: utils_img.center_crop(x, 0.1),\n",
165
+ " 'crop_05': lambda x: utils_img.center_crop(x, 0.5),\n",
166
+ " 'rot_25': lambda x: utils_img.rotate(x, 25),\n",
167
+ " 'rot_90': lambda x: utils_img.rotate(x, 90),\n",
168
+ " 'resize_03': lambda x: utils_img.resize(x, 0.3),\n",
169
+ " 'resize_07': lambda x: utils_img.resize(x, 0.7),\n",
170
+ " 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),\n",
171
+ " 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),\n",
172
+ " 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),\n",
173
+ " 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),\n",
174
+ " 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),\n",
175
+ " 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),\n",
176
+ " 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),\n",
177
+ " 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),\n",
178
+ "}\n",
179
+ "\n",
180
+ "for ii, imgs in enumerate(val_loader):\n",
181
+ " imgs = imgs.to(accelerator.device)\n",
182
+ " keys = key.repeat(imgs.shape[0], 1)\n",
183
+ "\n",
184
+ " imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f\n",
185
+ " imgs_z = imgs_z.mode()\n",
186
+ "\n",
187
+ " # decode latents with original and finetuned decoder\n",
188
+ " imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w\n",
189
+ " imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w\n",
190
+ "\n",
191
+ " # extract watermark\n",
192
+ " decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k\n",
193
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
194
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
195
+ "\n",
196
+ " log_stats = {\n",
197
+ " \"iteration\": ii,\n",
198
+ " \"psnr\": utils_img.psnr(imgs_w, imgs_d0).mean().item(),\n",
199
+ " \"bit_acc_avg\": torch.mean(bit_accs).item(),\n",
200
+ " }\n",
201
+ " for name, attack in attacks.items():\n",
202
+ " imgs_aug = attack(vqgan_to_imnet(imgs_w))\n",
203
+ " decoded = msg_decoder(imgs_aug) # b c h w -> b k\n",
204
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
205
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
206
+ " word_accs = (bit_accs == 1) # b\n",
207
+ " log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()\n",
208
+ "\n",
209
+ " df_SS = df_SS._append(log_stats, ignore_index=True)\n",
210
+ " save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir_SS, f'{ii:03}_wm_orig.png'))\n",
211
+ "df_SS.to_csv(os.path.join(saveimgs_dir_SS, 'bitacc.csv'), index=False)"
212
+ ]
213
+ }
214
+ ],
215
+ "metadata": {
216
+ "kernelspec": {
217
+ "display_name": "ldm",
218
+ "language": "python",
219
+ "name": "python3"
220
+ },
221
+ "language_info": {
222
+ "codemirror_mode": {
223
+ "name": "ipython",
224
+ "version": 3
225
+ },
226
+ "file_extension": ".py",
227
+ "mimetype": "text/x-python",
228
+ "name": "python",
229
+ "nbconvert_exporter": "python",
230
+ "pygments_lexer": "ipython3",
231
+ "version": "3.8.18"
232
+ }
233
+ },
234
+ "nbformat": 4,
235
+ "nbformat_minor": 2
236
+ }
watermarker/LaWa/examples/gen_wmimgs_WMA_dlwt.ipynb ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "No module 'xformers'. Proceeding without it.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\n",
28
+ " rank_zero_deprecation(\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ ">>> Building LDM model with config /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml and weights from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt...\n",
36
+ "Loading model from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\n",
37
+ "Global Step: 470000\n",
38
+ "LatentDiffusion: Running in eps-prediction mode\n",
39
+ "DiffusionWrapper has 859.52 M params.\n",
40
+ "making attention of type 'vanilla' with 512 in_channels\n",
41
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
42
+ "making attention of type 'vanilla' with 512 in_channels\n",
43
+ ">>> Building hidden decoder with weights from /pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt...\n",
44
+ "Key: 111010110101000001010111010011010100010000100111\n",
45
+ "Loaded the Stable Signature checkpoint from /pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/outputs/train_WMA_dlwt/004-exps/checkpoints/checkpoint-latest\n"
46
+ ]
47
+ },
48
+ {
49
+ "ename": "",
50
+ "evalue": "",
51
+ "output_type": "error",
52
+ "traceback": [
53
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "from omegaconf import OmegaConf\n",
59
+ "from ldm.models.autoencoder import AutoencoderKL\n",
60
+ "from ldm.models.diffusion.ddpm import LatentDiffusion\n",
61
+ "\n",
62
+ "import os\n",
63
+ "import torch\n",
64
+ "import utils\n",
65
+ "import utils_model\n",
66
+ "import utils_img\n",
67
+ "import torch.nn as nn\n",
68
+ "import numpy as np\n",
69
+ "from copy import deepcopy\n",
70
+ "from torchvision import transforms\n",
71
+ "import os\n",
72
+ "import pandas as pd\n",
73
+ "from torchvision.utils import save_image\n",
74
+ "from accelerate import Accelerator\n",
75
+ "accelerator = Accelerator()\n",
76
+ "\n",
77
+ "\n",
78
+ "apply_dlwt = True\n",
79
+ "ckpt_prefix = \"WMA_dlwt\" if apply_dlwt else \"WMA_fix_weights\"\n",
80
+ "exps_num = \"004-exps\"\n",
81
+ "\n",
82
+ "img_size = 256\n",
83
+ "batch_size = 4\n",
84
+ "seed = 0\n",
85
+ "ldm_config = \"/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml\"\n",
86
+ "ldm_ckpt = \"/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\"\n",
87
+ "msg_decoder_path = \"/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt\"\n",
88
+ "val_dir = \"/pubdata/ldd/Datasets/coco2017/val2017\"\n",
89
+ "\n",
90
+ "torch.manual_seed(seed)\n",
91
+ "torch.cuda.manual_seed_all(seed)\n",
92
+ "np.random.seed(seed)\n",
93
+ "\n",
94
+ "# Loads LDM auto-encoder models\n",
95
+ "print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')\n",
96
+ "config = OmegaConf.load(f\"{ldm_config}\")\n",
97
+ "ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, ldm_ckpt)\n",
98
+ "ldm_ae: AutoencoderKL = ldm_ae.first_stage_model\n",
99
+ "ldm_ae.eval()\n",
100
+ "ldm_ae.to(accelerator.device)\n",
101
+ "\n",
102
+ "# Loads hidden decoder\n",
103
+ "print(f'>>> Building hidden decoder with weights from {msg_decoder_path}...')\n",
104
+ "if 'torchscript' in msg_decoder_path: \n",
105
+ " msg_decoder = torch.jit.load(msg_decoder_path)\n",
106
+ "\n",
107
+ "msg_decoder.eval()\n",
108
+ "nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]\n",
109
+ "\n",
110
+ "# Freeze LDM and hidden decoder\n",
111
+ "for param in [*msg_decoder.parameters(), *ldm_ae.parameters()]:\n",
112
+ " param.requires_grad = False\n",
113
+ "\n",
114
+ "vqgan_transform = transforms.Compose([\n",
115
+ " transforms.Resize(img_size),\n",
116
+ " transforms.CenterCrop(img_size),\n",
117
+ " transforms.ToTensor(),\n",
118
+ " utils_img.normalize_vqgan,\n",
119
+ "])\n",
120
+ "val_loader = utils.get_dataloader(val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)\n",
121
+ "vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])\n",
122
+ "\n",
123
+ "key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)\n",
124
+ "key_str = \"\".join([ str(int(ii)) for ii in key.tolist()[0]])\n",
125
+ "print(f'Key: {key_str}')\n",
126
+ "\n",
127
+ "# Copy the LDM decoder and finetune the copy\n",
128
+ "ldm_decoder = deepcopy(ldm_ae)\n",
129
+ "ldm_decoder.encoder = nn.Identity()\n",
130
+ "ldm_decoder.quant_conv = nn.Identity()\n",
131
+ "# ldm_decoder.to(device)\n",
132
+ "for param in ldm_decoder.parameters():\n",
133
+ " param.requires_grad = False\n",
134
+ "\n",
135
+ "import wmadapter.wmadapter as wmadapter\n",
136
+ "\n",
137
+ "wm_adapter = wmadapter.Fuser(img_channels_list=[4, 512, 512, 256, 512, 512], watermark_bits=key)\n",
138
+ "vae_with_adapter = wmadapter.VAEWithAdapter(ldm_ae.decoder, wm_adapter)\n",
139
+ "ldm_decoder.decoder = vae_with_adapter\n",
140
+ "\n",
141
+ "for param in ldm_decoder.parameters():\n",
142
+ " param.requires_grad = False\n",
143
+ "\n",
144
+ "saveimgs_dir = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{ckpt_prefix}'\n",
145
+ "os.makedirs(saveimgs_dir, exist_ok=True)\n",
146
+ "vae_decoder_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/outputs/train_{ckpt_prefix}/{exps_num}/checkpoints/checkpoint-latest'\n",
147
+ "\n",
148
+ "msg_decoder, ldm_decoder, val_loader, key = accelerator.prepare(\n",
149
+ " msg_decoder, ldm_decoder, val_loader, key\n",
150
+ ")\n",
151
+ "accelerator.load_state(os.path.join(vae_decoder_ckpt_dir)) # Load the LoRA watermark checkpoint\n",
152
+ "print(f\"Loaded the Stable Signature checkpoint from {vae_decoder_ckpt_dir}\")\n",
153
+ "\n",
154
+ "df_WMA = pd.DataFrame(columns=[\n",
155
+ " \"iteration\",\n",
156
+ " \"psnr\",\n",
157
+ " \"bit_acc_avg\",\n",
158
+ "])\n",
159
+ "attacks = {\n",
160
+ " 'none': lambda x: x,\n",
161
+ " 'crop_01': lambda x: utils_img.center_crop(x, 0.1),\n",
162
+ " 'crop_05': lambda x: utils_img.center_crop(x, 0.5),\n",
163
+ " 'rot_25': lambda x: utils_img.rotate(x, 25),\n",
164
+ " 'rot_90': lambda x: utils_img.rotate(x, 90),\n",
165
+ " 'resize_03': lambda x: utils_img.resize(x, 0.3),\n",
166
+ " 'resize_07': lambda x: utils_img.resize(x, 0.7),\n",
167
+ " 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),\n",
168
+ " 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),\n",
169
+ " 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),\n",
170
+ " 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),\n",
171
+ " 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),\n",
172
+ " 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),\n",
173
+ " 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),\n",
174
+ " 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),\n",
175
+ "}\n",
176
+ "for ii, imgs in enumerate(val_loader):\n",
177
+ " imgs = imgs.to(accelerator.device)\n",
178
+ " keys = key.repeat(imgs.shape[0], 1)\n",
179
+ "\n",
180
+ " imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f\n",
181
+ " imgs_z = imgs_z.mode()\n",
182
+ "\n",
183
+ " # decode latents with original and finetuned decoder\n",
184
+ " imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w\n",
185
+ " imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w\n",
186
+ "\n",
187
+ " # extract watermark\n",
188
+ " decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k\n",
189
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
190
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
191
+ "\n",
192
+ " log_stats = {\n",
193
+ " \"iteration\": ii,\n",
194
+ " \"psnr\": utils_img.psnr(imgs_w, imgs_d0).mean().item(),\n",
195
+ " \"bit_acc_avg\": torch.mean(bit_accs).item(),\n",
196
+ " }\n",
197
+ " \n",
198
+ " for name, attack in attacks.items():\n",
199
+ " imgs_aug = attack(vqgan_to_imnet(imgs_w))\n",
200
+ " decoded = msg_decoder(imgs_aug) # b c h w -> b k\n",
201
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
202
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
203
+ " word_accs = (bit_accs == 1) # b\n",
204
+ " log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()\n",
205
+ "\n",
206
+ " df_WMA = df_WMA._append(log_stats, ignore_index=True)\n",
207
+ " save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir, f'{ii:03}_wm_orig.png'))\n",
208
+ "df_WMA.to_csv(os.path.join(saveimgs_dir, 'bitacc.csv'), index=False)"
209
+ ]
210
+ }
211
+ ],
212
+ "metadata": {
213
+ "kernelspec": {
214
+ "display_name": "ldm",
215
+ "language": "python",
216
+ "name": "python3"
217
+ },
218
+ "language_info": {
219
+ "codemirror_mode": {
220
+ "name": "ipython",
221
+ "version": 3
222
+ },
223
+ "file_extension": ".py",
224
+ "mimetype": "text/x-python",
225
+ "name": "python",
226
+ "nbconvert_exporter": "python",
227
+ "pygments_lexer": "ipython3",
228
+ "version": "3.8.18"
229
+ }
230
+ },
231
+ "nbformat": 4,
232
+ "nbformat_minor": 2
233
+ }
watermarker/LaWa/examples/gen_wmimgs_WMA_fix_weights.ipynb ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "No module 'xformers'. Proceeding without it.\n"
21
+ ]
22
+ },
23
+ {
24
+ "name": "stderr",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "/home/ldd/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\n",
28
+ " rank_zero_deprecation(\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ ">>> Building LDM model with config /pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml and weights from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt...\n",
36
+ "Loading model from /pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\n",
37
+ "Global Step: 470000\n",
38
+ "LatentDiffusion: Running in eps-prediction mode\n",
39
+ "DiffusionWrapper has 859.52 M params.\n",
40
+ "making attention of type 'vanilla' with 512 in_channels\n",
41
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
42
+ "making attention of type 'vanilla' with 512 in_channels\n",
43
+ ">>> Building hidden decoder with weights from /pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt...\n",
44
+ "Key: 111010110101000001010111010011010100010000100111\n",
45
+ "Loaded the Stable Signature checkpoint from /pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/outputs/train_WMA_fix_weights/002-exps/checkpoints/checkpoint-latest\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "from omegaconf import OmegaConf\n",
51
+ "from ldm.models.autoencoder import AutoencoderKL\n",
52
+ "from ldm.models.diffusion.ddpm import LatentDiffusion\n",
53
+ "\n",
54
+ "import os\n",
55
+ "import torch\n",
56
+ "import utils\n",
57
+ "import utils_model\n",
58
+ "import utils_img\n",
59
+ "import torch.nn as nn\n",
60
+ "import numpy as np\n",
61
+ "from copy import deepcopy\n",
62
+ "from torchvision import transforms\n",
63
+ "import os\n",
64
+ "import pandas as pd\n",
65
+ "from torchvision.utils import save_image\n",
66
+ "from accelerate import Accelerator\n",
67
+ "accelerator = Accelerator()\n",
68
+ "\n",
69
+ "\n",
70
+ "apply_dlwt = False\n",
71
+ "ckpt_prefix = \"WMA_dlwt\" if apply_dlwt else \"WMA_fix_weights\"\n",
72
+ "exps_num = \"002-exps\"\n",
73
+ "\n",
74
+ "img_size = 256\n",
75
+ "batch_size = 4\n",
76
+ "seed = 0\n",
77
+ "ldm_config = \"/pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml\"\n",
78
+ "ldm_ckpt = \"/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt\"\n",
79
+ "msg_decoder_path = \"/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt\"\n",
80
+ "val_dir = \"/pubdata/ldd/Datasets/coco2017/val2017\"\n",
81
+ "\n",
82
+ "torch.manual_seed(seed)\n",
83
+ "torch.cuda.manual_seed_all(seed)\n",
84
+ "np.random.seed(seed)\n",
85
+ "\n",
86
+ "# Loads LDM auto-encoder models\n",
87
+ "print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')\n",
88
+ "config = OmegaConf.load(f\"{ldm_config}\")\n",
89
+ "ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, ldm_ckpt)\n",
90
+ "ldm_ae: AutoencoderKL = ldm_ae.first_stage_model\n",
91
+ "ldm_ae.eval()\n",
92
+ "ldm_ae.to(accelerator.device)\n",
93
+ "\n",
94
+ "# Loads hidden decoder\n",
95
+ "print(f'>>> Building hidden decoder with weights from {msg_decoder_path}...')\n",
96
+ "if 'torchscript' in msg_decoder_path: \n",
97
+ " msg_decoder = torch.jit.load(msg_decoder_path)\n",
98
+ "\n",
99
+ "msg_decoder.eval()\n",
100
+ "nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]\n",
101
+ "\n",
102
+ "# Freeze LDM and hidden decoder\n",
103
+ "for param in [*msg_decoder.parameters(), *ldm_ae.parameters()]:\n",
104
+ " param.requires_grad = False\n",
105
+ "\n",
106
+ "vqgan_transform = transforms.Compose([\n",
107
+ " transforms.Resize(img_size),\n",
108
+ " transforms.CenterCrop(img_size),\n",
109
+ " transforms.ToTensor(),\n",
110
+ " utils_img.normalize_vqgan,\n",
111
+ "])\n",
112
+ "val_loader = utils.get_dataloader(val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)\n",
113
+ "vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])\n",
114
+ "\n",
115
+ "key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)\n",
116
+ "key_str = \"\".join([ str(int(ii)) for ii in key.tolist()[0]])\n",
117
+ "print(f'Key: {key_str}')\n",
118
+ "\n",
119
+ "# Copy the LDM decoder and finetune the copy\n",
120
+ "ldm_decoder = deepcopy(ldm_ae)\n",
121
+ "ldm_decoder.encoder = nn.Identity()\n",
122
+ "ldm_decoder.quant_conv = nn.Identity()\n",
123
+ "# ldm_decoder.to(device)\n",
124
+ "for param in ldm_decoder.parameters():\n",
125
+ " param.requires_grad = False\n",
126
+ "\n",
127
+ "import wmadapter.wmadapter as wmadapter\n",
128
+ "\n",
129
+ "wm_adapter = wmadapter.Fuser(img_channels_list=[4, 512, 512, 256, 512, 512], watermark_bits=key)\n",
130
+ "vae_with_adapter = wmadapter.VAEWithAdapter(ldm_ae.decoder, wm_adapter)\n",
131
+ "ldm_decoder.decoder = vae_with_adapter\n",
132
+ "\n",
133
+ "for param in ldm_decoder.parameters():\n",
134
+ " param.requires_grad = False\n",
135
+ "\n",
136
+ "saveimgs_dir = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{ckpt_prefix}'\n",
137
+ "os.makedirs(saveimgs_dir, exist_ok=True)\n",
138
+ "vae_decoder_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/Watermarker/stable_signature/outputs/train_{ckpt_prefix}/{exps_num}/checkpoints/checkpoint-latest'\n",
139
+ "\n",
140
+ "msg_decoder, ldm_decoder, val_loader, key = accelerator.prepare(\n",
141
+ " msg_decoder, ldm_decoder, val_loader, key\n",
142
+ ")\n",
143
+ "accelerator.load_state(os.path.join(vae_decoder_ckpt_dir)) # Load the LoRA watermark checkpoint\n",
144
+ "print(f\"Loaded the Stable Signature checkpoint from {vae_decoder_ckpt_dir}\")\n",
145
+ "\n",
146
+ "df_WMA = pd.DataFrame(columns=[\n",
147
+ " \"iteration\",\n",
148
+ " \"psnr\",\n",
149
+ " \"bit_acc_avg\",\n",
150
+ "])\n",
151
+ "attacks = {\n",
152
+ " 'none': lambda x: x,\n",
153
+ " 'crop_01': lambda x: utils_img.center_crop(x, 0.1),\n",
154
+ " 'crop_05': lambda x: utils_img.center_crop(x, 0.5),\n",
155
+ " 'rot_25': lambda x: utils_img.rotate(x, 25),\n",
156
+ " 'rot_90': lambda x: utils_img.rotate(x, 90),\n",
157
+ " 'resize_03': lambda x: utils_img.resize(x, 0.3),\n",
158
+ " 'resize_07': lambda x: utils_img.resize(x, 0.7),\n",
159
+ " 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),\n",
160
+ " 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),\n",
161
+ " 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),\n",
162
+ " 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),\n",
163
+ " 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),\n",
164
+ " 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),\n",
165
+ " 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),\n",
166
+ " 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),\n",
167
+ "}\n",
168
+ "for ii, imgs in enumerate(val_loader):\n",
169
+ " imgs = imgs.to(accelerator.device)\n",
170
+ " keys = key.repeat(imgs.shape[0], 1)\n",
171
+ "\n",
172
+ " imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f\n",
173
+ " imgs_z = imgs_z.mode()\n",
174
+ "\n",
175
+ " # decode latents with original and finetuned decoder\n",
176
+ " imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w\n",
177
+ " imgs_w = ldm_decoder.decode(imgs_z) # b z h/f w/f -> b c h w\n",
178
+ "\n",
179
+ " # extract watermark\n",
180
+ " decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k\n",
181
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
182
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
183
+ "\n",
184
+ " log_stats = {\n",
185
+ " \"iteration\": ii,\n",
186
+ " \"psnr\": utils_img.psnr(imgs_w, imgs_d0).mean().item(),\n",
187
+ " \"bit_acc_avg\": torch.mean(bit_accs).item(),\n",
188
+ " }\n",
189
+ " \n",
190
+ " for name, attack in attacks.items():\n",
191
+ " imgs_aug = attack(vqgan_to_imnet(imgs_w))\n",
192
+ " decoded = msg_decoder(imgs_aug) # b c h w -> b k\n",
193
+ " diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k\n",
194
+ " bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b\n",
195
+ " word_accs = (bit_accs == 1) # b\n",
196
+ " log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()\n",
197
+ "\n",
198
+ " df_WMA = df_WMA._append(log_stats, ignore_index=True)\n",
199
+ " save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir, f'{ii:03}_wm_orig.png'))\n",
200
+ "df_WMA.to_csv(os.path.join(saveimgs_dir, 'bitacc.csv'), index=False)"
201
+ ]
202
+ }
203
+ ],
204
+ "metadata": {
205
+ "kernelspec": {
206
+ "display_name": "ldm",
207
+ "language": "python",
208
+ "name": "python3"
209
+ },
210
+ "language_info": {
211
+ "codemirror_mode": {
212
+ "name": "ipython",
213
+ "version": 3
214
+ },
215
+ "file_extension": ".py",
216
+ "mimetype": "text/x-python",
217
+ "name": "python",
218
+ "nbconvert_exporter": "python",
219
+ "pygments_lexer": "ipython3",
220
+ "version": "3.8.18"
221
+ }
222
+ },
223
+ "nbformat": 4,
224
+ "nbformat_minor": 2
225
+ }
watermarker/LaWa/gen_wm_imgs.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from ldm.models.autoencoder import AutoencoderKL
3
+ from ldm.models.diffusion.ddpm import LatentDiffusion
4
+ import utils as utils
5
+ import utils_model as utils_model
6
+ import utils_img as utils_img
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import argparse
13
+ from copy import deepcopy
14
+ from torchvision import transforms
15
+ import os
16
+ import pandas as pd
17
+ from torchvision.utils import save_image
18
+ from accelerate import Accelerator
19
+ accelerator = Accelerator()
20
+
21
+ from ldm.util import instantiate_from_config
22
+
23
+
24
+ def main(args):
25
+ # args.apply_dlwt = True
26
+ # args.ckpt_prefix = "SS_dlwt" if args.apply_dlwt else "SS_fix_weights"
27
+ # args.exps_num = "005-exps"
28
+
29
+ # args.img_size = 256
30
+ # args.batch_size = 4
31
+ # args.seed = 0
32
+
33
+ # args.ldm_config = "/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml"
34
+ # args.ldm_ckpt = "/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt"
35
+ # args.msg_decoder_path = "/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt"
36
+ # args.val_dir = "/pubdata/ldd/Datasets/coco2017/val2017"
37
+
38
+ # Loads LDM auto-encoder models
39
+ print(f'>>> Building LDM model with config {args.ldm_config} and weights from {args.ldm_ckpt}...')
40
+ config = OmegaConf.load(f"{args.ldm_config}")
41
+ ldm_ae: LatentDiffusion = utils_model.load_model_from_config(config, args.ldm_ckpt)
42
+ ldm_ae: AutoencoderKL = ldm_ae.first_stage_model
43
+ ldm_ae.eval()
44
+ ldm_ae.to(accelerator.device)
45
+
46
+ saveimgs_dir = f'/pubdata/ldd/projects/EW-LoRA/experiments/evals/save_imgs_{args.ckpt_prefix}'
47
+ os.makedirs(saveimgs_dir, exist_ok=True)
48
+ vae_decoder_ckpt_dir = f'/pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/outputs/train_{args.ckpt_prefix}/checkpoints/epoch=000000-step=000024999.ckpt'
49
+
50
+ ### Load the pre-trained modified decoder model
51
+ config = OmegaConf.load(args.ldm_wm_config).model
52
+ message_len = config.params.decoder_config.params.message_len
53
+ if int(args.message_len) != message_len:
54
+ raise Exception(f"Provided message_len argument does not match the message length in the config file!")
55
+ ldm_decoder = instantiate_from_config(config)
56
+ # print(ldm_decoder.decoder)
57
+ state_dict = torch.load(vae_decoder_ckpt_dir, map_location=torch.device('cpu'))
58
+ if 'global_step' in state_dict:
59
+ print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}')
60
+ if 'state_dict' in state_dict:
61
+ state_dict = state_dict['state_dict']
62
+ misses, ignores = ldm_decoder.load_state_dict(state_dict, strict=False)
63
+ print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
64
+ ldm_decoder.eval()
65
+ ldm_decoder.to(accelerator.device)
66
+
67
+ # Loads hidden decoder
68
+ print(f'>>> Building hidden decoder with weights from {args.msg_decoder_path}...')
69
+ if 'torchscript' in args.msg_decoder_path:
70
+ msg_decoder = torch.jit.load(args.msg_decoder_path)
71
+ msg_decoder.eval()
72
+ nbit = msg_decoder(torch.zeros(1, 3, 128, 128).to(accelerator.device)).shape[-1]
73
+ msg_decoder.to(accelerator.device)
74
+
75
+ # Freeze LDM and hidden decoder
76
+ for param in [*msg_decoder.parameters(), *ldm_ae.parameters(), *ldm_decoder.parameters()]:
77
+ param.requires_grad = False
78
+
79
+ vqgan_transform = transforms.Compose([
80
+ transforms.Resize(args.img_size),
81
+ transforms.CenterCrop(args.img_size),
82
+ transforms.ToTensor(),
83
+ utils_img.normalize_vqgan,
84
+ ])
85
+ val_loader = utils.get_dataloader(args.val_dir, vqgan_transform, 1, num_imgs=1000, shuffle=False, num_workers=4, collate_fn=None)
86
+ vqgan_to_imnet = transforms.Compose([utils_img.unnormalize_vqgan, utils_img.normalize_img])
87
+
88
+ torch.manual_seed(args.seed)
89
+ torch.cuda.manual_seed_all(args.seed)
90
+ np.random.seed(args.seed)
91
+
92
+ key = torch.randint(0, 2, (1, nbit), dtype=torch.float32, device=accelerator.device)
93
+ key_str = "".join([ str(int(ii)) for ii in key.tolist()[0]])
94
+ print(f'Key: {key_str}')
95
+
96
+ df_SS = pd.DataFrame(columns=[
97
+ "iteration",
98
+ "psnr",
99
+ "bit_acc_avg",
100
+ ])
101
+ attacks = {
102
+ 'none': lambda x: x,
103
+ 'crop_01': lambda x: utils_img.center_crop(x, 0.1),
104
+ 'crop_05': lambda x: utils_img.center_crop(x, 0.5),
105
+ 'rot_25': lambda x: utils_img.rotate(x, 25),
106
+ 'rot_90': lambda x: utils_img.rotate(x, 90),
107
+ 'resize_03': lambda x: utils_img.resize(x, 0.3),
108
+ 'resize_07': lambda x: utils_img.resize(x, 0.7),
109
+ 'brightness_1p5': lambda x: utils_img.adjust_brightness(x, 1.5),
110
+ 'brightness_2': lambda x: utils_img.adjust_brightness(x, 2),
111
+ 'contrast_1p5': lambda x: utils_img.adjust_contrast(x, 1.5),
112
+ 'contrast_2': lambda x: utils_img.adjust_contrast(x, 2),
113
+ 'sharpness_1p5': lambda x: utils_img.adjust_sharpness(x, 1.5),
114
+ 'sharpness_2': lambda x: utils_img.adjust_sharpness(x, 2),
115
+ 'jpeg_80': lambda x: utils_img.jpeg_compress(x, 80),
116
+ 'jpeg_50': lambda x: utils_img.jpeg_compress(x, 50),
117
+ }
118
+
119
+ for ii, imgs in enumerate(val_loader):
120
+ imgs = imgs.to(accelerator.device)
121
+ keys = key.repeat(imgs.shape[0], 1).to(accelerator.device)
122
+
123
+ imgs_z = ldm_ae.encode(imgs) # b c h w -> b z h/f w/f
124
+ imgs_z = imgs_z.mode()
125
+
126
+ # decode latents with original and finetuned decoder
127
+ imgs_d0 = ldm_ae.decode(imgs_z) # b z h/f w/f -> b c h w
128
+ post_quant_noise = ldm_decoder.ae.post_quant_conv(imgs_z)
129
+ _, imgs_w = ldm_decoder(post_quant_noise, None, (2*keys-1)) # b z h/f w/f -> b c h w #TODO: Must do the stupid op. to get the correct message
130
+
131
+ # extract watermark
132
+ decoded = ldm_decoder.decoder(imgs_w.to("cuda"))
133
+ # decoded = msg_decoder(vqgan_to_imnet(imgs_w)) # b c h w -> b k
134
+ diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k
135
+ bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b
136
+
137
+ log_stats = {
138
+ "iteration": ii,
139
+ "psnr": utils_img.psnr(imgs_w, imgs_d0).mean().item(),
140
+ "bit_acc_avg": torch.mean(bit_accs).item(),
141
+ }
142
+ for name, attack in attacks.items():
143
+ imgs_aug = attack(imgs_w)
144
+ # decoded = msg_decoder(imgs_aug) # b c h w -> b k
145
+ decoded = ldm_decoder.decoder(imgs_aug.to("cuda"))
146
+ diff = (~torch.logical_xor(decoded>0, keys>0)) # b k -> b k
147
+ bit_accs = torch.sum(diff, dim=-1) / diff.shape[-1] # b k -> b
148
+ word_accs = (bit_accs == 1) # b
149
+ log_stats[f'bit_acc_{name}'] = torch.mean(bit_accs).item()
150
+
151
+ df_SS = df_SS._append(log_stats, ignore_index=True)
152
+ save_image(utils_img.unnormalize_vqgan(imgs_w), os.path.join(saveimgs_dir, f'{ii:03}_wm_orig.png'))
153
+ df_SS.to_csv(os.path.join(saveimgs_dir, 'bitacc.csv'), index=False)
154
+
155
+ def get_parser():
156
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
157
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size.")
158
+ parser.add_argument("--img_size", type=int, default=256, help="Image size.")
159
+ parser.add_argument("--seed", type=int, default=0, help="Seed.")
160
+ parser.add_argument("--message_len", type=int, default=48, help="Message length.")
161
+ parser.add_argument("--ldm_config", type=str, default="/pubdata/ldd/projects/EW-LoRA/watermarker/stable_signature/configs/stable-diffusion/v1-inference.yaml", help="LDM config.")
162
+ parser.add_argument("--ldm_wm_config", type=str, default="/pubdata/ldd/projects/EW-LoRA/watermarker/LaWa/configs/SD14_LaWa.yaml", help="LDM config.")
163
+ parser.add_argument("--ldm_ckpt", type=str, default="/pubdata/ldd/models/ldm_ckpts/sd-v1-4-full-ema.ckpt", help="LDM checkpoint.")
164
+ parser.add_argument("--msg_decoder_path", type=str, default="/pubdata/ldd/models/wm_encdec/hidden/ckpts/dec_48b_whit.torchscript.pt", help="Message decoder path.")
165
+ parser.add_argument("--val_dir", type=str, default="/pubdata/ldd/Datasets/coco2017/val2017", help="Validation directory.")
166
+ # parser.add_argument("--apply_dlwt", action="store_true", help="Apply DLWT.")
167
+ parser.add_argument("--ckpt_prefix", type=str, default="SS_dlwt", help="Checkpoint prefix.")
168
+ parser.add_argument("--exps_num", type=str, default="005-exps", help="Experiments number.")
169
+ return parser
170
+
171
+ if __name__ == '__main__':
172
+ # generate parser / parse parameters
173
+ parser = get_parser()
174
+ args = parser.parse_args()
175
+
176
+ # run experiment
177
+ main(args)
watermarker/LaWa/lawa_dataset/train_100k.csv ADDED
The diff for this file is too large to render. See raw diff
 
watermarker/LaWa/lawa_dataset/train_200k.csv ADDED
The diff for this file is too large to render. See raw diff
 
watermarker/LaWa/lawa_dataset/val_10k.csv ADDED
The diff for this file is too large to render. See raw diff
 
watermarker/LaWa/lawa_dataset/val_1k.csv ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,path
2
+ 99000,9/99000.jpg
3
+ 99001,9/99001.jpg
4
+ 99002,9/99002.jpg
5
+ 99003,9/99003.jpg
6
+ 99004,9/99004.jpg
7
+ 99005,9/99005.jpg
8
+ 99006,9/99006.jpg
9
+ 99007,9/99007.jpg
10
+ 99008,9/99008.jpg
11
+ 99009,9/99009.jpg
12
+ 99010,9/99010.jpg
13
+ 99011,9/99011.jpg
14
+ 99012,9/99012.jpg
15
+ 99013,9/99013.jpg
16
+ 99014,9/99014.jpg
17
+ 99015,9/99015.jpg
18
+ 99016,9/99016.jpg
19
+ 99017,9/99017.jpg
20
+ 99018,9/99018.jpg
21
+ 99019,9/99019.jpg
22
+ 99020,9/99020.jpg
23
+ 99021,9/99021.jpg
24
+ 99022,9/99022.jpg
25
+ 99023,9/99023.jpg
26
+ 99024,9/99024.jpg
27
+ 99025,9/99025.jpg
28
+ 99026,9/99026.jpg
29
+ 99027,9/99027.jpg
30
+ 99028,9/99028.jpg
31
+ 99029,9/99029.jpg
32
+ 99030,9/99030.jpg
33
+ 99031,9/99031.jpg
34
+ 99032,9/99032.jpg
35
+ 99033,9/99033.jpg
36
+ 99034,9/99034.jpg
37
+ 99035,9/99035.jpg
38
+ 99036,9/99036.jpg
39
+ 99037,9/99037.jpg
40
+ 99038,9/99038.jpg
41
+ 99039,9/99039.jpg
42
+ 99040,9/99040.jpg
43
+ 99041,9/99041.jpg
44
+ 99042,9/99042.jpg
45
+ 99043,9/99043.jpg
46
+ 99044,9/99044.jpg
47
+ 99045,9/99045.jpg
48
+ 99046,9/99046.jpg
49
+ 99047,9/99047.jpg
50
+ 99048,9/99048.jpg
51
+ 99049,9/99049.jpg
52
+ 99050,9/99050.jpg
53
+ 99051,9/99051.jpg
54
+ 99052,9/99052.jpg
55
+ 99053,9/99053.jpg
56
+ 99054,9/99054.jpg
57
+ 99055,9/99055.jpg
58
+ 99056,9/99056.jpg
59
+ 99057,9/99057.jpg
60
+ 99058,9/99058.jpg
61
+ 99059,9/99059.jpg
62
+ 99060,9/99060.jpg
63
+ 99061,9/99061.jpg
64
+ 99062,9/99062.jpg
65
+ 99063,9/99063.jpg
66
+ 99064,9/99064.jpg
67
+ 99065,9/99065.jpg
68
+ 99066,9/99066.jpg
69
+ 99067,9/99067.jpg
70
+ 99068,9/99068.jpg
71
+ 99069,9/99069.jpg
72
+ 99070,9/99070.jpg
73
+ 99071,9/99071.jpg
74
+ 99072,9/99072.jpg
75
+ 99073,9/99073.jpg
76
+ 99074,9/99074.jpg
77
+ 99075,9/99075.jpg
78
+ 99076,9/99076.jpg
79
+ 99077,9/99077.jpg
80
+ 99078,9/99078.jpg
81
+ 99079,9/99079.jpg
82
+ 99080,9/99080.jpg
83
+ 99081,9/99081.jpg
84
+ 99082,9/99082.jpg
85
+ 99083,9/99083.jpg
86
+ 99084,9/99084.jpg
87
+ 99085,9/99085.jpg
88
+ 99086,9/99086.jpg
89
+ 99087,9/99087.jpg
90
+ 99088,9/99088.jpg
91
+ 99089,9/99089.jpg
92
+ 99090,9/99090.jpg
93
+ 99091,9/99091.jpg
94
+ 99092,9/99092.jpg
95
+ 99093,9/99093.jpg
96
+ 99094,9/99094.jpg
97
+ 99095,9/99095.jpg
98
+ 99096,9/99096.jpg
99
+ 99097,9/99097.jpg
100
+ 99098,9/99098.jpg
101
+ 99099,9/99099.jpg
102
+ 99100,9/99100.jpg
103
+ 99101,9/99101.jpg
104
+ 99102,9/99102.jpg
105
+ 99103,9/99103.jpg
106
+ 99104,9/99104.jpg
107
+ 99105,9/99105.jpg
108
+ 99106,9/99106.jpg
109
+ 99107,9/99107.jpg
110
+ 99108,9/99108.jpg
111
+ 99109,9/99109.jpg
112
+ 99110,9/99110.jpg
113
+ 99111,9/99111.jpg
114
+ 99112,9/99112.jpg
115
+ 99113,9/99113.jpg
116
+ 99114,9/99114.jpg
117
+ 99115,9/99115.jpg
118
+ 99116,9/99116.jpg
119
+ 99117,9/99117.jpg
120
+ 99118,9/99118.jpg
121
+ 99119,9/99119.jpg
122
+ 99120,9/99120.jpg
123
+ 99121,9/99121.jpg
124
+ 99122,9/99122.jpg
125
+ 99123,9/99123.jpg
126
+ 99124,9/99124.jpg
127
+ 99125,9/99125.jpg
128
+ 99126,9/99126.jpg
129
+ 99127,9/99127.jpg
130
+ 99128,9/99128.jpg
131
+ 99129,9/99129.jpg
132
+ 99130,9/99130.jpg
133
+ 99131,9/99131.jpg
134
+ 99132,9/99132.jpg
135
+ 99133,9/99133.jpg
136
+ 99134,9/99134.jpg
137
+ 99135,9/99135.jpg
138
+ 99136,9/99136.jpg
139
+ 99137,9/99137.jpg
140
+ 99138,9/99138.jpg
141
+ 99139,9/99139.jpg
142
+ 99140,9/99140.jpg
143
+ 99141,9/99141.jpg
144
+ 99142,9/99142.jpg
145
+ 99143,9/99143.jpg
146
+ 99144,9/99144.jpg
147
+ 99145,9/99145.jpg
148
+ 99146,9/99146.jpg
149
+ 99147,9/99147.jpg
150
+ 99148,9/99148.jpg
151
+ 99149,9/99149.jpg
152
+ 99150,9/99150.jpg
153
+ 99151,9/99151.jpg
154
+ 99152,9/99152.jpg
155
+ 99153,9/99153.jpg
156
+ 99154,9/99154.jpg
157
+ 99155,9/99155.jpg
158
+ 99156,9/99156.jpg
159
+ 99157,9/99157.jpg
160
+ 99158,9/99158.jpg
161
+ 99159,9/99159.jpg
162
+ 99160,9/99160.jpg
163
+ 99161,9/99161.jpg
164
+ 99162,9/99162.jpg
165
+ 99163,9/99163.jpg
166
+ 99164,9/99164.jpg
167
+ 99165,9/99165.jpg
168
+ 99166,9/99166.jpg
169
+ 99167,9/99167.jpg
170
+ 99168,9/99168.jpg
171
+ 99169,9/99169.jpg
172
+ 99170,9/99170.jpg
173
+ 99171,9/99171.jpg
174
+ 99172,9/99172.jpg
175
+ 99173,9/99173.jpg
176
+ 99174,9/99174.jpg
177
+ 99175,9/99175.jpg
178
+ 99176,9/99176.jpg
179
+ 99177,9/99177.jpg
180
+ 99178,9/99178.jpg
181
+ 99179,9/99179.jpg
182
+ 99180,9/99180.jpg
183
+ 99181,9/99181.jpg
184
+ 99182,9/99182.jpg
185
+ 99183,9/99183.jpg
186
+ 99184,9/99184.jpg
187
+ 99185,9/99185.jpg
188
+ 99186,9/99186.jpg
189
+ 99187,9/99187.jpg
190
+ 99188,9/99188.jpg
191
+ 99189,9/99189.jpg
192
+ 99190,9/99190.jpg
193
+ 99191,9/99191.jpg
194
+ 99192,9/99192.jpg
195
+ 99193,9/99193.jpg
196
+ 99194,9/99194.jpg
197
+ 99195,9/99195.jpg
198
+ 99196,9/99196.jpg
199
+ 99197,9/99197.jpg
200
+ 99198,9/99198.jpg
201
+ 99199,9/99199.jpg
202
+ 99200,9/99200.jpg
203
+ 99201,9/99201.jpg
204
+ 99202,9/99202.jpg
205
+ 99203,9/99203.jpg
206
+ 99204,9/99204.jpg
207
+ 99205,9/99205.jpg
208
+ 99206,9/99206.jpg
209
+ 99207,9/99207.jpg
210
+ 99208,9/99208.jpg
211
+ 99209,9/99209.jpg
212
+ 99210,9/99210.jpg
213
+ 99211,9/99211.jpg
214
+ 99212,9/99212.jpg
215
+ 99213,9/99213.jpg
216
+ 99214,9/99214.jpg
217
+ 99215,9/99215.jpg
218
+ 99216,9/99216.jpg
219
+ 99217,9/99217.jpg
220
+ 99218,9/99218.jpg
221
+ 99219,9/99219.jpg
222
+ 99220,9/99220.jpg
223
+ 99221,9/99221.jpg
224
+ 99222,9/99222.jpg
225
+ 99223,9/99223.jpg
226
+ 99224,9/99224.jpg
227
+ 99225,9/99225.jpg
228
+ 99226,9/99226.jpg
229
+ 99227,9/99227.jpg
230
+ 99228,9/99228.jpg
231
+ 99229,9/99229.jpg
232
+ 99230,9/99230.jpg
233
+ 99231,9/99231.jpg
234
+ 99232,9/99232.jpg
235
+ 99233,9/99233.jpg
236
+ 99234,9/99234.jpg
237
+ 99235,9/99235.jpg
238
+ 99236,9/99236.jpg
239
+ 99237,9/99237.jpg
240
+ 99238,9/99238.jpg
241
+ 99239,9/99239.jpg
242
+ 99240,9/99240.jpg
243
+ 99241,9/99241.jpg
244
+ 99242,9/99242.jpg
245
+ 99243,9/99243.jpg
246
+ 99244,9/99244.jpg
247
+ 99245,9/99245.jpg
248
+ 99246,9/99246.jpg
249
+ 99247,9/99247.jpg
250
+ 99248,9/99248.jpg
251
+ 99249,9/99249.jpg
252
+ 99250,9/99250.jpg
253
+ 99251,9/99251.jpg
254
+ 99252,9/99252.jpg
255
+ 99253,9/99253.jpg
256
+ 99254,9/99254.jpg
257
+ 99255,9/99255.jpg
258
+ 99256,9/99256.jpg
259
+ 99257,9/99257.jpg
260
+ 99258,9/99258.jpg
261
+ 99259,9/99259.jpg
262
+ 99260,9/99260.jpg
263
+ 99261,9/99261.jpg
264
+ 99262,9/99262.jpg
265
+ 99263,9/99263.jpg
266
+ 99264,9/99264.jpg
267
+ 99265,9/99265.jpg
268
+ 99266,9/99266.jpg
269
+ 99267,9/99267.jpg
270
+ 99268,9/99268.jpg
271
+ 99269,9/99269.jpg
272
+ 99270,9/99270.jpg
273
+ 99271,9/99271.jpg
274
+ 99272,9/99272.jpg
275
+ 99273,9/99273.jpg
276
+ 99274,9/99274.jpg
277
+ 99275,9/99275.jpg
278
+ 99276,9/99276.jpg
279
+ 99277,9/99277.jpg
280
+ 99278,9/99278.jpg
281
+ 99279,9/99279.jpg
282
+ 99280,9/99280.jpg
283
+ 99281,9/99281.jpg
284
+ 99282,9/99282.jpg
285
+ 99283,9/99283.jpg
286
+ 99284,9/99284.jpg
287
+ 99285,9/99285.jpg
288
+ 99286,9/99286.jpg
289
+ 99287,9/99287.jpg
290
+ 99288,9/99288.jpg
291
+ 99289,9/99289.jpg
292
+ 99290,9/99290.jpg
293
+ 99291,9/99291.jpg
294
+ 99292,9/99292.jpg
295
+ 99293,9/99293.jpg
296
+ 99294,9/99294.jpg
297
+ 99295,9/99295.jpg
298
+ 99296,9/99296.jpg
299
+ 99297,9/99297.jpg
300
+ 99298,9/99298.jpg
301
+ 99299,9/99299.jpg
302
+ 99300,9/99300.jpg
303
+ 99301,9/99301.jpg
304
+ 99302,9/99302.jpg
305
+ 99303,9/99303.jpg
306
+ 99304,9/99304.jpg
307
+ 99305,9/99305.jpg
308
+ 99306,9/99306.jpg
309
+ 99307,9/99307.jpg
310
+ 99308,9/99308.jpg
311
+ 99309,9/99309.jpg
312
+ 99310,9/99310.jpg
313
+ 99311,9/99311.jpg
314
+ 99312,9/99312.jpg
315
+ 99313,9/99313.jpg
316
+ 99314,9/99314.jpg
317
+ 99315,9/99315.jpg
318
+ 99316,9/99316.jpg
319
+ 99317,9/99317.jpg
320
+ 99318,9/99318.jpg
321
+ 99319,9/99319.jpg
322
+ 99320,9/99320.jpg
323
+ 99321,9/99321.jpg
324
+ 99322,9/99322.jpg
325
+ 99323,9/99323.jpg
326
+ 99324,9/99324.jpg
327
+ 99325,9/99325.jpg
328
+ 99326,9/99326.jpg
329
+ 99327,9/99327.jpg
330
+ 99328,9/99328.jpg
331
+ 99329,9/99329.jpg
332
+ 99330,9/99330.jpg
333
+ 99331,9/99331.jpg
334
+ 99332,9/99332.jpg
335
+ 99333,9/99333.jpg
336
+ 99334,9/99334.jpg
337
+ 99335,9/99335.jpg
338
+ 99336,9/99336.jpg
339
+ 99337,9/99337.jpg
340
+ 99338,9/99338.jpg
341
+ 99339,9/99339.jpg
342
+ 99340,9/99340.jpg
343
+ 99341,9/99341.jpg
344
+ 99342,9/99342.jpg
345
+ 99343,9/99343.jpg
346
+ 99344,9/99344.jpg
347
+ 99345,9/99345.jpg
348
+ 99346,9/99346.jpg
349
+ 99347,9/99347.jpg
350
+ 99348,9/99348.jpg
351
+ 99349,9/99349.jpg
352
+ 99350,9/99350.jpg
353
+ 99351,9/99351.jpg
354
+ 99352,9/99352.jpg
355
+ 99353,9/99353.jpg
356
+ 99354,9/99354.jpg
357
+ 99355,9/99355.jpg
358
+ 99356,9/99356.jpg
359
+ 99357,9/99357.jpg
360
+ 99358,9/99358.jpg
361
+ 99359,9/99359.jpg
362
+ 99360,9/99360.jpg
363
+ 99361,9/99361.jpg
364
+ 99362,9/99362.jpg
365
+ 99363,9/99363.jpg
366
+ 99364,9/99364.jpg
367
+ 99365,9/99365.jpg
368
+ 99366,9/99366.jpg
369
+ 99367,9/99367.jpg
370
+ 99368,9/99368.jpg
371
+ 99369,9/99369.jpg
372
+ 99370,9/99370.jpg
373
+ 99371,9/99371.jpg
374
+ 99372,9/99372.jpg
375
+ 99373,9/99373.jpg
376
+ 99374,9/99374.jpg
377
+ 99375,9/99375.jpg
378
+ 99376,9/99376.jpg
379
+ 99377,9/99377.jpg
380
+ 99378,9/99378.jpg
381
+ 99379,9/99379.jpg
382
+ 99380,9/99380.jpg
383
+ 99381,9/99381.jpg
384
+ 99382,9/99382.jpg
385
+ 99383,9/99383.jpg
386
+ 99384,9/99384.jpg
387
+ 99385,9/99385.jpg
388
+ 99386,9/99386.jpg
389
+ 99387,9/99387.jpg
390
+ 99388,9/99388.jpg
391
+ 99389,9/99389.jpg
392
+ 99390,9/99390.jpg
393
+ 99391,9/99391.jpg
394
+ 99392,9/99392.jpg
395
+ 99393,9/99393.jpg
396
+ 99394,9/99394.jpg
397
+ 99395,9/99395.jpg
398
+ 99396,9/99396.jpg
399
+ 99397,9/99397.jpg
400
+ 99398,9/99398.jpg
401
+ 99399,9/99399.jpg
402
+ 99400,9/99400.jpg
403
+ 99401,9/99401.jpg
404
+ 99402,9/99402.jpg
405
+ 99403,9/99403.jpg
406
+ 99404,9/99404.jpg
407
+ 99405,9/99405.jpg
408
+ 99406,9/99406.jpg
409
+ 99407,9/99407.jpg
410
+ 99408,9/99408.jpg
411
+ 99409,9/99409.jpg
412
+ 99410,9/99410.jpg
413
+ 99411,9/99411.jpg
414
+ 99412,9/99412.jpg
415
+ 99413,9/99413.jpg
416
+ 99414,9/99414.jpg
417
+ 99415,9/99415.jpg
418
+ 99416,9/99416.jpg
419
+ 99417,9/99417.jpg
420
+ 99418,9/99418.jpg
421
+ 99419,9/99419.jpg
422
+ 99420,9/99420.jpg
423
+ 99421,9/99421.jpg
424
+ 99422,9/99422.jpg
425
+ 99423,9/99423.jpg
426
+ 99424,9/99424.jpg
427
+ 99425,9/99425.jpg
428
+ 99426,9/99426.jpg
429
+ 99427,9/99427.jpg
430
+ 99428,9/99428.jpg
431
+ 99429,9/99429.jpg
432
+ 99430,9/99430.jpg
433
+ 99431,9/99431.jpg
434
+ 99432,9/99432.jpg
435
+ 99433,9/99433.jpg
436
+ 99434,9/99434.jpg
437
+ 99435,9/99435.jpg
438
+ 99436,9/99436.jpg
439
+ 99437,9/99437.jpg
440
+ 99438,9/99438.jpg
441
+ 99439,9/99439.jpg
442
+ 99440,9/99440.jpg
443
+ 99441,9/99441.jpg
444
+ 99442,9/99442.jpg
445
+ 99443,9/99443.jpg
446
+ 99444,9/99444.jpg
447
+ 99445,9/99445.jpg
448
+ 99446,9/99446.jpg
449
+ 99447,9/99447.jpg
450
+ 99448,9/99448.jpg
451
+ 99449,9/99449.jpg
452
+ 99450,9/99450.jpg
453
+ 99451,9/99451.jpg
454
+ 99452,9/99452.jpg
455
+ 99453,9/99453.jpg
456
+ 99454,9/99454.jpg
457
+ 99455,9/99455.jpg
458
+ 99456,9/99456.jpg
459
+ 99457,9/99457.jpg
460
+ 99458,9/99458.jpg
461
+ 99459,9/99459.jpg
462
+ 99460,9/99460.jpg
463
+ 99461,9/99461.jpg
464
+ 99462,9/99462.jpg
465
+ 99463,9/99463.jpg
466
+ 99464,9/99464.jpg
467
+ 99465,9/99465.jpg
468
+ 99466,9/99466.jpg
469
+ 99467,9/99467.jpg
470
+ 99468,9/99468.jpg
471
+ 99469,9/99469.jpg
472
+ 99470,9/99470.jpg
473
+ 99471,9/99471.jpg
474
+ 99472,9/99472.jpg
475
+ 99473,9/99473.jpg
476
+ 99474,9/99474.jpg
477
+ 99475,9/99475.jpg
478
+ 99476,9/99476.jpg
479
+ 99477,9/99477.jpg
480
+ 99478,9/99478.jpg
481
+ 99479,9/99479.jpg
482
+ 99480,9/99480.jpg
483
+ 99481,9/99481.jpg
484
+ 99482,9/99482.jpg
485
+ 99483,9/99483.jpg
486
+ 99484,9/99484.jpg
487
+ 99485,9/99485.jpg
488
+ 99486,9/99486.jpg
489
+ 99487,9/99487.jpg
490
+ 99488,9/99488.jpg
491
+ 99489,9/99489.jpg
492
+ 99490,9/99490.jpg
493
+ 99491,9/99491.jpg
494
+ 99492,9/99492.jpg
495
+ 99493,9/99493.jpg
496
+ 99494,9/99494.jpg
497
+ 99495,9/99495.jpg
498
+ 99496,9/99496.jpg
499
+ 99497,9/99497.jpg
500
+ 99498,9/99498.jpg
501
+ 99499,9/99499.jpg
502
+ 99500,9/99500.jpg
503
+ 99501,9/99501.jpg
504
+ 99502,9/99502.jpg
505
+ 99503,9/99503.jpg
506
+ 99504,9/99504.jpg
507
+ 99505,9/99505.jpg
508
+ 99506,9/99506.jpg
509
+ 99507,9/99507.jpg
510
+ 99508,9/99508.jpg
511
+ 99509,9/99509.jpg
512
+ 99510,9/99510.jpg
513
+ 99511,9/99511.jpg
514
+ 99512,9/99512.jpg
515
+ 99513,9/99513.jpg
516
+ 99514,9/99514.jpg
517
+ 99515,9/99515.jpg
518
+ 99516,9/99516.jpg
519
+ 99517,9/99517.jpg
520
+ 99518,9/99518.jpg
521
+ 99519,9/99519.jpg
522
+ 99520,9/99520.jpg
523
+ 99521,9/99521.jpg
524
+ 99522,9/99522.jpg
525
+ 99523,9/99523.jpg
526
+ 99524,9/99524.jpg
527
+ 99525,9/99525.jpg
528
+ 99526,9/99526.jpg
529
+ 99527,9/99527.jpg
530
+ 99528,9/99528.jpg
531
+ 99529,9/99529.jpg
532
+ 99530,9/99530.jpg
533
+ 99531,9/99531.jpg
534
+ 99532,9/99532.jpg
535
+ 99533,9/99533.jpg
536
+ 99534,9/99534.jpg
537
+ 99535,9/99535.jpg
538
+ 99536,9/99536.jpg
539
+ 99537,9/99537.jpg
540
+ 99538,9/99538.jpg
541
+ 99539,9/99539.jpg
542
+ 99540,9/99540.jpg
543
+ 99541,9/99541.jpg
544
+ 99542,9/99542.jpg
545
+ 99543,9/99543.jpg
546
+ 99544,9/99544.jpg
547
+ 99545,9/99545.jpg
548
+ 99546,9/99546.jpg
549
+ 99547,9/99547.jpg
550
+ 99548,9/99548.jpg
551
+ 99549,9/99549.jpg
552
+ 99550,9/99550.jpg
553
+ 99551,9/99551.jpg
554
+ 99552,9/99552.jpg
555
+ 99553,9/99553.jpg
556
+ 99554,9/99554.jpg
557
+ 99555,9/99555.jpg
558
+ 99556,9/99556.jpg
559
+ 99557,9/99557.jpg
560
+ 99558,9/99558.jpg
561
+ 99559,9/99559.jpg
562
+ 99560,9/99560.jpg
563
+ 99561,9/99561.jpg
564
+ 99562,9/99562.jpg
565
+ 99563,9/99563.jpg
566
+ 99564,9/99564.jpg
567
+ 99565,9/99565.jpg
568
+ 99566,9/99566.jpg
569
+ 99567,9/99567.jpg
570
+ 99568,9/99568.jpg
571
+ 99569,9/99569.jpg
572
+ 99570,9/99570.jpg
573
+ 99571,9/99571.jpg
574
+ 99572,9/99572.jpg
575
+ 99573,9/99573.jpg
576
+ 99574,9/99574.jpg
577
+ 99575,9/99575.jpg
578
+ 99576,9/99576.jpg
579
+ 99577,9/99577.jpg
580
+ 99578,9/99578.jpg
581
+ 99579,9/99579.jpg
582
+ 99580,9/99580.jpg
583
+ 99581,9/99581.jpg
584
+ 99582,9/99582.jpg
585
+ 99583,9/99583.jpg
586
+ 99584,9/99584.jpg
587
+ 99585,9/99585.jpg
588
+ 99586,9/99586.jpg
589
+ 99587,9/99587.jpg
590
+ 99588,9/99588.jpg
591
+ 99589,9/99589.jpg
592
+ 99590,9/99590.jpg
593
+ 99591,9/99591.jpg
594
+ 99592,9/99592.jpg
595
+ 99593,9/99593.jpg
596
+ 99594,9/99594.jpg
597
+ 99595,9/99595.jpg
598
+ 99596,9/99596.jpg
599
+ 99597,9/99597.jpg
600
+ 99598,9/99598.jpg
601
+ 99599,9/99599.jpg
602
+ 99600,9/99600.jpg
603
+ 99601,9/99601.jpg
604
+ 99602,9/99602.jpg
605
+ 99603,9/99603.jpg
606
+ 99604,9/99604.jpg
607
+ 99605,9/99605.jpg
608
+ 99606,9/99606.jpg
609
+ 99607,9/99607.jpg
610
+ 99608,9/99608.jpg
611
+ 99609,9/99609.jpg
612
+ 99610,9/99610.jpg
613
+ 99611,9/99611.jpg
614
+ 99612,9/99612.jpg
615
+ 99613,9/99613.jpg
616
+ 99614,9/99614.jpg
617
+ 99615,9/99615.jpg
618
+ 99616,9/99616.jpg
619
+ 99617,9/99617.jpg
620
+ 99618,9/99618.jpg
621
+ 99619,9/99619.jpg
622
+ 99620,9/99620.jpg
623
+ 99621,9/99621.jpg
624
+ 99622,9/99622.jpg
625
+ 99623,9/99623.jpg
626
+ 99624,9/99624.jpg
627
+ 99625,9/99625.jpg
628
+ 99626,9/99626.jpg
629
+ 99627,9/99627.jpg
630
+ 99628,9/99628.jpg
631
+ 99629,9/99629.jpg
632
+ 99630,9/99630.jpg
633
+ 99631,9/99631.jpg
634
+ 99632,9/99632.jpg
635
+ 99633,9/99633.jpg
636
+ 99634,9/99634.jpg
637
+ 99635,9/99635.jpg
638
+ 99636,9/99636.jpg
639
+ 99637,9/99637.jpg
640
+ 99638,9/99638.jpg
641
+ 99639,9/99639.jpg
642
+ 99640,9/99640.jpg
643
+ 99641,9/99641.jpg
644
+ 99642,9/99642.jpg
645
+ 99643,9/99643.jpg
646
+ 99644,9/99644.jpg
647
+ 99645,9/99645.jpg
648
+ 99646,9/99646.jpg
649
+ 99647,9/99647.jpg
650
+ 99648,9/99648.jpg
651
+ 99649,9/99649.jpg
652
+ 99650,9/99650.jpg
653
+ 99651,9/99651.jpg
654
+ 99652,9/99652.jpg
655
+ 99653,9/99653.jpg
656
+ 99654,9/99654.jpg
657
+ 99655,9/99655.jpg
658
+ 99656,9/99656.jpg
659
+ 99657,9/99657.jpg
660
+ 99658,9/99658.jpg
661
+ 99659,9/99659.jpg
662
+ 99660,9/99660.jpg
663
+ 99661,9/99661.jpg
664
+ 99662,9/99662.jpg
665
+ 99663,9/99663.jpg
666
+ 99664,9/99664.jpg
667
+ 99665,9/99665.jpg
668
+ 99666,9/99666.jpg
669
+ 99667,9/99667.jpg
670
+ 99668,9/99668.jpg
671
+ 99669,9/99669.jpg
672
+ 99670,9/99670.jpg
673
+ 99671,9/99671.jpg
674
+ 99672,9/99672.jpg
675
+ 99673,9/99673.jpg
676
+ 99674,9/99674.jpg
677
+ 99675,9/99675.jpg
678
+ 99676,9/99676.jpg
679
+ 99677,9/99677.jpg
680
+ 99678,9/99678.jpg
681
+ 99679,9/99679.jpg
682
+ 99680,9/99680.jpg
683
+ 99681,9/99681.jpg
684
+ 99682,9/99682.jpg
685
+ 99683,9/99683.jpg
686
+ 99684,9/99684.jpg
687
+ 99685,9/99685.jpg
688
+ 99686,9/99686.jpg
689
+ 99687,9/99687.jpg
690
+ 99688,9/99688.jpg
691
+ 99689,9/99689.jpg
692
+ 99690,9/99690.jpg
693
+ 99691,9/99691.jpg
694
+ 99692,9/99692.jpg
695
+ 99693,9/99693.jpg
696
+ 99694,9/99694.jpg
697
+ 99695,9/99695.jpg
698
+ 99696,9/99696.jpg
699
+ 99697,9/99697.jpg
700
+ 99698,9/99698.jpg
701
+ 99699,9/99699.jpg
702
+ 99700,9/99700.jpg
703
+ 99701,9/99701.jpg
704
+ 99702,9/99702.jpg
705
+ 99703,9/99703.jpg
706
+ 99704,9/99704.jpg
707
+ 99705,9/99705.jpg
708
+ 99706,9/99706.jpg
709
+ 99707,9/99707.jpg
710
+ 99708,9/99708.jpg
711
+ 99709,9/99709.jpg
712
+ 99710,9/99710.jpg
713
+ 99711,9/99711.jpg
714
+ 99712,9/99712.jpg
715
+ 99713,9/99713.jpg
716
+ 99714,9/99714.jpg
717
+ 99715,9/99715.jpg
718
+ 99716,9/99716.jpg
719
+ 99717,9/99717.jpg
720
+ 99718,9/99718.jpg
721
+ 99719,9/99719.jpg
722
+ 99720,9/99720.jpg
723
+ 99721,9/99721.jpg
724
+ 99722,9/99722.jpg
725
+ 99723,9/99723.jpg
726
+ 99724,9/99724.jpg
727
+ 99725,9/99725.jpg
728
+ 99726,9/99726.jpg
729
+ 99727,9/99727.jpg
730
+ 99728,9/99728.jpg
731
+ 99729,9/99729.jpg
732
+ 99730,9/99730.jpg
733
+ 99731,9/99731.jpg
734
+ 99732,9/99732.jpg
735
+ 99733,9/99733.jpg
736
+ 99734,9/99734.jpg
737
+ 99735,9/99735.jpg
738
+ 99736,9/99736.jpg
739
+ 99737,9/99737.jpg
740
+ 99738,9/99738.jpg
741
+ 99739,9/99739.jpg
742
+ 99740,9/99740.jpg
743
+ 99741,9/99741.jpg
744
+ 99742,9/99742.jpg
745
+ 99743,9/99743.jpg
746
+ 99744,9/99744.jpg
747
+ 99745,9/99745.jpg
748
+ 99746,9/99746.jpg
749
+ 99747,9/99747.jpg
750
+ 99748,9/99748.jpg
751
+ 99749,9/99749.jpg
752
+ 99750,9/99750.jpg
753
+ 99751,9/99751.jpg
754
+ 99752,9/99752.jpg
755
+ 99753,9/99753.jpg
756
+ 99754,9/99754.jpg
757
+ 99755,9/99755.jpg
758
+ 99756,9/99756.jpg
759
+ 99757,9/99757.jpg
760
+ 99758,9/99758.jpg
761
+ 99759,9/99759.jpg
762
+ 99760,9/99760.jpg
763
+ 99761,9/99761.jpg
764
+ 99762,9/99762.jpg
765
+ 99763,9/99763.jpg
766
+ 99764,9/99764.jpg
767
+ 99765,9/99765.jpg
768
+ 99766,9/99766.jpg
769
+ 99767,9/99767.jpg
770
+ 99768,9/99768.jpg
771
+ 99769,9/99769.jpg
772
+ 99770,9/99770.jpg
773
+ 99771,9/99771.jpg
774
+ 99772,9/99772.jpg
775
+ 99773,9/99773.jpg
776
+ 99774,9/99774.jpg
777
+ 99775,9/99775.jpg
778
+ 99776,9/99776.jpg
779
+ 99777,9/99777.jpg
780
+ 99778,9/99778.jpg
781
+ 99779,9/99779.jpg
782
+ 99780,9/99780.jpg
783
+ 99781,9/99781.jpg
784
+ 99782,9/99782.jpg
785
+ 99783,9/99783.jpg
786
+ 99784,9/99784.jpg
787
+ 99785,9/99785.jpg
788
+ 99786,9/99786.jpg
789
+ 99787,9/99787.jpg
790
+ 99788,9/99788.jpg
791
+ 99789,9/99789.jpg
792
+ 99790,9/99790.jpg
793
+ 99791,9/99791.jpg
794
+ 99792,9/99792.jpg
795
+ 99793,9/99793.jpg
796
+ 99794,9/99794.jpg
797
+ 99795,9/99795.jpg
798
+ 99796,9/99796.jpg
799
+ 99797,9/99797.jpg
800
+ 99798,9/99798.jpg
801
+ 99799,9/99799.jpg
802
+ 99800,9/99800.jpg
803
+ 99801,9/99801.jpg
804
+ 99802,9/99802.jpg
805
+ 99803,9/99803.jpg
806
+ 99804,9/99804.jpg
807
+ 99805,9/99805.jpg
808
+ 99806,9/99806.jpg
809
+ 99807,9/99807.jpg
810
+ 99808,9/99808.jpg
811
+ 99809,9/99809.jpg
812
+ 99810,9/99810.jpg
813
+ 99811,9/99811.jpg
814
+ 99812,9/99812.jpg
815
+ 99813,9/99813.jpg
816
+ 99814,9/99814.jpg
817
+ 99815,9/99815.jpg
818
+ 99816,9/99816.jpg
819
+ 99817,9/99817.jpg
820
+ 99818,9/99818.jpg
821
+ 99819,9/99819.jpg
822
+ 99820,9/99820.jpg
823
+ 99821,9/99821.jpg
824
+ 99822,9/99822.jpg
825
+ 99823,9/99823.jpg
826
+ 99824,9/99824.jpg
827
+ 99825,9/99825.jpg
828
+ 99826,9/99826.jpg
829
+ 99827,9/99827.jpg
830
+ 99828,9/99828.jpg
831
+ 99829,9/99829.jpg
832
+ 99830,9/99830.jpg
833
+ 99831,9/99831.jpg
834
+ 99832,9/99832.jpg
835
+ 99833,9/99833.jpg
836
+ 99834,9/99834.jpg
837
+ 99835,9/99835.jpg
838
+ 99836,9/99836.jpg
839
+ 99837,9/99837.jpg
840
+ 99838,9/99838.jpg
841
+ 99839,9/99839.jpg
842
+ 99840,9/99840.jpg
843
+ 99841,9/99841.jpg
844
+ 99842,9/99842.jpg
845
+ 99843,9/99843.jpg
846
+ 99844,9/99844.jpg
847
+ 99845,9/99845.jpg
848
+ 99846,9/99846.jpg
849
+ 99847,9/99847.jpg
850
+ 99848,9/99848.jpg
851
+ 99849,9/99849.jpg
852
+ 99850,9/99850.jpg
853
+ 99851,9/99851.jpg
854
+ 99852,9/99852.jpg
855
+ 99853,9/99853.jpg
856
+ 99854,9/99854.jpg
857
+ 99855,9/99855.jpg
858
+ 99856,9/99856.jpg
859
+ 99857,9/99857.jpg
860
+ 99858,9/99858.jpg
861
+ 99859,9/99859.jpg
862
+ 99860,9/99860.jpg
863
+ 99861,9/99861.jpg
864
+ 99862,9/99862.jpg
865
+ 99863,9/99863.jpg
866
+ 99864,9/99864.jpg
867
+ 99865,9/99865.jpg
868
+ 99866,9/99866.jpg
869
+ 99867,9/99867.jpg
870
+ 99868,9/99868.jpg
871
+ 99869,9/99869.jpg
872
+ 99870,9/99870.jpg
873
+ 99871,9/99871.jpg
874
+ 99872,9/99872.jpg
875
+ 99873,9/99873.jpg
876
+ 99874,9/99874.jpg
877
+ 99875,9/99875.jpg
878
+ 99876,9/99876.jpg
879
+ 99877,9/99877.jpg
880
+ 99878,9/99878.jpg
881
+ 99879,9/99879.jpg
882
+ 99880,9/99880.jpg
883
+ 99881,9/99881.jpg
884
+ 99882,9/99882.jpg
885
+ 99883,9/99883.jpg
886
+ 99884,9/99884.jpg
887
+ 99885,9/99885.jpg
888
+ 99886,9/99886.jpg
889
+ 99887,9/99887.jpg
890
+ 99888,9/99888.jpg
891
+ 99889,9/99889.jpg
892
+ 99890,9/99890.jpg
893
+ 99891,9/99891.jpg
894
+ 99892,9/99892.jpg
895
+ 99893,9/99893.jpg
896
+ 99894,9/99894.jpg
897
+ 99895,9/99895.jpg
898
+ 99896,9/99896.jpg
899
+ 99897,9/99897.jpg
900
+ 99898,9/99898.jpg
901
+ 99899,9/99899.jpg
902
+ 99900,9/99900.jpg
903
+ 99901,9/99901.jpg
904
+ 99902,9/99902.jpg
905
+ 99903,9/99903.jpg
906
+ 99904,9/99904.jpg
907
+ 99905,9/99905.jpg
908
+ 99906,9/99906.jpg
909
+ 99907,9/99907.jpg
910
+ 99908,9/99908.jpg
911
+ 99909,9/99909.jpg
912
+ 99910,9/99910.jpg
913
+ 99911,9/99911.jpg
914
+ 99912,9/99912.jpg
915
+ 99913,9/99913.jpg
916
+ 99914,9/99914.jpg
917
+ 99915,9/99915.jpg
918
+ 99916,9/99916.jpg
919
+ 99917,9/99917.jpg
920
+ 99918,9/99918.jpg
921
+ 99919,9/99919.jpg
922
+ 99920,9/99920.jpg
923
+ 99921,9/99921.jpg
924
+ 99922,9/99922.jpg
925
+ 99923,9/99923.jpg
926
+ 99924,9/99924.jpg
927
+ 99925,9/99925.jpg
928
+ 99926,9/99926.jpg
929
+ 99927,9/99927.jpg
930
+ 99928,9/99928.jpg
931
+ 99929,9/99929.jpg
932
+ 99930,9/99930.jpg
933
+ 99931,9/99931.jpg
934
+ 99932,9/99932.jpg
935
+ 99933,9/99933.jpg
936
+ 99934,9/99934.jpg
937
+ 99935,9/99935.jpg
938
+ 99936,9/99936.jpg
939
+ 99937,9/99937.jpg
940
+ 99938,9/99938.jpg
941
+ 99939,9/99939.jpg
942
+ 99940,9/99940.jpg
943
+ 99941,9/99941.jpg
944
+ 99942,9/99942.jpg
945
+ 99943,9/99943.jpg
946
+ 99944,9/99944.jpg
947
+ 99945,9/99945.jpg
948
+ 99946,9/99946.jpg
949
+ 99947,9/99947.jpg
950
+ 99948,9/99948.jpg
951
+ 99949,9/99949.jpg
952
+ 99950,9/99950.jpg
953
+ 99951,9/99951.jpg
954
+ 99952,9/99952.jpg
955
+ 99953,9/99953.jpg
956
+ 99954,9/99954.jpg
957
+ 99955,9/99955.jpg
958
+ 99956,9/99956.jpg
959
+ 99957,9/99957.jpg
960
+ 99958,9/99958.jpg
961
+ 99959,9/99959.jpg
962
+ 99960,9/99960.jpg
963
+ 99961,9/99961.jpg
964
+ 99962,9/99962.jpg
965
+ 99963,9/99963.jpg
966
+ 99964,9/99964.jpg
967
+ 99965,9/99965.jpg
968
+ 99966,9/99966.jpg
969
+ 99967,9/99967.jpg
970
+ 99968,9/99968.jpg
971
+ 99969,9/99969.jpg
972
+ 99970,9/99970.jpg
973
+ 99971,9/99971.jpg
974
+ 99972,9/99972.jpg
975
+ 99973,9/99973.jpg
976
+ 99974,9/99974.jpg
977
+ 99975,9/99975.jpg
978
+ 99976,9/99976.jpg
979
+ 99977,9/99977.jpg
980
+ 99978,9/99978.jpg
981
+ 99979,9/99979.jpg
982
+ 99980,9/99980.jpg
983
+ 99981,9/99981.jpg
984
+ 99982,9/99982.jpg
985
+ 99983,9/99983.jpg
986
+ 99984,9/99984.jpg
987
+ 99985,9/99985.jpg
988
+ 99986,9/99986.jpg
989
+ 99987,9/99987.jpg
990
+ 99988,9/99988.jpg
991
+ 99989,9/99989.jpg
992
+ 99990,9/99990.jpg
993
+ 99991,9/99991.jpg
994
+ 99992,9/99992.jpg
995
+ 99993,9/99993.jpg
996
+ 99994,9/99994.jpg
997
+ 99995,9/99995.jpg
998
+ 99996,9/99996.jpg
999
+ 99997,9/99997.jpg
1000
+ 99998,9/99998.jpg
1001
+ 99999,9/99999.jpg
watermarker/LaWa/ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (6.59 kB). View file
 
watermarker/LaWa/ldm/data/__init__.py ADDED
File without changes
watermarker/LaWa/ldm/data/util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ldm.modules.midas.api import load_midas_transform
4
+
5
+
6
+ class AddMiDaS(object):
7
+ def __init__(self, model_type):
8
+ super().__init__()
9
+ self.transform = load_midas_transform(model_type)
10
+
11
+ def pt2np(self, x):
12
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
13
+ return x
14
+
15
+ def np2pt(self, x):
16
+ x = torch.from_numpy(x) * 2 - 1.
17
+ return x
18
+
19
+ def __call__(self, sample):
20
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
21
+ x = self.pt2np(sample['jpg'])
22
+ x = self.transform({"image": x})["image"]
23
+ sample['midas_in'] = x
24
+ return sample
watermarker/LaWa/ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (14.8 kB). View file
 
watermarker/LaWa/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
7
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from ldm.util import instantiate_from_config
10
+ from ldm.modules.ema import LitEma
11
+
12
+
13
+ class AutoencoderKL(pl.LightningModule):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels)==int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0. < ema_decay < 1.
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116
+ last_layer=self.get_last_layer(), split="train")
117
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119
+ return aeloss
120
+
121
+ if optimizer_idx == 1:
122
+ # train the discriminator
123
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124
+ last_layer=self.get_last_layer(), split="train")
125
+
126
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128
+ return discloss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+ log_dict = self._validation_step(batch, batch_idx)
132
+ with self.ema_scope():
133
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134
+ return log_dict
135
+
136
+ def _validation_step(self, batch, batch_idx, postfix=""):
137
+ inputs = self.get_input(batch, self.image_key)
138
+ reconstructions, posterior = self(inputs)
139
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140
+ last_layer=self.get_last_layer(), split="val"+postfix)
141
+
142
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143
+ last_layer=self.get_last_layer(), split="val"+postfix)
144
+
145
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146
+ self.log_dict(log_dict_ae)
147
+ self.log_dict(log_dict_disc)
148
+ return self.log_dict
149
+
150
+ def configure_optimizers(self):
151
+ lr = self.learning_rate
152
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154
+ if self.learn_logvar:
155
+ print(f"{self.__class__.__name__}: Learning logvar")
156
+ ae_params_list.append(self.loss.logvar)
157
+ opt_ae = torch.optim.Adam(ae_params_list,
158
+ lr=lr, betas=(0.5, 0.9))
159
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160
+ lr=lr, betas=(0.5, 0.9))
161
+ return [opt_ae, opt_disc], []
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.conv_out.weight
165
+
166
+ @torch.no_grad()
167
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168
+ log = dict()
169
+ x = self.get_input(batch, self.image_key)
170
+ x = x.to(self.device)
171
+ if not only_inputs:
172
+ xrec, posterior = self(x)
173
+ if x.shape[1] > 3:
174
+ # colorize with random projection
175
+ assert xrec.shape[1] > 3
176
+ x = self.to_rgb(x)
177
+ xrec = self.to_rgb(xrec)
178
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179
+ log["reconstructions"] = xrec
180
+ if log_ema or self.use_ema:
181
+ with self.ema_scope():
182
+ xrec_ema, posterior_ema = self(x)
183
+ if x.shape[1] > 3:
184
+ # colorize with random projection
185
+ assert xrec_ema.shape[1] > 3
186
+ xrec_ema = self.to_rgb(xrec_ema)
187
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188
+ log["reconstructions_ema"] = xrec_ema
189
+ log["inputs"] = x
190
+ return log
191
+
192
+ def to_rgb(self, x):
193
+ assert self.image_key == "segmentation"
194
+ if not hasattr(self, "colorize"):
195
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196
+ x = F.conv2d(x, weight=self.colorize)
197
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198
+ return x
199
+
200
+
201
+ class IdentityFirstStage(torch.nn.Module):
202
+ def __init__(self, *args, vq_interface=False, **kwargs):
203
+ self.vq_interface = vq_interface
204
+ super().__init__()
205
+
206
+ def encode(self, x, *args, **kwargs):
207
+ return x
208
+
209
+ def decode(self, x, *args, **kwargs):
210
+ return x
211
+
212
+ def quantize(self, x, *args, **kwargs):
213
+ if self.vq_interface:
214
+ return x, None, [None, None, None]
215
+ return x
216
+
217
+ def forward(self, x, *args, **kwargs):
218
+ return x
219
+
220
+
221
+ class VQModel(pl.LightningModule):
222
+ def __init__(self,
223
+ ddconfig,
224
+ lossconfig,
225
+ n_embed,
226
+ embed_dim,
227
+ ckpt_path=None,
228
+ ignore_keys=[],
229
+ image_key="image",
230
+ colorize_nlabels=None,
231
+ monitor=None,
232
+ batch_resize_range=None,
233
+ scheduler_config=None,
234
+ lr_g_factor=1.0,
235
+ remap=None,
236
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
237
+ use_ema=False
238
+ ):
239
+ super().__init__()
240
+
241
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
242
+
243
+ self.embed_dim = embed_dim
244
+ self.n_embed = n_embed
245
+ self.image_key = image_key
246
+ self.encoder = Encoder(**ddconfig)
247
+ self.decoder = Decoder(**ddconfig)
248
+ self.loss = instantiate_from_config(lossconfig)
249
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
250
+ remap=remap,
251
+ sane_index_shape=sane_index_shape)
252
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
253
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
254
+ if colorize_nlabels is not None:
255
+ assert type(colorize_nlabels)==int
256
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
257
+ if monitor is not None:
258
+ self.monitor = monitor
259
+ self.batch_resize_range = batch_resize_range
260
+ if self.batch_resize_range is not None:
261
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
262
+
263
+ self.use_ema = use_ema
264
+ if self.use_ema:
265
+ self.model_ema = LitEma(self)
266
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
267
+
268
+ if ckpt_path is not None:
269
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
270
+ self.scheduler_config = scheduler_config
271
+ self.lr_g_factor = lr_g_factor
272
+
273
+ @contextmanager
274
+ def ema_scope(self, context=None):
275
+ if self.use_ema:
276
+ self.model_ema.store(self.parameters())
277
+ self.model_ema.copy_to(self)
278
+ if context is not None:
279
+ print(f"{context}: Switched to EMA weights")
280
+ try:
281
+ yield None
282
+ finally:
283
+ if self.use_ema:
284
+ self.model_ema.restore(self.parameters())
285
+ if context is not None:
286
+ print(f"{context}: Restored training weights")
287
+
288
+ def init_from_ckpt(self, path, ignore_keys=list()):
289
+ sd = torch.load(path, map_location="cpu")["state_dict"]
290
+ keys = list(sd.keys())
291
+ for k in keys:
292
+ for ik in ignore_keys:
293
+ if k.startswith(ik):
294
+ print("Deleting key {} from state_dict.".format(k))
295
+ del sd[k]
296
+ missing, unexpected = self.load_state_dict(sd, strict=False)
297
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
298
+ if len(missing) > 0:
299
+ print(f"Missing Keys: {missing}")
300
+ print(f"Unexpected Keys: {unexpected}")
301
+
302
+ def on_train_batch_end(self, *args, **kwargs):
303
+ if self.use_ema:
304
+ self.model_ema(self)
305
+
306
+ def encode(self, x):
307
+ h = self.encoder(x)
308
+ h = self.quant_conv(h)
309
+ quant, emb_loss, info = self.quantize(h)
310
+ return quant, emb_loss, info
311
+
312
+ def encode_to_prequant(self, x):
313
+ h = self.encoder(x)
314
+ h = self.quant_conv(h)
315
+ return h
316
+
317
+ def decode(self, quant):
318
+ quant = self.post_quant_conv(quant)
319
+ dec = self.decoder(quant)
320
+ return dec
321
+
322
+ def decode_code(self, code_b):
323
+ quant_b = self.quantize.embed_code(code_b)
324
+ dec = self.decode(quant_b)
325
+ return dec
326
+
327
+ def forward(self, input, return_pred_indices=False):
328
+ quant, diff, (_,_,ind) = self.encode(input)
329
+ dec = self.decode(quant)
330
+ if return_pred_indices:
331
+ return dec, diff, ind
332
+ return dec, diff
333
+
334
+ def get_input(self, batch, k):
335
+ x = batch[k]
336
+ if len(x.shape) == 3:
337
+ x = x[..., None]
338
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
339
+ if self.batch_resize_range is not None:
340
+ lower_size = self.batch_resize_range[0]
341
+ upper_size = self.batch_resize_range[1]
342
+ if self.global_step <= 4:
343
+ # do the first few batches with max size to avoid later oom
344
+ new_resize = upper_size
345
+ else:
346
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
347
+ if new_resize != x.shape[2]:
348
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
349
+ x = x.detach()
350
+ return x
351
+
352
+ def training_step(self, batch, batch_idx, optimizer_idx):
353
+ # https://github.com/pytorch/pytorch/issues/37142
354
+ # try not to fool the heuristics
355
+ x = self.get_input(batch, self.image_key)
356
+ xrec, qloss, ind = self(x, return_pred_indices=True)
357
+
358
+ if optimizer_idx == 0:
359
+ # autoencode
360
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
361
+ last_layer=self.get_last_layer(), split="train",
362
+ predicted_indices=ind)
363
+
364
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
365
+ return aeloss
366
+
367
+ if optimizer_idx == 1:
368
+ # discriminator
369
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
370
+ last_layer=self.get_last_layer(), split="train")
371
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
372
+ return discloss
373
+
374
+ def validation_step(self, batch, batch_idx):
375
+ log_dict = self._validation_step(batch, batch_idx)
376
+ with self.ema_scope():
377
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
378
+ return log_dict
379
+
380
+ def _validation_step(self, batch, batch_idx, suffix=""):
381
+ x = self.get_input(batch, self.image_key)
382
+ xrec, qloss, ind = self(x, return_pred_indices=True)
383
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
384
+ self.global_step,
385
+ last_layer=self.get_last_layer(),
386
+ split="val"+suffix,
387
+ predicted_indices=ind
388
+ )
389
+
390
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
391
+ self.global_step,
392
+ last_layer=self.get_last_layer(),
393
+ split="val"+suffix,
394
+ predicted_indices=ind
395
+ )
396
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
397
+ self.log(f"val{suffix}/rec_loss", rec_loss,
398
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
399
+ self.log(f"val{suffix}/aeloss", aeloss,
400
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
401
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
402
+ del log_dict_ae[f"val{suffix}/rec_loss"]
403
+ self.log_dict(log_dict_ae)
404
+ self.log_dict(log_dict_disc)
405
+ return self.log_dict
406
+
407
+ def configure_optimizers(self):
408
+ lr_d = self.learning_rate
409
+ lr_g = self.lr_g_factor*self.learning_rate
410
+ print("lr_d", lr_d)
411
+ print("lr_g", lr_g)
412
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
413
+ list(self.decoder.parameters())+
414
+ list(self.quantize.parameters())+
415
+ list(self.quant_conv.parameters())+
416
+ list(self.post_quant_conv.parameters()),
417
+ lr=lr_g, betas=(0.5, 0.9))
418
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
419
+ lr=lr_d, betas=(0.5, 0.9))
420
+
421
+ if self.scheduler_config is not None:
422
+ scheduler = instantiate_from_config(self.scheduler_config)
423
+
424
+ print("Setting up LambdaLR scheduler...")
425
+ scheduler = [
426
+ {
427
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
428
+ 'interval': 'step',
429
+ 'frequency': 1
430
+ },
431
+ {
432
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
433
+ 'interval': 'step',
434
+ 'frequency': 1
435
+ },
436
+ ]
437
+ return [opt_ae, opt_disc], scheduler
438
+ return [opt_ae, opt_disc], []
439
+
440
+ def get_last_layer(self):
441
+ return self.decoder.conv_out.weight
442
+
443
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
444
+ log = dict()
445
+ x = self.get_input(batch, self.image_key)
446
+ x = x.to(self.device)
447
+ if only_inputs:
448
+ log["inputs"] = x
449
+ return log
450
+ xrec, _ = self(x)
451
+ if x.shape[1] > 3:
452
+ # colorize with random projection
453
+ assert xrec.shape[1] > 3
454
+ x = self.to_rgb(x)
455
+ xrec = self.to_rgb(xrec)
456
+ log["inputs"] = x
457
+ log["reconstructions"] = xrec
458
+ if plot_ema:
459
+ with self.ema_scope():
460
+ xrec_ema, _ = self(x)
461
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
462
+ log["reconstructions_ema"] = xrec_ema
463
+ return log
464
+
465
+ def to_rgb(self, x):
466
+ assert self.image_key == "segmentation"
467
+ if not hasattr(self, "colorize"):
468
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
469
+ x = F.conv2d(x, weight=self.colorize)
470
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
471
+ return x
472
+
473
+
474
+ class VQModelInterface(VQModel):
475
+ def __init__(self, embed_dim, *args, **kwargs):
476
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
477
+ self.embed_dim = embed_dim
478
+
479
+ def encode(self, x):
480
+ h = self.encoder(x)
481
+ h = self.quant_conv(h)
482
+ return h
483
+
484
+ def decode(self, h, force_not_quantize=False):
485
+ # also go through quantization layer
486
+ if not force_not_quantize:
487
+ quant, emb_loss, info = self.quantize(h)
488
+ else:
489
+ quant = h
490
+ quant = self.post_quant_conv(quant)
491
+ dec = self.decoder(quant)
492
+ return dec
watermarker/LaWa/ldm/models/diffusion/__init__.py ADDED
File without changes
watermarker/LaWa/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (170 Bytes). View file
 
watermarker/LaWa/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (9.38 kB). View file
 
watermarker/LaWa/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc ADDED
Binary file (53 kB). View file
 
watermarker/LaWa/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc ADDED
Binary file (7.53 kB). View file
 
watermarker/LaWa/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc ADDED
Binary file (1.07 kB). View file
 
watermarker/LaWa/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ if verbose:
102
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
103
+
104
+ samples, intermediates = self.ddim_sampling(conditioning, size,
105
+ callback=callback,
106
+ img_callback=img_callback,
107
+ quantize_denoised=quantize_x0,
108
+ mask=mask, x0=x0,
109
+ ddim_use_original_steps=False,
110
+ noise_dropout=noise_dropout,
111
+ temperature=temperature,
112
+ score_corrector=score_corrector,
113
+ corrector_kwargs=corrector_kwargs,
114
+ x_T=x_T,
115
+ log_every_t=log_every_t,
116
+ unconditional_guidance_scale=unconditional_guidance_scale,
117
+ unconditional_conditioning=unconditional_conditioning,
118
+ dynamic_threshold=dynamic_threshold,
119
+ ucg_schedule=ucg_schedule, verbose=verbose
120
+ )
121
+ return samples, intermediates
122
+
123
+ @torch.no_grad()
124
+ def ddim_sampling(self, cond, shape,
125
+ x_T=None, ddim_use_original_steps=False,
126
+ callback=None, timesteps=None, quantize_denoised=False,
127
+ mask=None, x0=None, img_callback=None, log_every_t=100,
128
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
129
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
130
+ ucg_schedule=None, verbose=True, **kwargs):
131
+ device = self.model.betas.device
132
+ b = shape[0]
133
+ if x_T is None:
134
+ img = torch.randn(shape, device=device)
135
+ else:
136
+ img = x_T
137
+
138
+ if timesteps is None:
139
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
140
+ elif timesteps is not None and not ddim_use_original_steps:
141
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
142
+ timesteps = self.ddim_timesteps[:subset_end]
143
+
144
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
145
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
146
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
147
+ if verbose:
148
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
149
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, miniters=total_steps//5, mininterval=300)
150
+ else:
151
+ iterator = time_range
152
+
153
+ for i, step in enumerate(iterator):
154
+ index = total_steps - i - 1
155
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
156
+
157
+ if mask is not None:
158
+ assert x0 is not None
159
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
160
+ img = img_orig * mask + (1. - mask) * img
161
+
162
+ if ucg_schedule is not None:
163
+ assert len(ucg_schedule) == len(time_range)
164
+ unconditional_guidance_scale = ucg_schedule[i]
165
+
166
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
167
+ quantize_denoised=quantize_denoised, temperature=temperature,
168
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
169
+ corrector_kwargs=corrector_kwargs,
170
+ unconditional_guidance_scale=unconditional_guidance_scale,
171
+ unconditional_conditioning=unconditional_conditioning,
172
+ dynamic_threshold=dynamic_threshold)
173
+ img, pred_x0 = outs
174
+ if callback: callback(i)
175
+ if img_callback: img_callback(pred_x0, i)
176
+
177
+ if index % log_every_t == 0 or index == total_steps - 1:
178
+ intermediates['x_inter'].append(img)
179
+ intermediates['pred_x0'].append(pred_x0)
180
+
181
+ return img, intermediates
182
+
183
+ @torch.no_grad()
184
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
185
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
186
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
187
+ dynamic_threshold=None):
188
+ b, *_, device = *x.shape, x.device
189
+
190
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
191
+ model_output = self.model.apply_model(x, t, c)
192
+ else:
193
+ x_in = torch.cat([x] * 2)
194
+ t_in = torch.cat([t] * 2)
195
+ if isinstance(c, dict):
196
+ assert isinstance(unconditional_conditioning, dict)
197
+ c_in = dict()
198
+ for k in c:
199
+ if isinstance(c[k], list):
200
+ c_in[k] = [torch.cat([
201
+ unconditional_conditioning[k][i],
202
+ c[k][i]]) for i in range(len(c[k]))]
203
+ else:
204
+ c_in[k] = torch.cat([
205
+ unconditional_conditioning[k],
206
+ c[k]])
207
+ elif isinstance(c, list):
208
+ c_in = list()
209
+ assert isinstance(unconditional_conditioning, list)
210
+ for i in range(len(c)):
211
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
212
+ else:
213
+ c_in = torch.cat([unconditional_conditioning, c])
214
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
215
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
216
+
217
+ if self.model.parameterization == "v":
218
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
219
+ else:
220
+ e_t = model_output
221
+
222
+ if score_corrector is not None:
223
+ assert self.model.parameterization == "eps", 'not implemented'
224
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
225
+
226
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
227
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
228
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
229
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
230
+ # select parameters corresponding to the currently considered timestep
231
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
232
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
233
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
234
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
235
+
236
+ # current prediction for x_0
237
+ if self.model.parameterization != "v":
238
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
239
+ else:
240
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
241
+
242
+ if quantize_denoised:
243
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
244
+
245
+ if dynamic_threshold is not None:
246
+ raise NotImplementedError()
247
+
248
+ # direction pointing to x_t
249
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
250
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
251
+ if noise_dropout > 0.:
252
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
253
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
254
+ return x_prev, pred_x0
255
+
256
+ @torch.no_grad()
257
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
258
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
259
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
260
+
261
+ assert t_enc <= num_reference_steps
262
+ num_steps = t_enc
263
+
264
+ if use_original_steps:
265
+ alphas_next = self.alphas_cumprod[:num_steps]
266
+ alphas = self.alphas_cumprod_prev[:num_steps]
267
+ else:
268
+ alphas_next = self.ddim_alphas[:num_steps]
269
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
270
+
271
+ x_next = x0
272
+ intermediates = []
273
+ inter_steps = []
274
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
275
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
276
+ if unconditional_guidance_scale == 1.:
277
+ noise_pred = self.model.apply_model(x_next, t, c)
278
+ else:
279
+ assert unconditional_conditioning is not None
280
+ e_t_uncond, noise_pred = torch.chunk(
281
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
282
+ torch.cat((unconditional_conditioning, c))), 2)
283
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
284
+
285
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
286
+ weighted_noise_pred = alphas_next[i].sqrt() * (
287
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
288
+ x_next = xt_weighted + weighted_noise_pred
289
+ if return_intermediates and i % (
290
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
291
+ intermediates.append(x_next)
292
+ inter_steps.append(i)
293
+ elif return_intermediates and i >= num_steps - 2:
294
+ intermediates.append(x_next)
295
+ inter_steps.append(i)
296
+ if callback: callback(i)
297
+
298
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
299
+ if return_intermediates:
300
+ out.update({'intermediates': intermediates})
301
+ return x_next, out
302
+
303
+ @torch.no_grad()
304
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
305
+ # fast, but does not allow for exact reconstruction
306
+ # t serves as an index to gather the correct alphas
307
+ if use_original_steps:
308
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
309
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
310
+ else:
311
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
312
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
313
+
314
+ if noise is None:
315
+ noise = torch.randn_like(x0)
316
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
317
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
318
+
319
+ @torch.no_grad()
320
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
321
+ use_original_steps=False, callback=None):
322
+
323
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
324
+ timesteps = timesteps[:t_start]
325
+
326
+ time_range = np.flip(timesteps)
327
+ total_steps = timesteps.shape[0]
328
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
329
+
330
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
331
+ x_dec = x_latent
332
+ for i, step in enumerate(iterator):
333
+ index = total_steps - i - 1
334
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
335
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
336
+ unconditional_guidance_scale=unconditional_guidance_scale,
337
+ unconditional_conditioning=unconditional_conditioning)
338
+ if callback: callback(i)
339
+ return x_dec
watermarker/LaWa/ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager, nullcontext
16
+ from functools import partial
17
+ import itertools
18
+ from tqdm import tqdm
19
+ from torchvision.utils import make_grid
20
+ from pytorch_lightning.utilities.distributed import rank_zero_only
21
+ # from pytorch_lightning.utilities.rank_zero import rank_zero_only
22
+ from omegaconf import ListConfig
23
+
24
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
25
+ from ldm.modules.ema import LitEma
26
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
27
+ from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
28
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
29
+ from ldm.models.diffusion.ddim import DDIMSampler
30
+
31
+
32
+ __conditioning_keys__ = {'concat': 'c_concat',
33
+ 'crossattn': 'c_crossattn',
34
+ 'adm': 'y'}
35
+
36
+
37
+ def disabled_train(self, mode=True):
38
+ """Overwrite model.train with this function to make sure train/eval mode
39
+ does not change anymore."""
40
+ return self
41
+
42
+
43
+ def uniform_on_device(r1, r2, shape, device):
44
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
45
+
46
+
47
+ class DDPM(pl.LightningModule):
48
+ # classic DDPM with Gaussian diffusion, in image space
49
+ def __init__(self,
50
+ unet_config,
51
+ timesteps=1000,
52
+ beta_schedule="linear",
53
+ loss_type="l2",
54
+ ckpt_path=None,
55
+ ignore_keys=[],
56
+ load_only_unet=False,
57
+ monitor="val/loss",
58
+ use_ema=True,
59
+ first_stage_key="image",
60
+ image_size=256,
61
+ channels=3,
62
+ log_every_t=100,
63
+ clip_denoised=True,
64
+ linear_start=1e-4,
65
+ linear_end=2e-2,
66
+ cosine_s=8e-3,
67
+ given_betas=None,
68
+ original_elbo_weight=0.,
69
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
70
+ l_simple_weight=1.,
71
+ conditioning_key=None,
72
+ parameterization="eps", # all assuming fixed variance schedules
73
+ scheduler_config=None,
74
+ use_positional_encodings=False,
75
+ learn_logvar=False,
76
+ logvar_init=0.,
77
+ make_it_fit=False,
78
+ ucg_training=None,
79
+ reset_ema=False,
80
+ reset_num_ema_updates=False,
81
+ ):
82
+ super().__init__()
83
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
84
+ self.parameterization = parameterization
85
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
86
+ self.cond_stage_model = None
87
+ self.clip_denoised = clip_denoised
88
+ self.log_every_t = log_every_t
89
+ self.first_stage_key = first_stage_key
90
+ self.image_size = image_size # try conv?
91
+ self.channels = channels
92
+ self.use_positional_encodings = use_positional_encodings
93
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
94
+ count_params(self.model, verbose=True)
95
+ self.use_ema = use_ema
96
+ if self.use_ema:
97
+ self.model_ema = LitEma(self.model)
98
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
99
+
100
+ self.use_scheduler = scheduler_config is not None
101
+ if self.use_scheduler:
102
+ self.scheduler_config = scheduler_config
103
+
104
+ self.v_posterior = v_posterior
105
+ self.original_elbo_weight = original_elbo_weight
106
+ self.l_simple_weight = l_simple_weight
107
+
108
+ if monitor is not None:
109
+ self.monitor = monitor
110
+ self.make_it_fit = make_it_fit
111
+ if reset_ema: assert exists(ckpt_path)
112
+ if ckpt_path is not None:
113
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
114
+ if reset_ema:
115
+ assert self.use_ema
116
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
117
+ self.model_ema = LitEma(self.model)
118
+ if reset_num_ema_updates:
119
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
120
+ assert self.use_ema
121
+ self.model_ema.reset_num_updates()
122
+
123
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
124
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
125
+
126
+ self.loss_type = loss_type
127
+
128
+ self.learn_logvar = learn_logvar
129
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
130
+ if self.learn_logvar:
131
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
132
+ else:
133
+ self.register_buffer('logvar', logvar)
134
+
135
+ self.ucg_training = ucg_training or dict()
136
+ if self.ucg_training:
137
+ self.ucg_prng = np.random.RandomState()
138
+
139
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
140
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
141
+ if exists(given_betas):
142
+ betas = given_betas
143
+ else:
144
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
145
+ cosine_s=cosine_s)
146
+ alphas = 1. - betas
147
+ alphas_cumprod = np.cumprod(alphas, axis=0)
148
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
149
+
150
+ timesteps, = betas.shape
151
+ self.num_timesteps = int(timesteps)
152
+ self.linear_start = linear_start
153
+ self.linear_end = linear_end
154
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
155
+
156
+ to_torch = partial(torch.tensor, dtype=torch.float32)
157
+
158
+ self.register_buffer('betas', to_torch(betas))
159
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
160
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
161
+
162
+ # calculations for diffusion q(x_t | x_{t-1}) and others
163
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
164
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
165
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
166
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
167
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
168
+
169
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
170
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
171
+ 1. - alphas_cumprod) + self.v_posterior * betas
172
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
173
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
174
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
175
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
176
+ self.register_buffer('posterior_mean_coef1', to_torch(
177
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
178
+ self.register_buffer('posterior_mean_coef2', to_torch(
179
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
180
+
181
+ if self.parameterization == "eps":
182
+ lvlb_weights = self.betas ** 2 / (
183
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
184
+ elif self.parameterization == "x0":
185
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
186
+ elif self.parameterization == "v":
187
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
188
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
189
+ else:
190
+ raise NotImplementedError("mu not supported")
191
+ lvlb_weights[0] = lvlb_weights[1]
192
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
193
+ assert not torch.isnan(self.lvlb_weights).all()
194
+
195
+ @contextmanager
196
+ def ema_scope(self, context=None):
197
+ if self.use_ema:
198
+ self.model_ema.store(self.model.parameters())
199
+ self.model_ema.copy_to(self.model)
200
+ if context is not None:
201
+ print(f"{context}: Switched to EMA weights")
202
+ try:
203
+ yield None
204
+ finally:
205
+ if self.use_ema:
206
+ self.model_ema.restore(self.model.parameters())
207
+ if context is not None:
208
+ print(f"{context}: Restored training weights")
209
+
210
+ @torch.no_grad()
211
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
212
+ sd = torch.load(path, map_location="cpu")
213
+ if "state_dict" in list(sd.keys()):
214
+ sd = sd["state_dict"]
215
+ keys = list(sd.keys())
216
+ for k in keys:
217
+ for ik in ignore_keys:
218
+ if k.startswith(ik):
219
+ print("Deleting key {} from state_dict.".format(k))
220
+ del sd[k]
221
+ if self.make_it_fit:
222
+ n_params = len([name for name, _ in
223
+ itertools.chain(self.named_parameters(),
224
+ self.named_buffers())])
225
+ for name, param in tqdm(
226
+ itertools.chain(self.named_parameters(),
227
+ self.named_buffers()),
228
+ desc="Fitting old weights to new weights",
229
+ total=n_params
230
+ ):
231
+ if not name in sd:
232
+ continue
233
+ old_shape = sd[name].shape
234
+ new_shape = param.shape
235
+ assert len(old_shape) == len(new_shape)
236
+ if len(new_shape) > 2:
237
+ # we only modify first two axes
238
+ assert new_shape[2:] == old_shape[2:]
239
+ # assumes first axis corresponds to output dim
240
+ if not new_shape == old_shape:
241
+ new_param = param.clone()
242
+ old_param = sd[name]
243
+ if len(new_shape) == 1:
244
+ for i in range(new_param.shape[0]):
245
+ new_param[i] = old_param[i % old_shape[0]]
246
+ elif len(new_shape) >= 2:
247
+ for i in range(new_param.shape[0]):
248
+ for j in range(new_param.shape[1]):
249
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
250
+
251
+ n_used_old = torch.ones(old_shape[1])
252
+ for j in range(new_param.shape[1]):
253
+ n_used_old[j % old_shape[1]] += 1
254
+ n_used_new = torch.zeros(new_shape[1])
255
+ for j in range(new_param.shape[1]):
256
+ n_used_new[j] = n_used_old[j % old_shape[1]]
257
+
258
+ n_used_new = n_used_new[None, :]
259
+ while len(n_used_new.shape) < len(new_shape):
260
+ n_used_new = n_used_new.unsqueeze(-1)
261
+ new_param /= n_used_new
262
+
263
+ sd[name] = new_param
264
+
265
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
266
+ sd, strict=False)
267
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
268
+ if len(missing) > 0:
269
+ print(f"Missing Keys:\n {missing}")
270
+ if len(unexpected) > 0:
271
+ print(f"\nUnexpected Keys:\n {unexpected}")
272
+
273
+ def q_mean_variance(self, x_start, t):
274
+ """
275
+ Get the distribution q(x_t | x_0).
276
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
277
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
278
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
279
+ """
280
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
281
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
282
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
283
+ return mean, variance, log_variance
284
+
285
+ def predict_start_from_noise(self, x_t, t, noise):
286
+ return (
287
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
288
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
289
+ )
290
+
291
+ def predict_start_from_z_and_v(self, x_t, t, v):
292
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
293
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
294
+ return (
295
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
296
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
297
+ )
298
+
299
+ def predict_eps_from_z_and_v(self, x_t, t, v):
300
+ return (
301
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
302
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
303
+ )
304
+
305
+ def q_posterior(self, x_start, x_t, t):
306
+ posterior_mean = (
307
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
308
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
309
+ )
310
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
311
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
312
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
313
+
314
+ def p_mean_variance(self, x, t, clip_denoised: bool):
315
+ model_out = self.model(x, t)
316
+ if self.parameterization == "eps":
317
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
318
+ elif self.parameterization == "x0":
319
+ x_recon = model_out
320
+ if clip_denoised:
321
+ x_recon.clamp_(-1., 1.)
322
+
323
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
324
+ return model_mean, posterior_variance, posterior_log_variance
325
+
326
+ @torch.no_grad()
327
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
328
+ b, *_, device = *x.shape, x.device
329
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
330
+ noise = noise_like(x.shape, device, repeat_noise)
331
+ # no noise when t == 0
332
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
333
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
334
+
335
+ @torch.no_grad()
336
+ def p_sample_loop(self, shape, return_intermediates=False):
337
+ device = self.betas.device
338
+ b = shape[0]
339
+ img = torch.randn(shape, device=device)
340
+ intermediates = [img]
341
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
342
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
343
+ clip_denoised=self.clip_denoised)
344
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
345
+ intermediates.append(img)
346
+ if return_intermediates:
347
+ return img, intermediates
348
+ return img
349
+
350
+ @torch.no_grad()
351
+ def sample(self, batch_size=16, return_intermediates=False):
352
+ image_size = self.image_size
353
+ channels = self.channels
354
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
355
+ return_intermediates=return_intermediates)
356
+
357
+ def q_sample(self, x_start, t, noise=None):
358
+ noise = default(noise, lambda: torch.randn_like(x_start))
359
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
360
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
361
+
362
+ def get_v(self, x, noise, t):
363
+ return (
364
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
365
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
366
+ )
367
+
368
+ def get_loss(self, pred, target, mean=True):
369
+ if self.loss_type == 'l1':
370
+ loss = (target - pred).abs()
371
+ if mean:
372
+ loss = loss.mean()
373
+ elif self.loss_type == 'l2':
374
+ if mean:
375
+ loss = torch.nn.functional.mse_loss(target, pred)
376
+ else:
377
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
378
+ else:
379
+ raise NotImplementedError("unknown loss type '{loss_type}'")
380
+
381
+ return loss
382
+
383
+ def p_losses(self, x_start, t, noise=None):
384
+ noise = default(noise, lambda: torch.randn_like(x_start))
385
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
386
+ model_out = self.model(x_noisy, t)
387
+
388
+ loss_dict = {}
389
+ if self.parameterization == "eps":
390
+ target = noise
391
+ elif self.parameterization == "x0":
392
+ target = x_start
393
+ elif self.parameterization == "v":
394
+ target = self.get_v(x_start, noise, t)
395
+ else:
396
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
397
+
398
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
399
+
400
+ log_prefix = 'train' if self.training else 'val'
401
+
402
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
403
+ loss_simple = loss.mean() * self.l_simple_weight
404
+
405
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
406
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
407
+
408
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
409
+
410
+ loss_dict.update({f'{log_prefix}/loss': loss})
411
+
412
+ return loss, loss_dict
413
+
414
+ def forward(self, x, *args, **kwargs):
415
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
416
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
417
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
418
+ return self.p_losses(x, t, *args, **kwargs)
419
+
420
+ def get_input(self, batch, k):
421
+ x = batch[k]
422
+ if len(x.shape) == 3:
423
+ x = x[..., None]
424
+ x = rearrange(x, 'b h w c -> b c h w')
425
+ x = x.to(memory_format=torch.contiguous_format).float()
426
+ return x
427
+
428
+ def shared_step(self, batch):
429
+ x = self.get_input(batch, self.first_stage_key)
430
+ loss, loss_dict = self(x)
431
+ return loss, loss_dict
432
+
433
+ def training_step(self, batch, batch_idx):
434
+ for k in self.ucg_training:
435
+ p = self.ucg_training[k]["p"]
436
+ val = self.ucg_training[k]["val"]
437
+ if val is None:
438
+ val = ""
439
+ for i in range(len(batch[k])):
440
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
441
+ batch[k][i] = val
442
+
443
+ loss, loss_dict = self.shared_step(batch)
444
+
445
+ self.log_dict(loss_dict, prog_bar=True,
446
+ logger=True, on_step=True, on_epoch=True)
447
+
448
+ self.log("global_step", self.global_step,
449
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
450
+
451
+ if self.use_scheduler:
452
+ lr = self.optimizers().param_groups[0]['lr']
453
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
454
+
455
+ return loss
456
+
457
+ @torch.no_grad()
458
+ def validation_step(self, batch, batch_idx):
459
+ _, loss_dict_no_ema = self.shared_step(batch)
460
+ with self.ema_scope():
461
+ _, loss_dict_ema = self.shared_step(batch)
462
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
463
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
464
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
465
+
466
+ def on_train_batch_end(self, *args, **kwargs):
467
+ if self.use_ema:
468
+ self.model_ema(self.model)
469
+
470
+ def _get_rows_from_list(self, samples):
471
+ n_imgs_per_row = len(samples)
472
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
473
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
474
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
475
+ return denoise_grid
476
+
477
+ @torch.no_grad()
478
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
479
+ log = dict()
480
+ x = self.get_input(batch, self.first_stage_key)
481
+ N = min(x.shape[0], N)
482
+ n_row = min(x.shape[0], n_row)
483
+ x = x.to(self.device)[:N]
484
+ log["inputs"] = x
485
+
486
+ # get diffusion row
487
+ diffusion_row = list()
488
+ x_start = x[:n_row]
489
+
490
+ for t in range(self.num_timesteps):
491
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
492
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
493
+ t = t.to(self.device).long()
494
+ noise = torch.randn_like(x_start)
495
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
496
+ diffusion_row.append(x_noisy)
497
+
498
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
499
+
500
+ if sample:
501
+ # get denoise row
502
+ with self.ema_scope("Plotting"):
503
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
504
+
505
+ log["samples"] = samples
506
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
507
+
508
+ if return_keys:
509
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
510
+ return log
511
+ else:
512
+ return {key: log[key] for key in return_keys}
513
+ return log
514
+
515
+ def configure_optimizers(self):
516
+ lr = self.learning_rate
517
+ params = list(self.model.parameters())
518
+ if self.learn_logvar:
519
+ params = params + [self.logvar]
520
+ opt = torch.optim.AdamW(params, lr=lr)
521
+ return opt
522
+
523
+
524
+ class LatentDiffusion(DDPM):
525
+ """main class"""
526
+
527
+ def __init__(self,
528
+ first_stage_config,
529
+ cond_stage_config,
530
+ num_timesteps_cond=None,
531
+ cond_stage_key="image",
532
+ cond_stage_trainable=False,
533
+ concat_mode=True,
534
+ cond_stage_forward=None,
535
+ conditioning_key=None,
536
+ scale_factor=1.0,
537
+ scale_by_std=False,
538
+ force_null_conditioning=False,
539
+ *args, **kwargs):
540
+ self.force_null_conditioning = force_null_conditioning
541
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
542
+ self.scale_by_std = scale_by_std
543
+ assert self.num_timesteps_cond <= kwargs['timesteps']
544
+ # for backwards compatibility after implementation of DiffusionWrapper
545
+ if conditioning_key is None:
546
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
547
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
548
+ conditioning_key = None
549
+ ckpt_path = kwargs.pop("ckpt_path", None)
550
+ reset_ema = kwargs.pop("reset_ema", False)
551
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
552
+ ignore_keys = kwargs.pop("ignore_keys", [])
553
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
554
+ self.concat_mode = concat_mode
555
+ self.cond_stage_trainable = cond_stage_trainable
556
+ self.cond_stage_key = cond_stage_key
557
+ try:
558
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
559
+ except:
560
+ self.num_downs = 0
561
+ if not scale_by_std:
562
+ self.scale_factor = scale_factor
563
+ else:
564
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
565
+ self.instantiate_first_stage(first_stage_config)
566
+ self.instantiate_cond_stage(cond_stage_config)
567
+ self.cond_stage_forward = cond_stage_forward
568
+ self.clip_denoised = False
569
+ self.bbox_tokenizer = None
570
+
571
+ self.restarted_from_ckpt = False
572
+ if ckpt_path is not None:
573
+ self.init_from_ckpt(ckpt_path, ignore_keys)
574
+ self.restarted_from_ckpt = True
575
+ if reset_ema:
576
+ assert self.use_ema
577
+ print(
578
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
579
+ self.model_ema = LitEma(self.model)
580
+ if reset_num_ema_updates:
581
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
582
+ assert self.use_ema
583
+ self.model_ema.reset_num_updates()
584
+
585
+ def make_cond_schedule(self, ):
586
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
587
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
588
+ self.cond_ids[:self.num_timesteps_cond] = ids
589
+
590
+ @rank_zero_only
591
+ @torch.no_grad()
592
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
593
+ # only for very first batch
594
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
595
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
596
+ # set rescale weight to 1./std of encodings
597
+ print("### USING STD-RESCALING ###")
598
+ x = super().get_input(batch, self.first_stage_key)
599
+ x = x.to(self.device)
600
+ encoder_posterior = self.encode_first_stage(x)
601
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
602
+ del self.scale_factor
603
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
604
+ print(f"setting self.scale_factor to {self.scale_factor}")
605
+ print("### USING STD-RESCALING ###")
606
+
607
+ def register_schedule(self,
608
+ given_betas=None, beta_schedule="linear", timesteps=1000,
609
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
610
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
611
+
612
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
613
+ if self.shorten_cond_schedule:
614
+ self.make_cond_schedule()
615
+
616
+ def instantiate_first_stage(self, config):
617
+ model = instantiate_from_config(config)
618
+ self.first_stage_model = model.eval()
619
+ self.first_stage_model.train = disabled_train
620
+ for param in self.first_stage_model.parameters():
621
+ param.requires_grad = False
622
+
623
+ def instantiate_cond_stage(self, config):
624
+ if not self.cond_stage_trainable:
625
+ if config == "__is_first_stage__":
626
+ print("Using first stage also as cond stage.")
627
+ self.cond_stage_model = self.first_stage_model
628
+ elif config == "__is_unconditional__":
629
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
630
+ self.cond_stage_model = None
631
+ # self.be_unconditional = True
632
+ else:
633
+ model = instantiate_from_config(config)
634
+ self.cond_stage_model = model.eval()
635
+ self.cond_stage_model.train = disabled_train
636
+ for param in self.cond_stage_model.parameters():
637
+ param.requires_grad = False
638
+ else:
639
+ assert config != '__is_first_stage__'
640
+ assert config != '__is_unconditional__'
641
+ model = instantiate_from_config(config)
642
+ self.cond_stage_model = model
643
+
644
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
645
+ denoise_row = []
646
+ for zd in tqdm(samples, desc=desc):
647
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
648
+ force_not_quantize=force_no_decoder_quantization))
649
+ n_imgs_per_row = len(denoise_row)
650
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
651
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
652
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
653
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
654
+ return denoise_grid
655
+
656
+ def get_first_stage_encoding(self, encoder_posterior):
657
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
658
+ z = encoder_posterior.sample()
659
+ elif isinstance(encoder_posterior, torch.Tensor):
660
+ z = encoder_posterior
661
+ else:
662
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
663
+ return self.scale_factor * z
664
+
665
+ def get_learned_conditioning(self, c):
666
+ if self.cond_stage_forward is None:
667
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
668
+ c = self.cond_stage_model.encode(c)
669
+ if isinstance(c, DiagonalGaussianDistribution):
670
+ c = c.mode()
671
+ else:
672
+ c = self.cond_stage_model(c)
673
+ else:
674
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
675
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
676
+ return c
677
+
678
+ def meshgrid(self, h, w):
679
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
680
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
681
+
682
+ arr = torch.cat([y, x], dim=-1)
683
+ return arr
684
+
685
+ def delta_border(self, h, w):
686
+ """
687
+ :param h: height
688
+ :param w: width
689
+ :return: normalized distance to image border,
690
+ wtith min distance = 0 at border and max dist = 0.5 at image center
691
+ """
692
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
693
+ arr = self.meshgrid(h, w) / lower_right_corner
694
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
695
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
696
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
697
+ return edge_dist
698
+
699
+ def get_weighting(self, h, w, Ly, Lx, device):
700
+ weighting = self.delta_border(h, w)
701
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
702
+ self.split_input_params["clip_max_weight"], )
703
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
704
+
705
+ if self.split_input_params["tie_braker"]:
706
+ L_weighting = self.delta_border(Ly, Lx)
707
+ L_weighting = torch.clip(L_weighting,
708
+ self.split_input_params["clip_min_tie_weight"],
709
+ self.split_input_params["clip_max_tie_weight"])
710
+
711
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
712
+ weighting = weighting * L_weighting
713
+ return weighting
714
+
715
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
716
+ """
717
+ :param x: img of size (bs, c, h, w)
718
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
719
+ """
720
+ bs, nc, h, w = x.shape
721
+
722
+ # number of crops in image
723
+ Ly = (h - kernel_size[0]) // stride[0] + 1
724
+ Lx = (w - kernel_size[1]) // stride[1] + 1
725
+
726
+ if uf == 1 and df == 1:
727
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
728
+ unfold = torch.nn.Unfold(**fold_params)
729
+
730
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
731
+
732
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
733
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
734
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
735
+
736
+ elif uf > 1 and df == 1:
737
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
738
+ unfold = torch.nn.Unfold(**fold_params)
739
+
740
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
741
+ dilation=1, padding=0,
742
+ stride=(stride[0] * uf, stride[1] * uf))
743
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
744
+
745
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
746
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
747
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
748
+
749
+ elif df > 1 and uf == 1:
750
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
751
+ unfold = torch.nn.Unfold(**fold_params)
752
+
753
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
754
+ dilation=1, padding=0,
755
+ stride=(stride[0] // df, stride[1] // df))
756
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
757
+
758
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
759
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
760
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
761
+
762
+ else:
763
+ raise NotImplementedError
764
+
765
+ return fold, unfold, normalization, weighting
766
+
767
+ @torch.no_grad()
768
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
769
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
770
+ x = super().get_input(batch, k)
771
+ if bs is not None:
772
+ x = x[:bs]
773
+ x = x.to(self.device)
774
+ encoder_posterior = self.encode_first_stage(x)
775
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
776
+
777
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
778
+ if cond_key is None:
779
+ cond_key = self.cond_stage_key
780
+ if cond_key != self.first_stage_key:
781
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
782
+ xc = batch[cond_key]
783
+ elif cond_key in ['class_label', 'cls']:
784
+ xc = batch
785
+ else:
786
+ xc = super().get_input(batch, cond_key).to(self.device)
787
+ else:
788
+ xc = x
789
+ if not self.cond_stage_trainable or force_c_encode:
790
+ if isinstance(xc, dict) or isinstance(xc, list):
791
+ c = self.get_learned_conditioning(xc)
792
+ else:
793
+ c = self.get_learned_conditioning(xc.to(self.device))
794
+ else:
795
+ c = xc
796
+ if bs is not None:
797
+ c = c[:bs]
798
+
799
+ if self.use_positional_encodings:
800
+ pos_x, pos_y = self.compute_latent_shifts(batch)
801
+ ckey = __conditioning_keys__[self.model.conditioning_key]
802
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
803
+
804
+ else:
805
+ c = None
806
+ xc = None
807
+ if self.use_positional_encodings:
808
+ pos_x, pos_y = self.compute_latent_shifts(batch)
809
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
810
+ out = [z, c]
811
+ if return_first_stage_outputs:
812
+ xrec = self.decode_first_stage(z)
813
+ out.extend([x, xrec])
814
+ if return_x:
815
+ out.extend([x])
816
+ if return_original_cond:
817
+ out.append(xc)
818
+ return out
819
+
820
+ @torch.no_grad()
821
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
822
+ if predict_cids:
823
+ if z.dim() == 4:
824
+ z = torch.argmax(z.exp(), dim=1).long()
825
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
826
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
827
+
828
+ z = 1. / self.scale_factor * z
829
+ return self.first_stage_model.decode(z)
830
+
831
+ @torch.no_grad()
832
+ def encode_first_stage(self, x):
833
+ return self.first_stage_model.encode(x)
834
+
835
+ def shared_step(self, batch, **kwargs):
836
+ x, c = self.get_input(batch, self.first_stage_key)
837
+ loss = self(x, c)
838
+ return loss
839
+
840
+ def forward(self, x, c, *args, **kwargs):
841
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
842
+ if self.model.conditioning_key is not None:
843
+ assert c is not None
844
+ if self.cond_stage_trainable:
845
+ c = self.get_learned_conditioning(c)
846
+ if self.shorten_cond_schedule: # TODO: drop this option
847
+ tc = self.cond_ids[t].to(self.device)
848
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
849
+ return self.p_losses(x, c, t, *args, **kwargs)
850
+
851
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
852
+ if isinstance(cond, dict):
853
+ # hybrid case, cond is expected to be a dict
854
+ pass
855
+ else:
856
+ if not isinstance(cond, list):
857
+ cond = [cond]
858
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
859
+ cond = {key: cond}
860
+
861
+ x_recon = self.model(x_noisy, t, **cond)
862
+
863
+ if isinstance(x_recon, tuple) and not return_ids:
864
+ return x_recon[0]
865
+ else:
866
+ return x_recon
867
+
868
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
869
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
870
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
871
+
872
+ def _prior_bpd(self, x_start):
873
+ """
874
+ Get the prior KL term for the variational lower-bound, measured in
875
+ bits-per-dim.
876
+ This term can't be optimized, as it only depends on the encoder.
877
+ :param x_start: the [N x C x ...] tensor of inputs.
878
+ :return: a batch of [N] KL values (in bits), one per batch element.
879
+ """
880
+ batch_size = x_start.shape[0]
881
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
882
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
883
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
884
+ return mean_flat(kl_prior) / np.log(2.0)
885
+
886
+ def p_losses(self, x_start, cond, t, noise=None):
887
+ noise = default(noise, lambda: torch.randn_like(x_start))
888
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
889
+ model_output = self.apply_model(x_noisy, t, cond)
890
+
891
+ loss_dict = {}
892
+ prefix = 'train' if self.training else 'val'
893
+
894
+ if self.parameterization == "x0":
895
+ target = x_start
896
+ elif self.parameterization == "eps":
897
+ target = noise
898
+ elif self.parameterization == "v":
899
+ target = self.get_v(x_start, noise, t)
900
+ else:
901
+ raise NotImplementedError()
902
+
903
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
904
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
905
+
906
+ logvar_t = self.logvar[t].to(self.device)
907
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
908
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
909
+ if self.learn_logvar:
910
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
911
+ loss_dict.update({'logvar': self.logvar.data.mean()})
912
+
913
+ loss = self.l_simple_weight * loss.mean()
914
+
915
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
916
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
917
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
918
+ loss += (self.original_elbo_weight * loss_vlb)
919
+ loss_dict.update({f'{prefix}/loss': loss})
920
+
921
+ return loss, loss_dict
922
+
923
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
924
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
925
+ t_in = t
926
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
927
+
928
+ if score_corrector is not None:
929
+ assert self.parameterization == "eps"
930
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
931
+
932
+ if return_codebook_ids:
933
+ model_out, logits = model_out
934
+
935
+ if self.parameterization == "eps":
936
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
937
+ elif self.parameterization == "x0":
938
+ x_recon = model_out
939
+ else:
940
+ raise NotImplementedError()
941
+
942
+ if clip_denoised:
943
+ x_recon.clamp_(-1., 1.)
944
+ if quantize_denoised:
945
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
946
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
947
+ if return_codebook_ids:
948
+ return model_mean, posterior_variance, posterior_log_variance, logits
949
+ elif return_x0:
950
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
951
+ else:
952
+ return model_mean, posterior_variance, posterior_log_variance
953
+
954
+ @torch.no_grad()
955
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
956
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
957
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
958
+ b, *_, device = *x.shape, x.device
959
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
960
+ return_codebook_ids=return_codebook_ids,
961
+ quantize_denoised=quantize_denoised,
962
+ return_x0=return_x0,
963
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
964
+ if return_codebook_ids:
965
+ raise DeprecationWarning("Support dropped.")
966
+ model_mean, _, model_log_variance, logits = outputs
967
+ elif return_x0:
968
+ model_mean, _, model_log_variance, x0 = outputs
969
+ else:
970
+ model_mean, _, model_log_variance = outputs
971
+
972
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
973
+ if noise_dropout > 0.:
974
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
975
+ # no noise when t == 0
976
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
977
+
978
+ if return_codebook_ids:
979
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
980
+ if return_x0:
981
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
982
+ else:
983
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
984
+
985
+ @torch.no_grad()
986
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
987
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
988
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
989
+ log_every_t=None):
990
+ if not log_every_t:
991
+ log_every_t = self.log_every_t
992
+ timesteps = self.num_timesteps
993
+ if batch_size is not None:
994
+ b = batch_size if batch_size is not None else shape[0]
995
+ shape = [batch_size] + list(shape)
996
+ else:
997
+ b = batch_size = shape[0]
998
+ if x_T is None:
999
+ img = torch.randn(shape, device=self.device)
1000
+ else:
1001
+ img = x_T
1002
+ intermediates = []
1003
+ if cond is not None:
1004
+ if isinstance(cond, dict):
1005
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1006
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1007
+ else:
1008
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1009
+
1010
+ if start_T is not None:
1011
+ timesteps = min(timesteps, start_T)
1012
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1013
+ total=timesteps) if verbose else reversed(
1014
+ range(0, timesteps))
1015
+ if type(temperature) == float:
1016
+ temperature = [temperature] * timesteps
1017
+
1018
+ for i in iterator:
1019
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1020
+ if self.shorten_cond_schedule:
1021
+ assert self.model.conditioning_key != 'hybrid'
1022
+ tc = self.cond_ids[ts].to(cond.device)
1023
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1024
+
1025
+ img, x0_partial = self.p_sample(img, cond, ts,
1026
+ clip_denoised=self.clip_denoised,
1027
+ quantize_denoised=quantize_denoised, return_x0=True,
1028
+ temperature=temperature[i], noise_dropout=noise_dropout,
1029
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1030
+ if mask is not None:
1031
+ assert x0 is not None
1032
+ img_orig = self.q_sample(x0, ts)
1033
+ img = img_orig * mask + (1. - mask) * img
1034
+
1035
+ if i % log_every_t == 0 or i == timesteps - 1:
1036
+ intermediates.append(x0_partial)
1037
+ if callback: callback(i)
1038
+ if img_callback: img_callback(img, i)
1039
+ return img, intermediates
1040
+
1041
+ @torch.no_grad()
1042
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1043
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1044
+ mask=None, x0=None, img_callback=None, start_T=None,
1045
+ log_every_t=None):
1046
+
1047
+ if not log_every_t:
1048
+ log_every_t = self.log_every_t
1049
+ device = self.betas.device
1050
+ b = shape[0]
1051
+ if x_T is None:
1052
+ img = torch.randn(shape, device=device)
1053
+ else:
1054
+ img = x_T
1055
+
1056
+ intermediates = [img]
1057
+ if timesteps is None:
1058
+ timesteps = self.num_timesteps
1059
+
1060
+ if start_T is not None:
1061
+ timesteps = min(timesteps, start_T)
1062
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1063
+ range(0, timesteps))
1064
+
1065
+ if mask is not None:
1066
+ assert x0 is not None
1067
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1068
+
1069
+ for i in iterator:
1070
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1071
+ if self.shorten_cond_schedule:
1072
+ assert self.model.conditioning_key != 'hybrid'
1073
+ tc = self.cond_ids[ts].to(cond.device)
1074
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1075
+
1076
+ img = self.p_sample(img, cond, ts,
1077
+ clip_denoised=self.clip_denoised,
1078
+ quantize_denoised=quantize_denoised)
1079
+ if mask is not None:
1080
+ img_orig = self.q_sample(x0, ts)
1081
+ img = img_orig * mask + (1. - mask) * img
1082
+
1083
+ if i % log_every_t == 0 or i == timesteps - 1:
1084
+ intermediates.append(img)
1085
+ if callback: callback(i)
1086
+ if img_callback: img_callback(img, i)
1087
+
1088
+ if return_intermediates:
1089
+ return img, intermediates
1090
+ return img
1091
+
1092
+ @torch.no_grad()
1093
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1094
+ verbose=True, timesteps=None, quantize_denoised=False,
1095
+ mask=None, x0=None, shape=None, **kwargs):
1096
+ if shape is None:
1097
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1098
+ if cond is not None:
1099
+ if isinstance(cond, dict):
1100
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1101
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1102
+ else:
1103
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1104
+ return self.p_sample_loop(cond,
1105
+ shape,
1106
+ return_intermediates=return_intermediates, x_T=x_T,
1107
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1108
+ mask=mask, x0=x0)
1109
+
1110
+ @torch.no_grad()
1111
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1112
+ if ddim:
1113
+ ddim_sampler = DDIMSampler(self)
1114
+ shape = (self.channels, self.image_size, self.image_size)
1115
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1116
+ shape, cond, verbose=False, **kwargs)
1117
+
1118
+ else:
1119
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1120
+ return_intermediates=True, **kwargs)
1121
+
1122
+ return samples, intermediates
1123
+
1124
+ @torch.no_grad()
1125
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1126
+ if null_label is not None:
1127
+ xc = null_label
1128
+ if isinstance(xc, ListConfig):
1129
+ xc = list(xc)
1130
+ if isinstance(xc, dict) or isinstance(xc, list):
1131
+ c = self.get_learned_conditioning(xc)
1132
+ else:
1133
+ if hasattr(xc, "to"):
1134
+ xc = xc.to(self.device)
1135
+ c = self.get_learned_conditioning(xc)
1136
+ else:
1137
+ if self.cond_stage_key in ["class_label", "cls"]:
1138
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1139
+ return self.get_learned_conditioning(xc)
1140
+ else:
1141
+ raise NotImplementedError("todo")
1142
+ if isinstance(c, list): # in case the encoder gives us a list
1143
+ for i in range(len(c)):
1144
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1145
+ else:
1146
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1147
+ return c
1148
+
1149
+ @torch.no_grad()
1150
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1151
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1152
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1153
+ use_ema_scope=True,
1154
+ **kwargs):
1155
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1156
+ use_ddim = ddim_steps is not None
1157
+
1158
+ log = dict()
1159
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1160
+ return_first_stage_outputs=True,
1161
+ force_c_encode=True,
1162
+ return_original_cond=True,
1163
+ bs=N)
1164
+ N = min(x.shape[0], N)
1165
+ n_row = min(x.shape[0], n_row)
1166
+ log["inputs"] = x
1167
+ log["reconstruction"] = xrec
1168
+ if self.model.conditioning_key is not None:
1169
+ if hasattr(self.cond_stage_model, "decode"):
1170
+ xc = self.cond_stage_model.decode(c)
1171
+ log["conditioning"] = xc
1172
+ elif self.cond_stage_key in ["caption", "txt"]:
1173
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1174
+ log["conditioning"] = xc
1175
+ elif self.cond_stage_key in ['class_label', "cls"]:
1176
+ try:
1177
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1178
+ log['conditioning'] = xc
1179
+ except KeyError:
1180
+ # probably no "human_label" in batch
1181
+ pass
1182
+ elif isimage(xc):
1183
+ log["conditioning"] = xc
1184
+ if ismap(xc):
1185
+ log["original_conditioning"] = self.to_rgb(xc)
1186
+
1187
+ if plot_diffusion_rows:
1188
+ # get diffusion row
1189
+ diffusion_row = list()
1190
+ z_start = z[:n_row]
1191
+ for t in range(self.num_timesteps):
1192
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1193
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1194
+ t = t.to(self.device).long()
1195
+ noise = torch.randn_like(z_start)
1196
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1197
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1198
+
1199
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1200
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1201
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1202
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1203
+ log["diffusion_row"] = diffusion_grid
1204
+
1205
+ if sample:
1206
+ # get denoise row
1207
+ with ema_scope("Sampling"):
1208
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1209
+ ddim_steps=ddim_steps, eta=ddim_eta)
1210
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1211
+ x_samples = self.decode_first_stage(samples)
1212
+ log["samples"] = x_samples
1213
+ if plot_denoise_rows:
1214
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1215
+ log["denoise_row"] = denoise_grid
1216
+
1217
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1218
+ self.first_stage_model, IdentityFirstStage):
1219
+ # also display when quantizing x0 while sampling
1220
+ with ema_scope("Plotting Quantized Denoised"):
1221
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1222
+ ddim_steps=ddim_steps, eta=ddim_eta,
1223
+ quantize_denoised=True)
1224
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1225
+ # quantize_denoised=True)
1226
+ x_samples = self.decode_first_stage(samples.to(self.device))
1227
+ log["samples_x0_quantized"] = x_samples
1228
+
1229
+ if unconditional_guidance_scale > 1.0:
1230
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1231
+ if self.model.conditioning_key == "crossattn-adm":
1232
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1233
+ with ema_scope("Sampling with classifier-free guidance"):
1234
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1235
+ ddim_steps=ddim_steps, eta=ddim_eta,
1236
+ unconditional_guidance_scale=unconditional_guidance_scale,
1237
+ unconditional_conditioning=uc,
1238
+ )
1239
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1240
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1241
+
1242
+ if inpaint:
1243
+ # make a simple center square
1244
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1245
+ mask = torch.ones(N, h, w).to(self.device)
1246
+ # zeros will be filled in
1247
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1248
+ mask = mask[:, None, ...]
1249
+ with ema_scope("Plotting Inpaint"):
1250
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1251
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1252
+ x_samples = self.decode_first_stage(samples.to(self.device))
1253
+ log["samples_inpainting"] = x_samples
1254
+ log["mask"] = mask
1255
+
1256
+ # outpaint
1257
+ mask = 1. - mask
1258
+ with ema_scope("Plotting Outpaint"):
1259
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1260
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1261
+ x_samples = self.decode_first_stage(samples.to(self.device))
1262
+ log["samples_outpainting"] = x_samples
1263
+
1264
+ if plot_progressive_rows:
1265
+ with ema_scope("Plotting Progressives"):
1266
+ img, progressives = self.progressive_denoising(c,
1267
+ shape=(self.channels, self.image_size, self.image_size),
1268
+ batch_size=N)
1269
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1270
+ log["progressive_row"] = prog_row
1271
+
1272
+ if return_keys:
1273
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1274
+ return log
1275
+ else:
1276
+ return {key: log[key] for key in return_keys}
1277
+ return log
1278
+
1279
+ def configure_optimizers(self):
1280
+ lr = self.learning_rate
1281
+ params = list(self.model.parameters())
1282
+ if self.cond_stage_trainable:
1283
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1284
+ params = params + list(self.cond_stage_model.parameters())
1285
+ if self.learn_logvar:
1286
+ print('Diffusion model optimizing logvar')
1287
+ params.append(self.logvar)
1288
+ opt = torch.optim.AdamW(params, lr=lr)
1289
+ if self.use_scheduler:
1290
+ assert 'target' in self.scheduler_config
1291
+ scheduler = instantiate_from_config(self.scheduler_config)
1292
+
1293
+ print("Setting up LambdaLR scheduler...")
1294
+ scheduler = [
1295
+ {
1296
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1297
+ 'interval': 'step',
1298
+ 'frequency': 1
1299
+ }]
1300
+ return [opt], scheduler
1301
+ return opt
1302
+
1303
+ @torch.no_grad()
1304
+ def to_rgb(self, x):
1305
+ x = x.float()
1306
+ if not hasattr(self, "colorize"):
1307
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1308
+ x = nn.functional.conv2d(x, weight=self.colorize)
1309
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1310
+ return x
1311
+
1312
+
1313
+ class DiffusionWrapper(pl.LightningModule):
1314
+ def __init__(self, diff_model_config, conditioning_key):
1315
+ super().__init__()
1316
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1317
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1318
+ self.conditioning_key = conditioning_key
1319
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1320
+
1321
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1322
+ if self.conditioning_key is None:
1323
+ out = self.diffusion_model(x, t)
1324
+ elif self.conditioning_key == 'concat':
1325
+ xc = torch.cat([x] + c_concat, dim=1)
1326
+ out = self.diffusion_model(xc, t)
1327
+ elif self.conditioning_key == 'crossattn':
1328
+ if not self.sequential_cross_attn:
1329
+ cc = torch.cat(c_crossattn, 1)
1330
+ else:
1331
+ cc = c_crossattn
1332
+ out = self.diffusion_model(x, t, context=cc)
1333
+ elif self.conditioning_key == 'hybrid':
1334
+ xc = torch.cat([x] + c_concat, dim=1)
1335
+ cc = torch.cat(c_crossattn, 1)
1336
+ out = self.diffusion_model(xc, t, context=cc)
1337
+ elif self.conditioning_key == 'hybrid-adm':
1338
+ assert c_adm is not None
1339
+ xc = torch.cat([x] + c_concat, dim=1)
1340
+ cc = torch.cat(c_crossattn, 1)
1341
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1342
+ elif self.conditioning_key == 'crossattn-adm':
1343
+ assert c_adm is not None
1344
+ cc = torch.cat(c_crossattn, 1)
1345
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
1346
+ elif self.conditioning_key == 'adm':
1347
+ cc = c_crossattn[0]
1348
+ out = self.diffusion_model(x, t, y=cc)
1349
+ else:
1350
+ raise NotImplementedError()
1351
+
1352
+ return out
1353
+
1354
+
1355
+ class LatentUpscaleDiffusion(LatentDiffusion):
1356
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
1357
+ super().__init__(*args, **kwargs)
1358
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1359
+ assert not self.cond_stage_trainable
1360
+ self.instantiate_low_stage(low_scale_config)
1361
+ self.low_scale_key = low_scale_key
1362
+ self.noise_level_key = noise_level_key
1363
+
1364
+ def instantiate_low_stage(self, config):
1365
+ model = instantiate_from_config(config)
1366
+ self.low_scale_model = model.eval()
1367
+ self.low_scale_model.train = disabled_train
1368
+ for param in self.low_scale_model.parameters():
1369
+ param.requires_grad = False
1370
+
1371
+ @torch.no_grad()
1372
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1373
+ if not log_mode:
1374
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1375
+ else:
1376
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1377
+ force_c_encode=True, return_original_cond=True, bs=bs)
1378
+ x_low = batch[self.low_scale_key][:bs]
1379
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
1380
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1381
+ zx, noise_level = self.low_scale_model(x_low)
1382
+ if self.noise_level_key is not None:
1383
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1384
+ raise NotImplementedError('TODO')
1385
+
1386
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1387
+ if log_mode:
1388
+ # TODO: maybe disable if too expensive
1389
+ x_low_rec = self.low_scale_model.decode(zx)
1390
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1391
+ return z, all_conds
1392
+
1393
+ @torch.no_grad()
1394
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1395
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1396
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1397
+ **kwargs):
1398
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1399
+ use_ddim = ddim_steps is not None
1400
+
1401
+ log = dict()
1402
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1403
+ log_mode=True)
1404
+ N = min(x.shape[0], N)
1405
+ n_row = min(x.shape[0], n_row)
1406
+ log["inputs"] = x
1407
+ log["reconstruction"] = xrec
1408
+ log["x_lr"] = x_low
1409
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1410
+ if self.model.conditioning_key is not None:
1411
+ if hasattr(self.cond_stage_model, "decode"):
1412
+ xc = self.cond_stage_model.decode(c)
1413
+ log["conditioning"] = xc
1414
+ elif self.cond_stage_key in ["caption", "txt"]:
1415
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1416
+ log["conditioning"] = xc
1417
+ elif self.cond_stage_key in ['class_label', 'cls']:
1418
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1419
+ log['conditioning'] = xc
1420
+ elif isimage(xc):
1421
+ log["conditioning"] = xc
1422
+ if ismap(xc):
1423
+ log["original_conditioning"] = self.to_rgb(xc)
1424
+
1425
+ if plot_diffusion_rows:
1426
+ # get diffusion row
1427
+ diffusion_row = list()
1428
+ z_start = z[:n_row]
1429
+ for t in range(self.num_timesteps):
1430
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1431
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1432
+ t = t.to(self.device).long()
1433
+ noise = torch.randn_like(z_start)
1434
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1435
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1436
+
1437
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1438
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1439
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1440
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1441
+ log["diffusion_row"] = diffusion_grid
1442
+
1443
+ if sample:
1444
+ # get denoise row
1445
+ with ema_scope("Sampling"):
1446
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1447
+ ddim_steps=ddim_steps, eta=ddim_eta)
1448
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1449
+ x_samples = self.decode_first_stage(samples)
1450
+ log["samples"] = x_samples
1451
+ if plot_denoise_rows:
1452
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1453
+ log["denoise_row"] = denoise_grid
1454
+
1455
+ if unconditional_guidance_scale > 1.0:
1456
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1457
+ # TODO explore better "unconditional" choices for the other keys
1458
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1459
+ uc = dict()
1460
+ for k in c:
1461
+ if k == "c_crossattn":
1462
+ assert isinstance(c[k], list) and len(c[k]) == 1
1463
+ uc[k] = [uc_tmp]
1464
+ elif k == "c_adm": # todo: only run with text-based guidance?
1465
+ assert isinstance(c[k], torch.Tensor)
1466
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1467
+ uc[k] = c[k]
1468
+ elif isinstance(c[k], list):
1469
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1470
+ else:
1471
+ uc[k] = c[k]
1472
+
1473
+ with ema_scope("Sampling with classifier-free guidance"):
1474
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1475
+ ddim_steps=ddim_steps, eta=ddim_eta,
1476
+ unconditional_guidance_scale=unconditional_guidance_scale,
1477
+ unconditional_conditioning=uc,
1478
+ )
1479
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1480
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1481
+
1482
+ if plot_progressive_rows:
1483
+ with ema_scope("Plotting Progressives"):
1484
+ img, progressives = self.progressive_denoising(c,
1485
+ shape=(self.channels, self.image_size, self.image_size),
1486
+ batch_size=N)
1487
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1488
+ log["progressive_row"] = prog_row
1489
+
1490
+ return log
1491
+
1492
+
1493
+ class LatentFinetuneDiffusion(LatentDiffusion):
1494
+ """
1495
+ Basis for different finetunas, such as inpainting or depth2image
1496
+ To disable finetuning mode, set finetune_keys to None
1497
+ """
1498
+
1499
+ def __init__(self,
1500
+ concat_keys: tuple,
1501
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1502
+ "model_ema.diffusion_modelinput_blocks00weight"
1503
+ ),
1504
+ keep_finetune_dims=4,
1505
+ # if model was trained without concat mode before and we would like to keep these channels
1506
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1507
+ c_concat_log_end=None,
1508
+ *args, **kwargs
1509
+ ):
1510
+ ckpt_path = kwargs.pop("ckpt_path", None)
1511
+ ignore_keys = kwargs.pop("ignore_keys", list())
1512
+ super().__init__(*args, **kwargs)
1513
+ self.finetune_keys = finetune_keys
1514
+ self.concat_keys = concat_keys
1515
+ self.keep_dims = keep_finetune_dims
1516
+ self.c_concat_log_start = c_concat_log_start
1517
+ self.c_concat_log_end = c_concat_log_end
1518
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1519
+ if exists(ckpt_path):
1520
+ self.init_from_ckpt(ckpt_path, ignore_keys)
1521
+
1522
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1523
+ sd = torch.load(path, map_location="cpu")
1524
+ if "state_dict" in list(sd.keys()):
1525
+ sd = sd["state_dict"]
1526
+ keys = list(sd.keys())
1527
+ for k in keys:
1528
+ for ik in ignore_keys:
1529
+ if k.startswith(ik):
1530
+ print("Deleting key {} from state_dict.".format(k))
1531
+ del sd[k]
1532
+
1533
+ # make it explicit, finetune by including extra input channels
1534
+ if exists(self.finetune_keys) and k in self.finetune_keys:
1535
+ new_entry = None
1536
+ for name, param in self.named_parameters():
1537
+ if name in self.finetune_keys:
1538
+ print(
1539
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1540
+ new_entry = torch.zeros_like(param) # zero init
1541
+ assert exists(new_entry), 'did not find matching parameter to modify'
1542
+ new_entry[:, :self.keep_dims, ...] = sd[k]
1543
+ sd[k] = new_entry
1544
+
1545
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
1546
+ sd, strict=False)
1547
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1548
+ if len(missing) > 0:
1549
+ print(f"Missing Keys: {missing}")
1550
+ if len(unexpected) > 0:
1551
+ print(f"Unexpected Keys: {unexpected}")
1552
+
1553
+ @torch.no_grad()
1554
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1555
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1556
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1557
+ use_ema_scope=True,
1558
+ **kwargs):
1559
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1560
+ use_ddim = ddim_steps is not None
1561
+
1562
+ log = dict()
1563
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1564
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1565
+ N = min(x.shape[0], N)
1566
+ n_row = min(x.shape[0], n_row)
1567
+ log["inputs"] = x
1568
+ log["reconstruction"] = xrec
1569
+ if self.model.conditioning_key is not None:
1570
+ if hasattr(self.cond_stage_model, "decode"):
1571
+ xc = self.cond_stage_model.decode(c)
1572
+ log["conditioning"] = xc
1573
+ elif self.cond_stage_key in ["caption", "txt"]:
1574
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1575
+ log["conditioning"] = xc
1576
+ elif self.cond_stage_key in ['class_label', 'cls']:
1577
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1578
+ log['conditioning'] = xc
1579
+ elif isimage(xc):
1580
+ log["conditioning"] = xc
1581
+ if ismap(xc):
1582
+ log["original_conditioning"] = self.to_rgb(xc)
1583
+
1584
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1585
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1586
+
1587
+ if plot_diffusion_rows:
1588
+ # get diffusion row
1589
+ diffusion_row = list()
1590
+ z_start = z[:n_row]
1591
+ for t in range(self.num_timesteps):
1592
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1593
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1594
+ t = t.to(self.device).long()
1595
+ noise = torch.randn_like(z_start)
1596
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1597
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1598
+
1599
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1600
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1601
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1602
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1603
+ log["diffusion_row"] = diffusion_grid
1604
+
1605
+ if sample:
1606
+ # get denoise row
1607
+ with ema_scope("Sampling"):
1608
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1609
+ batch_size=N, ddim=use_ddim,
1610
+ ddim_steps=ddim_steps, eta=ddim_eta)
1611
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1612
+ x_samples = self.decode_first_stage(samples)
1613
+ log["samples"] = x_samples
1614
+ if plot_denoise_rows:
1615
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1616
+ log["denoise_row"] = denoise_grid
1617
+
1618
+ if unconditional_guidance_scale > 1.0:
1619
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1620
+ uc_cat = c_cat
1621
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1622
+ with ema_scope("Sampling with classifier-free guidance"):
1623
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1624
+ batch_size=N, ddim=use_ddim,
1625
+ ddim_steps=ddim_steps, eta=ddim_eta,
1626
+ unconditional_guidance_scale=unconditional_guidance_scale,
1627
+ unconditional_conditioning=uc_full,
1628
+ )
1629
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1630
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1631
+
1632
+ return log
1633
+
1634
+
1635
+ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1636
+ """
1637
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1638
+ e.g. mask as concat and text via cross-attn.
1639
+ To disable finetuning mode, set finetune_keys to None
1640
+ """
1641
+
1642
+ def __init__(self,
1643
+ concat_keys=("mask", "masked_image"),
1644
+ masked_image_key="masked_image",
1645
+ *args, **kwargs
1646
+ ):
1647
+ super().__init__(concat_keys, *args, **kwargs)
1648
+ self.masked_image_key = masked_image_key
1649
+ assert self.masked_image_key in concat_keys
1650
+
1651
+ @torch.no_grad()
1652
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1653
+ # note: restricted to non-trainable encoders currently
1654
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1655
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1656
+ force_c_encode=True, return_original_cond=True, bs=bs)
1657
+
1658
+ assert exists(self.concat_keys)
1659
+ c_cat = list()
1660
+ for ck in self.concat_keys:
1661
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1662
+ if bs is not None:
1663
+ cc = cc[:bs]
1664
+ cc = cc.to(self.device)
1665
+ bchw = z.shape
1666
+ if ck != self.masked_image_key:
1667
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1668
+ else:
1669
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1670
+ c_cat.append(cc)
1671
+ c_cat = torch.cat(c_cat, dim=1)
1672
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1673
+ if return_first_stage_outputs:
1674
+ return z, all_conds, x, xrec, xc
1675
+ return z, all_conds
1676
+
1677
+ @torch.no_grad()
1678
+ def log_images(self, *args, **kwargs):
1679
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1680
+ log["masked_image"] = rearrange(args[0]["masked_image"],
1681
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1682
+ return log
1683
+
1684
+
1685
+ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1686
+ """
1687
+ condition on monocular depth estimation
1688
+ """
1689
+
1690
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
1691
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
1692
+ self.depth_model = instantiate_from_config(depth_stage_config)
1693
+ self.depth_stage_key = concat_keys[0]
1694
+
1695
+ @torch.no_grad()
1696
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1697
+ # note: restricted to non-trainable encoders currently
1698
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1699
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1700
+ force_c_encode=True, return_original_cond=True, bs=bs)
1701
+
1702
+ assert exists(self.concat_keys)
1703
+ assert len(self.concat_keys) == 1
1704
+ c_cat = list()
1705
+ for ck in self.concat_keys:
1706
+ cc = batch[ck]
1707
+ if bs is not None:
1708
+ cc = cc[:bs]
1709
+ cc = cc.to(self.device)
1710
+ cc = self.depth_model(cc)
1711
+ cc = torch.nn.functional.interpolate(
1712
+ cc,
1713
+ size=z.shape[2:],
1714
+ mode="bicubic",
1715
+ align_corners=False,
1716
+ )
1717
+
1718
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1719
+ keepdim=True)
1720
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1721
+ c_cat.append(cc)
1722
+ c_cat = torch.cat(c_cat, dim=1)
1723
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1724
+ if return_first_stage_outputs:
1725
+ return z, all_conds, x, xrec, xc
1726
+ return z, all_conds
1727
+
1728
+ @torch.no_grad()
1729
+ def log_images(self, *args, **kwargs):
1730
+ log = super().log_images(*args, **kwargs)
1731
+ depth = self.depth_model(args[0][self.depth_stage_key])
1732
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1733
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1734
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1735
+ return log
1736
+
1737
+
1738
+ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1739
+ """
1740
+ condition on low-res image (and optionally on some spatial noise augmentation)
1741
+ """
1742
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1743
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
1744
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
1745
+ self.reshuffle_patch_size = reshuffle_patch_size
1746
+ self.low_scale_model = None
1747
+ if low_scale_config is not None:
1748
+ print("Initializing a low-scale model")
1749
+ assert exists(low_scale_key)
1750
+ self.instantiate_low_stage(low_scale_config)
1751
+ self.low_scale_key = low_scale_key
1752
+
1753
+ def instantiate_low_stage(self, config):
1754
+ model = instantiate_from_config(config)
1755
+ self.low_scale_model = model.eval()
1756
+ self.low_scale_model.train = disabled_train
1757
+ for param in self.low_scale_model.parameters():
1758
+ param.requires_grad = False
1759
+
1760
+ @torch.no_grad()
1761
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1762
+ # note: restricted to non-trainable encoders currently
1763
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
1764
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1765
+ force_c_encode=True, return_original_cond=True, bs=bs)
1766
+
1767
+ assert exists(self.concat_keys)
1768
+ assert len(self.concat_keys) == 1
1769
+ # optionally make spatial noise_level here
1770
+ c_cat = list()
1771
+ noise_level = None
1772
+ for ck in self.concat_keys:
1773
+ cc = batch[ck]
1774
+ cc = rearrange(cc, 'b h w c -> b c h w')
1775
+ if exists(self.reshuffle_patch_size):
1776
+ assert isinstance(self.reshuffle_patch_size, int)
1777
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
1778
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
1779
+ if bs is not None:
1780
+ cc = cc[:bs]
1781
+ cc = cc.to(self.device)
1782
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
1783
+ cc, noise_level = self.low_scale_model(cc)
1784
+ c_cat.append(cc)
1785
+ c_cat = torch.cat(c_cat, dim=1)
1786
+ if exists(noise_level):
1787
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
1788
+ else:
1789
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1790
+ if return_first_stage_outputs:
1791
+ return z, all_conds, x, xrec, xc
1792
+ return z, all_conds
1793
+
1794
+ @torch.no_grad()
1795
+ def log_images(self, *args, **kwargs):
1796
+ log = super().log_images(*args, **kwargs)
1797
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1798
+ return log
watermarker/LaWa/ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
watermarker/LaWa/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (211 Bytes). View file
 
watermarker/LaWa/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc ADDED
Binary file (51.6 kB). View file
 
watermarker/LaWa/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc ADDED
Binary file (2.79 kB). View file
 
watermarker/LaWa/ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError(
75
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
+ schedule))
77
+
78
+ self.schedule = schedule
79
+ if schedule == 'discrete':
80
+ if betas is not None:
81
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
+ else:
83
+ assert alphas_cumprod is not None
84
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
85
+ self.total_N = len(log_alphas)
86
+ self.T = 1.
87
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
89
+ else:
90
+ self.total_N = 1000
91
+ self.beta_0 = continuous_beta_0
92
+ self.beta_1 = continuous_beta_1
93
+ self.cosine_s = 0.008
94
+ self.cosine_beta_max = 999.
95
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
+ 1. + self.cosine_s) / math.pi - self.cosine_s
97
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
+ self.schedule = schedule
99
+ if schedule == 'cosine':
100
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
+ self.T = 0.9946
103
+ else:
104
+ self.T = 1.
105
+
106
+ def marginal_log_mean_coeff(self, t):
107
+ """
108
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
+ """
110
+ if self.schedule == 'discrete':
111
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
+ self.log_alpha_array.to(t.device)).reshape((-1))
113
+ elif self.schedule == 'linear':
114
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
+ elif self.schedule == 'cosine':
116
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
+ return log_alpha_t
119
+
120
+ def marginal_alpha(self, t):
121
+ """
122
+ Compute alpha_t of a given continuous-time label t in [0, T].
123
+ """
124
+ return torch.exp(self.marginal_log_mean_coeff(t))
125
+
126
+ def marginal_std(self, t):
127
+ """
128
+ Compute sigma_t of a given continuous-time label t in [0, T].
129
+ """
130
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
+
132
+ def marginal_lambda(self, t):
133
+ """
134
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
+ """
136
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
137
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
+ return log_mean_coeff - log_std
139
+
140
+ def inverse_lambda(self, lamb):
141
+ """
142
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
+ """
144
+ if self.schedule == 'linear':
145
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
+ Delta = self.beta_0 ** 2 + tmp
147
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
+ elif self.schedule == 'discrete':
149
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
+ torch.flip(self.t_array.to(lamb.device), [1]))
152
+ return t.reshape((-1,))
153
+ else:
154
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
+ 1. + self.cosine_s) / math.pi - self.cosine_s
157
+ t = t_fn(log_alpha)
158
+ return t
159
+
160
+
161
+ def model_wrapper(
162
+ model,
163
+ noise_schedule,
164
+ model_type="noise",
165
+ model_kwargs={},
166
+ guidance_type="uncond",
167
+ condition=None,
168
+ unconditional_condition=None,
169
+ guidance_scale=1.,
170
+ classifier_fn=None,
171
+ classifier_kwargs={},
172
+ ):
173
+ """Create a wrapper function for the noise prediction model.
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
+ We support four types of the diffusion model by setting `model_type`:
177
+ 1. "noise": noise prediction model. (Trained by predicting noise).
178
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
+ arXiv preprint arXiv:2202.00512 (2022).
183
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
+ arXiv preprint arXiv:2210.02303 (2022).
185
+
186
+ 4. "score": marginal score function. (Trained by denoising score matching).
187
+ Note that the score function and the noise prediction model follows a simple relationship:
188
+ ```
189
+ noise(x_t, t) = -sigma_t * score(x_t, t)
190
+ ```
191
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
192
+ 1. "uncond": unconditional sampling by DPMs.
193
+ The input `model` has the following format:
194
+ ``
195
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
+ ``
197
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
+ The input `model` has the following format:
199
+ ``
200
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
+ ``
202
+ The input `classifier_fn` has the following format:
203
+ ``
204
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
+ ``
206
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
+ arXiv preprint arXiv:2207.12598 (2022).
216
+
217
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
+ or continuous-time labels (i.e. epsilon to T).
219
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
+ ``
221
+ def model_fn(x, t_continuous) -> noise:
222
+ t_input = get_model_input_time(t_continuous)
223
+ return noise_pred(model, x, t_input, **model_kwargs)
224
+ ``
225
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
+ ===============================================================
227
+ Args:
228
+ model: A diffusion model with the corresponding format described above.
229
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
+ model_type: A `str`. The parameterization type of the diffusion model.
231
+ "noise" or "x_start" or "v" or "score".
232
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
+ guidance_type: A `str`. The type of the guidance for sampling.
234
+ "uncond" or "classifier" or "classifier-free".
235
+ condition: A pytorch tensor. The condition for the guided sampling.
236
+ Only used for "classifier" or "classifier-free" guidance type.
237
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
+ Only used for "classifier-free" guidance type.
239
+ guidance_scale: A `float`. The scale for the guided sampling.
240
+ classifier_fn: A classifier function. Only used for the classifier guidance.
241
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
+ Returns:
243
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
+ """
245
+
246
+ def get_model_input_time(t_continuous):
247
+ """
248
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
+ For continuous-time DPMs, we just use `t_continuous`.
251
+ """
252
+ if noise_schedule.schedule == 'discrete':
253
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
+ else:
255
+ return t_continuous
256
+
257
+ def noise_pred_fn(x, t_continuous, cond=None):
258
+ if t_continuous.reshape((-1,)).shape[0] == 1:
259
+ t_continuous = t_continuous.expand((x.shape[0]))
260
+ t_input = get_model_input_time(t_continuous)
261
+ if cond is None:
262
+ output = model(x, t_input, **model_kwargs)
263
+ else:
264
+ output = model(x, t_input, cond, **model_kwargs)
265
+ if model_type == "noise":
266
+ return output
267
+ elif model_type == "x_start":
268
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
+ dims = x.dim()
270
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
+ elif model_type == "v":
272
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
+ dims = x.dim()
274
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
+ elif model_type == "score":
276
+ sigma_t = noise_schedule.marginal_std(t_continuous)
277
+ dims = x.dim()
278
+ return -expand_dims(sigma_t, dims) * output
279
+
280
+ def cond_grad_fn(x, t_input):
281
+ """
282
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
+ """
284
+ with torch.enable_grad():
285
+ x_in = x.detach().requires_grad_(True)
286
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
+
289
+ def model_fn(x, t_continuous):
290
+ """
291
+ The noise predicition model function that is used for DPM-Solver.
292
+ """
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ if guidance_type == "uncond":
296
+ return noise_pred_fn(x, t_continuous)
297
+ elif guidance_type == "classifier":
298
+ assert classifier_fn is not None
299
+ t_input = get_model_input_time(t_continuous)
300
+ cond_grad = cond_grad_fn(x, t_input)
301
+ sigma_t = noise_schedule.marginal_std(t_continuous)
302
+ noise = noise_pred_fn(x, t_continuous)
303
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
+ elif guidance_type == "classifier-free":
305
+ if guidance_scale == 1. or unconditional_condition is None:
306
+ return noise_pred_fn(x, t_continuous, cond=condition)
307
+ else:
308
+ x_in = torch.cat([x] * 2)
309
+ t_in = torch.cat([t_continuous] * 2)
310
+ c_in = torch.cat([unconditional_condition, condition])
311
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
312
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
313
+
314
+ assert model_type in ["noise", "x_start", "v"]
315
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
316
+ return model_fn
317
+
318
+
319
+ class DPM_Solver:
320
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
321
+ """Construct a DPM-Solver.
322
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
323
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
324
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
325
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
326
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
327
+ Args:
328
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
329
+ ``
330
+ def model_fn(x, t_continuous):
331
+ return noise
332
+ ``
333
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
334
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
335
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
336
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
337
+
338
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
339
+ """
340
+ self.model = model_fn
341
+ self.noise_schedule = noise_schedule
342
+ self.predict_x0 = predict_x0
343
+ self.thresholding = thresholding
344
+ self.max_val = max_val
345
+
346
+ def noise_prediction_fn(self, x, t):
347
+ """
348
+ Return the noise prediction model.
349
+ """
350
+ return self.model(x, t)
351
+
352
+ def data_prediction_fn(self, x, t):
353
+ """
354
+ Return the data prediction model (with thresholding).
355
+ """
356
+ noise = self.noise_prediction_fn(x, t)
357
+ dims = x.dim()
358
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
359
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
360
+ if self.thresholding:
361
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
362
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
363
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
364
+ x0 = torch.clamp(x0, -s, s) / s
365
+ return x0
366
+
367
+ def model_fn(self, x, t):
368
+ """
369
+ Convert the model to the noise prediction model or the data prediction model.
370
+ """
371
+ if self.predict_x0:
372
+ return self.data_prediction_fn(x, t)
373
+ else:
374
+ return self.noise_prediction_fn(x, t)
375
+
376
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
377
+ """Compute the intermediate time steps for sampling.
378
+ Args:
379
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
380
+ - 'logSNR': uniform logSNR for the time steps.
381
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
382
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
383
+ t_T: A `float`. The starting time of the sampling (default is T).
384
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
385
+ N: A `int`. The total number of the spacing of the time steps.
386
+ device: A torch device.
387
+ Returns:
388
+ A pytorch tensor of the time steps, with the shape (N + 1,).
389
+ """
390
+ if skip_type == 'logSNR':
391
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
392
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
393
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
394
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
395
+ elif skip_type == 'time_uniform':
396
+ return torch.linspace(t_T, t_0, N + 1).to(device)
397
+ elif skip_type == 'time_quadratic':
398
+ t_order = 2
399
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
400
+ return t
401
+ else:
402
+ raise ValueError(
403
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
404
+
405
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
406
+ """
407
+ Get the order of each step for sampling by the singlestep DPM-Solver.
408
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
409
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
410
+ - If order == 1:
411
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
412
+ - If order == 2:
413
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
414
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
415
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
416
+ - If order == 3:
417
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
418
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
419
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
420
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
421
+ ============================================
422
+ Args:
423
+ order: A `int`. The max order for the solver (2 or 3).
424
+ steps: A `int`. The total number of function evaluations (NFE).
425
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
426
+ - 'logSNR': uniform logSNR for the time steps.
427
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
428
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
429
+ t_T: A `float`. The starting time of the sampling (default is T).
430
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
431
+ device: A torch device.
432
+ Returns:
433
+ orders: A list of the solver order of each step.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3, ] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3, ] * (K - 1) + [1]
441
+ else:
442
+ orders = [3, ] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2, ] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2, ] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = 1
452
+ orders = [1, ] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
460
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
461
+ return timesteps_outer, orders
462
+
463
+ def denoise_to_zero_fn(self, x, s):
464
+ """
465
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
466
+ """
467
+ return self.data_prediction_fn(x, s)
468
+
469
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
470
+ """
471
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
472
+ Args:
473
+ x: A pytorch tensor. The initial value at time `s`.
474
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
475
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
476
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
477
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
478
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
479
+ Returns:
480
+ x_t: A pytorch tensor. The approximated solution at time `t`.
481
+ """
482
+ ns = self.noise_schedule
483
+ dims = x.dim()
484
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
485
+ h = lambda_t - lambda_s
486
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
487
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
488
+ alpha_t = torch.exp(log_alpha_t)
489
+
490
+ if self.predict_x0:
491
+ phi_1 = torch.expm1(-h)
492
+ if model_s is None:
493
+ model_s = self.model_fn(x, s)
494
+ x_t = (
495
+ expand_dims(sigma_t / sigma_s, dims) * x
496
+ - expand_dims(alpha_t * phi_1, dims) * model_s
497
+ )
498
+ if return_intermediate:
499
+ return x_t, {'model_s': model_s}
500
+ else:
501
+ return x_t
502
+ else:
503
+ phi_1 = torch.expm1(h)
504
+ if model_s is None:
505
+ model_s = self.model_fn(x, s)
506
+ x_t = (
507
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
508
+ - expand_dims(sigma_t * phi_1, dims) * model_s
509
+ )
510
+ if return_intermediate:
511
+ return x_t, {'model_s': model_s}
512
+ else:
513
+ return x_t
514
+
515
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
516
+ solver_type='dpm_solver'):
517
+ """
518
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
519
+ Args:
520
+ x: A pytorch tensor. The initial value at time `s`.
521
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
522
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
523
+ r1: A `float`. The hyperparameter of the second-order solver.
524
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
525
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
526
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
527
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
528
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
529
+ Returns:
530
+ x_t: A pytorch tensor. The approximated solution at time `t`.
531
+ """
532
+ if solver_type not in ['dpm_solver', 'taylor']:
533
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
534
+ if r1 is None:
535
+ r1 = 0.5
536
+ ns = self.noise_schedule
537
+ dims = x.dim()
538
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
539
+ h = lambda_t - lambda_s
540
+ lambda_s1 = lambda_s + r1 * h
541
+ s1 = ns.inverse_lambda(lambda_s1)
542
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
543
+ s1), ns.marginal_log_mean_coeff(t)
544
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
545
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
546
+
547
+ if self.predict_x0:
548
+ phi_11 = torch.expm1(-r1 * h)
549
+ phi_1 = torch.expm1(-h)
550
+
551
+ if model_s is None:
552
+ model_s = self.model_fn(x, s)
553
+ x_s1 = (
554
+ expand_dims(sigma_s1 / sigma_s, dims) * x
555
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
556
+ )
557
+ model_s1 = self.model_fn(x_s1, s1)
558
+ if solver_type == 'dpm_solver':
559
+ x_t = (
560
+ expand_dims(sigma_t / sigma_s, dims) * x
561
+ - expand_dims(alpha_t * phi_1, dims) * model_s
562
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
563
+ )
564
+ elif solver_type == 'taylor':
565
+ x_t = (
566
+ expand_dims(sigma_t / sigma_s, dims) * x
567
+ - expand_dims(alpha_t * phi_1, dims) * model_s
568
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
569
+ model_s1 - model_s)
570
+ )
571
+ else:
572
+ phi_11 = torch.expm1(r1 * h)
573
+ phi_1 = torch.expm1(h)
574
+
575
+ if model_s is None:
576
+ model_s = self.model_fn(x, s)
577
+ x_s1 = (
578
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
579
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
580
+ )
581
+ model_s1 = self.model_fn(x_s1, s1)
582
+ if solver_type == 'dpm_solver':
583
+ x_t = (
584
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
585
+ - expand_dims(sigma_t * phi_1, dims) * model_s
586
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
587
+ )
588
+ elif solver_type == 'taylor':
589
+ x_t = (
590
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
591
+ - expand_dims(sigma_t * phi_1, dims) * model_s
592
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
593
+ )
594
+ if return_intermediate:
595
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
596
+ else:
597
+ return x_t
598
+
599
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
600
+ return_intermediate=False, solver_type='dpm_solver'):
601
+ """
602
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
603
+ Args:
604
+ x: A pytorch tensor. The initial value at time `s`.
605
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
606
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
607
+ r1: A `float`. The hyperparameter of the third-order solver.
608
+ r2: A `float`. The hyperparameter of the third-order solver.
609
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
610
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
612
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
613
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
614
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
615
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
616
+ Returns:
617
+ x_t: A pytorch tensor. The approximated solution at time `t`.
618
+ """
619
+ if solver_type not in ['dpm_solver', 'taylor']:
620
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
621
+ if r1 is None:
622
+ r1 = 1. / 3.
623
+ if r2 is None:
624
+ r2 = 2. / 3.
625
+ ns = self.noise_schedule
626
+ dims = x.dim()
627
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
628
+ h = lambda_t - lambda_s
629
+ lambda_s1 = lambda_s + r1 * h
630
+ lambda_s2 = lambda_s + r2 * h
631
+ s1 = ns.inverse_lambda(lambda_s1)
632
+ s2 = ns.inverse_lambda(lambda_s2)
633
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
634
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
635
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
636
+ s2), ns.marginal_std(t)
637
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
638
+
639
+ if self.predict_x0:
640
+ phi_11 = torch.expm1(-r1 * h)
641
+ phi_12 = torch.expm1(-r2 * h)
642
+ phi_1 = torch.expm1(-h)
643
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
644
+ phi_2 = phi_1 / h + 1.
645
+ phi_3 = phi_2 / h - 0.5
646
+
647
+ if model_s is None:
648
+ model_s = self.model_fn(x, s)
649
+ if model_s1 is None:
650
+ x_s1 = (
651
+ expand_dims(sigma_s1 / sigma_s, dims) * x
652
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
653
+ )
654
+ model_s1 = self.model_fn(x_s1, s1)
655
+ x_s2 = (
656
+ expand_dims(sigma_s2 / sigma_s, dims) * x
657
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
658
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
659
+ )
660
+ model_s2 = self.model_fn(x_s2, s2)
661
+ if solver_type == 'dpm_solver':
662
+ x_t = (
663
+ expand_dims(sigma_t / sigma_s, dims) * x
664
+ - expand_dims(alpha_t * phi_1, dims) * model_s
665
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
666
+ )
667
+ elif solver_type == 'taylor':
668
+ D1_0 = (1. / r1) * (model_s1 - model_s)
669
+ D1_1 = (1. / r2) * (model_s2 - model_s)
670
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
671
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
672
+ x_t = (
673
+ expand_dims(sigma_t / sigma_s, dims) * x
674
+ - expand_dims(alpha_t * phi_1, dims) * model_s
675
+ + expand_dims(alpha_t * phi_2, dims) * D1
676
+ - expand_dims(alpha_t * phi_3, dims) * D2
677
+ )
678
+ else:
679
+ phi_11 = torch.expm1(r1 * h)
680
+ phi_12 = torch.expm1(r2 * h)
681
+ phi_1 = torch.expm1(h)
682
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
683
+ phi_2 = phi_1 / h - 1.
684
+ phi_3 = phi_2 / h - 0.5
685
+
686
+ if model_s is None:
687
+ model_s = self.model_fn(x, s)
688
+ if model_s1 is None:
689
+ x_s1 = (
690
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
691
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
692
+ )
693
+ model_s1 = self.model_fn(x_s1, s1)
694
+ x_s2 = (
695
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
696
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
697
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
698
+ )
699
+ model_s2 = self.model_fn(x_s2, s2)
700
+ if solver_type == 'dpm_solver':
701
+ x_t = (
702
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
703
+ - expand_dims(sigma_t * phi_1, dims) * model_s
704
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
705
+ )
706
+ elif solver_type == 'taylor':
707
+ D1_0 = (1. / r1) * (model_s1 - model_s)
708
+ D1_1 = (1. / r2) * (model_s2 - model_s)
709
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
710
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
711
+ x_t = (
712
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
713
+ - expand_dims(sigma_t * phi_1, dims) * model_s
714
+ - expand_dims(sigma_t * phi_2, dims) * D1
715
+ - expand_dims(sigma_t * phi_3, dims) * D2
716
+ )
717
+
718
+ if return_intermediate:
719
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
720
+ else:
721
+ return x_t
722
+
723
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
724
+ """
725
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
726
+ Args:
727
+ x: A pytorch tensor. The initial value at time `s`.
728
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
729
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
730
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
731
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
732
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
733
+ Returns:
734
+ x_t: A pytorch tensor. The approximated solution at time `t`.
735
+ """
736
+ if solver_type not in ['dpm_solver', 'taylor']:
737
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
738
+ ns = self.noise_schedule
739
+ dims = x.dim()
740
+ model_prev_1, model_prev_0 = model_prev_list
741
+ t_prev_1, t_prev_0 = t_prev_list
742
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
743
+ t_prev_0), ns.marginal_lambda(t)
744
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
745
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
746
+ alpha_t = torch.exp(log_alpha_t)
747
+
748
+ h_0 = lambda_prev_0 - lambda_prev_1
749
+ h = lambda_t - lambda_prev_0
750
+ r0 = h_0 / h
751
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
752
+ if self.predict_x0:
753
+ if solver_type == 'dpm_solver':
754
+ x_t = (
755
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
756
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
757
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
758
+ )
759
+ elif solver_type == 'taylor':
760
+ x_t = (
761
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
762
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
763
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
764
+ )
765
+ else:
766
+ if solver_type == 'dpm_solver':
767
+ x_t = (
768
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
769
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
770
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
771
+ )
772
+ elif solver_type == 'taylor':
773
+ x_t = (
774
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
775
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
776
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
777
+ )
778
+ return x_t
779
+
780
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
781
+ """
782
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
783
+ Args:
784
+ x: A pytorch tensor. The initial value at time `s`.
785
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
786
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
787
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
788
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
789
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
790
+ Returns:
791
+ x_t: A pytorch tensor. The approximated solution at time `t`.
792
+ """
793
+ ns = self.noise_schedule
794
+ dims = x.dim()
795
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
796
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
797
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
798
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
799
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
800
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
801
+ alpha_t = torch.exp(log_alpha_t)
802
+
803
+ h_1 = lambda_prev_1 - lambda_prev_2
804
+ h_0 = lambda_prev_0 - lambda_prev_1
805
+ h = lambda_t - lambda_prev_0
806
+ r0, r1 = h_0 / h, h_1 / h
807
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
808
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
809
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
810
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
811
+ if self.predict_x0:
812
+ x_t = (
813
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
814
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
815
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
816
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
817
+ )
818
+ else:
819
+ x_t = (
820
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
821
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
822
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
823
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
824
+ )
825
+ return x_t
826
+
827
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
828
+ r2=None):
829
+ """
830
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
831
+ Args:
832
+ x: A pytorch tensor. The initial value at time `s`.
833
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
834
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
835
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
836
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
837
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
838
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
839
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
840
+ r2: A `float`. The hyperparameter of the third-order solver.
841
+ Returns:
842
+ x_t: A pytorch tensor. The approximated solution at time `t`.
843
+ """
844
+ if order == 1:
845
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
846
+ elif order == 2:
847
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
848
+ solver_type=solver_type, r1=r1)
849
+ elif order == 3:
850
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
851
+ solver_type=solver_type, r1=r1, r2=r2)
852
+ else:
853
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
854
+
855
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
856
+ """
857
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
858
+ Args:
859
+ x: A pytorch tensor. The initial value at time `s`.
860
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
861
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
862
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
863
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
864
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
865
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
866
+ Returns:
867
+ x_t: A pytorch tensor. The approximated solution at time `t`.
868
+ """
869
+ if order == 1:
870
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
871
+ elif order == 2:
872
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
873
+ elif order == 3:
874
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
875
+ else:
876
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
877
+
878
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
879
+ solver_type='dpm_solver'):
880
+ """
881
+ The adaptive step size solver based on singlestep DPM-Solver.
882
+ Args:
883
+ x: A pytorch tensor. The initial value at time `t_T`.
884
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
885
+ t_T: A `float`. The starting time of the sampling (default is T).
886
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
887
+ h_init: A `float`. The initial step size (for logSNR).
888
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
889
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
890
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
891
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
892
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
893
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
894
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
895
+ Returns:
896
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
898
+ """
899
+ ns = self.noise_schedule
900
+ s = t_T * torch.ones((x.shape[0],)).to(x)
901
+ lambda_s = ns.marginal_lambda(s)
902
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
903
+ h = h_init * torch.ones_like(s).to(x)
904
+ x_prev = x
905
+ nfe = 0
906
+ if order == 2:
907
+ r1 = 0.5
908
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
909
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
910
+ solver_type=solver_type,
911
+ **kwargs)
912
+ elif order == 3:
913
+ r1, r2 = 1. / 3., 2. / 3.
914
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
915
+ return_intermediate=True,
916
+ solver_type=solver_type)
917
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
918
+ solver_type=solver_type,
919
+ **kwargs)
920
+ else:
921
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
922
+ while torch.abs((s - t_0)).mean() > t_err:
923
+ t = ns.inverse_lambda(lambda_s + h)
924
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
925
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
926
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
927
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
928
+ E = norm_fn((x_higher - x_lower) / delta).max()
929
+ if torch.all(E <= 1.):
930
+ x = x_higher
931
+ s = t
932
+ x_prev = x_lower
933
+ lambda_s = ns.marginal_lambda(s)
934
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
935
+ nfe += order
936
+ print('adaptive solver nfe', nfe)
937
+ return x
938
+
939
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
940
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
941
+ atol=0.0078, rtol=0.05,
942
+ ):
943
+ """
944
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
945
+ =====================================================
946
+ We support the following algorithms for both noise prediction model and data prediction model:
947
+ - 'singlestep':
948
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
949
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
950
+ The total number of function evaluations (NFE) == `steps`.
951
+ Given a fixed NFE == `steps`, the sampling procedure is:
952
+ - If `order` == 1:
953
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
954
+ - If `order` == 2:
955
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
956
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
957
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
958
+ - If `order` == 3:
959
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
960
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
961
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
962
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
963
+ - 'multistep':
964
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
965
+ We initialize the first `order` values by lower order multistep solvers.
966
+ Given a fixed NFE == `steps`, the sampling procedure is:
967
+ Denote K = steps.
968
+ - If `order` == 1:
969
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
970
+ - If `order` == 2:
971
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
972
+ - If `order` == 3:
973
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
974
+ - 'singlestep_fixed':
975
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
976
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
977
+ - 'adaptive':
978
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
979
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
980
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
981
+ (NFE) and the sample quality.
982
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
983
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
984
+ =====================================================
985
+ Some advices for choosing the algorithm:
986
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
987
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
988
+ e.g.
989
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
990
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
991
+ skip_type='time_uniform', method='singlestep')
992
+ - For **guided sampling with large guidance scale** by DPMs:
993
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
994
+ e.g.
995
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
996
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
997
+ skip_type='time_uniform', method='multistep')
998
+ We support three types of `skip_type`:
999
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1000
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1001
+ - 'time_quadratic': quadratic time for the time steps.
1002
+ =====================================================
1003
+ Args:
1004
+ x: A pytorch tensor. The initial value at time `t_start`
1005
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1006
+ steps: A `int`. The total number of function evaluations (NFE).
1007
+ t_start: A `float`. The starting time of the sampling.
1008
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1009
+ t_end: A `float`. The ending time of the sampling.
1010
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1011
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1012
+ For discrete-time DPMs:
1013
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1014
+ For continuous-time DPMs:
1015
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1016
+ order: A `int`. The order of DPM-Solver.
1017
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1018
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1019
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1020
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1021
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1022
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1023
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1024
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1025
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1026
+ it for high-resolutional images.
1027
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1028
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1029
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1030
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1031
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1032
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1033
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1034
+ Returns:
1035
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1036
+ """
1037
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1038
+ t_T = self.noise_schedule.T if t_start is None else t_start
1039
+ device = x.device
1040
+ if method == 'adaptive':
1041
+ with torch.no_grad():
1042
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1043
+ solver_type=solver_type)
1044
+ elif method == 'multistep':
1045
+ assert steps >= order
1046
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1047
+ assert timesteps.shape[0] - 1 == steps
1048
+ with torch.no_grad():
1049
+ vec_t = timesteps[0].expand((x.shape[0]))
1050
+ model_prev_list = [self.model_fn(x, vec_t)]
1051
+ t_prev_list = [vec_t]
1052
+ # Init the first `order` values by lower order multistep DPM-Solver.
1053
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1054
+ vec_t = timesteps[init_order].expand(x.shape[0])
1055
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1056
+ solver_type=solver_type)
1057
+ model_prev_list.append(self.model_fn(x, vec_t))
1058
+ t_prev_list.append(vec_t)
1059
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1060
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1061
+ vec_t = timesteps[step].expand(x.shape[0])
1062
+ if lower_order_final and steps < 15:
1063
+ step_order = min(order, steps + 1 - step)
1064
+ else:
1065
+ step_order = order
1066
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1067
+ solver_type=solver_type)
1068
+ for i in range(order - 1):
1069
+ t_prev_list[i] = t_prev_list[i + 1]
1070
+ model_prev_list[i] = model_prev_list[i + 1]
1071
+ t_prev_list[-1] = vec_t
1072
+ # We do not need to evaluate the final model value.
1073
+ if step < steps:
1074
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1075
+ elif method in ['singlestep', 'singlestep_fixed']:
1076
+ if method == 'singlestep':
1077
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1078
+ skip_type=skip_type,
1079
+ t_T=t_T, t_0=t_0,
1080
+ device=device)
1081
+ elif method == 'singlestep_fixed':
1082
+ K = steps // order
1083
+ orders = [order, ] * K
1084
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1085
+ for i, order in enumerate(orders):
1086
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1087
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1088
+ N=order, device=device)
1089
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1090
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1091
+ h = lambda_inner[-1] - lambda_inner[0]
1092
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1093
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1094
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1095
+ if denoise_to_zero:
1096
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
+ return x
1098
+
1099
+
1100
+ #############################################################
1101
+ # other utility functions
1102
+ #############################################################
1103
+
1104
+ def interpolate_fn(x, xp, yp):
1105
+ """
1106
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1107
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1108
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1109
+ Args:
1110
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1111
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1112
+ yp: PyTorch tensor with shape [C, K].
1113
+ Returns:
1114
+ The function values f(x), with shape [N, C].
1115
+ """
1116
+ N, K = x.shape[0], xp.shape[1]
1117
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1118
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1119
+ x_idx = torch.argmin(x_indices, dim=2)
1120
+ cand_start_idx = x_idx - 1
1121
+ start_idx = torch.where(
1122
+ torch.eq(x_idx, 0),
1123
+ torch.tensor(1, device=x.device),
1124
+ torch.where(
1125
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1126
+ ),
1127
+ )
1128
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1129
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1130
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1131
+ start_idx2 = torch.where(
1132
+ torch.eq(x_idx, 0),
1133
+ torch.tensor(0, device=x.device),
1134
+ torch.where(
1135
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1136
+ ),
1137
+ )
1138
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1139
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1140
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1141
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1142
+ return cand
1143
+
1144
+
1145
+ def expand_dims(v, dims):
1146
+ """
1147
+ Expand the tensor `v` to the dim `dims`.
1148
+ Args:
1149
+ `v`: a PyTorch tensor with shape [N].
1150
+ `dim`: a `int`.
1151
+ Returns:
1152
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1153
+ """
1154
+ return v[(...,) + (None,) * (dims - 1)]
watermarker/LaWa/ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+ import torch
3
+
4
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5
+
6
+
7
+ MODEL_TYPES = {
8
+ "eps": "noise",
9
+ "v": "v"
10
+ }
11
+
12
+
13
+ class DPMSolverSampler(object):
14
+ def __init__(self, model, **kwargs):
15
+ super().__init__()
16
+ self.model = model
17
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ if attr.device != torch.device("cuda"):
23
+ attr = attr.to(torch.device("cuda"))
24
+ setattr(self, name, attr)
25
+
26
+ @torch.no_grad()
27
+ def sample(self,
28
+ S,
29
+ batch_size,
30
+ shape,
31
+ conditioning=None,
32
+ callback=None,
33
+ normals_sequence=None,
34
+ img_callback=None,
35
+ quantize_x0=False,
36
+ eta=0.,
37
+ mask=None,
38
+ x0=None,
39
+ temperature=1.,
40
+ noise_dropout=0.,
41
+ score_corrector=None,
42
+ corrector_kwargs=None,
43
+ verbose=True,
44
+ x_T=None,
45
+ log_every_t=100,
46
+ unconditional_guidance_scale=1.,
47
+ unconditional_conditioning=None,
48
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49
+ **kwargs
50
+ ):
51
+ if conditioning is not None:
52
+ if isinstance(conditioning, dict):
53
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54
+ if cbs != batch_size:
55
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56
+ else:
57
+ if conditioning.shape[0] != batch_size:
58
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59
+
60
+ # sampling
61
+ C, H, W = shape
62
+ size = (batch_size, C, H, W)
63
+
64
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65
+
66
+ device = self.model.betas.device
67
+ if x_T is None:
68
+ img = torch.randn(size, device=device)
69
+ else:
70
+ img = x_T
71
+
72
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73
+
74
+ model_fn = model_wrapper(
75
+ lambda x, t, c: self.model.apply_model(x, t, c),
76
+ ns,
77
+ model_type=MODEL_TYPES[self.model.parameterization],
78
+ guidance_type="classifier-free",
79
+ condition=conditioning,
80
+ unconditional_condition=unconditional_conditioning,
81
+ guidance_scale=unconditional_guidance_scale,
82
+ )
83
+
84
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86
+
87
+ return x.to(device), None
watermarker/LaWa/ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+ from ldm.models.diffusion.sampling_util import norm_thresholding
10
+
11
+
12
+ class PLMSSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ if ddim_eta != 0:
27
+ raise ValueError('ddim_eta must be 0 for PLMS')
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
+ alphas_cumprod = self.model.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.model.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+ @torch.no_grad()
59
+ def sample(self,
60
+ S,
61
+ batch_size,
62
+ shape,
63
+ conditioning=None,
64
+ callback=None,
65
+ normals_sequence=None,
66
+ img_callback=None,
67
+ quantize_x0=False,
68
+ eta=0.,
69
+ mask=None,
70
+ x0=None,
71
+ temperature=1.,
72
+ noise_dropout=0.,
73
+ score_corrector=None,
74
+ corrector_kwargs=None,
75
+ verbose=True,
76
+ x_T=None,
77
+ log_every_t=100,
78
+ unconditional_guidance_scale=1.,
79
+ unconditional_conditioning=None,
80
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
+ dynamic_threshold=None,
82
+ **kwargs
83
+ ):
84
+ if conditioning is not None:
85
+ if isinstance(conditioning, dict):
86
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+ else:
90
+ if conditioning.shape[0] != batch_size:
91
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
+
93
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
+ # sampling
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ print(f'Data shape for PLMS sampling is {size}')
98
+
99
+ samples, intermediates = self.plms_sampling(conditioning, size,
100
+ callback=callback,
101
+ img_callback=img_callback,
102
+ quantize_denoised=quantize_x0,
103
+ mask=mask, x0=x0,
104
+ ddim_use_original_steps=False,
105
+ noise_dropout=noise_dropout,
106
+ temperature=temperature,
107
+ score_corrector=score_corrector,
108
+ corrector_kwargs=corrector_kwargs,
109
+ x_T=x_T,
110
+ log_every_t=log_every_t,
111
+ unconditional_guidance_scale=unconditional_guidance_scale,
112
+ unconditional_conditioning=unconditional_conditioning,
113
+ dynamic_threshold=dynamic_threshold,
114
+ )
115
+ return samples, intermediates
116
+
117
+ @torch.no_grad()
118
+ def plms_sampling(self, cond, shape,
119
+ x_T=None, ddim_use_original_steps=False,
120
+ callback=None, timesteps=None, quantize_denoised=False,
121
+ mask=None, x0=None, img_callback=None, log_every_t=100,
122
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
124
+ dynamic_threshold=None):
125
+ device = self.model.betas.device
126
+ b = shape[0]
127
+ if x_T is None:
128
+ img = torch.randn(shape, device=device)
129
+ else:
130
+ img = x_T
131
+
132
+ if timesteps is None:
133
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134
+ elif timesteps is not None and not ddim_use_original_steps:
135
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136
+ timesteps = self.ddim_timesteps[:subset_end]
137
+
138
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
139
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
142
+
143
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144
+ old_eps = []
145
+
146
+ for i, step in enumerate(iterator):
147
+ index = total_steps - i - 1
148
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
149
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150
+
151
+ if mask is not None:
152
+ assert x0 is not None
153
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154
+ img = img_orig * mask + (1. - mask) * img
155
+
156
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157
+ quantize_denoised=quantize_denoised, temperature=temperature,
158
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
159
+ corrector_kwargs=corrector_kwargs,
160
+ unconditional_guidance_scale=unconditional_guidance_scale,
161
+ unconditional_conditioning=unconditional_conditioning,
162
+ old_eps=old_eps, t_next=ts_next,
163
+ dynamic_threshold=dynamic_threshold)
164
+ img, pred_x0, e_t = outs
165
+ old_eps.append(e_t)
166
+ if len(old_eps) >= 4:
167
+ old_eps.pop(0)
168
+ if callback: callback(i)
169
+ if img_callback: img_callback(pred_x0, i)
170
+
171
+ if index % log_every_t == 0 or index == total_steps - 1:
172
+ intermediates['x_inter'].append(img)
173
+ intermediates['pred_x0'].append(pred_x0)
174
+
175
+ return img, intermediates
176
+
177
+ @torch.no_grad()
178
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181
+ dynamic_threshold=None):
182
+ b, *_, device = *x.shape, x.device
183
+
184
+ def get_model_output(x, t):
185
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186
+ e_t = self.model.apply_model(x, t, c)
187
+ else:
188
+ x_in = torch.cat([x] * 2)
189
+ t_in = torch.cat([t] * 2)
190
+ c_in = torch.cat([unconditional_conditioning, c])
191
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193
+
194
+ if score_corrector is not None:
195
+ assert self.model.parameterization == "eps"
196
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197
+
198
+ return e_t
199
+
200
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204
+
205
+ def get_x_prev_and_pred_x0(e_t, index):
206
+ # select parameters corresponding to the currently considered timestep
207
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211
+
212
+ # current prediction for x_0
213
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214
+ if quantize_denoised:
215
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216
+ if dynamic_threshold is not None:
217
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218
+ # direction pointing to x_t
219
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221
+ if noise_dropout > 0.:
222
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224
+ return x_prev, pred_x0
225
+
226
+ e_t = get_model_output(x, t)
227
+ if len(old_eps) == 0:
228
+ # Pseudo Improved Euler (2nd order)
229
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230
+ e_t_next = get_model_output(x_prev, t_next)
231
+ e_t_prime = (e_t + e_t_next) / 2
232
+ elif len(old_eps) == 1:
233
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
235
+ elif len(old_eps) == 2:
236
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238
+ elif len(old_eps) >= 3:
239
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241
+
242
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243
+
244
+ return x_prev, pred_x0, e_t
watermarker/LaWa/ldm/models/diffusion/sampling_util.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def append_dims(x, target_dims):
6
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8
+ dims_to_append = target_dims - x.ndim
9
+ if dims_to_append < 0:
10
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11
+ return x[(...,) + (None,) * dims_to_append]
12
+
13
+
14
+ def norm_thresholding(x0, value):
15
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16
+ return x0 * (value / s)
17
+
18
+
19
+ def spatial_norm_thresholding(x0, value):
20
+ # b c h w
21
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22
+ return x0 * (value / s)
watermarker/LaWa/ldm/modules/__pycache__/attention.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
watermarker/LaWa/ldm/modules/__pycache__/ema.cpython-38.pyc ADDED
Binary file (3.2 kB). View file
 
watermarker/LaWa/ldm/modules/__pycache__/x_transformer.cpython-38.pyc ADDED
Binary file (18.3 kB). View file