XavierJiezou commited on
Commit
921503d
·
verified ·
1 Parent(s): f8ca548

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +29 -7
  3. README.md +16 -24
  4. datasets/celeba.py +209 -0
  5. datasets/cityscapes.py +303 -0
  6. docs/gcdp.png +3 -0
  7. environment.yaml +66 -0
  8. example.py +25 -0
  9. imagen_pytorch/__init__.py +26 -0
  10. imagen_pytorch/cli.py +52 -0
  11. imagen_pytorch/configs.py +181 -0
  12. imagen_pytorch/data.py +73 -0
  13. imagen_pytorch/elucidated_imagen.py +846 -0
  14. imagen_pytorch/imagen_pytorch.py +2515 -0
  15. imagen_pytorch/imagen_video/__init__.py +1 -0
  16. imagen_pytorch/imagen_video/imagen_video.py +1662 -0
  17. imagen_pytorch/joint_imagen.py +1942 -0
  18. imagen_pytorch/t5.py +119 -0
  19. imagen_pytorch/trainer.py +1782 -0
  20. imagen_pytorch/utils.py +61 -0
  21. imagen_pytorch/version.py +1 -0
  22. pyproject.toml +3 -0
  23. repaint/LICENSES/LICENSE +13 -0
  24. repaint/LICENSES/LICENSE_guided_diffusion +21 -0
  25. repaint/LICENSES/README.md +11 -0
  26. repaint/README.md +205 -0
  27. repaint/conf_mgt/__init__.py +18 -0
  28. repaint/conf_mgt/conf_base.py +128 -0
  29. repaint/confs/face_example.yml +87 -0
  30. repaint/confs/test_c256_ev2li.yml +86 -0
  31. repaint/confs/test_c256_ex64.yml +86 -0
  32. repaint/confs/test_c256_genhalf.yml +86 -0
  33. repaint/confs/test_c256_nn2.yml +86 -0
  34. repaint/confs/test_c256_thick.yml +86 -0
  35. repaint/confs/test_c256_thin.yml +86 -0
  36. repaint/confs/test_inet256_ev2li.yml +87 -0
  37. repaint/confs/test_inet256_ex64.yml +87 -0
  38. repaint/confs/test_inet256_genhalf.yml +87 -0
  39. repaint/confs/test_inet256_nn2.yml +87 -0
  40. repaint/confs/test_inet256_thick.yml +87 -0
  41. repaint/confs/test_inet256_thin.yml +87 -0
  42. repaint/confs/test_p256_ev2li.yml +86 -0
  43. repaint/confs/test_p256_ex64.yml +86 -0
  44. repaint/confs/test_p256_genhalf.yml +86 -0
  45. repaint/confs/test_p256_nn2.yml +86 -0
  46. repaint/confs/test_p256_thick.yml +86 -0
  47. repaint/confs/test_p256_thin.yml +86 -0
  48. repaint/download.sh +19 -0
  49. repaint/guided_diffusion/__init__.py +19 -0
  50. repaint/guided_diffusion/dist_util.py +43 -0
.gitattributes CHANGED
@@ -47,3 +47,4 @@ tedigan/ext/experiment/inference_coupled/input_label.png filter=lfs diff=lfs mer
47
  tedigan/ext/experiment/inference_results/input_label.png filter=lfs diff=lfs merge=lfs -text
48
  uniteandconquer/utils/faces.png filter=lfs diff=lfs merge=lfs -text
49
  uniteandconquer/utils/natural.png filter=lfs diff=lfs merge=lfs -text
 
 
47
  tedigan/ext/experiment/inference_results/input_label.png filter=lfs diff=lfs merge=lfs -text
48
  uniteandconquer/utils/faces.png filter=lfs diff=lfs merge=lfs -text
49
  uniteandconquer/utils/natural.png filter=lfs diff=lfs merge=lfs -text
50
+ docs/gcdp.png filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,3 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
@@ -106,10 +132,8 @@ ipython_config.py
106
  #pdm.lock
107
  # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
  # in version control.
109
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
  .pdm.toml
111
- .pdm-python
112
- .pdm-build/
113
 
114
  # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
  __pypackages__/
@@ -161,7 +185,5 @@ cython_debug/
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
163
 
164
- # push to github
165
- *.pt
166
- *.pth
167
- *.ckpt
 
1
+ logs
2
+ debug
3
+ wandb_dir
4
+ checkpoints
5
+ squeue.txt
6
+ results
7
+
8
+ # Created by https://www.toptal.com/developers/gitignore/api/python,linux
9
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
10
+
11
+ ### Linux ###
12
+ *~
13
+
14
+ # temporary files which can be created if a process still has a handle open of a deleted file
15
+ .fuse_hidden*
16
+
17
+ # KDE directory preferences
18
+ .directory
19
+
20
+ # Linux trash folder which might appear on any partition or disk
21
+ .Trash-*
22
+
23
+ # .nfs files are created when an open file is removed but is still being accessed
24
+ .nfs*
25
+
26
+ ### Python ###
27
  # Byte-compiled / optimized / DLL files
28
  __pycache__/
29
  *.py[cod]
 
132
  #pdm.lock
133
  # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
134
  # in version control.
135
+ # https://pdm.fming.dev/#use-with-ide
136
  .pdm.toml
 
 
137
 
138
  # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
139
  __pypackages__/
 
185
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
186
  #.idea/
187
 
188
+ # End of https://www.toptal.com/developers/gitignore/api/python,linux
189
+ _baselines/datasetGAN/StyleGAN.pytorch/outputs/cityscapes/2023-03-05/log.txt
 
 
README.md CHANGED
@@ -1,24 +1,16 @@
1
- # MMFace
2
-
3
- ## TODO
4
-
5
- - [ ] Diffusion-driven GAN Inversion Reproduction
6
- - [ ] Datasets Download
7
-
8
- ## Installation
9
-
10
- ## Datasets
11
-
12
- ## Training
13
-
14
- ## Evaluation
15
-
16
- ## Methods
17
-
18
- - [x] [TediGAN (CVPR 2021)](https://github.com/IIGROUP/TediGAN)
19
- - [x] [UniteandConquer (CVPR 2023)](https://github.com/Nithin-GK/UniteandConquer)
20
- - [x] [Collaborative-Diffusion (CVPR 2023)](https://github.com/ziqihuangg/Collaborative-Diffusion)
21
- - [x] [GCDP (ICCV 2023)](https://github.com/pmh9960/GCDP) (Text2Image)
22
- - [x] [PixelFace+ (MM 2023)](https://github.com/qazwsx671713/PixelFace-Plus)
23
- - [ ] [Diffusion-driven GAN Inversion (CVPR 2024)](https://github.com/1211sh/Diffusion-driven_GAN-Inversion/)
24
- - [ ] [MM2Latent (ECCVW 2024)](https://github.com/Open-Debin/MM2Latent)
 
1
+ # Evaluation
2
+
3
+ ```bash
4
+ CUDA_VISIBLE_DEVICES=4 python test.py --model_type=base_128x128 \
5
+ --checkpoint_path checkpoints/celeba/base_128x128_flip_100/checkpoint.500000.pt \
6
+ --end_sample_idx=1 \
7
+ --test_batch_size=1 \
8
+ --dataset celeba \
9
+ --num_classes 19 \
10
+ --save_path=results/celeba/base.png \
11
+ --test_captions "The woman wears earrings. She has wavy hair. She is attractive."
12
+ ```
13
+
14
+ ```bash
15
+ CUDA_VISIBLE_DEVICES=4 python test.py --conf_path confs/face_example.yml
16
+ ```
 
 
 
 
 
 
 
 
datasets/celeba.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import random
4
+ from collections import namedtuple
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ from torchvision.transforms import Compose, InterpolationMode, RandomCrop, RandomHorizontalFlip, Resize, ToTensor
14
+
15
+ CelebaClass = namedtuple('CelebaClass', ['name', 'id', 'color'])
16
+ # autopep8: off
17
+ classes = [
18
+ CelebaClass('background', 0, ( 0, 0, 0)),
19
+ CelebaClass('skin', 1, (204, 0, 0)),
20
+ CelebaClass('nose', 2, ( 76, 153, 0)),
21
+ CelebaClass('eye_g', 3, (204, 204, 0)),
22
+ CelebaClass('l_eye', 4, ( 51, 51, 255)),
23
+ CelebaClass('r_eye', 5, (204, 0, 204)),
24
+ CelebaClass('l_brow', 6, ( 0, 255, 255)),
25
+ CelebaClass('r_brow', 7, (255, 204, 204)),
26
+ CelebaClass('l_ear', 8, (102, 51, 0)),
27
+ CelebaClass('r_ear', 9, (255, 0, 0)),
28
+ CelebaClass('mouth', 10, (102, 204, 0)),
29
+ CelebaClass('u_lip', 11, (255, 255, 0)),
30
+ CelebaClass('l_lip', 12, ( 0, 0, 153)),
31
+ CelebaClass('hair', 13, ( 0, 0, 204)),
32
+ CelebaClass('hat', 14, (255, 51, 153)),
33
+ CelebaClass('ear_r', 15, ( 0, 204, 204)),
34
+ CelebaClass('neck_l', 16, ( 0, 51, 0)),
35
+ CelebaClass('neck', 17, (255, 153, 51)),
36
+ CelebaClass('cloth', 18, ( 0, 204, 0)),
37
+ ]
38
+ # autopep8: on
39
+ num_classes = 19
40
+ mapping_id = torch.tensor([x.id for x in classes])
41
+ colors = torch.tensor([cls.color for cls in classes])
42
+
43
+
44
+ def normalize_to_neg_one_to_one(img):
45
+ return img * 2 - 1
46
+
47
+
48
+ def unnormalize_to_zero_to_one(img):
49
+ return (img + 1) * 0.5
50
+
51
+
52
+ def unnormalize_and_clamp_to_zero_to_one(img):
53
+ return torch.clamp(unnormalize_to_zero_to_one(img.cpu()), 0, 1)
54
+
55
+
56
+ def exists(val):
57
+ return val is not None
58
+
59
+
60
+ def default(val, d):
61
+ if exists(val):
62
+ return val
63
+ return d() if callable(d) else d
64
+
65
+
66
+ class ToTensorNoNorm():
67
+ def __call__(self, X_i):
68
+ X_i = np.array(X_i)
69
+
70
+ if len(X_i.shape) == 2:
71
+ # Add channel dim.
72
+ X_i = X_i[:, :, None]
73
+
74
+ return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
75
+
76
+
77
+ def interpolate_3d(x, *args, **kwargs):
78
+ return F.interpolate(x.unsqueeze(0), *args, **kwargs).squeeze(0)
79
+
80
+
81
+ class RandomResize(nn.Module):
82
+ def __init__(self, scale=(0.5, 2.0), mode='nearest'):
83
+ super().__init__()
84
+ self.scale = scale
85
+ self.mode = mode
86
+
87
+ def get_random_scale(self):
88
+ return random.uniform(*self.scale)
89
+
90
+ def forward(self, x):
91
+ random_scale = self.get_random_scale()
92
+ x = interpolate_3d(x, scale_factor=random_scale, mode=self.mode)
93
+ return x
94
+
95
+
96
+ def read_jsonl(jsonl_path):
97
+ import jsonlines
98
+ lines = []
99
+ with jsonlines.open(jsonl_path, 'r') as f:
100
+ for line in f.iter():
101
+ lines.append(line)
102
+ return lines
103
+
104
+
105
+ class CelebaDataset(Dataset):
106
+ def __init__(
107
+ self,
108
+ root="",
109
+ split='train',
110
+ side_x=128,
111
+ side_y=128,
112
+ caption_list_dir='',
113
+ augmentation_type='flip',
114
+ ):
115
+ super().__init__()
116
+ self.root = Path(root)
117
+ self.image_dir = osp.join(self.root, 'CelebA-HQ-img')
118
+ self.label_dir = osp.join(self.root, 'CelebAMask-HQ-mask-anno', 'preprocessed')
119
+ self.split = split
120
+ self.side_x = side_x
121
+ self.side_y = side_y
122
+
123
+ self.caption_list_dir = caption_list_dir
124
+ captions_jsonl = read_jsonl(osp.join(self.caption_list_dir, f'{split}_captions.jsonl'))
125
+ self.caption_dict = {}
126
+ for caption_jsonl in captions_jsonl:
127
+ self.caption_dict[osp.splitext(caption_jsonl['file_name'])[0]] = caption_jsonl['text']
128
+
129
+ if augmentation_type == 'none':
130
+ self.augmentation = Compose([
131
+ Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
132
+ # ToTensor(),
133
+ ])
134
+ elif augmentation_type == 'flip':
135
+ self.augmentation = Compose([
136
+ Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
137
+ RandomHorizontalFlip(p=0.5),
138
+ # ToTensor(),
139
+ ])
140
+ elif 'resizedCrop' in augmentation_type:
141
+ scale = [float(s) for s in augmentation_type.split('_')[1:]]
142
+ assert len(scale) == 2, scale
143
+ self.augmentation = Compose([
144
+ RandomResize(scale=scale, mode='nearest'),
145
+ RandomCrop((1024, 1024)),
146
+ Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
147
+ RandomHorizontalFlip(p=0.5),
148
+ # ToTensor(),
149
+ ])
150
+ else:
151
+ raise NotImplementedError(augmentation_type)
152
+
153
+ # verification
154
+ self.images = sorted([osp.join(self.image_dir, file) for file in os.listdir(self.image_dir)
155
+ if osp.splitext(file)[0] in self.caption_dict.keys()])
156
+ self.labels = sorted([osp.join(self.label_dir, file) for file in os.listdir(self.label_dir)
157
+ if osp.splitext(file)[0] in self.caption_dict.keys()])
158
+
159
+ assert len(self.images) == len(self.labels), f'{len(self.images)} != {len(self.labels)}'
160
+ for img, lbl in zip(self.images, self.labels):
161
+ assert osp.splitext(osp.basename(img))[0] == osp.splitext(osp.basename(lbl))[0]
162
+
163
+ def __len__(self):
164
+ return len(self.images)
165
+
166
+ def random_sample(self):
167
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
168
+
169
+ def sequential_sample(self, ind):
170
+ if ind >= self.__len__() - 1:
171
+ return self.__getitem__(0)
172
+ return self.__getitem__(ind + 1)
173
+
174
+ def skip_sample(self, ind):
175
+ return self.sequential_sample(ind=ind)
176
+
177
+ def get_caption_list_objects(self, idx):
178
+ filename = osp.splitext(osp.basename(self.images[idx]))[0]
179
+ caption = random.choice(self.caption_dict[filename])
180
+ return caption
181
+
182
+ def __getitem__(self, idx):
183
+ # load image label
184
+ try:
185
+ original_pil_image = Image.open(self.images[idx]).convert("RGB")
186
+ original_pil_target = Image.open(self.labels[idx])
187
+ except (OSError, ValueError) as e:
188
+ print(f"An exception occurred trying to load file {self.images[idx]}.")
189
+ print(f"Skipping index {idx}")
190
+ return self.skip_sample(idx)
191
+
192
+ # Transforms
193
+ image = Resize((1024, 1024), InterpolationMode.NEAREST)(ToTensor()(original_pil_image))
194
+ label = Resize((1024, 1024), InterpolationMode.NEAREST)(ToTensorNoNorm()(original_pil_target).float())
195
+ img_lbl = self.augmentation(torch.cat([image, label]))
196
+
197
+ caption = self.get_caption_list_objects(idx)
198
+
199
+ return img_lbl[:3], img_lbl[3:], caption
200
+
201
+
202
+ def transform_lbl(lbl: torch.Tensor, *args, **kwargs):
203
+ lbl = lbl.long()
204
+ if lbl.size(1) == 1:
205
+ # Remove single channel axis.
206
+ lbl = lbl[:, 0]
207
+ rgbs = colors[lbl]
208
+ rgbs = rgbs.permute(0, 3, 1, 2)
209
+ return rgbs / 255.
datasets/cityscapes.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from collections import namedtuple
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from cityscapesscripts.helpers.labels import trainId2label
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+ from torchvision.transforms import Compose, InterpolationMode, RandomCrop, RandomHorizontalFlip, Resize, ToTensor
15
+
16
+ CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
17
+ 'has_instances', 'ignore_in_eval', 'color'])
18
+ # autopep8: off
19
+ classes = [
20
+ CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, ( 0, 0, 0)),
21
+ CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, ( 0, 0, 0)),
22
+ CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, ( 0, 0, 0)),
23
+ CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, ( 0, 0, 0)),
24
+ CityscapesClass('static', 4, 255, 'void', 0, False, True, ( 0, 0, 0)),
25
+ CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
26
+ CityscapesClass('ground', 6, 255, 'void', 0, False, True, ( 81, 0, 81)),
27
+ CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
28
+ CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
29
+ CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
30
+ CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
31
+ CityscapesClass('building', 11, 2, 'construction', 2, False, False, ( 70, 70, 70)),
32
+ CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
33
+ CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
34
+ CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
35
+ CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
36
+ CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
37
+ CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
38
+ CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
39
+ CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
40
+ CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
41
+ CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
42
+ CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
43
+ CityscapesClass('sky', 23, 10, 'sky', 5, False, False, ( 70, 130, 180)),
44
+ CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
45
+ CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
46
+ CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, ( 0, 0, 142)),
47
+ CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, ( 0, 0, 70)),
48
+ CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, ( 0, 60, 100)),
49
+ CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, ( 0, 0, 90)),
50
+ CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, ( 0, 0, 110)),
51
+ CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, ( 0, 80, 100)),
52
+ CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, ( 0, 0, 230)),
53
+ CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
54
+ CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, ( 0, 0, 142)),
55
+ ]
56
+ # autopep8: on
57
+
58
+ map_id_to_id = torch.tensor([x.id for x in classes])
59
+ map_id_to_category_id = torch.tensor([x.category_id for x in classes])
60
+ map_id_to_train_id = torch.tensor([x.train_id for x in classes])
61
+ id_type_to_classes = dict(
62
+ id=dict(num_classes=34,
63
+ map_fn=torch.tensor([x if x not in (-1, ) else 0 for x in map_id_to_id]),
64
+ names=[cls.name for cls in classes][:-1]),
65
+ category_id=dict(num_classes=8,
66
+ map_fn=map_id_to_category_id,
67
+ names=[cls.name for cls in classes][:-1]), # TODO it is wrong
68
+ train_id=dict(num_classes=20,
69
+ map_fn=torch.tensor([x if x not in (-1, 255) else 19 for x in map_id_to_train_id]),
70
+ names=[i.name for i in classes if i.train_id != 255][:-1] + ['unlabeled']),
71
+ )
72
+
73
+
74
+ def normalize_to_neg_one_to_one(img):
75
+ return img * 2 - 1
76
+
77
+
78
+ def unnormalize_to_zero_to_one(img):
79
+ return (img + 1) * 0.5
80
+
81
+
82
+ def unnormalize_and_clamp_to_zero_to_one(img):
83
+ return torch.clamp(unnormalize_to_zero_to_one(img.cpu()), 0, 1)
84
+
85
+
86
+ def exists(val):
87
+ return val is not None
88
+
89
+
90
+ def default(val, d):
91
+ if exists(val):
92
+ return val
93
+ return d() if callable(d) else d
94
+
95
+
96
+ class ToTensorNoNorm():
97
+ def __call__(self, X_i):
98
+ X_i = np.array(X_i)
99
+
100
+ if len(X_i.shape) == 2:
101
+ # Add channel dim.
102
+ X_i = X_i[:, :, None]
103
+
104
+ return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
105
+
106
+
107
+ def interpolate_3d(x, *args, **kwargs):
108
+ return F.interpolate(x.unsqueeze(0), *args, **kwargs).squeeze(0)
109
+
110
+
111
+ class RandomResize(nn.Module):
112
+ def __init__(self, scale=(0.5, 2.0), mode='nearest'):
113
+ super().__init__()
114
+ self.scale = scale
115
+ self.mode = mode
116
+
117
+ def get_random_scale(self):
118
+ return random.uniform(*self.scale)
119
+
120
+ def forward(self, x):
121
+ random_scale = self.get_random_scale()
122
+ x = interpolate_3d(x, scale_factor=random_scale, mode=self.mode)
123
+ return x
124
+
125
+
126
+ def read_jsonl(jsonl_path):
127
+ import jsonlines
128
+ lines = []
129
+ with jsonlines.open(jsonl_path, 'r') as f:
130
+ for line in f.iter():
131
+ lines.append(line)
132
+ return lines
133
+
134
+
135
+ class CityscapesDataset(Dataset):
136
+ def __init__(
137
+ self,
138
+ root="",
139
+ split='train',
140
+ side_x=64,
141
+ side_y=64,
142
+ shuffle=False,
143
+ caption_list_dir='',
144
+ id_type='train_id',
145
+ augmentation_type='flip',
146
+ ):
147
+ super().__init__()
148
+ self.root = Path(root)
149
+ self.image_dir = os.path.join(self.root, 'leftImg8bit')
150
+ self.label_dir = os.path.join(self.root, 'gtFine')
151
+ self.split = split
152
+ self.metadata = read_jsonl(os.path.join(caption_list_dir, f'{split}_captions.jsonl'))
153
+ self.metadata = sorted(self.metadata, key=lambda line: line['file_name'])
154
+
155
+ assert id_type == 'train_id'
156
+ self.map_fn = id_type_to_classes[id_type]['map_fn']
157
+ self.class_names = id_type_to_classes[id_type]['names']
158
+ self.num_classes = id_type_to_classes[id_type]['num_classes']
159
+
160
+ # self.text_ctx_len = text_ctx_len
161
+ self.shuffle = shuffle
162
+ self.side_x = side_x
163
+ self.side_y = side_y
164
+
165
+ if augmentation_type == 'none':
166
+ self.augmentation = Compose([
167
+ Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
168
+ # ToTensor(),
169
+ ])
170
+ elif augmentation_type == 'flip':
171
+ self.augmentation = Compose([
172
+ Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
173
+ RandomHorizontalFlip(p=0.5),
174
+ # ToTensor(),
175
+ ])
176
+ elif 'resizedCrop' in augmentation_type:
177
+ scale = [float(s) for s in augmentation_type.split('_')[1:]]
178
+ assert len(scale) == 2, scale
179
+ self.augmentation = Compose([
180
+ RandomResize(scale=scale, mode='nearest'),
181
+ RandomCrop((1024, 2048)),
182
+ Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
183
+ RandomHorizontalFlip(p=0.5),
184
+ # ToTensor(),
185
+ ])
186
+ else:
187
+ raise NotImplementedError(augmentation_type)
188
+
189
+ # filenames of images and labels
190
+ self.images = []
191
+ self.labels = []
192
+ for line in self.metadata:
193
+ cityname = line['file_name'].split('_')[0]
194
+ split = 'val' if cityname in ['frankfurt', 'lindau', 'munster'] else 'train'
195
+ img_dir = os.path.join(self.image_dir, split, cityname, line['file_name'])
196
+ lbl_dir = os.path.join(self.label_dir, split, cityname,
197
+ line['file_name'].replace('leftImg8bit.png', 'gtFine_labelIds.png'))
198
+ assert os.path.isfile(img_dir), img_dir
199
+ assert os.path.isfile(lbl_dir), lbl_dir
200
+ self.images.append(img_dir)
201
+ self.labels.append(lbl_dir)
202
+
203
+ def __len__(self):
204
+ return len(self.images)
205
+
206
+ def random_sample(self):
207
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
208
+
209
+ def sequential_sample(self, ind):
210
+ if ind >= self.__len__() - 1:
211
+ return self.__getitem__(0)
212
+ return self.__getitem__(ind + 1)
213
+
214
+ def skip_sample(self, ind):
215
+ if self.shuffle:
216
+ return self.random_sample()
217
+ return self.sequential_sample(ind=ind)
218
+
219
+ def get_caption_list_objects(self, idx):
220
+ caption = random.choice(self.metadata[idx]['text'])
221
+ return caption
222
+
223
+ def _load_json(self, path):
224
+ with open(path, 'r') as file:
225
+ data = json.load(file)
226
+ return data
227
+
228
+ def __getitem__(self, idx):
229
+ # load image
230
+ try:
231
+ original_pil_image = Image.open(self.images[idx]).convert("RGB")
232
+ original_pil_target = Image.open(self.labels[idx])
233
+ except (OSError, ValueError) as e:
234
+ print(f"An exception occurred trying to load file {self.images[idx]}.")
235
+ print(f"Skipping index {idx}")
236
+ return self.skip_sample(idx)
237
+
238
+ # Transforms
239
+ image = ToTensor()(original_pil_image)
240
+ label = ToTensorNoNorm()(original_pil_target)
241
+ label = self.map_fn[label.long()]
242
+ img_lbl = self.augmentation(torch.cat([image, label]))
243
+
244
+ caption = self.get_caption_list_objects(idx)
245
+
246
+ return img_lbl[:3], img_lbl[3:], caption
247
+
248
+
249
+ def indices_segmentation_to_img(indices, colors):
250
+ if indices.size(1) == 1:
251
+ # Remove single channel axis.
252
+ indices = indices[:, 0]
253
+ # for train_id
254
+ indices = indices * (indices != 255) + torch.ones_like(indices) * 19 * (indices == 255)
255
+ rgbs = colors[indices]
256
+ rgbs = rgbs.permute(0, 3, 1, 2)
257
+ return rgbs / 255.
258
+
259
+
260
+ def get_colors_from_id_type(id_type):
261
+ num_classes = len(id_type_to_classes[id_type]['map_fn'].unique())
262
+ colors = torch.zeros((num_classes, 3))
263
+ exist_ids = []
264
+ for idx, cls in enumerate(id_type_to_classes[id_type]['map_fn']):
265
+ if cls == 255:
266
+ cls = 19
267
+ if cls not in exist_ids:
268
+ colors[cls] = torch.tensor(classes[idx].color)
269
+ exist_ids.append(cls)
270
+ return colors
271
+
272
+
273
+ def transform_lbl(lbl, id_type='id'):
274
+ colors = get_colors_from_id_type(id_type)
275
+ return indices_segmentation_to_img(lbl, colors)
276
+
277
+
278
+ def transform_img_lbl(x, id_type='id', unnorm=True):
279
+ colors = get_colors_from_id_type(id_type)
280
+
281
+ x = x.detach().cpu()
282
+ x = x.unsqueeze(0) if x.dim() == 3 else x
283
+
284
+ # b, _, h, w = x.shape
285
+ img = x[:, :3]
286
+ lbl = x[:, 3:].long()
287
+ img = unnormalize_to_zero_to_one(img) if unnorm else img
288
+ saved_img = torch.cat([img, indices_segmentation_to_img(lbl, colors)]) # b * 2, 3, h ,w
289
+ return saved_img
290
+
291
+
292
+ def trainId2label_fn(train_id_map):
293
+ saved_label_id = torch.zeros_like(train_id_map)
294
+ for t_id, label in trainId2label.items():
295
+ if label.ignoreInEval:
296
+ continue
297
+ saved_label_id[train_id_map == t_id] = label.id
298
+ return saved_label_id
299
+
300
+
301
+ def change_19_to_255(id_map):
302
+ id_map[id_map == 19] = 255
303
+ return id_map
docs/gcdp.png ADDED

Git LFS Details

  • SHA256: c6a2ae1b22b793b7746bd8cf888e3273c195624320a2b411def38402dc174e52
  • Pointer size: 132 Bytes
  • Size of remote file: 3.99 MB
environment.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ channels:
2
+ - pytorch
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - cudatoolkit=11.3.1
7
+ - python=3.9.17
8
+ - pytorch=1.10.1
9
+ - torchvision=0.11.2
10
+ - pip:
11
+ - accelerate==0.21.0
12
+ - annotated-types==0.5.0
13
+ - appdirs==1.4.4
14
+ - attrs==23.1.0
15
+ - autopep8==2.0.2
16
+ - certifi==2023.7.22
17
+ - charset-normalizer==3.2.0
18
+ - cityscapesscripts==2.2.2
19
+ - click==8.1.6
20
+ - coloredlogs==15.0.1
21
+ - contourpy==1.1.0
22
+ - cycler==0.11.0
23
+ - docker-pycreds==0.4.0
24
+ - einops==0.6.1
25
+ - einops-exts==0.0.4
26
+ - ema-pytorch==0.2.3
27
+ - filelock==3.12.2
28
+ - fonttools==4.41.1
29
+ - fsspec==2023.6.0
30
+ - gitdb==4.0.10
31
+ - gitpython==3.1.32
32
+ - huggingface-hub==0.16.4
33
+ - humanfriendly==10.0
34
+ - idna==3.4
35
+ - importlib-resources==6.0.0
36
+ - jsonlines==3.1.0
37
+ - kiwisolver==1.4.4
38
+ - kornia==0.6.12
39
+ - matplotlib==3.7.2
40
+ - packaging==23.1
41
+ - pathtools==0.1.2
42
+ - protobuf==4.23.4
43
+ - psutil==5.9.5
44
+ - pycodestyle==2.11.0
45
+ - pydantic==2.1.1
46
+ - pydantic-core==2.4.0
47
+ - pyparsing==3.0.9
48
+ - pyquaternion==0.9.9
49
+ - python-dateutil==2.8.2
50
+ - pytorch-warmup==0.1.1
51
+ - pyyaml==6.0.1
52
+ - regex==2023.6.3
53
+ - requests==2.31.0
54
+ - safetensors==0.3.1
55
+ - sentencepiece==0.1.99
56
+ - sentry-sdk==1.28.1
57
+ - setproctitle==1.3.2
58
+ - smmap==5.0.0
59
+ - tokenizers==0.13.3
60
+ - tomli==2.0.1
61
+ - tqdm==4.65.0
62
+ - transformers==4.31.0
63
+ - typing==3.7.4.3
64
+ - urllib3==2.0.4
65
+ - wandb==0.15.7
66
+ - zipp==3.16.2
example.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ # Load the image
5
+ image_path = '/home/zouxuechao/mmface/gcdp/repaint/data/datasets/gt_keep_masks/gcdp/base_0_0.png'
6
+ image = Image.open(image_path)
7
+
8
+ # Convert the image to a numpy array
9
+ image_array = np.array(image)
10
+
11
+ # Check if the image is 256x128 with 3 channels (RGB)
12
+ if image_array.shape == (128, 256, 3):
13
+ # Modify the left 128x128 section to be 0 (black) for all 3 channels
14
+ image_array[:, :128, :] = 0
15
+
16
+ # Modify the right 128x128 section to be 255 (white) for all 3 channels
17
+ image_array[:, 128:, :] = 255
18
+
19
+ # Convert back to an image
20
+ modified_image = Image.fromarray(image_array)
21
+
22
+ # Save the modified image
23
+ modified_image.save('modified_image.png')
24
+ else:
25
+ print("The image does not have the required size (256x128 with 3 channels).")
imagen_pytorch/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imagen_pytorch.imagen_pytorch import Imagen, Unet
2
+ from imagen_pytorch.imagen_pytorch import NullUnet
3
+ from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
4
+ from imagen_pytorch.trainer import ImagenTrainer
5
+ from imagen_pytorch.version import __version__
6
+
7
+ # imagen using the elucidated ddpm from Tero Karras' new paper
8
+
9
+ from imagen_pytorch.elucidated_imagen import ElucidatedImagen
10
+
11
+ # config driven creation of imagen instances
12
+
13
+ from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig
14
+
15
+ # utils
16
+
17
+ from imagen_pytorch.utils import load_imagen_from_checkpoint
18
+
19
+ # video
20
+
21
+ from imagen_pytorch.imagen_video import Unet3D
22
+
23
+ # joint
24
+
25
+ from imagen_pytorch.joint_imagen import BaseJointUnet, JointImagen, SRJointUnet
26
+ from imagen_pytorch.trainer import ImagenTrainer, JointImagenTrainer
imagen_pytorch/cli.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from imagen_pytorch import load_imagen_from_checkpoint
6
+ from imagen_pytorch.version import __version__
7
+ from imagen_pytorch.utils import safeget
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def simple_slugify(text, max_length = 255):
13
+ return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
14
+
15
+ def main():
16
+ pass
17
+
18
+ @click.command()
19
+ @click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
20
+ @click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
21
+ @click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
22
+ @click.argument('text')
23
+ def imagen(
24
+ model,
25
+ cond_scale,
26
+ load_ema,
27
+ text
28
+ ):
29
+ model_path = Path(model)
30
+ full_model_path = str(model_path.resolve())
31
+ assert model_path.exists(), f'model not found at {full_model_path}'
32
+ loaded = torch.load(str(model_path))
33
+
34
+ # get version
35
+
36
+ version = safeget(loaded, 'version')
37
+ print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
38
+
39
+ # get imagen parameters and type
40
+
41
+ imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
42
+ imagen.cuda()
43
+
44
+ # generate image
45
+
46
+ pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True)
47
+
48
+ image_path = f'./{simple_slugify(text)}.png'
49
+ pil_image[0].save(image_path)
50
+
51
+ print(f'image saved to {str(image_path)}')
52
+ return
imagen_pytorch/configs.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pydantic import BaseModel, validator, root_validator
3
+ from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
4
+ from enum import Enum
5
+
6
+ from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
7
+ from imagen_pytorch.trainer import ImagenTrainer
8
+ from imagen_pytorch.elucidated_imagen import ElucidatedImagen
9
+ from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim
10
+
11
+ # helper functions
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+ def default(val, d):
17
+ return val if exists(val) else d
18
+
19
+ def ListOrTuple(inner_type):
20
+ return Union[List[inner_type], Tuple[inner_type]]
21
+
22
+ def SingleOrList(inner_type):
23
+ return Union[inner_type, ListOrTuple(inner_type)]
24
+
25
+ # noise schedule
26
+
27
+ class NoiseSchedule(Enum):
28
+ cosine = 'cosine'
29
+ linear = 'linear'
30
+
31
+ class AllowExtraBaseModel(BaseModel):
32
+ class Config:
33
+ extra = "allow"
34
+ use_enum_values = True
35
+
36
+ # imagen pydantic classes
37
+
38
+ class NullUnetConfig(BaseModel):
39
+ is_null: bool
40
+
41
+ def create(self):
42
+ return NullUnet()
43
+
44
+ class UnetConfig(AllowExtraBaseModel):
45
+ dim: int
46
+ dim_mults: ListOrTuple(int)
47
+ text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
48
+ cond_dim: int = None
49
+ channels: int = 3
50
+ attn_dim_head: int = 32
51
+ attn_heads: int = 16
52
+
53
+ def create(self):
54
+ return Unet(**self.dict())
55
+
56
+ class Unet3DConfig(AllowExtraBaseModel):
57
+ dim: int
58
+ dim_mults: ListOrTuple(int)
59
+ text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
60
+ cond_dim: int = None
61
+ channels: int = 3
62
+ attn_dim_head: int = 32
63
+ attn_heads: int = 16
64
+
65
+ def create(self):
66
+ return Unet3D(**self.dict())
67
+
68
+ class ImagenConfig(AllowExtraBaseModel):
69
+ unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
70
+ image_sizes: ListOrTuple(int)
71
+ video: bool = False
72
+ timesteps: SingleOrList(int) = 1000
73
+ noise_schedules: SingleOrList(NoiseSchedule) = 'cosine'
74
+ text_encoder_name: str = DEFAULT_T5_NAME
75
+ channels: int = 3
76
+ loss_type: str = 'l2'
77
+ cond_drop_prob: float = 0.5
78
+
79
+ @validator('image_sizes')
80
+ def check_image_sizes(cls, image_sizes, values):
81
+ unets = values.get('unets')
82
+ if len(image_sizes) != len(unets):
83
+ raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
84
+ return image_sizes
85
+
86
+ def create(self):
87
+ decoder_kwargs = self.dict()
88
+ unets_kwargs = decoder_kwargs.pop('unets')
89
+ is_video = decoder_kwargs.pop('video', False)
90
+
91
+ unets = []
92
+
93
+ for unet, unet_kwargs in zip(self.unets, unets_kwargs):
94
+ if isinstance(unet, NullUnetConfig):
95
+ unet_klass = NullUnet
96
+ elif is_video:
97
+ unet_klass = Unet3D
98
+ else:
99
+ unet_klass = Unet
100
+
101
+ unets.append(unet_klass(**unet_kwargs))
102
+
103
+ imagen = Imagen(unets, **decoder_kwargs)
104
+
105
+ imagen._config = self.dict().copy()
106
+ return imagen
107
+
108
+ class ElucidatedImagenConfig(AllowExtraBaseModel):
109
+ unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
110
+ image_sizes: ListOrTuple(int)
111
+ video: bool = False
112
+ text_encoder_name: str = DEFAULT_T5_NAME
113
+ channels: int = 3
114
+ cond_drop_prob: float = 0.5
115
+ num_sample_steps: SingleOrList(int) = 32
116
+ sigma_min: SingleOrList(float) = 0.002
117
+ sigma_max: SingleOrList(int) = 80
118
+ sigma_data: SingleOrList(float) = 0.5
119
+ rho: SingleOrList(int) = 7
120
+ P_mean: SingleOrList(float) = -1.2
121
+ P_std: SingleOrList(float) = 1.2
122
+ S_churn: SingleOrList(int) = 80
123
+ S_tmin: SingleOrList(float) = 0.05
124
+ S_tmax: SingleOrList(int) = 50
125
+ S_noise: SingleOrList(float) = 1.003
126
+
127
+ @validator('image_sizes')
128
+ def check_image_sizes(cls, image_sizes, values):
129
+ unets = values.get('unets')
130
+ if len(image_sizes) != len(unets):
131
+ raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
132
+ return image_sizes
133
+
134
+ def create(self):
135
+ decoder_kwargs = self.dict()
136
+ unets_kwargs = decoder_kwargs.pop('unets')
137
+ is_video = decoder_kwargs.pop('video', False)
138
+
139
+ unet_klass = Unet3D if is_video else Unet
140
+
141
+ unets = []
142
+
143
+ for unet, unet_kwargs in zip(self.unets, unets_kwargs):
144
+ if isinstance(unet, NullUnetConfig):
145
+ unet_klass = NullUnet
146
+ elif is_video:
147
+ unet_klass = Unet3D
148
+ else:
149
+ unet_klass = Unet
150
+
151
+ unets.append(unet_klass(**unet_kwargs))
152
+
153
+ imagen = ElucidatedImagen(unets, **decoder_kwargs)
154
+
155
+ imagen._config = self.dict().copy()
156
+ return imagen
157
+
158
+ class ImagenTrainerConfig(AllowExtraBaseModel):
159
+ imagen: dict
160
+ elucidated: bool = False
161
+ video: bool = False
162
+ use_ema: bool = True
163
+ lr: SingleOrList(float) = 1e-4
164
+ eps: SingleOrList(float) = 1e-8
165
+ beta1: float = 0.9
166
+ beta2: float = 0.99
167
+ max_grad_norm: Optional[float] = None
168
+ group_wd_params: bool = True
169
+ warmup_steps: SingleOrList(Optional[int]) = None
170
+ cosine_decay_max_steps: SingleOrList(Optional[int]) = None
171
+
172
+ def create(self):
173
+ trainer_kwargs = self.dict()
174
+
175
+ imagen_config = trainer_kwargs.pop('imagen')
176
+ elucidated = trainer_kwargs.pop('elucidated')
177
+
178
+ imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
179
+ imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()
180
+
181
+ return ImagenTrainer(imagen, **trainer_kwargs)
imagen_pytorch/data.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms as T, utils
8
+
9
+ from PIL import Image
10
+
11
+ # helpers functions
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+ def cycle(dl):
17
+ while True:
18
+ for data in dl:
19
+ yield data
20
+
21
+ def convert_image_to(img_type, image):
22
+ if image.mode != img_type:
23
+ return image.convert(img_type)
24
+ return image
25
+
26
+ # dataset and dataloader
27
+
28
+ class Dataset(Dataset):
29
+ def __init__(
30
+ self,
31
+ folder,
32
+ image_size,
33
+ exts = ['jpg', 'jpeg', 'png', 'tiff'],
34
+ convert_image_to_type = None
35
+ ):
36
+ super().__init__()
37
+ self.folder = folder
38
+ self.image_size = image_size
39
+ self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
40
+
41
+ convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
42
+
43
+ self.transform = T.Compose([
44
+ T.Lambda(convert_fn),
45
+ T.Resize(image_size),
46
+ T.RandomHorizontalFlip(),
47
+ T.CenterCrop(image_size),
48
+ T.ToTensor()
49
+ ])
50
+
51
+ def __len__(self):
52
+ return len(self.paths)
53
+
54
+ def __getitem__(self, index):
55
+ path = self.paths[index]
56
+ img = Image.open(path)
57
+ return self.transform(img)
58
+
59
+ def get_images_dataloader(
60
+ folder,
61
+ *,
62
+ batch_size,
63
+ image_size,
64
+ shuffle = True,
65
+ cycle_dl = False,
66
+ pin_memory = True
67
+ ):
68
+ ds = Dataset(folder, image_size)
69
+ dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
70
+
71
+ if cycle_dl:
72
+ dl = cycle(dl)
73
+ return dl
imagen_pytorch/elucidated_imagen.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+ from random import random
3
+ from functools import partial
4
+ from contextlib import contextmanager, nullcontext
5
+ from typing import List, Union
6
+ from collections import namedtuple
7
+ from tqdm.auto import tqdm
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn, einsum
12
+ from torch.cuda.amp import autocast
13
+ from torch.nn.parallel import DistributedDataParallel
14
+ import torchvision.transforms as T
15
+
16
+ import kornia.augmentation as K
17
+
18
+ from einops import rearrange, repeat, reduce
19
+ from einops_exts import rearrange_many
20
+
21
+ from imagen_pytorch.imagen_pytorch import (
22
+ GaussianDiffusionContinuousTimes,
23
+ Unet,
24
+ NullUnet,
25
+ first,
26
+ exists,
27
+ identity,
28
+ maybe,
29
+ default,
30
+ cast_tuple,
31
+ cast_uint8_images_to_float,
32
+ is_float_dtype,
33
+ eval_decorator,
34
+ check_shape,
35
+ pad_tuple_to_length,
36
+ resize_image_to,
37
+ right_pad_dims_to,
38
+ module_device,
39
+ normalize_neg_one_to_one,
40
+ unnormalize_zero_to_one,
41
+ )
42
+
43
+ from imagen_pytorch.imagen_video.imagen_video import (
44
+ Unet3D,
45
+ resize_video_to
46
+ )
47
+
48
+ from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
49
+
50
+ # constants
51
+
52
+ Hparams_fields = [
53
+ 'num_sample_steps',
54
+ 'sigma_min',
55
+ 'sigma_max',
56
+ 'sigma_data',
57
+ 'rho',
58
+ 'P_mean',
59
+ 'P_std',
60
+ 'S_churn',
61
+ 'S_tmin',
62
+ 'S_tmax',
63
+ 'S_noise'
64
+ ]
65
+
66
+ Hparams = namedtuple('Hparams', Hparams_fields)
67
+
68
+ # helper functions
69
+
70
+ def log(t, eps = 1e-20):
71
+ return torch.log(t.clamp(min = eps))
72
+
73
+ # main class
74
+
75
+ class ElucidatedImagen(nn.Module):
76
+ def __init__(
77
+ self,
78
+ unets,
79
+ *,
80
+ image_sizes, # for cascading ddpm, image size at each stage
81
+ text_encoder_name = DEFAULT_T5_NAME,
82
+ text_embed_dim = None,
83
+ channels = 3,
84
+ cond_drop_prob = 0.1,
85
+ random_crop_sizes = None,
86
+ lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
87
+ per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
88
+ condition_on_text = True,
89
+ auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
90
+ dynamic_thresholding = True,
91
+ dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
92
+ only_train_unet_number = None,
93
+ lowres_noise_schedule = 'linear',
94
+ num_sample_steps = 32, # number of sampling steps
95
+ sigma_min = 0.002, # min noise level
96
+ sigma_max = 80, # max noise level
97
+ sigma_data = 0.5, # standard deviation of data distribution
98
+ rho = 7, # controls the sampling schedule
99
+ P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
100
+ P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
101
+ S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
102
+ S_tmin = 0.05,
103
+ S_tmax = 50,
104
+ S_noise = 1.003,
105
+ ):
106
+ super().__init__()
107
+
108
+ self.only_train_unet_number = only_train_unet_number
109
+
110
+ # conditioning hparams
111
+
112
+ self.condition_on_text = condition_on_text
113
+ self.unconditional = not condition_on_text
114
+
115
+ # channels
116
+
117
+ self.channels = channels
118
+
119
+ # automatically take care of ensuring that first unet is unconditional
120
+ # while the rest of the unets are conditioned on the low resolution image produced by previous unet
121
+
122
+ unets = cast_tuple(unets)
123
+ num_unets = len(unets)
124
+
125
+ # randomly cropping for upsampler training
126
+
127
+ self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
128
+ assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
129
+
130
+ # lowres augmentation noise schedule
131
+
132
+ self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
133
+
134
+ # get text encoder
135
+
136
+ self.text_encoder_name = text_encoder_name
137
+ self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
138
+
139
+ self.encode_text = partial(t5_encode_text, name = text_encoder_name)
140
+
141
+ # construct unets
142
+
143
+ self.unets = nn.ModuleList([])
144
+ self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
145
+
146
+ for ind, one_unet in enumerate(unets):
147
+ assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
148
+ is_first = ind == 0
149
+
150
+ one_unet = one_unet.cast_model_parameters(
151
+ lowres_cond = not is_first,
152
+ cond_on_text = self.condition_on_text,
153
+ text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
154
+ channels = self.channels,
155
+ channels_out = self.channels
156
+ )
157
+
158
+ self.unets.append(one_unet)
159
+
160
+ # determine whether we are training on images or video
161
+
162
+ is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
163
+ self.is_video = is_video
164
+
165
+ self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
166
+ self.resize_to = resize_video_to if is_video else resize_image_to
167
+
168
+ # unet image sizes
169
+
170
+ self.image_sizes = cast_tuple(self.image_sizes)
171
+ assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}'
172
+
173
+ self.sample_channels = cast_tuple(self.channels, num_unets)
174
+
175
+ # cascading ddpm related stuff
176
+
177
+ lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
178
+ assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
179
+
180
+ self.lowres_sample_noise_level = lowres_sample_noise_level
181
+ self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
182
+
183
+ # classifier free guidance
184
+
185
+ self.cond_drop_prob = cond_drop_prob
186
+ self.can_classifier_guidance = cond_drop_prob > 0.
187
+
188
+ # normalize and unnormalize image functions
189
+
190
+ self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
191
+ self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
192
+ self.input_image_range = (0. if auto_normalize_img else -1., 1.)
193
+
194
+ # dynamic thresholding
195
+
196
+ self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
197
+ self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
198
+
199
+ # elucidating parameters
200
+
201
+ hparams = [
202
+ num_sample_steps,
203
+ sigma_min,
204
+ sigma_max,
205
+ sigma_data,
206
+ rho,
207
+ P_mean,
208
+ P_std,
209
+ S_churn,
210
+ S_tmin,
211
+ S_tmax,
212
+ S_noise,
213
+ ]
214
+
215
+ hparams = [cast_tuple(hp, num_unets) for hp in hparams]
216
+ self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)]
217
+
218
+ # one temp parameter for keeping track of device
219
+
220
+ self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
221
+
222
+ # default to device of unets passed in
223
+
224
+ self.to(next(self.unets.parameters()).device)
225
+
226
+ def force_unconditional_(self):
227
+ self.condition_on_text = False
228
+ self.unconditional = True
229
+
230
+ for unet in self.unets:
231
+ unet.cond_on_text = False
232
+
233
+ @property
234
+ def device(self):
235
+ return self._temp.device
236
+
237
+ def get_unet(self, unet_number):
238
+ assert 0 < unet_number <= len(self.unets)
239
+ index = unet_number - 1
240
+
241
+ if isinstance(self.unets, nn.ModuleList):
242
+ unets_list = [unet for unet in self.unets]
243
+ delattr(self, 'unets')
244
+ self.unets = unets_list
245
+
246
+ if index != self.unet_being_trained_index:
247
+ for unet_index, unet in enumerate(self.unets):
248
+ unet.to(self.device if unet_index == index else 'cpu')
249
+
250
+ self.unet_being_trained_index = index
251
+ return self.unets[index]
252
+
253
+ def reset_unets_all_one_device(self, device = None):
254
+ device = default(device, self.device)
255
+ self.unets = nn.ModuleList([*self.unets])
256
+ self.unets.to(device)
257
+
258
+ self.unet_being_trained_index = -1
259
+
260
+ @contextmanager
261
+ def one_unet_in_gpu(self, unet_number = None, unet = None):
262
+ assert exists(unet_number) ^ exists(unet)
263
+
264
+ if exists(unet_number):
265
+ unet = self.unets[unet_number - 1]
266
+
267
+ devices = [module_device(unet) for unet in self.unets]
268
+ self.unets.cpu()
269
+ unet.to(self.device)
270
+
271
+ yield
272
+
273
+ for unet, device in zip(self.unets, devices):
274
+ unet.to(device)
275
+
276
+ # overriding state dict functions
277
+
278
+ def state_dict(self, *args, **kwargs):
279
+ self.reset_unets_all_one_device()
280
+ return super().state_dict(*args, **kwargs)
281
+
282
+ def load_state_dict(self, *args, **kwargs):
283
+ self.reset_unets_all_one_device()
284
+ return super().load_state_dict(*args, **kwargs)
285
+
286
+ # dynamic thresholding
287
+
288
+ def threshold_x_start(self, x_start, dynamic_threshold = True):
289
+ if not dynamic_threshold:
290
+ return x_start.clamp(-1., 1.)
291
+
292
+ s = torch.quantile(
293
+ rearrange(x_start, 'b ... -> b (...)').abs(),
294
+ self.dynamic_thresholding_percentile,
295
+ dim = -1
296
+ )
297
+
298
+ s.clamp_(min = 1.)
299
+ s = right_pad_dims_to(x_start, s)
300
+ return x_start.clamp(-s, s) / s
301
+
302
+ # derived preconditioning params - Table 1
303
+
304
+ def c_skip(self, sigma_data, sigma):
305
+ return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
306
+
307
+ def c_out(self, sigma_data, sigma):
308
+ return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
309
+
310
+ def c_in(self, sigma_data, sigma):
311
+ return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
312
+
313
+ def c_noise(self, sigma):
314
+ return log(sigma) * 0.25
315
+
316
+ # preconditioned network output
317
+ # equation (7) in the paper
318
+
319
+ def preconditioned_network_forward(
320
+ self,
321
+ unet_forward,
322
+ noised_images,
323
+ sigma,
324
+ *,
325
+ sigma_data,
326
+ clamp = False,
327
+ dynamic_threshold = True,
328
+ **kwargs
329
+ ):
330
+ batch, device = noised_images.shape[0], noised_images.device
331
+
332
+ if isinstance(sigma, float):
333
+ sigma = torch.full((batch,), sigma, device = device)
334
+
335
+ padded_sigma = self.right_pad_dims_to_datatype(sigma)
336
+
337
+ net_out = unet_forward(
338
+ self.c_in(sigma_data, padded_sigma) * noised_images,
339
+ self.c_noise(sigma),
340
+ **kwargs
341
+ )
342
+
343
+ out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out
344
+
345
+ if not clamp:
346
+ return out
347
+
348
+ return self.threshold_x_start(out, dynamic_threshold)
349
+
350
+ # sampling
351
+
352
+ # sample schedule
353
+ # equation (5) in the paper
354
+
355
+ def sample_schedule(
356
+ self,
357
+ num_sample_steps,
358
+ rho,
359
+ sigma_min,
360
+ sigma_max
361
+ ):
362
+ N = num_sample_steps
363
+ inv_rho = 1 / rho
364
+
365
+ steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
366
+ sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
367
+
368
+ sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
369
+ return sigmas
370
+
371
+ @torch.no_grad()
372
+ def one_unet_sample(
373
+ self,
374
+ unet,
375
+ shape,
376
+ *,
377
+ unet_number,
378
+ clamp = True,
379
+ dynamic_threshold = True,
380
+ cond_scale = 1.,
381
+ use_tqdm = True,
382
+ inpaint_images = None,
383
+ inpaint_masks = None,
384
+ inpaint_resample_times = 5,
385
+ init_images = None,
386
+ skip_steps = None,
387
+ sigma_min = None,
388
+ sigma_max = None,
389
+ **kwargs
390
+ ):
391
+ # get specific sampling hyperparameters for unet
392
+
393
+ hp = self.hparams[unet_number - 1]
394
+
395
+ sigma_min = default(sigma_min, hp.sigma_min)
396
+ sigma_max = default(sigma_max, hp.sigma_max)
397
+
398
+ # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
399
+
400
+ sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max)
401
+
402
+ gammas = torch.where(
403
+ (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax),
404
+ min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1),
405
+ 0.
406
+ )
407
+
408
+ sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
409
+
410
+ # images is noise at the beginning
411
+
412
+ init_sigma = sigmas[0]
413
+
414
+ images = init_sigma * torch.randn(shape, device = self.device)
415
+
416
+ # initializing with an image
417
+
418
+ if exists(init_images):
419
+ images += init_images
420
+
421
+ # keeping track of x0, for self conditioning if needed
422
+
423
+ x_start = None
424
+
425
+ # prepare inpainting images and mask
426
+
427
+ has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
428
+ resample_times = inpaint_resample_times if has_inpainting else 1
429
+
430
+ if has_inpainting:
431
+ inpaint_images = self.normalize_img(inpaint_images)
432
+ inpaint_images = self.resize_to(inpaint_images, shape[-1])
433
+ inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()
434
+
435
+ # unet kwargs
436
+
437
+ unet_kwargs = dict(
438
+ sigma_data = hp.sigma_data,
439
+ clamp = clamp,
440
+ dynamic_threshold = dynamic_threshold,
441
+ cond_scale = cond_scale,
442
+ **kwargs
443
+ )
444
+
445
+ # gradually denoise
446
+
447
+ initial_step = default(skip_steps, 0)
448
+ sigmas_and_gammas = sigmas_and_gammas[initial_step:]
449
+
450
+ total_steps = len(sigmas_and_gammas)
451
+
452
+ for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm):
453
+ is_last_timestep = ind == (total_steps - 1)
454
+
455
+ sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
456
+
457
+ for r in reversed(range(resample_times)):
458
+ is_last_resample_step = r == 0
459
+
460
+ eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
461
+
462
+ sigma_hat = sigma + gamma * sigma
463
+ added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps
464
+
465
+ images_hat = images + added_noise
466
+
467
+ self_cond = x_start if unet.self_cond else None
468
+
469
+ if has_inpainting:
470
+ images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks
471
+
472
+ model_output = self.preconditioned_network_forward(
473
+ unet.forward_with_cond_scale,
474
+ images_hat,
475
+ sigma_hat,
476
+ self_cond = self_cond,
477
+ **unet_kwargs
478
+ )
479
+
480
+ denoised_over_sigma = (images_hat - model_output) / sigma_hat
481
+
482
+ images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
483
+
484
+ # second order correction, if not the last timestep
485
+
486
+ if sigma_next != 0:
487
+ self_cond = model_output if unet.self_cond else None
488
+
489
+ model_output_next = self.preconditioned_network_forward(
490
+ unet.forward_with_cond_scale,
491
+ images_next,
492
+ sigma_next,
493
+ self_cond = self_cond,
494
+ **unet_kwargs
495
+ )
496
+
497
+ denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
498
+ images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
499
+
500
+ images = images_next
501
+
502
+ if has_inpainting and not (is_last_resample_step or is_last_timestep):
503
+ # renoise in repaint and then resample
504
+ repaint_noise = torch.randn(shape, device = self.device)
505
+ images = images + (sigma - sigma_next) * repaint_noise
506
+
507
+ x_start = model_output # save model output for self conditioning
508
+
509
+ images = images.clamp(-1., 1.)
510
+
511
+ if has_inpainting:
512
+ images = images * ~inpaint_masks + inpaint_images * inpaint_masks
513
+
514
+ return self.unnormalize_img(images)
515
+
516
+ @torch.no_grad()
517
+ @eval_decorator
518
+ def sample(
519
+ self,
520
+ texts: List[str] = None,
521
+ text_masks = None,
522
+ text_embeds = None,
523
+ cond_images = None,
524
+ inpaint_images = None,
525
+ inpaint_masks = None,
526
+ inpaint_resample_times = 5,
527
+ init_images = None,
528
+ skip_steps = None,
529
+ sigma_min = None,
530
+ sigma_max = None,
531
+ video_frames = None,
532
+ batch_size = 1,
533
+ cond_scale = 1.,
534
+ lowres_sample_noise_level = None,
535
+ start_at_unet_number = 1,
536
+ start_image_or_video = None,
537
+ stop_at_unet_number = None,
538
+ return_all_unet_outputs = False,
539
+ return_pil_images = False,
540
+ use_tqdm = True,
541
+ device = None,
542
+ ):
543
+ device = default(device, self.device)
544
+ self.reset_unets_all_one_device(device = device)
545
+
546
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
547
+
548
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
549
+ assert all([*map(len, texts)]), 'text cannot be empty'
550
+
551
+ with autocast(enabled = False):
552
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
553
+
554
+ text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
555
+
556
+ if not self.unconditional:
557
+ assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
558
+
559
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
560
+ batch_size = text_embeds.shape[0]
561
+
562
+ if exists(inpaint_images):
563
+ if self.unconditional:
564
+ if batch_size == 1: # assume researcher wants to broadcast along inpainted images
565
+ batch_size = inpaint_images.shape[0]
566
+
567
+ assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
568
+ assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
569
+
570
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
571
+ assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
572
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
573
+
574
+ assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
575
+
576
+ outputs = []
577
+
578
+ is_cuda = next(self.parameters()).is_cuda
579
+ device = next(self.parameters()).device
580
+
581
+ lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
582
+
583
+ num_unets = len(self.unets)
584
+ cond_scale = cast_tuple(cond_scale, num_unets)
585
+
586
+ # handle video and frame dimension
587
+
588
+ assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
589
+
590
+ frame_dims = (video_frames,) if self.is_video else tuple()
591
+
592
+ # initializing with an image or video
593
+
594
+ init_images = cast_tuple(init_images, num_unets)
595
+ init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
596
+
597
+ skip_steps = cast_tuple(skip_steps, num_unets)
598
+
599
+ sigma_min = cast_tuple(sigma_min, num_unets)
600
+ sigma_max = cast_tuple(sigma_max, num_unets)
601
+
602
+ # handle starting at a unet greater than 1, for training only-upscaler training
603
+
604
+ if start_at_unet_number > 1:
605
+ assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
606
+ assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
607
+ assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
608
+
609
+ prev_image_size = self.image_sizes[start_at_unet_number - 2]
610
+ img = self.resize_to(start_image_or_video, prev_image_size)
611
+
612
+ # go through each unet in cascade
613
+
614
+ for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
615
+ if unet_number < start_at_unet_number:
616
+ continue
617
+
618
+ assert not isinstance(unet, NullUnet), 'cannot sample from null unet'
619
+
620
+ context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
621
+
622
+ with context:
623
+ lowres_cond_img = lowres_noise_times = None
624
+
625
+ shape = (batch_size, channel, *frame_dims, image_size, image_size)
626
+
627
+ if unet.lowres_cond:
628
+ lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
629
+
630
+ lowres_cond_img = self.resize_to(img, image_size)
631
+ lowres_cond_img = self.normalize_img(lowres_cond_img)
632
+
633
+ lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
634
+
635
+ if exists(unet_init_images):
636
+ unet_init_images = self.resize_to(unet_init_images, image_size)
637
+
638
+ shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
639
+
640
+ img = self.one_unet_sample(
641
+ unet,
642
+ shape,
643
+ unet_number = unet_number,
644
+ text_embeds = text_embeds,
645
+ text_mask = text_masks,
646
+ cond_images = cond_images,
647
+ inpaint_images = inpaint_images,
648
+ inpaint_masks = inpaint_masks,
649
+ inpaint_resample_times = inpaint_resample_times,
650
+ init_images = unet_init_images,
651
+ skip_steps = unet_skip_steps,
652
+ sigma_min = unet_sigma_min,
653
+ sigma_max = unet_sigma_max,
654
+ cond_scale = unet_cond_scale,
655
+ lowres_cond_img = lowres_cond_img,
656
+ lowres_noise_times = lowres_noise_times,
657
+ dynamic_threshold = dynamic_threshold,
658
+ use_tqdm = use_tqdm
659
+ )
660
+
661
+ outputs.append(img)
662
+
663
+ if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
664
+ break
665
+
666
+ output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
667
+
668
+ if not return_pil_images:
669
+ return outputs[output_index]
670
+
671
+ if not return_all_unet_outputs:
672
+ outputs = outputs[-1:]
673
+
674
+ assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet'
675
+
676
+ pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
677
+
678
+ return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
679
+
680
+ # training
681
+
682
+ def loss_weight(self, sigma_data, sigma):
683
+ return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
684
+
685
+ def noise_distribution(self, P_mean, P_std, batch_size):
686
+ return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
687
+
688
+ def forward(
689
+ self,
690
+ images,
691
+ unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
692
+ texts: List[str] = None,
693
+ text_embeds = None,
694
+ text_masks = None,
695
+ unet_number = None,
696
+ cond_images = None
697
+ ):
698
+ assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
699
+ assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
700
+ unet_number = default(unet_number, 1)
701
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
702
+
703
+ images = cast_uint8_images_to_float(images)
704
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
705
+
706
+ assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
707
+
708
+ unet_index = unet_number - 1
709
+
710
+ unet = default(unet, lambda: self.get_unet(unet_number))
711
+
712
+ assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
713
+
714
+ target_image_size = self.image_sizes[unet_index]
715
+ random_crop_size = self.random_crop_sizes[unet_index]
716
+ prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
717
+ hp = self.hparams[unet_index]
718
+
719
+ batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)
720
+
721
+ frames = images.shape[2] if is_video else None
722
+
723
+ check_shape(images, 'b c ...', c = self.channels)
724
+
725
+ assert h >= target_image_size and w >= target_image_size
726
+
727
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
728
+ assert all([*map(len, texts)]), 'text cannot be empty'
729
+ assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
730
+
731
+ with autocast(enabled = False):
732
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
733
+
734
+ text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
735
+
736
+ if not self.unconditional:
737
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
738
+
739
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
740
+ assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
741
+
742
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
743
+
744
+ lowres_cond_img = lowres_aug_times = None
745
+ if exists(prev_image_size):
746
+ lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
747
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
748
+
749
+ if self.per_sample_random_aug_noise_level:
750
+ lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
751
+ else:
752
+ lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
753
+ lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
754
+
755
+ images = self.resize_to(images, target_image_size)
756
+
757
+ # normalize to [-1, 1]
758
+
759
+ images = self.normalize_img(images)
760
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
761
+
762
+ # random cropping during training
763
+ # for upsamplers
764
+
765
+ if exists(random_crop_size):
766
+ aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
767
+
768
+ if is_video:
769
+ images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h w -> (b f) c h w')
770
+
771
+ # make sure low res conditioner and image both get augmented the same way
772
+ # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
773
+ images = aug(images)
774
+ lowres_cond_img = aug(lowres_cond_img, params = aug._params)
775
+
776
+ if is_video:
777
+ images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h w -> b c f h w', f = frames)
778
+
779
+ # noise the lowres conditioning image
780
+ # at sample time, they then fix the noise level of 0.1 - 0.3
781
+
782
+ lowres_cond_img_noisy = None
783
+ if exists(lowres_cond_img):
784
+ lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
785
+
786
+ # get the sigmas
787
+
788
+ sigmas = self.noise_distribution(hp.P_mean, hp.P_std, batch_size)
789
+ padded_sigmas = self.right_pad_dims_to_datatype(sigmas)
790
+
791
+ # noise
792
+
793
+ noise = torch.randn_like(images)
794
+ noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
795
+
796
+ # unet kwargs
797
+
798
+ unet_kwargs = dict(
799
+ sigma_data = hp.sigma_data,
800
+ text_embeds = text_embeds,
801
+ text_mask = text_masks,
802
+ cond_images = cond_images,
803
+ lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
804
+ lowres_cond_img = lowres_cond_img_noisy,
805
+ cond_drop_prob = self.cond_drop_prob,
806
+ )
807
+
808
+ # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
809
+
810
+ # Because 'unet' can be an instance of DistributedDataParallel coming from the
811
+ # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
812
+ # access the member 'module' of the wrapped unet instance.
813
+ self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet
814
+
815
+ if self_cond and random() < 0.5:
816
+ with torch.no_grad():
817
+ pred_x0 = self.preconditioned_network_forward(
818
+ unet.forward,
819
+ noised_images,
820
+ sigmas,
821
+ **unet_kwargs
822
+ ).detach()
823
+
824
+ unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0}
825
+
826
+ # get prediction
827
+
828
+ denoised_images = self.preconditioned_network_forward(
829
+ unet.forward,
830
+ noised_images,
831
+ sigmas,
832
+ **unet_kwargs
833
+ )
834
+
835
+ # losses
836
+
837
+ losses = F.mse_loss(denoised_images, images, reduction = 'none')
838
+ losses = reduce(losses, 'b ... -> b', 'mean')
839
+
840
+ # loss weighting
841
+
842
+ losses = losses * self.loss_weight(hp.sigma_data, sigmas)
843
+
844
+ # return average loss
845
+
846
+ return losses.mean()
imagen_pytorch/imagen_pytorch.py ADDED
@@ -0,0 +1,2515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ from random import random
4
+ from typing import List, Union
5
+ from tqdm.auto import tqdm
6
+ from functools import partial, wraps
7
+ from contextlib import contextmanager, nullcontext
8
+ from collections import namedtuple
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.nn.parallel import DistributedDataParallel
14
+ from torch import nn, einsum
15
+ from torch.cuda.amp import autocast
16
+ from torch.special import expm1
17
+ import torchvision.transforms as T
18
+
19
+ import kornia.augmentation as K
20
+
21
+ from einops import rearrange, repeat, reduce
22
+ from einops.layers.torch import Rearrange, Reduce
23
+ from einops_exts import rearrange_many, repeat_many, check_shape
24
+ from einops_exts.torch import EinopsToAndFrom
25
+
26
+ from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
27
+
28
+ from imagen_pytorch.imagen_video.imagen_video import Unet3D, resize_video_to
29
+
30
+ # helper functions
31
+
32
+ def exists(val):
33
+ return val is not None
34
+
35
+ def identity(t, *args, **kwargs):
36
+ return t
37
+
38
+ def first(arr, d = None):
39
+ if len(arr) == 0:
40
+ return d
41
+ return arr[0]
42
+
43
+ def maybe(fn):
44
+ @wraps(fn)
45
+ def inner(x):
46
+ if not exists(x):
47
+ return x
48
+ return fn(x)
49
+ return inner
50
+
51
+ def once(fn):
52
+ called = False
53
+ @wraps(fn)
54
+ def inner(x):
55
+ nonlocal called
56
+ if called:
57
+ return
58
+ called = True
59
+ return fn(x)
60
+ return inner
61
+
62
+ print_once = once(print)
63
+
64
+ def default(val, d):
65
+ if exists(val):
66
+ return val
67
+ return d() if callable(d) else d
68
+
69
+ def cast_tuple(val, length = None):
70
+ if isinstance(val, list):
71
+ val = tuple(val)
72
+
73
+ output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
74
+
75
+ if exists(length):
76
+ assert len(output) == length
77
+
78
+ return output
79
+
80
+ def is_float_dtype(dtype):
81
+ return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
82
+
83
+ def cast_uint8_images_to_float(images):
84
+ if not images.dtype == torch.uint8:
85
+ return images
86
+ return images / 255
87
+
88
+ def module_device(module):
89
+ return next(module.parameters()).device
90
+
91
+ def zero_init_(m):
92
+ nn.init.zeros_(m.weight)
93
+ if exists(m.bias):
94
+ nn.init.zeros_(m.bias)
95
+
96
+ def eval_decorator(fn):
97
+ def inner(model, *args, **kwargs):
98
+ was_training = model.training
99
+ model.eval()
100
+ out = fn(model, *args, **kwargs)
101
+ model.train(was_training)
102
+ return out
103
+ return inner
104
+
105
+ def pad_tuple_to_length(t, length, fillvalue = None):
106
+ remain_length = length - len(t)
107
+ if remain_length <= 0:
108
+ return t
109
+ return (*t, *((fillvalue,) * remain_length))
110
+
111
+ # helper classes
112
+
113
+ class Identity(nn.Module):
114
+ def __init__(self, *args, **kwargs):
115
+ super().__init__()
116
+
117
+ def forward(self, x, *args, **kwargs):
118
+ return x
119
+
120
+ # tensor helpers
121
+
122
+ def log(t, eps: float = 1e-12):
123
+ return torch.log(t.clamp(min = eps))
124
+
125
+ def l2norm(t):
126
+ return F.normalize(t, dim = -1)
127
+
128
+ def right_pad_dims_to(x, t):
129
+ padding_dims = x.ndim - t.ndim
130
+ if padding_dims <= 0:
131
+ return t
132
+ return t.view(*t.shape, *((1,) * padding_dims))
133
+
134
+ def masked_mean(t, *, dim, mask = None):
135
+ if not exists(mask):
136
+ return t.mean(dim = dim)
137
+
138
+ denom = mask.sum(dim = dim, keepdim = True)
139
+ mask = rearrange(mask, 'b n -> b n 1')
140
+ masked_t = t.masked_fill(~mask, 0.)
141
+
142
+ return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
143
+
144
+ def resize_image_to(
145
+ image,
146
+ target_image_size,
147
+ clamp_range = None
148
+ ):
149
+ orig_image_size = image.shape[-1]
150
+
151
+ if orig_image_size == target_image_size:
152
+ return image
153
+
154
+ out = F.interpolate(image, target_image_size, mode = 'nearest')
155
+
156
+ if exists(clamp_range):
157
+ out = out.clamp(*clamp_range)
158
+
159
+ return out
160
+
161
+ # image normalization functions
162
+ # ddpms expect images to be in the range of -1 to 1
163
+
164
+ def normalize_neg_one_to_one(img):
165
+ return img * 2 - 1
166
+
167
+ def unnormalize_zero_to_one(normed_img):
168
+ return (normed_img + 1) * 0.5
169
+
170
+ # classifier free guidance functions
171
+
172
+ def prob_mask_like(shape, prob, device):
173
+ if prob == 1:
174
+ return torch.ones(shape, device = device, dtype = torch.bool)
175
+ elif prob == 0:
176
+ return torch.zeros(shape, device = device, dtype = torch.bool)
177
+ else:
178
+ return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
179
+
180
+ # gaussian diffusion with continuous time helper functions and classes
181
+ # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
182
+
183
+ @torch.jit.script
184
+ def beta_linear_log_snr(t):
185
+ return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
186
+
187
+ @torch.jit.script
188
+ def alpha_cosine_log_snr(t, s: float = 0.008):
189
+ return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
190
+
191
+ def log_snr_to_alpha_sigma(log_snr):
192
+ return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
193
+
194
+ class GaussianDiffusionContinuousTimes(nn.Module):
195
+ def __init__(self, *, noise_schedule, timesteps = 1000):
196
+ super().__init__()
197
+
198
+ if noise_schedule == "linear":
199
+ self.log_snr = beta_linear_log_snr
200
+ elif noise_schedule == "cosine":
201
+ self.log_snr = alpha_cosine_log_snr
202
+ else:
203
+ raise ValueError(f'invalid noise schedule {noise_schedule}')
204
+
205
+ self.num_timesteps = timesteps
206
+
207
+ def get_times(self, batch_size, noise_level, *, device):
208
+ return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
209
+
210
+ def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
211
+ return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)
212
+
213
+ def get_condition(self, times):
214
+ return maybe(self.log_snr)(times)
215
+
216
+ def get_sampling_timesteps(self, batch, *, device):
217
+ times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
218
+ times = repeat(times, 't -> b t', b = batch)
219
+ times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
220
+ times = times.unbind(dim = -1)
221
+ return times
222
+
223
+ def q_posterior(self, x_start, x_t, t, *, t_next = None):
224
+ t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
225
+
226
+ """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
227
+ log_snr = self.log_snr(t)
228
+ log_snr_next = self.log_snr(t_next)
229
+ log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
230
+
231
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr)
232
+ alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
233
+
234
+ # c - as defined near eq 33
235
+ c = -expm1(log_snr - log_snr_next)
236
+ posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
237
+
238
+ # following (eq. 33)
239
+ posterior_variance = (sigma_next ** 2) * c
240
+ posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
241
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
242
+
243
+ def q_sample(self, x_start, t, noise = None):
244
+ dtype = x_start.dtype
245
+
246
+ if isinstance(t, float):
247
+ batch = x_start.shape[0]
248
+ t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
249
+
250
+ noise = default(noise, lambda: torch.randn_like(x_start))
251
+ log_snr = self.log_snr(t).type(dtype)
252
+ log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
253
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
254
+
255
+ return alpha * x_start + sigma * noise, log_snr
256
+
257
+ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
258
+ shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
259
+ batch = shape[0]
260
+
261
+ if isinstance(from_t, float):
262
+ from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
263
+
264
+ if isinstance(to_t, float):
265
+ to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
266
+
267
+ noise = default(noise, lambda: torch.randn_like(x_from))
268
+
269
+ log_snr = self.log_snr(from_t)
270
+ log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
271
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
272
+
273
+ log_snr_to = self.log_snr(to_t)
274
+ log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
275
+ alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
276
+
277
+ return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
278
+
279
+ def predict_start_from_noise(self, x_t, t, noise):
280
+ log_snr = self.log_snr(t)
281
+ log_snr = right_pad_dims_to(x_t, log_snr)
282
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr)
283
+ return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
284
+
285
+ # norms and residuals
286
+
287
+ class LayerNorm(nn.Module):
288
+ def __init__(self, feats, stable = False, dim = -1):
289
+ super().__init__()
290
+ self.stable = stable
291
+ self.dim = dim
292
+
293
+ self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))
294
+
295
+ def forward(self, x):
296
+ dtype, dim = x.dtype, self.dim
297
+
298
+ if self.stable:
299
+ x = x / x.amax(dim = dim, keepdim = True).detach()
300
+
301
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
302
+ var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
303
+ mean = torch.mean(x, dim = dim, keepdim = True)
304
+
305
+ return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)
306
+
307
+ ChanLayerNorm = partial(LayerNorm, dim = -3)
308
+
309
+ class Always():
310
+ def __init__(self, val):
311
+ self.val = val
312
+
313
+ def __call__(self, *args, **kwargs):
314
+ return self.val
315
+
316
+ class Residual(nn.Module):
317
+ def __init__(self, fn):
318
+ super().__init__()
319
+ self.fn = fn
320
+
321
+ def forward(self, x, **kwargs):
322
+ return self.fn(x, **kwargs) + x
323
+
324
+ class Parallel(nn.Module):
325
+ def __init__(self, *fns):
326
+ super().__init__()
327
+ self.fns = nn.ModuleList(fns)
328
+
329
+ def forward(self, x):
330
+ outputs = [fn(x) for fn in self.fns]
331
+ return sum(outputs)
332
+
333
+ # attention pooling
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(
337
+ self,
338
+ *,
339
+ dim,
340
+ dim_head = 64,
341
+ heads = 8,
342
+ cosine_sim_attn = False
343
+ ):
344
+ super().__init__()
345
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
346
+ self.cosine_sim_attn = cosine_sim_attn
347
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
348
+
349
+ self.heads = heads
350
+ inner_dim = dim_head * heads
351
+
352
+ self.norm = nn.LayerNorm(dim)
353
+ self.norm_latents = nn.LayerNorm(dim)
354
+
355
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
356
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
357
+
358
+ self.to_out = nn.Sequential(
359
+ nn.Linear(inner_dim, dim, bias = False),
360
+ nn.LayerNorm(dim)
361
+ )
362
+
363
+ def forward(self, x, latents, mask = None):
364
+ x = self.norm(x)
365
+ latents = self.norm_latents(latents)
366
+
367
+ b, h = x.shape[0], self.heads
368
+
369
+ q = self.to_q(latents)
370
+
371
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
372
+ kv_input = torch.cat((x, latents), dim = -2)
373
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
374
+
375
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
376
+
377
+ q = q * self.scale
378
+
379
+ # cosine sim attention
380
+
381
+ if self.cosine_sim_attn:
382
+ q, k = map(l2norm, (q, k))
383
+
384
+ # similarities and masking
385
+
386
+ sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
387
+
388
+ if exists(mask):
389
+ max_neg_value = -torch.finfo(sim.dtype).max
390
+ mask = F.pad(mask, (0, latents.shape[-2]), value = True)
391
+ mask = rearrange(mask, 'b j -> b 1 1 j')
392
+ sim = sim.masked_fill(~mask, max_neg_value)
393
+
394
+ # attention
395
+
396
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
397
+ attn = attn.to(sim.dtype)
398
+
399
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
400
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
401
+ return self.to_out(out)
402
+
403
+ class PerceiverResampler(nn.Module):
404
+ def __init__(
405
+ self,
406
+ *,
407
+ dim,
408
+ depth,
409
+ dim_head = 64,
410
+ heads = 8,
411
+ num_latents = 64,
412
+ num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
413
+ max_seq_len = 512,
414
+ ff_mult = 4,
415
+ cosine_sim_attn = False
416
+ ):
417
+ super().__init__()
418
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
419
+
420
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
421
+
422
+ self.to_latents_from_mean_pooled_seq = None
423
+
424
+ if num_latents_mean_pooled > 0:
425
+ self.to_latents_from_mean_pooled_seq = nn.Sequential(
426
+ LayerNorm(dim),
427
+ nn.Linear(dim, dim * num_latents_mean_pooled),
428
+ Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
429
+ )
430
+
431
+ self.layers = nn.ModuleList([])
432
+ for _ in range(depth):
433
+ self.layers.append(nn.ModuleList([
434
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
435
+ FeedForward(dim = dim, mult = ff_mult)
436
+ ]))
437
+
438
+ def forward(self, x, mask = None):
439
+ n, device = x.shape[1], x.device
440
+ pos_emb = self.pos_emb(torch.arange(n, device = device))
441
+
442
+ x_with_pos = x + pos_emb
443
+
444
+ latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
445
+
446
+ if exists(self.to_latents_from_mean_pooled_seq):
447
+ meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
448
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
449
+ latents = torch.cat((meanpooled_latents, latents), dim = -2)
450
+
451
+ for attn, ff in self.layers:
452
+ latents = attn(x_with_pos, latents, mask = mask) + latents
453
+ latents = ff(latents) + latents
454
+
455
+ return latents
456
+
457
+ # attention
458
+
459
+ class Attention(nn.Module):
460
+ def __init__(
461
+ self,
462
+ dim,
463
+ *,
464
+ dim_head = 64,
465
+ heads = 8,
466
+ context_dim = None,
467
+ cosine_sim_attn = False
468
+ ):
469
+ super().__init__()
470
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
471
+ self.cosine_sim_attn = cosine_sim_attn
472
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
473
+
474
+ self.heads = heads
475
+ inner_dim = dim_head * heads
476
+
477
+ self.norm = LayerNorm(dim)
478
+
479
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
480
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
481
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
482
+
483
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
484
+
485
+ self.to_out = nn.Sequential(
486
+ nn.Linear(inner_dim, dim, bias = False),
487
+ LayerNorm(dim)
488
+ )
489
+
490
+ def forward(self, x, context = None, mask = None, attn_bias = None):
491
+ b, n, device = *x.shape[:2], x.device
492
+
493
+ x = self.norm(x)
494
+
495
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
496
+
497
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
498
+ q = q * self.scale
499
+
500
+ # add null key / value for classifier free guidance in prior net
501
+
502
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
503
+ k = torch.cat((nk, k), dim = -2)
504
+ v = torch.cat((nv, v), dim = -2)
505
+
506
+ # add text conditioning, if present
507
+
508
+ if exists(context):
509
+ assert exists(self.to_context)
510
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
511
+ k = torch.cat((ck, k), dim = -2)
512
+ v = torch.cat((cv, v), dim = -2)
513
+
514
+ # cosine sim attention
515
+
516
+ if self.cosine_sim_attn:
517
+ q, k = map(l2norm, (q, k))
518
+
519
+ # calculate query / key similarities
520
+
521
+ sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
522
+
523
+ # relative positional encoding (T5 style)
524
+
525
+ if exists(attn_bias):
526
+ sim = sim + attn_bias
527
+
528
+ # masking
529
+
530
+ max_neg_value = -torch.finfo(sim.dtype).max
531
+
532
+ if exists(mask):
533
+ mask = F.pad(mask, (1, 0), value = True)
534
+ mask = rearrange(mask, 'b j -> b 1 1 j')
535
+ sim = sim.masked_fill(~mask, max_neg_value)
536
+
537
+ # attention
538
+
539
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
540
+ attn = attn.to(sim.dtype)
541
+
542
+ # aggregate values
543
+
544
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
545
+
546
+ out = rearrange(out, 'b h n d -> b n (h d)')
547
+ return self.to_out(out)
548
+
549
+ # decoder
550
+
551
+ def Upsample(dim, dim_out = None):
552
+ dim_out = default(dim_out, dim)
553
+
554
+ return nn.Sequential(
555
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
556
+ nn.Conv2d(dim, dim_out, 3, padding = 1)
557
+ )
558
+
559
+ class PixelShuffleUpsample(nn.Module):
560
+ """
561
+ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
562
+ https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
563
+ """
564
+ def __init__(self, dim, dim_out = None):
565
+ super().__init__()
566
+ dim_out = default(dim_out, dim)
567
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
568
+
569
+ self.net = nn.Sequential(
570
+ conv,
571
+ nn.SiLU(),
572
+ nn.PixelShuffle(2)
573
+ )
574
+
575
+ self.init_conv_(conv)
576
+
577
+ def init_conv_(self, conv):
578
+ o, i, h, w = conv.weight.shape
579
+ conv_weight = torch.empty(o // 4, i, h, w)
580
+ nn.init.kaiming_uniform_(conv_weight)
581
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
582
+
583
+ conv.weight.data.copy_(conv_weight)
584
+ nn.init.zeros_(conv.bias.data)
585
+
586
+ def forward(self, x):
587
+ return self.net(x)
588
+
589
+ def Downsample(dim, dim_out = None):
590
+ # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
591
+ # named SP-conv in the paper, but basically a pixel unshuffle
592
+ dim_out = default(dim_out, dim)
593
+ return nn.Sequential(
594
+ Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
595
+ nn.Conv2d(dim * 4, dim_out, 1)
596
+ )
597
+
598
+ class SinusoidalPosEmb(nn.Module):
599
+ def __init__(self, dim):
600
+ super().__init__()
601
+ self.dim = dim
602
+
603
+ def forward(self, x):
604
+ half_dim = self.dim // 2
605
+ emb = math.log(10000) / (half_dim - 1)
606
+ emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
607
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
608
+ return torch.cat((emb.sin(), emb.cos()), dim = -1)
609
+
610
+ class LearnedSinusoidalPosEmb(nn.Module):
611
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
612
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
613
+
614
+ def __init__(self, dim):
615
+ super().__init__()
616
+ assert (dim % 2) == 0
617
+ half_dim = dim // 2
618
+ self.weights = nn.Parameter(torch.randn(half_dim))
619
+
620
+ def forward(self, x):
621
+ x = rearrange(x, 'b -> b 1')
622
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
623
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
624
+ fouriered = torch.cat((x, fouriered), dim = -1)
625
+ return fouriered
626
+
627
+ class Block(nn.Module):
628
+ def __init__(
629
+ self,
630
+ dim,
631
+ dim_out,
632
+ groups = 8,
633
+ norm = True
634
+ ):
635
+ super().__init__()
636
+ self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
637
+ self.activation = nn.SiLU()
638
+ self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
639
+
640
+ def forward(self, x, scale_shift = None):
641
+ x = self.groupnorm(x)
642
+
643
+ if exists(scale_shift):
644
+ scale, shift = scale_shift
645
+ x = x * (scale + 1) + shift
646
+
647
+ x = self.activation(x)
648
+ return self.project(x)
649
+
650
+ class ResnetBlock(nn.Module):
651
+ def __init__(
652
+ self,
653
+ dim,
654
+ dim_out,
655
+ *,
656
+ cond_dim = None,
657
+ time_cond_dim = None,
658
+ groups = 8,
659
+ linear_attn = False,
660
+ use_gca = False,
661
+ squeeze_excite = False,
662
+ **attn_kwargs
663
+ ):
664
+ super().__init__()
665
+
666
+ self.time_mlp = None
667
+
668
+ if exists(time_cond_dim):
669
+ self.time_mlp = nn.Sequential(
670
+ nn.SiLU(),
671
+ nn.Linear(time_cond_dim, dim_out * 2)
672
+ )
673
+
674
+ self.cross_attn = None
675
+
676
+ if exists(cond_dim):
677
+ attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
678
+
679
+ self.cross_attn = EinopsToAndFrom(
680
+ 'b c h w',
681
+ 'b (h w) c',
682
+ attn_klass(
683
+ dim = dim_out,
684
+ context_dim = cond_dim,
685
+ **attn_kwargs
686
+ )
687
+ )
688
+
689
+ self.block1 = Block(dim, dim_out, groups = groups)
690
+ self.block2 = Block(dim_out, dim_out, groups = groups)
691
+
692
+ self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
693
+
694
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
695
+
696
+
697
+ def forward(self, x, time_emb = None, cond = None):
698
+
699
+ scale_shift = None
700
+ if exists(self.time_mlp) and exists(time_emb):
701
+ time_emb = self.time_mlp(time_emb)
702
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
703
+ scale_shift = time_emb.chunk(2, dim = 1)
704
+
705
+ h = self.block1(x)
706
+
707
+ if exists(self.cross_attn):
708
+ assert exists(cond)
709
+ h = self.cross_attn(h, context = cond) + h
710
+
711
+ h = self.block2(h, scale_shift = scale_shift)
712
+
713
+ h = h * self.gca(h)
714
+
715
+ return h + self.res_conv(x)
716
+
717
+ class CrossAttention(nn.Module):
718
+ def __init__(
719
+ self,
720
+ dim,
721
+ *,
722
+ context_dim = None,
723
+ dim_head = 64,
724
+ heads = 8,
725
+ norm_context = False,
726
+ cosine_sim_attn = False
727
+ ):
728
+ super().__init__()
729
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
730
+ self.cosine_sim_attn = cosine_sim_attn
731
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
732
+
733
+ self.heads = heads
734
+ inner_dim = dim_head * heads
735
+
736
+ context_dim = default(context_dim, dim)
737
+
738
+ self.norm = LayerNorm(dim)
739
+ self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
740
+
741
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
742
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
743
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
744
+
745
+ self.to_out = nn.Sequential(
746
+ nn.Linear(inner_dim, dim, bias = False),
747
+ LayerNorm(dim)
748
+ )
749
+
750
+ def forward(self, x, context, mask = None):
751
+ b, n, device = *x.shape[:2], x.device
752
+
753
+ x = self.norm(x)
754
+ context = self.norm_context(context)
755
+
756
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
757
+
758
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
759
+
760
+ # add null key / value for classifier free guidance in prior net
761
+
762
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
763
+
764
+ k = torch.cat((nk, k), dim = -2)
765
+ v = torch.cat((nv, v), dim = -2)
766
+
767
+ q = q * self.scale
768
+
769
+ # cosine sim attention
770
+
771
+ if self.cosine_sim_attn:
772
+ q, k = map(l2norm, (q, k))
773
+
774
+ # similarities
775
+
776
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
777
+
778
+ # masking
779
+
780
+ max_neg_value = -torch.finfo(sim.dtype).max
781
+
782
+ if exists(mask):
783
+ mask = F.pad(mask, (1, 0), value = True)
784
+ mask = rearrange(mask, 'b j -> b 1 1 j')
785
+ sim = sim.masked_fill(~mask, max_neg_value)
786
+
787
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
788
+ attn = attn.to(sim.dtype)
789
+
790
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
791
+ out = rearrange(out, 'b h n d -> b n (h d)')
792
+ return self.to_out(out)
793
+
794
+ class LinearCrossAttention(CrossAttention):
795
+ def forward(self, x, context, mask = None):
796
+ b, n, device = *x.shape[:2], x.device
797
+
798
+ x = self.norm(x)
799
+ context = self.norm_context(context)
800
+
801
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
802
+
803
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
804
+
805
+ # add null key / value for classifier free guidance in prior net
806
+
807
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b)
808
+
809
+ k = torch.cat((nk, k), dim = -2)
810
+ v = torch.cat((nv, v), dim = -2)
811
+
812
+ # masking
813
+
814
+ max_neg_value = -torch.finfo(x.dtype).max
815
+
816
+ if exists(mask):
817
+ mask = F.pad(mask, (1, 0), value = True)
818
+ mask = rearrange(mask, 'b n -> b n 1')
819
+ k = k.masked_fill(~mask, max_neg_value)
820
+ v = v.masked_fill(~mask, 0.)
821
+
822
+ # linear attention
823
+
824
+ q = q.softmax(dim = -1)
825
+ k = k.softmax(dim = -2)
826
+
827
+ q = q * self.scale
828
+
829
+ context = einsum('b n d, b n e -> b d e', k, v)
830
+ out = einsum('b n d, b d e -> b n e', q, context)
831
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
832
+ return self.to_out(out)
833
+
834
+ class LinearAttention(nn.Module):
835
+ def __init__(
836
+ self,
837
+ dim,
838
+ dim_head = 32,
839
+ heads = 8,
840
+ dropout = 0.05,
841
+ context_dim = None,
842
+ **kwargs
843
+ ):
844
+ super().__init__()
845
+ self.scale = dim_head ** -0.5
846
+ self.heads = heads
847
+ inner_dim = dim_head * heads
848
+ self.norm = ChanLayerNorm(dim)
849
+
850
+ self.nonlin = nn.SiLU()
851
+
852
+ self.to_q = nn.Sequential(
853
+ nn.Dropout(dropout),
854
+ nn.Conv2d(dim, inner_dim, 1, bias = False),
855
+ nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
856
+ )
857
+
858
+ self.to_k = nn.Sequential(
859
+ nn.Dropout(dropout),
860
+ nn.Conv2d(dim, inner_dim, 1, bias = False),
861
+ nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
862
+ )
863
+
864
+ self.to_v = nn.Sequential(
865
+ nn.Dropout(dropout),
866
+ nn.Conv2d(dim, inner_dim, 1, bias = False),
867
+ nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
868
+ )
869
+
870
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
871
+
872
+ self.to_out = nn.Sequential(
873
+ nn.Conv2d(inner_dim, dim, 1, bias = False),
874
+ ChanLayerNorm(dim)
875
+ )
876
+
877
+ def forward(self, fmap, context = None):
878
+ h, x, y = self.heads, *fmap.shape[-2:]
879
+
880
+ fmap = self.norm(fmap)
881
+ q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
882
+ q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
883
+
884
+ if exists(context):
885
+ assert exists(self.to_context)
886
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
887
+ ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
888
+ k = torch.cat((k, ck), dim = -2)
889
+ v = torch.cat((v, cv), dim = -2)
890
+
891
+ q = q.softmax(dim = -1)
892
+ k = k.softmax(dim = -2)
893
+
894
+ q = q * self.scale
895
+
896
+ context = einsum('b n d, b n e -> b d e', k, v)
897
+ out = einsum('b n d, b d e -> b n e', q, context)
898
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
899
+
900
+ out = self.nonlin(out)
901
+ return self.to_out(out)
902
+
903
+ class GlobalContext(nn.Module):
904
+ """ basically a superior form of squeeze-excitation that is attention-esque """
905
+
906
+ def __init__(
907
+ self,
908
+ *,
909
+ dim_in,
910
+ dim_out
911
+ ):
912
+ super().__init__()
913
+ self.to_k = nn.Conv2d(dim_in, 1, 1)
914
+ hidden_dim = max(3, dim_out // 2)
915
+
916
+ self.net = nn.Sequential(
917
+ nn.Conv2d(dim_in, hidden_dim, 1),
918
+ nn.SiLU(),
919
+ nn.Conv2d(hidden_dim, dim_out, 1),
920
+ nn.Sigmoid()
921
+ )
922
+
923
+ def forward(self, x):
924
+ context = self.to_k(x)
925
+ x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
926
+ out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
927
+ out = rearrange(out, '... -> ... 1')
928
+ return self.net(out)
929
+
930
+ def FeedForward(dim, mult = 2):
931
+ hidden_dim = int(dim * mult)
932
+ return nn.Sequential(
933
+ LayerNorm(dim),
934
+ nn.Linear(dim, hidden_dim, bias = False),
935
+ nn.GELU(),
936
+ LayerNorm(hidden_dim),
937
+ nn.Linear(hidden_dim, dim, bias = False)
938
+ )
939
+
940
+ def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
941
+ hidden_dim = int(dim * mult)
942
+ return nn.Sequential(
943
+ ChanLayerNorm(dim),
944
+ nn.Conv2d(dim, hidden_dim, 1, bias = False),
945
+ nn.GELU(),
946
+ ChanLayerNorm(hidden_dim),
947
+ nn.Conv2d(hidden_dim, dim, 1, bias = False)
948
+ )
949
+
950
+ class TransformerBlock(nn.Module):
951
+ def __init__(
952
+ self,
953
+ dim,
954
+ *,
955
+ depth = 1,
956
+ heads = 8,
957
+ dim_head = 32,
958
+ ff_mult = 2,
959
+ context_dim = None,
960
+ cosine_sim_attn = False
961
+ ):
962
+ super().__init__()
963
+ self.layers = nn.ModuleList([])
964
+
965
+ for _ in range(depth):
966
+ self.layers.append(nn.ModuleList([
967
+ EinopsToAndFrom('b c h w', 'b (h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
968
+ ChanFeedForward(dim = dim, mult = ff_mult)
969
+ ]))
970
+
971
+ def forward(self, x, context = None):
972
+ for attn, ff in self.layers:
973
+ x = attn(x, context = context) + x
974
+ x = ff(x) + x
975
+ return x
976
+
977
+ class LinearAttentionTransformerBlock(nn.Module):
978
+ def __init__(
979
+ self,
980
+ dim,
981
+ *,
982
+ depth = 1,
983
+ heads = 8,
984
+ dim_head = 32,
985
+ ff_mult = 2,
986
+ context_dim = None,
987
+ **kwargs
988
+ ):
989
+ super().__init__()
990
+ self.layers = nn.ModuleList([])
991
+
992
+ for _ in range(depth):
993
+ self.layers.append(nn.ModuleList([
994
+ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
995
+ ChanFeedForward(dim = dim, mult = ff_mult)
996
+ ]))
997
+
998
+ def forward(self, x, context = None):
999
+ for attn, ff in self.layers:
1000
+ x = attn(x, context = context) + x
1001
+ x = ff(x) + x
1002
+ return x
1003
+
1004
+ class CrossEmbedLayer(nn.Module):
1005
+ def __init__(
1006
+ self,
1007
+ dim_in,
1008
+ kernel_sizes,
1009
+ dim_out = None,
1010
+ stride = 2
1011
+ ):
1012
+ super().__init__()
1013
+ assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
1014
+ dim_out = default(dim_out, dim_in)
1015
+
1016
+ kernel_sizes = sorted(kernel_sizes)
1017
+ num_scales = len(kernel_sizes)
1018
+
1019
+ # calculate the dimension at each scale
1020
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
1021
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
1022
+
1023
+ self.convs = nn.ModuleList([])
1024
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
1025
+ self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
1026
+
1027
+ def forward(self, x):
1028
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
1029
+ return torch.cat(fmaps, dim = 1)
1030
+
1031
+ class UpsampleCombiner(nn.Module):
1032
+ def __init__(
1033
+ self,
1034
+ dim,
1035
+ *,
1036
+ enabled = False,
1037
+ dim_ins = tuple(),
1038
+ dim_outs = tuple()
1039
+ ):
1040
+ super().__init__()
1041
+ dim_outs = cast_tuple(dim_outs, len(dim_ins))
1042
+ assert len(dim_ins) == len(dim_outs)
1043
+
1044
+ self.enabled = enabled
1045
+
1046
+ if not self.enabled:
1047
+ self.dim_out = dim
1048
+ return
1049
+
1050
+ self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
1051
+ self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
1052
+
1053
+ def forward(self, x, fmaps = None):
1054
+ target_size = x.shape[-1]
1055
+
1056
+ fmaps = default(fmaps, tuple())
1057
+
1058
+ if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
1059
+ return x
1060
+
1061
+ fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
1062
+ outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
1063
+ return torch.cat((x, *outs), dim = 1)
1064
+
1065
+ class Unet(nn.Module):
1066
+ def __init__(
1067
+ self,
1068
+ *,
1069
+ dim,
1070
+ image_embed_dim = 1024,
1071
+ text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
1072
+ num_resnet_blocks = 1,
1073
+ cond_dim = None,
1074
+ num_image_tokens = 4,
1075
+ num_time_tokens = 2,
1076
+ learned_sinu_pos_emb_dim = 16,
1077
+ out_dim = None,
1078
+ dim_mults=(1, 2, 4, 8),
1079
+ cond_images_channels = 0,
1080
+ channels = 3,
1081
+ channels_out = None,
1082
+ attn_dim_head = 64,
1083
+ attn_heads = 8,
1084
+ ff_mult = 2.,
1085
+ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
1086
+ layer_attns = True,
1087
+ layer_attns_depth = 1,
1088
+ layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
1089
+ attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
1090
+ layer_cross_attns = True,
1091
+ use_linear_attn = False,
1092
+ use_linear_cross_attn = False,
1093
+ cond_on_text = True,
1094
+ max_text_len = 256,
1095
+ init_dim = None,
1096
+ resnet_groups = 8,
1097
+ init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
1098
+ init_cross_embed = True,
1099
+ init_cross_embed_kernel_sizes = (3, 7, 15),
1100
+ cross_embed_downsample = False,
1101
+ cross_embed_downsample_kernel_sizes = (2, 4),
1102
+ attn_pool_text = True,
1103
+ attn_pool_num_latents = 32,
1104
+ dropout = 0.,
1105
+ memory_efficient = False,
1106
+ init_conv_to_final_conv_residual = False,
1107
+ use_global_context_attn = True,
1108
+ scale_skip_connection = True,
1109
+ final_resnet_block = True,
1110
+ final_conv_kernel_size = 3,
1111
+ cosine_sim_attn = False,
1112
+ self_cond = False,
1113
+ combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
1114
+ pixel_shuffle_upsample = True # may address checkboard artifacts
1115
+ ):
1116
+ super().__init__()
1117
+
1118
+ # guide researchers
1119
+
1120
+ assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
1121
+
1122
+ if dim < 128:
1123
+ print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
1124
+
1125
+ # save locals to take care of some hyperparameters for cascading DDPM
1126
+
1127
+ self._locals = locals()
1128
+ self._locals.pop('self', None)
1129
+ self._locals.pop('__class__', None)
1130
+
1131
+ # determine dimensions
1132
+
1133
+ self.channels = channels
1134
+ self.channels_out = default(channels_out, channels)
1135
+
1136
+ # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
1137
+ # (2) in self conditioning, one appends the predict x0 (x_start)
1138
+ init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
1139
+ init_dim = default(init_dim, dim)
1140
+
1141
+ self.self_cond = self_cond
1142
+
1143
+ # optional image conditioning
1144
+
1145
+ self.has_cond_image = cond_images_channels > 0
1146
+ self.cond_images_channels = cond_images_channels
1147
+
1148
+ init_channels += cond_images_channels
1149
+
1150
+ # initial convolution
1151
+
1152
+ self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
1153
+
1154
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
1155
+ in_out = list(zip(dims[:-1], dims[1:]))
1156
+
1157
+ # time conditioning
1158
+
1159
+ cond_dim = default(cond_dim, dim)
1160
+ time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
1161
+
1162
+ # embedding time for log(snr) noise from continuous version
1163
+
1164
+ sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
1165
+ sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
1166
+
1167
+ self.to_time_hiddens = nn.Sequential(
1168
+ sinu_pos_emb,
1169
+ nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
1170
+ nn.SiLU()
1171
+ )
1172
+
1173
+ self.to_time_cond = nn.Sequential(
1174
+ nn.Linear(time_cond_dim, time_cond_dim)
1175
+ )
1176
+
1177
+ # project to time tokens as well as time hiddens
1178
+
1179
+ self.to_time_tokens = nn.Sequential(
1180
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1181
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1182
+ )
1183
+
1184
+ # low res aug noise conditioning
1185
+
1186
+ self.lowres_cond = lowres_cond
1187
+
1188
+ if lowres_cond:
1189
+ self.to_lowres_time_hiddens = nn.Sequential(
1190
+ LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
1191
+ nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
1192
+ nn.SiLU()
1193
+ )
1194
+
1195
+ self.to_lowres_time_cond = nn.Sequential(
1196
+ nn.Linear(time_cond_dim, time_cond_dim)
1197
+ )
1198
+
1199
+ self.to_lowres_time_tokens = nn.Sequential(
1200
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1201
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1202
+ )
1203
+
1204
+ # normalizations
1205
+
1206
+ self.norm_cond = nn.LayerNorm(cond_dim)
1207
+
1208
+ # text encoding conditioning (optional)
1209
+
1210
+ self.text_to_cond = None
1211
+
1212
+ if cond_on_text:
1213
+ assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
1214
+ self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
1215
+
1216
+ # finer control over whether to condition on text encodings
1217
+
1218
+ self.cond_on_text = cond_on_text
1219
+
1220
+ # attention pooling
1221
+
1222
+ self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
1223
+
1224
+ # for classifier free guidance
1225
+
1226
+ self.max_text_len = max_text_len
1227
+
1228
+ self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
1229
+ self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
1230
+
1231
+ # for non-attention based text conditioning at all points in the network where time is also conditioned
1232
+
1233
+ self.to_text_non_attn_cond = None
1234
+
1235
+ if cond_on_text:
1236
+ self.to_text_non_attn_cond = nn.Sequential(
1237
+ nn.LayerNorm(cond_dim),
1238
+ nn.Linear(cond_dim, time_cond_dim),
1239
+ nn.SiLU(),
1240
+ nn.Linear(time_cond_dim, time_cond_dim)
1241
+ )
1242
+
1243
+ # attention related params
1244
+
1245
+ attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
1246
+
1247
+ num_layers = len(in_out)
1248
+
1249
+ # resnet block klass
1250
+
1251
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
1252
+ resnet_groups = cast_tuple(resnet_groups, num_layers)
1253
+
1254
+ resnet_klass = partial(ResnetBlock, **attn_kwargs)
1255
+
1256
+ layer_attns = cast_tuple(layer_attns, num_layers)
1257
+ layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
1258
+ layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
1259
+
1260
+ use_linear_attn = cast_tuple(use_linear_attn, num_layers)
1261
+ use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
1262
+
1263
+ assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
1264
+
1265
+ # downsample klass
1266
+
1267
+ downsample_klass = Downsample
1268
+
1269
+ if cross_embed_downsample:
1270
+ downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
1271
+
1272
+ # initial resnet block (for memory efficient unet)
1273
+
1274
+ self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
1275
+
1276
+ # scale for resnet skip connections
1277
+
1278
+ self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
1279
+
1280
+ # layers
1281
+
1282
+ self.downs = nn.ModuleList([])
1283
+ self.ups = nn.ModuleList([])
1284
+ num_resolutions = len(in_out)
1285
+
1286
+ layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
1287
+ reversed_layer_params = list(map(reversed, layer_params))
1288
+
1289
+ # downsampling layers
1290
+
1291
+ skip_connect_dims = [] # keep track of skip connection dimensions
1292
+
1293
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
1294
+ is_last = ind >= (num_resolutions - 1)
1295
+
1296
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1297
+
1298
+ if layer_attn:
1299
+ transformer_block_klass = TransformerBlock
1300
+ elif layer_use_linear_attn:
1301
+ transformer_block_klass = LinearAttentionTransformerBlock
1302
+ else:
1303
+ transformer_block_klass = Identity
1304
+
1305
+ current_dim = dim_in
1306
+
1307
+ # whether to pre-downsample, from memory efficient unet
1308
+
1309
+ pre_downsample = None
1310
+
1311
+ if memory_efficient:
1312
+ pre_downsample = downsample_klass(dim_in, dim_out)
1313
+ current_dim = dim_out
1314
+
1315
+ skip_connect_dims.append(current_dim)
1316
+
1317
+ # whether to do post-downsample, for non-memory efficient unet
1318
+
1319
+ post_downsample = None
1320
+ if not memory_efficient:
1321
+ post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv2d(dim_in, dim_out, 3, padding = 1), nn.Conv2d(dim_in, dim_out, 1))
1322
+
1323
+ self.downs.append(nn.ModuleList([
1324
+ pre_downsample,
1325
+ resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1326
+ nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1327
+ transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1328
+ post_downsample
1329
+ ]))
1330
+
1331
+ # middle layers
1332
+
1333
+ mid_dim = dims[-1]
1334
+
1335
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1336
+ self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
1337
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1338
+
1339
+ # upsample klass
1340
+
1341
+ upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1342
+
1343
+ # upsampling layers
1344
+
1345
+ upsample_fmap_dims = []
1346
+
1347
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
1348
+ is_last = ind == (len(in_out) - 1)
1349
+
1350
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1351
+
1352
+ if layer_attn:
1353
+ transformer_block_klass = TransformerBlock
1354
+ elif layer_use_linear_attn:
1355
+ transformer_block_klass = LinearAttentionTransformerBlock
1356
+ else:
1357
+ transformer_block_klass = Identity
1358
+
1359
+ skip_connect_dim = skip_connect_dims.pop()
1360
+
1361
+ upsample_fmap_dims.append(dim_out)
1362
+
1363
+ self.ups.append(nn.ModuleList([
1364
+ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1365
+ nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1366
+ transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1367
+ upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
1368
+ ]))
1369
+
1370
+ # whether to combine feature maps from all upsample blocks before final resnet block out
1371
+
1372
+ self.upsample_combiner = UpsampleCombiner(
1373
+ dim = dim,
1374
+ enabled = combine_upsample_fmaps,
1375
+ dim_ins = upsample_fmap_dims,
1376
+ dim_outs = dim
1377
+ )
1378
+
1379
+ # whether to do a final residual from initial conv to the final resnet block out
1380
+
1381
+ self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
1382
+ final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
1383
+
1384
+ # final optional resnet block and convolution out
1385
+
1386
+ self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
1387
+
1388
+ final_conv_dim_in = dim if final_resnet_block else final_conv_dim
1389
+ final_conv_dim_in += (channels if lowres_cond else 0)
1390
+
1391
+ self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
1392
+
1393
+ zero_init_(self.final_conv)
1394
+
1395
+ # if the current settings for the unet are not correct
1396
+ # for cascading DDPM, then reinit the unet with the right settings
1397
+ def cast_model_parameters(
1398
+ self,
1399
+ *,
1400
+ lowres_cond,
1401
+ text_embed_dim,
1402
+ channels,
1403
+ channels_out,
1404
+ cond_on_text
1405
+ ):
1406
+ if lowres_cond == self.lowres_cond and \
1407
+ channels == self.channels and \
1408
+ cond_on_text == self.cond_on_text and \
1409
+ text_embed_dim == self._locals['text_embed_dim'] and \
1410
+ channels_out == self.channels_out:
1411
+ return self
1412
+
1413
+ updated_kwargs = dict(
1414
+ lowres_cond = lowres_cond,
1415
+ text_embed_dim = text_embed_dim,
1416
+ channels = channels,
1417
+ channels_out = channels_out,
1418
+ cond_on_text = cond_on_text
1419
+ )
1420
+
1421
+ return self.__class__(**{**self._locals, **updated_kwargs})
1422
+
1423
+ # methods for returning the full unet config as well as its parameter state
1424
+
1425
+ def to_config_and_state_dict(self):
1426
+ return self._locals, self.state_dict()
1427
+
1428
+ # class method for rehydrating the unet from its config and state dict
1429
+
1430
+ @classmethod
1431
+ def from_config_and_state_dict(klass, config, state_dict):
1432
+ unet = klass(**config)
1433
+ unet.load_state_dict(state_dict)
1434
+ return unet
1435
+
1436
+ # methods for persisting unet to disk
1437
+
1438
+ def persist_to_file(self, path):
1439
+ path = Path(path)
1440
+ path.parents[0].mkdir(exist_ok = True, parents = True)
1441
+
1442
+ config, state_dict = self.to_config_and_state_dict()
1443
+ pkg = dict(config = config, state_dict = state_dict)
1444
+ torch.save(pkg, str(path))
1445
+
1446
+ # class method for rehydrating the unet from file saved with `persist_to_file`
1447
+
1448
+ @classmethod
1449
+ def hydrate_from_file(klass, path):
1450
+ path = Path(path)
1451
+ assert path.exists()
1452
+ pkg = torch.load(str(path))
1453
+
1454
+ assert 'config' in pkg and 'state_dict' in pkg
1455
+ config, state_dict = pkg['config'], pkg['state_dict']
1456
+
1457
+ return Unet.from_config_and_state_dict(config, state_dict)
1458
+
1459
+ # forward with classifier free guidance
1460
+
1461
+ def forward_with_cond_scale(
1462
+ self,
1463
+ *args,
1464
+ cond_scale = 1.,
1465
+ **kwargs
1466
+ ):
1467
+ logits = self.forward(*args, **kwargs)
1468
+
1469
+ if cond_scale == 1:
1470
+ return logits
1471
+
1472
+ null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1473
+ return null_logits + (logits - null_logits) * cond_scale
1474
+
1475
+ def forward(
1476
+ self,
1477
+ x,
1478
+ time,
1479
+ *,
1480
+ lowres_cond_img = None,
1481
+ lowres_noise_times = None,
1482
+ text_embeds = None,
1483
+ text_mask = None,
1484
+ cond_images = None,
1485
+ self_cond = None,
1486
+ cond_drop_prob = 0.
1487
+ ):
1488
+ batch_size, device = x.shape[0], x.device
1489
+
1490
+ # condition on self
1491
+
1492
+ if self.self_cond:
1493
+ self_cond = default(self_cond, lambda: torch.zeros_like(x))
1494
+ x = torch.cat((x, self_cond), dim = 1)
1495
+
1496
+ # add low resolution conditioning, if present
1497
+
1498
+ assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
1499
+ assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
1500
+
1501
+ if exists(lowres_cond_img):
1502
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1503
+
1504
+ # condition on input image
1505
+
1506
+ assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
1507
+
1508
+ if exists(cond_images):
1509
+ assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
1510
+ cond_images = resize_image_to(cond_images, x.shape[-1])
1511
+ x = torch.cat((cond_images, x), dim = 1)
1512
+
1513
+ # initial convolution
1514
+
1515
+ x = self.init_conv(x)
1516
+
1517
+ # init conv residual
1518
+
1519
+ if self.init_conv_to_final_conv_residual:
1520
+ init_conv_residual = x.clone()
1521
+
1522
+ # time conditioning
1523
+
1524
+ time_hiddens = self.to_time_hiddens(time)
1525
+
1526
+ # derive time tokens
1527
+
1528
+ time_tokens = self.to_time_tokens(time_hiddens)
1529
+ t = self.to_time_cond(time_hiddens)
1530
+
1531
+ # add lowres time conditioning to time hiddens
1532
+ # and add lowres time tokens along sequence dimension for attention
1533
+
1534
+ if self.lowres_cond:
1535
+ lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
1536
+ lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
1537
+ lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
1538
+
1539
+ t = t + lowres_t
1540
+ time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
1541
+
1542
+ # text conditioning
1543
+
1544
+ text_tokens = None
1545
+
1546
+ if exists(text_embeds) and self.cond_on_text:
1547
+
1548
+ # conditional dropout
1549
+
1550
+ text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1551
+
1552
+ text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
1553
+ text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
1554
+
1555
+ # calculate text embeds
1556
+
1557
+ text_tokens = self.text_to_cond(text_embeds)
1558
+
1559
+ text_tokens = text_tokens[:, :self.max_text_len]
1560
+
1561
+ if exists(text_mask):
1562
+ text_mask = text_mask[:, :self.max_text_len]
1563
+
1564
+ text_tokens_len = text_tokens.shape[1]
1565
+ remainder = self.max_text_len - text_tokens_len
1566
+
1567
+ if remainder > 0:
1568
+ text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
1569
+
1570
+ if exists(text_mask):
1571
+ if remainder > 0:
1572
+ text_mask = F.pad(text_mask, (0, remainder), value = False)
1573
+
1574
+ text_mask = rearrange(text_mask, 'b n -> b n 1')
1575
+ text_keep_mask_embed = text_mask & text_keep_mask_embed
1576
+
1577
+ null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
1578
+
1579
+ text_tokens = torch.where(
1580
+ text_keep_mask_embed,
1581
+ text_tokens,
1582
+ null_text_embed
1583
+ )
1584
+
1585
+ if exists(self.attn_pool):
1586
+ text_tokens = self.attn_pool(text_tokens)
1587
+
1588
+ # extra non-attention conditioning by projecting and then summing text embeddings to time
1589
+ # termed as text hiddens
1590
+
1591
+ mean_pooled_text_tokens = text_tokens.mean(dim = -2)
1592
+
1593
+ text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
1594
+
1595
+ null_text_hidden = self.null_text_hidden.to(t.dtype)
1596
+
1597
+ text_hiddens = torch.where(
1598
+ text_keep_mask_hidden,
1599
+ text_hiddens,
1600
+ null_text_hidden
1601
+ )
1602
+
1603
+ t = t + text_hiddens
1604
+
1605
+ # main conditioning tokens (c)
1606
+
1607
+ c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
1608
+
1609
+ # normalize conditioning tokens
1610
+
1611
+ c = self.norm_cond(c)
1612
+
1613
+ # initial resnet block (for memory efficient unet)
1614
+
1615
+ if exists(self.init_resnet_block):
1616
+ x = self.init_resnet_block(x, t)
1617
+
1618
+ # go through the layers of the unet, down and up
1619
+
1620
+ hiddens = []
1621
+
1622
+ for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
1623
+ if exists(pre_downsample):
1624
+ x = pre_downsample(x)
1625
+
1626
+ x = init_block(x, t, c)
1627
+
1628
+ for resnet_block in resnet_blocks:
1629
+ x = resnet_block(x, t)
1630
+ hiddens.append(x)
1631
+
1632
+ x = attn_block(x, c)
1633
+ hiddens.append(x)
1634
+
1635
+ if exists(post_downsample):
1636
+ x = post_downsample(x)
1637
+
1638
+ x = self.mid_block1(x, t, c)
1639
+
1640
+ if exists(self.mid_attn):
1641
+ x = self.mid_attn(x)
1642
+
1643
+ x = self.mid_block2(x, t, c)
1644
+
1645
+ add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
1646
+
1647
+ up_hiddens = []
1648
+
1649
+ for init_block, resnet_blocks, attn_block, upsample in self.ups:
1650
+ x = add_skip_connection(x)
1651
+ x = init_block(x, t, c)
1652
+
1653
+ for resnet_block in resnet_blocks:
1654
+ x = add_skip_connection(x)
1655
+ x = resnet_block(x, t)
1656
+
1657
+ x = attn_block(x, c)
1658
+ up_hiddens.append(x.contiguous())
1659
+ x = upsample(x)
1660
+
1661
+ # whether to combine all feature maps from upsample blocks
1662
+
1663
+ x = self.upsample_combiner(x, up_hiddens)
1664
+
1665
+ # final top-most residual if needed
1666
+
1667
+ if self.init_conv_to_final_conv_residual:
1668
+ x = torch.cat((x, init_conv_residual), dim = 1)
1669
+
1670
+ if exists(self.final_res_block):
1671
+ x = self.final_res_block(x, t)
1672
+
1673
+ if exists(lowres_cond_img):
1674
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1675
+
1676
+ return self.final_conv(x)
1677
+
1678
+ # null unet
1679
+
1680
+ class NullUnet(nn.Module):
1681
+ def __init__(self, *args, **kwargs):
1682
+ super().__init__()
1683
+ self.lowres_cond = False
1684
+ self.dummy_parameter = nn.Parameter(torch.tensor([0.]))
1685
+
1686
+ def cast_model_parameters(self, *args, **kwargs):
1687
+ return self
1688
+
1689
+ def forward(self, x, *args, **kwargs):
1690
+ return x
1691
+
1692
+ # predefined unets, with configs lining up with hyperparameters in appendix of paper
1693
+
1694
+ class BaseUnet64(Unet):
1695
+ def __init__(self, *args, **kwargs):
1696
+ default_kwargs = dict(
1697
+ dim = 512,
1698
+ dim_mults = (1, 2, 3, 4),
1699
+ num_resnet_blocks = 3,
1700
+ layer_attns = (False, True, True, True),
1701
+ layer_cross_attns = (False, True, True, True),
1702
+ attn_heads = 8,
1703
+ ff_mult = 2.,
1704
+ memory_efficient = False
1705
+ )
1706
+ super().__init__(*args, **{**default_kwargs, **kwargs})
1707
+
1708
+ class SRUnet256(Unet):
1709
+ def __init__(self, *args, **kwargs):
1710
+ default_kwargs = dict(
1711
+ dim = 128,
1712
+ dim_mults = (1, 2, 4, 8),
1713
+ num_resnet_blocks = (2, 4, 8, 8),
1714
+ layer_attns = (False, False, False, True),
1715
+ layer_cross_attns = (False, False, False, True),
1716
+ attn_heads = 8,
1717
+ ff_mult = 2.,
1718
+ memory_efficient = True
1719
+ )
1720
+ super().__init__(*args, **{**default_kwargs, **kwargs})
1721
+
1722
+ class SRUnet1024(Unet):
1723
+ def __init__(self, *args, **kwargs):
1724
+ default_kwargs = dict(
1725
+ dim = 128,
1726
+ dim_mults = (1, 2, 4, 8),
1727
+ num_resnet_blocks = (2, 4, 8, 8),
1728
+ layer_attns = False,
1729
+ layer_cross_attns = (False, False, False, True),
1730
+ attn_heads = 8,
1731
+ ff_mult = 2.,
1732
+ memory_efficient = True
1733
+ )
1734
+ super().__init__(*args, **{**default_kwargs, **kwargs})
1735
+
1736
+ # main imagen ddpm class, which is a cascading DDPM from Ho et al.
1737
+
1738
+ class Imagen(nn.Module):
1739
+ def __init__(
1740
+ self,
1741
+ unets,
1742
+ *,
1743
+ image_sizes, # for cascading ddpm, image size at each stage
1744
+ text_encoder_name = DEFAULT_T5_NAME,
1745
+ text_embed_dim = None,
1746
+ channels = 3,
1747
+ timesteps = 1000,
1748
+ sample_timesteps=100,
1749
+ cond_drop_prob = 0.1,
1750
+ loss_type = 'l2',
1751
+ noise_schedules = 'cosine',
1752
+ pred_objectives = 'noise',
1753
+ random_crop_sizes = None,
1754
+ lowres_noise_schedule = 'linear',
1755
+ lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
1756
+ per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
1757
+ condition_on_text = True,
1758
+ auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
1759
+ p2_loss_weight_gamma = 0.5, # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
1760
+ p2_loss_weight_k = 1,
1761
+ dynamic_thresholding = True,
1762
+ dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
1763
+ only_train_unet_number = None
1764
+ ):
1765
+ super().__init__()
1766
+
1767
+ # loss
1768
+
1769
+ if loss_type == 'l1':
1770
+ loss_fn = F.l1_loss
1771
+ elif loss_type == 'l2':
1772
+ loss_fn = F.mse_loss
1773
+ elif loss_type == 'huber':
1774
+ loss_fn = F.smooth_l1_loss
1775
+ else:
1776
+ raise NotImplementedError()
1777
+
1778
+ self.loss_type = loss_type
1779
+ self.loss_fn = loss_fn
1780
+
1781
+ # conditioning hparams
1782
+
1783
+ self.condition_on_text = condition_on_text
1784
+ self.unconditional = not condition_on_text
1785
+
1786
+ # channels
1787
+
1788
+ self.channels = channels
1789
+
1790
+ # automatically take care of ensuring that first unet is unconditional
1791
+ # while the rest of the unets are conditioned on the low resolution image produced by previous unet
1792
+
1793
+ unets = cast_tuple(unets)
1794
+ num_unets = len(unets)
1795
+
1796
+ # determine noise schedules per unet
1797
+
1798
+ timesteps = cast_tuple(timesteps, num_unets)
1799
+ sample_timesteps = cast_tuple(sample_timesteps, num_unets)
1800
+
1801
+ # make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets
1802
+
1803
+ noise_schedules = cast_tuple(noise_schedules)
1804
+ noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
1805
+ noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
1806
+
1807
+ # construct noise schedulers
1808
+
1809
+ noise_scheduler_klass = GaussianDiffusionContinuousTimes
1810
+ self.noise_schedulers = nn.ModuleList([])
1811
+
1812
+ for timestep, noise_schedule in zip(timesteps, noise_schedules):
1813
+ noise_scheduler = noise_scheduler_klass(noise_schedule = noise_schedule, timesteps = timestep)
1814
+ self.noise_schedulers.append(noise_scheduler)
1815
+
1816
+ self.noise_schedulers_sample = nn.ModuleList([])
1817
+
1818
+ for sample_timestep, noise_schedule in zip(sample_timesteps, noise_schedules):
1819
+ noise_scheduler_sample = noise_scheduler_klass(noise_schedule=noise_schedule, timesteps=sample_timestep)
1820
+ self.noise_schedulers_sample.append(noise_scheduler_sample)
1821
+
1822
+ # randomly cropping for upsampler training
1823
+
1824
+ self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
1825
+ assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
1826
+
1827
+ # lowres augmentation noise schedule
1828
+
1829
+ self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
1830
+
1831
+ # ddpm objectives - predicting noise by default
1832
+
1833
+ self.pred_objectives = cast_tuple(pred_objectives, num_unets)
1834
+
1835
+ # get text encoder
1836
+
1837
+ self.text_encoder_name = text_encoder_name
1838
+ self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
1839
+
1840
+ self.encode_text = partial(t5_encode_text, name = text_encoder_name)
1841
+
1842
+ # construct unets
1843
+
1844
+ self.unets = nn.ModuleList([])
1845
+
1846
+ self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
1847
+ self.only_train_unet_number = only_train_unet_number
1848
+
1849
+ for ind, one_unet in enumerate(unets):
1850
+ assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
1851
+ is_first = ind == 0
1852
+
1853
+ one_unet = one_unet.cast_model_parameters(
1854
+ lowres_cond = not is_first,
1855
+ cond_on_text = self.condition_on_text,
1856
+ text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
1857
+ channels = self.channels,
1858
+ channels_out = self.channels
1859
+ )
1860
+
1861
+ self.unets.append(one_unet)
1862
+
1863
+ # unet image sizes
1864
+
1865
+ image_sizes = cast_tuple(image_sizes)
1866
+ self.image_sizes = image_sizes
1867
+
1868
+ assert num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {image_sizes}'
1869
+
1870
+ self.sample_channels = cast_tuple(self.channels, num_unets)
1871
+
1872
+ # determine whether we are training on images or video
1873
+
1874
+ is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
1875
+ self.is_video = is_video
1876
+
1877
+ self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
1878
+ self.resize_to = resize_video_to if is_video else resize_image_to
1879
+
1880
+ # cascading ddpm related stuff
1881
+
1882
+ lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
1883
+ assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
1884
+
1885
+ self.lowres_sample_noise_level = lowres_sample_noise_level
1886
+ self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
1887
+
1888
+ # classifier free guidance
1889
+
1890
+ self.cond_drop_prob = cond_drop_prob
1891
+ self.can_classifier_guidance = cond_drop_prob > 0.
1892
+
1893
+ # normalize and unnormalize image functions
1894
+
1895
+ self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
1896
+ self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
1897
+ self.input_image_range = (0. if auto_normalize_img else -1., 1.)
1898
+
1899
+ # dynamic thresholding
1900
+
1901
+ self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
1902
+ self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
1903
+
1904
+ # p2 loss weight
1905
+
1906
+ self.p2_loss_weight_k = p2_loss_weight_k
1907
+ self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
1908
+
1909
+ assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), 'in paper, they noticed any gamma greater than 2 is harmful'
1910
+
1911
+ # one temp parameter for keeping track of device
1912
+
1913
+ self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
1914
+
1915
+ # default to device of unets passed in
1916
+
1917
+ self.to(next(self.unets.parameters()).device)
1918
+
1919
+ def force_unconditional_(self):
1920
+ self.condition_on_text = False
1921
+ self.unconditional = True
1922
+
1923
+ for unet in self.unets:
1924
+ unet.cond_on_text = False
1925
+
1926
+ @property
1927
+ def device(self):
1928
+ return self._temp.device
1929
+
1930
+ def get_unet(self, unet_number):
1931
+ assert 0 < unet_number <= len(self.unets)
1932
+ index = unet_number - 1
1933
+
1934
+ if isinstance(self.unets, nn.ModuleList):
1935
+ unets_list = [unet for unet in self.unets]
1936
+ delattr(self, 'unets')
1937
+ self.unets = unets_list
1938
+
1939
+ if index != self.unet_being_trained_index:
1940
+ for unet_index, unet in enumerate(self.unets):
1941
+ unet.to(self.device if unet_index == index else 'cpu')
1942
+
1943
+ self.unet_being_trained_index = index
1944
+ return self.unets[index]
1945
+
1946
+ def reset_unets_all_one_device(self, device = None):
1947
+ device = default(device, self.device)
1948
+ self.unets = nn.ModuleList([*self.unets])
1949
+ self.unets.to(device)
1950
+
1951
+ self.unet_being_trained_index = -1
1952
+
1953
+ @contextmanager
1954
+ def one_unet_in_gpu(self, unet_number = None, unet = None):
1955
+ assert exists(unet_number) ^ exists(unet)
1956
+
1957
+ if exists(unet_number):
1958
+ unet = self.unets[unet_number - 1]
1959
+
1960
+ devices = [module_device(unet) for unet in self.unets]
1961
+ self.unets.cpu()
1962
+ unet.to(self.device)
1963
+
1964
+ yield
1965
+
1966
+ for unet, device in zip(self.unets, devices):
1967
+ unet.to(device)
1968
+
1969
+ # overriding state dict functions
1970
+
1971
+ def state_dict(self, *args, **kwargs):
1972
+ self.reset_unets_all_one_device()
1973
+ return super().state_dict(*args, **kwargs)
1974
+
1975
+ def load_state_dict(self, *args, **kwargs):
1976
+ self.reset_unets_all_one_device()
1977
+ return super().load_state_dict(*args, **kwargs)
1978
+
1979
+ # gaussian diffusion methods
1980
+
1981
+ def p_mean_variance(
1982
+ self,
1983
+ unet,
1984
+ x,
1985
+ t,
1986
+ *,
1987
+ noise_scheduler,
1988
+ text_embeds = None,
1989
+ text_mask = None,
1990
+ cond_images = None,
1991
+ lowres_cond_img = None,
1992
+ self_cond = None,
1993
+ lowres_noise_times = None,
1994
+ cond_scale = 1.,
1995
+ model_output = None,
1996
+ t_next = None,
1997
+ pred_objective = 'noise',
1998
+ dynamic_threshold = True
1999
+ ):
2000
+ assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
2001
+
2002
+ pred = default(model_output, lambda: unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times)))
2003
+
2004
+ if pred_objective == 'noise':
2005
+ x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
2006
+ elif pred_objective == 'x_start':
2007
+ x_start = pred
2008
+ else:
2009
+ raise ValueError(f'unknown objective {pred_objective}')
2010
+
2011
+ if dynamic_threshold:
2012
+ # following pseudocode in appendix
2013
+ # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
2014
+ s = torch.quantile(
2015
+ rearrange(x_start, 'b ... -> b (...)').abs(),
2016
+ self.dynamic_thresholding_percentile,
2017
+ dim = -1
2018
+ )
2019
+
2020
+ s.clamp_(min = 1.)
2021
+ s = right_pad_dims_to(x_start, s)
2022
+ x_start = x_start.clamp(-s, s) / s
2023
+ else:
2024
+ x_start.clamp_(-1., 1.)
2025
+
2026
+ mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
2027
+ return mean_and_variance, x_start
2028
+
2029
+ @torch.no_grad()
2030
+ def p_sample(
2031
+ self,
2032
+ unet,
2033
+ x,
2034
+ t,
2035
+ *,
2036
+ noise_scheduler,
2037
+ t_next = None,
2038
+ text_embeds = None,
2039
+ text_mask = None,
2040
+ cond_images = None,
2041
+ cond_scale = 1.,
2042
+ self_cond = None,
2043
+ lowres_cond_img = None,
2044
+ lowres_noise_times = None,
2045
+ pred_objective = 'noise',
2046
+ dynamic_threshold = True
2047
+ ):
2048
+ b, *_, device = *x.shape, x.device
2049
+ (model_mean, _, model_log_variance), x_start = self.p_mean_variance(unet, x = x, t = t, t_next = t_next, noise_scheduler = noise_scheduler, text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = lowres_noise_times, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)
2050
+ noise = torch.randn_like(x)
2051
+ # no noise when t == 0
2052
+ is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
2053
+ nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
2054
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
2055
+ return pred, x_start
2056
+
2057
+ @torch.no_grad()
2058
+ def p_sample_loop(
2059
+ self,
2060
+ unet,
2061
+ shape,
2062
+ *,
2063
+ noise_scheduler,
2064
+ lowres_cond_img = None,
2065
+ lowres_noise_times = None,
2066
+ text_embeds = None,
2067
+ text_mask = None,
2068
+ cond_images = None,
2069
+ inpaint_images = None,
2070
+ inpaint_masks = None,
2071
+ inpaint_resample_times = 5,
2072
+ init_images = None,
2073
+ skip_steps = None,
2074
+ cond_scale = 1,
2075
+ pred_objective = 'noise',
2076
+ dynamic_threshold = True,
2077
+ use_tqdm = True
2078
+ ):
2079
+ device = self.device
2080
+
2081
+ batch = shape[0]
2082
+ img = torch.randn(shape, device = device)
2083
+
2084
+ # for initialization with an image or video
2085
+
2086
+ if exists(init_images):
2087
+ img += init_images
2088
+
2089
+ # keep track of x0, for self conditioning
2090
+
2091
+ x_start = None
2092
+
2093
+ # prepare inpainting
2094
+
2095
+ has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
2096
+ resample_times = inpaint_resample_times if has_inpainting else 1
2097
+
2098
+ if has_inpainting:
2099
+ inpaint_images = self.normalize_img(inpaint_images)
2100
+ inpaint_images = self.resize_to(inpaint_images, shape[-1])
2101
+ inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()
2102
+
2103
+ # time
2104
+
2105
+ timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)
2106
+
2107
+ # whether to skip any steps
2108
+
2109
+ skip_steps = default(skip_steps, 0)
2110
+ timesteps = timesteps[skip_steps:]
2111
+
2112
+ for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
2113
+ is_last_timestep = times_next == 0
2114
+
2115
+ for r in reversed(range(resample_times)):
2116
+ is_last_resample_step = r == 0
2117
+
2118
+ if has_inpainting:
2119
+ noised_inpaint_images, _ = noise_scheduler.q_sample(inpaint_images, t = times)
2120
+ img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks
2121
+
2122
+ self_cond = x_start if unet.self_cond else None
2123
+
2124
+ img, x_start = self.p_sample(
2125
+ unet,
2126
+ img,
2127
+ times,
2128
+ t_next = times_next,
2129
+ text_embeds = text_embeds,
2130
+ text_mask = text_mask,
2131
+ cond_images = cond_images,
2132
+ cond_scale = cond_scale,
2133
+ self_cond = self_cond,
2134
+ lowres_cond_img = lowres_cond_img,
2135
+ lowres_noise_times = lowres_noise_times,
2136
+ noise_scheduler = noise_scheduler,
2137
+ pred_objective = pred_objective,
2138
+ dynamic_threshold = dynamic_threshold
2139
+ )
2140
+
2141
+ if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
2142
+ renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
2143
+
2144
+ img = torch.where(
2145
+ self.right_pad_dims_to_datatype(is_last_timestep),
2146
+ img,
2147
+ renoised_img
2148
+ )
2149
+
2150
+ img.clamp_(-1., 1.)
2151
+
2152
+ # final inpainting
2153
+
2154
+ if has_inpainting:
2155
+ img = img * ~inpaint_masks + inpaint_images * inpaint_masks
2156
+
2157
+ unnormalize_img = self.unnormalize_img(img)
2158
+ return unnormalize_img
2159
+
2160
+ @torch.no_grad()
2161
+ @eval_decorator
2162
+ def sample(
2163
+ self,
2164
+ texts: List[str] = None,
2165
+ text_masks = None,
2166
+ text_embeds = None,
2167
+ video_frames = None,
2168
+ cond_images = None,
2169
+ inpaint_images = None,
2170
+ inpaint_masks = None,
2171
+ inpaint_resample_times = 5,
2172
+ init_images = None,
2173
+ skip_steps = None,
2174
+ batch_size = 1,
2175
+ cond_scale = 1.,
2176
+ lowres_sample_noise_level = None,
2177
+ start_at_unet_number = 1,
2178
+ start_image_or_video = None,
2179
+ stop_at_unet_number = None,
2180
+ return_all_unet_outputs = False,
2181
+ return_pil_images = False,
2182
+ device = None,
2183
+ use_tqdm = True
2184
+ ):
2185
+ device = default(device, self.device)
2186
+ self.reset_unets_all_one_device(device = device)
2187
+
2188
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
2189
+
2190
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
2191
+ assert all([*map(len, texts)]), 'text cannot be empty'
2192
+
2193
+ with autocast(enabled = False):
2194
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
2195
+
2196
+ text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
2197
+
2198
+ if not self.unconditional:
2199
+ assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
2200
+
2201
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
2202
+ batch_size = text_embeds.shape[0]
2203
+
2204
+ if exists(inpaint_images):
2205
+ if self.unconditional:
2206
+ if batch_size == 1: # assume researcher wants to broadcast along inpainted images
2207
+ batch_size = inpaint_images.shape[0]
2208
+
2209
+ assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
2210
+ assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
2211
+
2212
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
2213
+ assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
2214
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
2215
+
2216
+ assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
2217
+
2218
+ outputs = []
2219
+
2220
+ is_cuda = next(self.parameters()).is_cuda
2221
+ device = next(self.parameters()).device
2222
+
2223
+ lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
2224
+
2225
+ num_unets = len(self.unets)
2226
+
2227
+ # condition scaling
2228
+
2229
+ cond_scale = cast_tuple(cond_scale, num_unets)
2230
+
2231
+ # add frame dimension for video
2232
+
2233
+ assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
2234
+
2235
+ frame_dims = (video_frames,) if self.is_video else tuple()
2236
+
2237
+ # for initial image and skipping steps
2238
+
2239
+ init_images = cast_tuple(init_images, num_unets)
2240
+ init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
2241
+
2242
+ skip_steps = cast_tuple(skip_steps, num_unets)
2243
+
2244
+ # handle starting at a unet greater than 1, for training only-upscaler training
2245
+
2246
+ if start_at_unet_number > 1:
2247
+ assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
2248
+ assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
2249
+ assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
2250
+
2251
+ prev_image_size = self.image_sizes[start_at_unet_number - 2]
2252
+ img = self.resize_to(start_image_or_video, prev_image_size)
2253
+
2254
+ # go through each unet in cascade
2255
+
2256
+ for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers_sample, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
2257
+
2258
+ if unet_number < start_at_unet_number:
2259
+ continue
2260
+
2261
+ assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'
2262
+
2263
+ context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
2264
+
2265
+ with context:
2266
+ lowres_cond_img = lowres_noise_times = None
2267
+ shape = (batch_size, channel, *frame_dims, *image_size)
2268
+
2269
+ if unet.lowres_cond:
2270
+ lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
2271
+
2272
+ lowres_cond_img = self.resize_to(img, image_size)
2273
+
2274
+ lowres_cond_img = self.normalize_img(lowres_cond_img)
2275
+ lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
2276
+
2277
+ if exists(unet_init_images):
2278
+ unet_init_images = self.resize_to(unet_init_images, image_size)
2279
+
2280
+ shape = (batch_size, self.channels, *frame_dims, *image_size)
2281
+
2282
+ img = self.p_sample_loop(
2283
+ unet,
2284
+ shape,
2285
+ text_embeds = text_embeds,
2286
+ text_mask = text_masks,
2287
+ cond_images = cond_images,
2288
+ inpaint_images = inpaint_images,
2289
+ inpaint_masks = inpaint_masks,
2290
+ inpaint_resample_times = inpaint_resample_times,
2291
+ init_images = unet_init_images,
2292
+ skip_steps = unet_skip_steps,
2293
+ cond_scale = unet_cond_scale,
2294
+ lowres_cond_img = lowres_cond_img,
2295
+ lowres_noise_times = lowres_noise_times,
2296
+ noise_scheduler = noise_scheduler,
2297
+ pred_objective = pred_objective,
2298
+ dynamic_threshold = dynamic_threshold,
2299
+ use_tqdm = use_tqdm
2300
+ )
2301
+
2302
+ outputs.append(img)
2303
+
2304
+ if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
2305
+ break
2306
+
2307
+ output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
2308
+
2309
+ if not return_pil_images:
2310
+ return outputs[output_index]
2311
+
2312
+ if not return_all_unet_outputs:
2313
+ outputs = outputs[-1:]
2314
+
2315
+ assert not self.is_video, 'converting sampled video tensor to video file is not supported yet'
2316
+
2317
+ pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
2318
+
2319
+ return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
2320
+
2321
+ def p_losses(
2322
+ self,
2323
+ unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],
2324
+ x_start,
2325
+ times,
2326
+ *,
2327
+ noise_scheduler,
2328
+ lowres_cond_img = None,
2329
+ lowres_aug_times = None,
2330
+ text_embeds = None,
2331
+ text_mask = None,
2332
+ cond_images = None,
2333
+ noise = None,
2334
+ times_next = None,
2335
+ pred_objective = 'noise',
2336
+ p2_loss_weight_gamma = 0.,
2337
+ random_crop_size = None
2338
+ ):
2339
+ is_video = x_start.ndim == 5
2340
+
2341
+ noise = default(noise, lambda: torch.randn_like(x_start))
2342
+
2343
+ # normalize to [-1, 1]
2344
+
2345
+ x_start = self.normalize_img(x_start)
2346
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
2347
+
2348
+ # random cropping during training
2349
+ # for upsamplers
2350
+
2351
+ if exists(random_crop_size):
2352
+ if is_video:
2353
+ frames = x_start.shape[2]
2354
+ x_start, lowres_cond_img, noise = rearrange_many((x_start, lowres_cond_img, noise), 'b c f h w -> (b f) c h w')
2355
+
2356
+ aug = K.RandomCrop(random_crop_size, p = 1.)
2357
+
2358
+ # make sure low res conditioner and image both get augmented the same way
2359
+ # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
2360
+ x_start = aug(x_start)
2361
+ lowres_cond_img = aug(lowres_cond_img, params = aug._params)
2362
+ noise = aug(noise, params = aug._params)
2363
+
2364
+ if is_video:
2365
+ x_start, lowres_cond_img, noise = rearrange_many((x_start, lowres_cond_img, noise), '(b f) c h w -> b c f h w', f = frames)
2366
+
2367
+ # get x_t
2368
+
2369
+ x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
2370
+
2371
+ # also noise the lowres conditioning image
2372
+ # at sample time, they then fix the noise level of 0.1 - 0.3
2373
+
2374
+ lowres_cond_img_noisy = None
2375
+ if exists(lowres_cond_img):
2376
+ lowres_aug_times = default(lowres_aug_times, times)
2377
+ lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
2378
+
2379
+ # time condition
2380
+
2381
+ noise_cond = noise_scheduler.get_condition(times)
2382
+
2383
+ # unet kwargs
2384
+
2385
+ unet_kwargs = dict(
2386
+ text_embeds = text_embeds,
2387
+ text_mask = text_mask,
2388
+ cond_images = cond_images,
2389
+ lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
2390
+ lowres_cond_img = lowres_cond_img_noisy,
2391
+ cond_drop_prob = self.cond_drop_prob,
2392
+ )
2393
+
2394
+ # self condition if needed
2395
+
2396
+ # Because 'unet' can be an instance of DistributedDataParallel coming from the
2397
+ # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
2398
+ # access the member 'module' of the wrapped unet instance.
2399
+ self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
2400
+
2401
+ if self_cond and random() < 0.5:
2402
+ with torch.no_grad():
2403
+ pred = unet.forward(
2404
+ x_noisy,
2405
+ noise_cond,
2406
+ **unet_kwargs
2407
+ ).detach()
2408
+
2409
+ x_start = noise_scheduler.predict_start_from_noise(x_noisy, t = times, noise = pred) if pred_objective == 'noise' else pred
2410
+
2411
+ unet_kwargs = {**unet_kwargs, 'self_cond': x_start}
2412
+
2413
+ # get prediction
2414
+
2415
+ pred = unet.forward(
2416
+ x_noisy,
2417
+ noise_cond,
2418
+ **unet_kwargs
2419
+ )
2420
+
2421
+ # prediction objective
2422
+
2423
+ if pred_objective == 'noise':
2424
+ target = noise
2425
+ elif pred_objective == 'x_start':
2426
+ target = x_start
2427
+ else:
2428
+ raise ValueError(f'unknown objective {pred_objective}')
2429
+
2430
+ # losses
2431
+
2432
+ losses = self.loss_fn(pred, target, reduction = 'none')
2433
+ losses = reduce(losses, 'b ... -> b', 'mean')
2434
+
2435
+ # p2 loss reweighting
2436
+
2437
+ if p2_loss_weight_gamma > 0:
2438
+ loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
2439
+ losses = losses * loss_weight
2440
+
2441
+ return losses.mean()
2442
+
2443
+ def forward(
2444
+ self,
2445
+ images,
2446
+ unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
2447
+ texts: List[str] = None,
2448
+ text_embeds = None,
2449
+ text_masks = None,
2450
+ unet_number = None,
2451
+ cond_images = None
2452
+ ):
2453
+ # assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
2454
+ assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
2455
+ unet_number = default(unet_number, 1)
2456
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
2457
+
2458
+ images = cast_uint8_images_to_float(images)
2459
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
2460
+
2461
+ assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
2462
+
2463
+ unet_index = unet_number - 1
2464
+
2465
+ unet = default(unet, lambda: self.get_unet(unet_number))
2466
+
2467
+ assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
2468
+
2469
+ noise_scheduler = self.noise_schedulers[unet_index]
2470
+ p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
2471
+ pred_objective = self.pred_objectives[unet_index]
2472
+ target_image_size = self.image_sizes[unet_index]
2473
+ random_crop_size = self.random_crop_sizes[unet_index]
2474
+ prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
2475
+
2476
+ b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5
2477
+
2478
+ check_shape(images, 'b c ...', c = self.channels)
2479
+ assert h >= target_image_size[0] and w >= target_image_size[1]
2480
+
2481
+ frames = images.shape[2] if is_video else None
2482
+
2483
+ times = noise_scheduler.sample_random_times(b, device = device)
2484
+
2485
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
2486
+ assert all([*map(len, texts)]), 'text cannot be empty'
2487
+ assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
2488
+
2489
+ with autocast(enabled = False):
2490
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
2491
+
2492
+ text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
2493
+
2494
+ if not self.unconditional:
2495
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
2496
+
2497
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
2498
+ assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
2499
+
2500
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
2501
+
2502
+ lowres_cond_img = lowres_aug_times = None
2503
+ if exists(prev_image_size):
2504
+ lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
2505
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
2506
+
2507
+ if self.per_sample_random_aug_noise_level:
2508
+ lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device)
2509
+ else:
2510
+ lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
2511
+ lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b)
2512
+
2513
+ images = self.resize_to(images, target_image_size)
2514
+
2515
+ return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size)
imagen_pytorch/imagen_video/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from imagen_pytorch.imagen_video.imagen_video import Unet3D
imagen_pytorch/imagen_video/imagen_video.py ADDED
@@ -0,0 +1,1662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ from typing import List
4
+ from tqdm.auto import tqdm
5
+ from functools import partial, wraps
6
+ from contextlib import contextmanager, nullcontext
7
+ from collections import namedtuple
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn, einsum
13
+
14
+ from einops import rearrange, repeat, reduce
15
+ from einops.layers.torch import Rearrange, Reduce
16
+ from einops_exts import rearrange_many, repeat_many, check_shape
17
+ from einops_exts.torch import EinopsToAndFrom
18
+
19
+ from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+ def identity(t, *args, **kwargs):
27
+ return t
28
+
29
+ def first(arr, d = None):
30
+ if len(arr) == 0:
31
+ return d
32
+ return arr[0]
33
+
34
+ def maybe(fn):
35
+ @wraps(fn)
36
+ def inner(x):
37
+ if not exists(x):
38
+ return x
39
+ return fn(x)
40
+ return inner
41
+
42
+ def once(fn):
43
+ called = False
44
+ @wraps(fn)
45
+ def inner(x):
46
+ nonlocal called
47
+ if called:
48
+ return
49
+ called = True
50
+ return fn(x)
51
+ return inner
52
+
53
+ print_once = once(print)
54
+
55
+ def default(val, d):
56
+ if exists(val):
57
+ return val
58
+ return d() if callable(d) else d
59
+
60
+ def cast_tuple(val, length = None):
61
+ if isinstance(val, list):
62
+ val = tuple(val)
63
+
64
+ output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
65
+
66
+ if exists(length):
67
+ assert len(output) == length
68
+
69
+ return output
70
+
71
+ def cast_uint8_images_to_float(images):
72
+ if not images.dtype == torch.uint8:
73
+ return images
74
+ return images / 255
75
+
76
+ def module_device(module):
77
+ return next(module.parameters()).device
78
+
79
+ def zero_init_(m):
80
+ nn.init.zeros_(m.weight)
81
+ if exists(m.bias):
82
+ nn.init.zeros_(m.bias)
83
+
84
+ def eval_decorator(fn):
85
+ def inner(model, *args, **kwargs):
86
+ was_training = model.training
87
+ model.eval()
88
+ out = fn(model, *args, **kwargs)
89
+ model.train(was_training)
90
+ return out
91
+ return inner
92
+
93
+ def pad_tuple_to_length(t, length, fillvalue = None):
94
+ remain_length = length - len(t)
95
+ if remain_length <= 0:
96
+ return t
97
+ return (*t, *((fillvalue,) * remain_length))
98
+
99
+ # helper classes
100
+
101
+ class Identity(nn.Module):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__()
104
+
105
+ def forward(self, x, *args, **kwargs):
106
+ return x
107
+
108
+ # tensor helpers
109
+
110
+ def log(t, eps: float = 1e-12):
111
+ return torch.log(t.clamp(min = eps))
112
+
113
+ def l2norm(t):
114
+ return F.normalize(t, dim = -1)
115
+
116
+ def right_pad_dims_to(x, t):
117
+ padding_dims = x.ndim - t.ndim
118
+ if padding_dims <= 0:
119
+ return t
120
+ return t.view(*t.shape, *((1,) * padding_dims))
121
+
122
+ def masked_mean(t, *, dim, mask = None):
123
+ if not exists(mask):
124
+ return t.mean(dim = dim)
125
+
126
+ denom = mask.sum(dim = dim, keepdim = True)
127
+ mask = rearrange(mask, 'b n -> b n 1')
128
+ masked_t = t.masked_fill(~mask, 0.)
129
+
130
+ return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
131
+
132
+ def resize_video_to(
133
+ video,
134
+ target_image_size,
135
+ clamp_range = None
136
+ ):
137
+ orig_video_size = video.shape[-1]
138
+
139
+ if orig_video_size == target_image_size:
140
+ return video
141
+
142
+
143
+ frames = video.shape[2]
144
+ video = rearrange(video, 'b c f h w -> (b f) c h w')
145
+
146
+ out = F.interpolate(video, target_image_size, mode = 'nearest')
147
+
148
+ if exists(clamp_range):
149
+ out = out.clamp(*clamp_range)
150
+
151
+ out = rearrange(out, '(b f) c h w -> b c f h w', f = frames)
152
+
153
+ return out
154
+
155
+ # classifier free guidance functions
156
+
157
+ def prob_mask_like(shape, prob, device):
158
+ if prob == 1:
159
+ return torch.ones(shape, device = device, dtype = torch.bool)
160
+ elif prob == 0:
161
+ return torch.zeros(shape, device = device, dtype = torch.bool)
162
+ else:
163
+ return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
164
+
165
+ # norms and residuals
166
+
167
+ class LayerNorm(nn.Module):
168
+ def __init__(self, dim, stable = False):
169
+ super().__init__()
170
+ self.stable = stable
171
+ self.g = nn.Parameter(torch.ones(dim))
172
+
173
+ def forward(self, x):
174
+ if self.stable:
175
+ x = x / x.amax(dim = -1, keepdim = True).detach()
176
+
177
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
178
+ var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
179
+ mean = torch.mean(x, dim = -1, keepdim = True)
180
+ return (x - mean) * (var + eps).rsqrt() * self.g
181
+
182
+ class ChanLayerNorm(nn.Module):
183
+ def __init__(self, dim, stable = False):
184
+ super().__init__()
185
+ self.stable = stable
186
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
187
+
188
+ def forward(self, x):
189
+ if self.stable:
190
+ x = x / x.amax(dim = 1, keepdim = True).detach()
191
+
192
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
193
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
194
+ mean = torch.mean(x, dim = 1, keepdim = True)
195
+ return (x - mean) * (var + eps).rsqrt() * self.g
196
+
197
+ class Always():
198
+ def __init__(self, val):
199
+ self.val = val
200
+
201
+ def __call__(self, *args, **kwargs):
202
+ return self.val
203
+
204
+ class Residual(nn.Module):
205
+ def __init__(self, fn):
206
+ super().__init__()
207
+ self.fn = fn
208
+
209
+ def forward(self, x, **kwargs):
210
+ return self.fn(x, **kwargs) + x
211
+
212
+ class Parallel(nn.Module):
213
+ def __init__(self, *fns):
214
+ super().__init__()
215
+ self.fns = nn.ModuleList(fns)
216
+
217
+ def forward(self, x):
218
+ outputs = [fn(x) for fn in self.fns]
219
+ return sum(outputs)
220
+
221
+ # attention pooling
222
+
223
+ class PerceiverAttention(nn.Module):
224
+ def __init__(
225
+ self,
226
+ *,
227
+ dim,
228
+ dim_head = 64,
229
+ heads = 8,
230
+ cosine_sim_attn = False
231
+ ):
232
+ super().__init__()
233
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
234
+ self.cosine_sim_attn = cosine_sim_attn
235
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
236
+
237
+ self.heads = heads
238
+ inner_dim = dim_head * heads
239
+
240
+ self.norm = nn.LayerNorm(dim)
241
+ self.norm_latents = nn.LayerNorm(dim)
242
+
243
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
244
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
245
+
246
+ self.to_out = nn.Sequential(
247
+ nn.Linear(inner_dim, dim, bias = False),
248
+ nn.LayerNorm(dim)
249
+ )
250
+
251
+ def forward(self, x, latents, mask = None):
252
+ x = self.norm(x)
253
+ latents = self.norm_latents(latents)
254
+
255
+ b, h = x.shape[0], self.heads
256
+
257
+ q = self.to_q(latents)
258
+
259
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
260
+ kv_input = torch.cat((x, latents), dim = -2)
261
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
262
+
263
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
264
+
265
+ q = q * self.scale
266
+
267
+ # cosine sim attention
268
+
269
+ if self.cosine_sim_attn:
270
+ q, k = map(l2norm, (q, k))
271
+
272
+ # similarities and masking
273
+
274
+ sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
275
+
276
+ if exists(mask):
277
+ max_neg_value = -torch.finfo(sim.dtype).max
278
+ mask = F.pad(mask, (0, latents.shape[-2]), value = True)
279
+ mask = rearrange(mask, 'b j -> b 1 1 j')
280
+ sim = sim.masked_fill(~mask, max_neg_value)
281
+
282
+ # attention
283
+
284
+ attn = sim.softmax(dim = -1)
285
+
286
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
287
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
288
+ return self.to_out(out)
289
+
290
+ class PerceiverResampler(nn.Module):
291
+ def __init__(
292
+ self,
293
+ *,
294
+ dim,
295
+ depth,
296
+ dim_head = 64,
297
+ heads = 8,
298
+ num_latents = 64,
299
+ num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
300
+ max_seq_len = 512,
301
+ ff_mult = 4,
302
+ cosine_sim_attn = False
303
+ ):
304
+ super().__init__()
305
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
306
+
307
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
308
+
309
+ self.to_latents_from_mean_pooled_seq = None
310
+
311
+ if num_latents_mean_pooled > 0:
312
+ self.to_latents_from_mean_pooled_seq = nn.Sequential(
313
+ LayerNorm(dim),
314
+ nn.Linear(dim, dim * num_latents_mean_pooled),
315
+ Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
316
+ )
317
+
318
+ self.layers = nn.ModuleList([])
319
+ for _ in range(depth):
320
+ self.layers.append(nn.ModuleList([
321
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
322
+ FeedForward(dim = dim, mult = ff_mult)
323
+ ]))
324
+
325
+ def forward(self, x, mask = None):
326
+ n, device = x.shape[1], x.device
327
+ pos_emb = self.pos_emb(torch.arange(n, device = device))
328
+
329
+ x_with_pos = x + pos_emb
330
+
331
+ latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
332
+
333
+ if exists(self.to_latents_from_mean_pooled_seq):
334
+ meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
335
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
336
+ latents = torch.cat((meanpooled_latents, latents), dim = -2)
337
+
338
+ for attn, ff in self.layers:
339
+ latents = attn(x_with_pos, latents, mask = mask) + latents
340
+ latents = ff(latents) + latents
341
+
342
+ return latents
343
+
344
+ # attention
345
+
346
+ class Attention(nn.Module):
347
+ def __init__(
348
+ self,
349
+ dim,
350
+ *,
351
+ dim_head = 64,
352
+ heads = 8,
353
+ causal = False,
354
+ context_dim = None,
355
+ cosine_sim_attn = False
356
+ ):
357
+ super().__init__()
358
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
359
+ self.causal = causal
360
+
361
+ self.cosine_sim_attn = cosine_sim_attn
362
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
363
+
364
+ self.heads = heads
365
+ inner_dim = dim_head * heads
366
+
367
+ self.norm = LayerNorm(dim)
368
+
369
+ self.null_attn_bias = nn.Parameter(torch.randn(heads))
370
+
371
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
372
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
373
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
374
+
375
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
376
+
377
+ self.to_out = nn.Sequential(
378
+ nn.Linear(inner_dim, dim, bias = False),
379
+ LayerNorm(dim)
380
+ )
381
+
382
+ def forward(self, x, context = None, mask = None, attn_bias = None):
383
+ b, n, device = *x.shape[:2], x.device
384
+
385
+ x = self.norm(x)
386
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
387
+
388
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
389
+ q = q * self.scale
390
+
391
+ # add null key / value for classifier free guidance in prior net
392
+
393
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
394
+ k = torch.cat((nk, k), dim = -2)
395
+ v = torch.cat((nv, v), dim = -2)
396
+
397
+ # add text conditioning, if present
398
+
399
+ if exists(context):
400
+ assert exists(self.to_context)
401
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
402
+ k = torch.cat((ck, k), dim = -2)
403
+ v = torch.cat((cv, v), dim = -2)
404
+
405
+ # cosine sim attention
406
+
407
+ if self.cosine_sim_attn:
408
+ q, k = map(l2norm, (q, k))
409
+
410
+ # calculate query / key similarities
411
+
412
+ sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
413
+
414
+ # relative positional encoding (T5 style)
415
+
416
+ if exists(attn_bias):
417
+ null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
418
+ attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
419
+ sim = sim + attn_bias
420
+
421
+ # masking
422
+
423
+ max_neg_value = -torch.finfo(sim.dtype).max
424
+
425
+ if self.causal:
426
+ i, j = sim.shape[-2:]
427
+ causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
428
+ sim = sim.masked_fill(causal_mask, max_neg_value)
429
+
430
+ if exists(mask):
431
+ mask = F.pad(mask, (1, 0), value = True)
432
+ mask = rearrange(mask, 'b j -> b 1 1 j')
433
+ sim = sim.masked_fill(~mask, max_neg_value)
434
+
435
+ # attention
436
+
437
+ attn = sim.softmax(dim = -1)
438
+
439
+ # aggregate values
440
+
441
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
442
+
443
+ out = rearrange(out, 'b h n d -> b n (h d)')
444
+ return self.to_out(out)
445
+
446
+ # pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension
447
+
448
+ def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
449
+ kernel = cast_tuple(kernel, 2)
450
+ stride = cast_tuple(stride, 2)
451
+ padding = cast_tuple(padding, 2)
452
+
453
+ if len(kernel) == 2:
454
+ kernel = (1, *kernel)
455
+
456
+ if len(stride) == 2:
457
+ stride = (1, *stride)
458
+
459
+ if len(padding) == 2:
460
+ padding = (0, *padding)
461
+
462
+ return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
463
+
464
+ class Pad(nn.Module):
465
+ def __init__(self, padding, value = 0.):
466
+ super().__init__()
467
+ self.padding = padding
468
+ self.value = value
469
+
470
+ def forward(self, x):
471
+ return F.pad(x, self.padding, value = self.value)
472
+
473
+ # decoder
474
+
475
+ def Upsample(dim, dim_out = None):
476
+ dim_out = default(dim_out, dim)
477
+
478
+ return nn.Sequential(
479
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
480
+ Conv2d(dim, dim_out, 3, padding = 1)
481
+ )
482
+
483
+ class PixelShuffleUpsample(nn.Module):
484
+ def __init__(self, dim, dim_out = None):
485
+ super().__init__()
486
+ dim_out = default(dim_out, dim)
487
+ conv = Conv2d(dim, dim_out * 4, 1)
488
+
489
+ self.net = nn.Sequential(
490
+ conv,
491
+ nn.SiLU()
492
+ )
493
+
494
+ self.pixel_shuffle = nn.PixelShuffle(2)
495
+
496
+ self.init_conv_(conv)
497
+
498
+ def init_conv_(self, conv):
499
+ o, i, f, h, w = conv.weight.shape
500
+ conv_weight = torch.empty(o // 4, i, f, h, w)
501
+ nn.init.kaiming_uniform_(conv_weight)
502
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
503
+
504
+ conv.weight.data.copy_(conv_weight)
505
+ nn.init.zeros_(conv.bias.data)
506
+
507
+ def forward(self, x):
508
+ out = self.net(x)
509
+ frames = x.shape[2]
510
+ out = rearrange(out, 'b c f h w -> (b f) c h w')
511
+ out = self.pixel_shuffle(out)
512
+ return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
513
+
514
+ def Downsample(dim, dim_out = None):
515
+ dim_out = default(dim_out, dim)
516
+ return Conv2d(dim, dim_out, 4, 2, 1)
517
+
518
+ class SinusoidalPosEmb(nn.Module):
519
+ def __init__(self, dim):
520
+ super().__init__()
521
+ self.dim = dim
522
+
523
+ def forward(self, x):
524
+ half_dim = self.dim // 2
525
+ emb = math.log(10000) / (half_dim - 1)
526
+ emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
527
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
528
+ return torch.cat((emb.sin(), emb.cos()), dim = -1)
529
+
530
+ class LearnedSinusoidalPosEmb(nn.Module):
531
+ def __init__(self, dim):
532
+ super().__init__()
533
+ assert (dim % 2) == 0
534
+ half_dim = dim // 2
535
+ self.weights = nn.Parameter(torch.randn(half_dim))
536
+
537
+ def forward(self, x):
538
+ x = rearrange(x, 'b -> b 1')
539
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
540
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
541
+ fouriered = torch.cat((x, fouriered), dim = -1)
542
+ return fouriered
543
+
544
+ class Block(nn.Module):
545
+ def __init__(
546
+ self,
547
+ dim,
548
+ dim_out,
549
+ groups = 8,
550
+ norm = True
551
+ ):
552
+ super().__init__()
553
+ self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
554
+ self.activation = nn.SiLU()
555
+ self.project = Conv2d(dim, dim_out, 3, padding = 1)
556
+
557
+ def forward(self, x, scale_shift = None):
558
+ x = self.groupnorm(x)
559
+
560
+ if exists(scale_shift):
561
+ scale, shift = scale_shift
562
+ x = x * (scale + 1) + shift
563
+
564
+ x = self.activation(x)
565
+ return self.project(x)
566
+
567
+ class ResnetBlock(nn.Module):
568
+ def __init__(
569
+ self,
570
+ dim,
571
+ dim_out,
572
+ *,
573
+ cond_dim = None,
574
+ time_cond_dim = None,
575
+ groups = 8,
576
+ linear_attn = False,
577
+ use_gca = False,
578
+ squeeze_excite = False,
579
+ **attn_kwargs
580
+ ):
581
+ super().__init__()
582
+
583
+ self.time_mlp = None
584
+
585
+ if exists(time_cond_dim):
586
+ self.time_mlp = nn.Sequential(
587
+ nn.SiLU(),
588
+ nn.Linear(time_cond_dim, dim_out * 2)
589
+ )
590
+
591
+ self.cross_attn = None
592
+
593
+ if exists(cond_dim):
594
+ attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
595
+
596
+ self.cross_attn = EinopsToAndFrom(
597
+ 'b c f h w',
598
+ 'b (f h w) c',
599
+ attn_klass(
600
+ dim = dim_out,
601
+ context_dim = cond_dim,
602
+ **attn_kwargs
603
+ )
604
+ )
605
+
606
+ self.block1 = Block(dim, dim_out, groups = groups)
607
+ self.block2 = Block(dim_out, dim_out, groups = groups)
608
+
609
+ self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
610
+
611
+ self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
612
+
613
+
614
+ def forward(self, x, time_emb = None, cond = None):
615
+
616
+ scale_shift = None
617
+ if exists(self.time_mlp) and exists(time_emb):
618
+ time_emb = self.time_mlp(time_emb)
619
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
620
+ scale_shift = time_emb.chunk(2, dim = 1)
621
+
622
+ h = self.block1(x)
623
+
624
+ if exists(self.cross_attn):
625
+ assert exists(cond)
626
+ h = self.cross_attn(h, context = cond) + h
627
+
628
+ h = self.block2(h, scale_shift = scale_shift)
629
+
630
+ h = h * self.gca(h)
631
+
632
+ return h + self.res_conv(x)
633
+
634
+ class CrossAttention(nn.Module):
635
+ def __init__(
636
+ self,
637
+ dim,
638
+ *,
639
+ context_dim = None,
640
+ dim_head = 64,
641
+ heads = 8,
642
+ norm_context = False,
643
+ cosine_sim_attn = False
644
+ ):
645
+ super().__init__()
646
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
647
+ self.cosine_sim_attn = cosine_sim_attn
648
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
649
+
650
+ self.heads = heads
651
+ inner_dim = dim_head * heads
652
+
653
+ context_dim = default(context_dim, dim)
654
+
655
+ self.norm = LayerNorm(dim)
656
+ self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
657
+
658
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
659
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
660
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
661
+
662
+ self.to_out = nn.Sequential(
663
+ nn.Linear(inner_dim, dim, bias = False),
664
+ LayerNorm(dim)
665
+ )
666
+
667
+ def forward(self, x, context, mask = None):
668
+ b, n, device = *x.shape[:2], x.device
669
+
670
+ x = self.norm(x)
671
+ context = self.norm_context(context)
672
+
673
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
674
+
675
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
676
+
677
+ # add null key / value for classifier free guidance in prior net
678
+
679
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
680
+
681
+ k = torch.cat((nk, k), dim = -2)
682
+ v = torch.cat((nv, v), dim = -2)
683
+
684
+ q = q * self.scale
685
+
686
+ # cosine sim attention
687
+
688
+ if self.cosine_sim_attn:
689
+ q, k = map(l2norm, (q, k))
690
+
691
+ # similarities
692
+
693
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
694
+
695
+ # masking
696
+
697
+ max_neg_value = -torch.finfo(sim.dtype).max
698
+
699
+ if exists(mask):
700
+ mask = F.pad(mask, (1, 0), value = True)
701
+ mask = rearrange(mask, 'b j -> b 1 1 j')
702
+ sim = sim.masked_fill(~mask, max_neg_value)
703
+
704
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
705
+
706
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
707
+ out = rearrange(out, 'b h n d -> b n (h d)')
708
+ return self.to_out(out)
709
+
710
+ class LinearCrossAttention(CrossAttention):
711
+ def forward(self, x, context, mask = None):
712
+ b, n, device = *x.shape[:2], x.device
713
+
714
+ x = self.norm(x)
715
+ context = self.norm_context(context)
716
+
717
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
718
+
719
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
720
+
721
+ # add null key / value for classifier free guidance in prior net
722
+
723
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b)
724
+
725
+ k = torch.cat((nk, k), dim = -2)
726
+ v = torch.cat((nv, v), dim = -2)
727
+
728
+ # masking
729
+
730
+ max_neg_value = -torch.finfo(x.dtype).max
731
+
732
+ if exists(mask):
733
+ mask = F.pad(mask, (1, 0), value = True)
734
+ mask = rearrange(mask, 'b n -> b n 1')
735
+ k = k.masked_fill(~mask, max_neg_value)
736
+ v = v.masked_fill(~mask, 0.)
737
+
738
+ # linear attention
739
+
740
+ q = q.softmax(dim = -1)
741
+ k = k.softmax(dim = -2)
742
+
743
+ q = q * self.scale
744
+
745
+ context = einsum('b n d, b n e -> b d e', k, v)
746
+ out = einsum('b n d, b d e -> b n e', q, context)
747
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
748
+ return self.to_out(out)
749
+
750
+ class LinearAttention(nn.Module):
751
+ def __init__(
752
+ self,
753
+ dim,
754
+ dim_head = 32,
755
+ heads = 8,
756
+ dropout = 0.05,
757
+ context_dim = None,
758
+ **kwargs
759
+ ):
760
+ super().__init__()
761
+ self.scale = dim_head ** -0.5
762
+ self.heads = heads
763
+ inner_dim = dim_head * heads
764
+ self.norm = ChanLayerNorm(dim)
765
+
766
+ self.nonlin = nn.SiLU()
767
+
768
+ self.to_q = nn.Sequential(
769
+ nn.Dropout(dropout),
770
+ Conv2d(dim, inner_dim, 1, bias = False),
771
+ Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
772
+ )
773
+
774
+ self.to_k = nn.Sequential(
775
+ nn.Dropout(dropout),
776
+ Conv2d(dim, inner_dim, 1, bias = False),
777
+ Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
778
+ )
779
+
780
+ self.to_v = nn.Sequential(
781
+ nn.Dropout(dropout),
782
+ Conv2d(dim, inner_dim, 1, bias = False),
783
+ Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
784
+ )
785
+
786
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
787
+
788
+ self.to_out = nn.Sequential(
789
+ Conv2d(inner_dim, dim, 1, bias = False),
790
+ ChanLayerNorm(dim)
791
+ )
792
+
793
+ def forward(self, fmap, context = None):
794
+ h, x, y = self.heads, *fmap.shape[-2:]
795
+
796
+ fmap = self.norm(fmap)
797
+ q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
798
+ q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
799
+
800
+ if exists(context):
801
+ assert exists(self.to_context)
802
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
803
+ ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
804
+ k = torch.cat((k, ck), dim = -2)
805
+ v = torch.cat((v, cv), dim = -2)
806
+
807
+ q = q.softmax(dim = -1)
808
+ k = k.softmax(dim = -2)
809
+
810
+ q = q * self.scale
811
+
812
+ context = einsum('b n d, b n e -> b d e', k, v)
813
+ out = einsum('b n d, b d e -> b n e', q, context)
814
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
815
+
816
+ out = self.nonlin(out)
817
+ return self.to_out(out)
818
+
819
+ class GlobalContext(nn.Module):
820
+ """ basically a superior form of squeeze-excitation that is attention-esque """
821
+
822
+ def __init__(
823
+ self,
824
+ *,
825
+ dim_in,
826
+ dim_out
827
+ ):
828
+ super().__init__()
829
+ self.to_k = Conv2d(dim_in, 1, 1)
830
+ hidden_dim = max(3, dim_out // 2)
831
+
832
+ self.net = nn.Sequential(
833
+ Conv2d(dim_in, hidden_dim, 1),
834
+ nn.SiLU(),
835
+ Conv2d(hidden_dim, dim_out, 1),
836
+ nn.Sigmoid()
837
+ )
838
+
839
+ def forward(self, x):
840
+ context = self.to_k(x)
841
+ x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
842
+ out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
843
+ out = rearrange(out, '... -> ... 1 1')
844
+ return self.net(out)
845
+
846
+ def FeedForward(dim, mult = 2):
847
+ hidden_dim = int(dim * mult)
848
+ return nn.Sequential(
849
+ LayerNorm(dim),
850
+ nn.Linear(dim, hidden_dim, bias = False),
851
+ nn.GELU(),
852
+ LayerNorm(hidden_dim),
853
+ nn.Linear(hidden_dim, dim, bias = False)
854
+ )
855
+
856
+ def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
857
+ hidden_dim = int(dim * mult)
858
+ return nn.Sequential(
859
+ ChanLayerNorm(dim),
860
+ Conv2d(dim, hidden_dim, 1, bias = False),
861
+ nn.GELU(),
862
+ ChanLayerNorm(hidden_dim),
863
+ Conv2d(hidden_dim, dim, 1, bias = False)
864
+ )
865
+
866
+ class TransformerBlock(nn.Module):
867
+ def __init__(
868
+ self,
869
+ dim,
870
+ *,
871
+ depth = 1,
872
+ heads = 8,
873
+ dim_head = 32,
874
+ ff_mult = 2,
875
+ context_dim = None,
876
+ cosine_sim_attn = False
877
+ ):
878
+ super().__init__()
879
+ self.layers = nn.ModuleList([])
880
+
881
+ for _ in range(depth):
882
+ self.layers.append(nn.ModuleList([
883
+ EinopsToAndFrom('b c f h w', 'b (f h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
884
+ ChanFeedForward(dim = dim, mult = ff_mult)
885
+ ]))
886
+
887
+ def forward(self, x, context = None):
888
+ for attn, ff in self.layers:
889
+ x = attn(x, context = context) + x
890
+ x = ff(x) + x
891
+ return x
892
+
893
+ class LinearAttentionTransformerBlock(nn.Module):
894
+ def __init__(
895
+ self,
896
+ dim,
897
+ *,
898
+ depth = 1,
899
+ heads = 8,
900
+ dim_head = 32,
901
+ ff_mult = 2,
902
+ context_dim = None,
903
+ **kwargs
904
+ ):
905
+ super().__init__()
906
+ self.layers = nn.ModuleList([])
907
+
908
+ for _ in range(depth):
909
+ self.layers.append(nn.ModuleList([
910
+ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
911
+ ChanFeedForward(dim = dim, mult = ff_mult)
912
+ ]))
913
+
914
+ def forward(self, x, context = None):
915
+ for attn, ff in self.layers:
916
+ x = attn(x, context = context) + x
917
+ x = ff(x) + x
918
+ return x
919
+
920
+ class CrossEmbedLayer(nn.Module):
921
+ def __init__(
922
+ self,
923
+ dim_in,
924
+ kernel_sizes,
925
+ dim_out = None,
926
+ stride = 2
927
+ ):
928
+ super().__init__()
929
+ assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
930
+ dim_out = default(dim_out, dim_in)
931
+
932
+ kernel_sizes = sorted(kernel_sizes)
933
+ num_scales = len(kernel_sizes)
934
+
935
+ # calculate the dimension at each scale
936
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
937
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
938
+
939
+ self.convs = nn.ModuleList([])
940
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
941
+ self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
942
+
943
+ def forward(self, x):
944
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
945
+ return torch.cat(fmaps, dim = 1)
946
+
947
+ class UpsampleCombiner(nn.Module):
948
+ def __init__(
949
+ self,
950
+ dim,
951
+ *,
952
+ enabled = False,
953
+ dim_ins = tuple(),
954
+ dim_outs = tuple()
955
+ ):
956
+ super().__init__()
957
+ dim_outs = cast_tuple(dim_outs, len(dim_ins))
958
+ assert len(dim_ins) == len(dim_outs)
959
+
960
+ self.enabled = enabled
961
+
962
+ if not self.enabled:
963
+ self.dim_out = dim
964
+ return
965
+
966
+ self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
967
+ self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
968
+
969
+ def forward(self, x, fmaps = None):
970
+ target_size = x.shape[-1]
971
+
972
+ fmaps = default(fmaps, tuple())
973
+
974
+ if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
975
+ return x
976
+
977
+ fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
978
+ outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
979
+ return torch.cat((x, *outs), dim = 1)
980
+
981
+ class DynamicPositionBias(nn.Module):
982
+ def __init__(
983
+ self,
984
+ dim,
985
+ *,
986
+ heads,
987
+ depth
988
+ ):
989
+ super().__init__()
990
+ self.mlp = nn.ModuleList([])
991
+
992
+ self.mlp.append(nn.Sequential(
993
+ nn.Linear(1, dim),
994
+ LayerNorm(dim),
995
+ nn.SiLU()
996
+ ))
997
+
998
+ for _ in range(max(depth - 1, 0)):
999
+ self.mlp.append(nn.Sequential(
1000
+ nn.Linear(dim, dim),
1001
+ LayerNorm(dim),
1002
+ nn.SiLU()
1003
+ ))
1004
+
1005
+ self.mlp.append(nn.Linear(dim, heads))
1006
+
1007
+ def forward(self, n, device, dtype):
1008
+ i = torch.arange(n, device = device)
1009
+ j = torch.arange(n, device = device)
1010
+
1011
+ indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
1012
+ indices += (n - 1)
1013
+
1014
+ pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
1015
+ pos = rearrange(pos, '... -> ... 1')
1016
+
1017
+ for layer in self.mlp:
1018
+ pos = layer(pos)
1019
+
1020
+ bias = pos[indices]
1021
+ bias = rearrange(bias, 'i j h -> h i j')
1022
+ return bias
1023
+
1024
+ class Unet3D(nn.Module):
1025
+ def __init__(
1026
+ self,
1027
+ *,
1028
+ dim,
1029
+ image_embed_dim = 1024,
1030
+ text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
1031
+ num_resnet_blocks = 1,
1032
+ cond_dim = None,
1033
+ num_image_tokens = 4,
1034
+ num_time_tokens = 2,
1035
+ learned_sinu_pos_emb_dim = 16,
1036
+ out_dim = None,
1037
+ dim_mults=(1, 2, 4, 8),
1038
+ cond_images_channels = 0,
1039
+ channels = 3,
1040
+ channels_out = None,
1041
+ attn_dim_head = 64,
1042
+ attn_heads = 8,
1043
+ ff_mult = 2.,
1044
+ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
1045
+ layer_attns = False,
1046
+ layer_attns_depth = 1,
1047
+ layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
1048
+ attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
1049
+ time_rel_pos_bias_depth = 2,
1050
+ time_causal_attn = True,
1051
+ layer_cross_attns = True,
1052
+ use_linear_attn = False,
1053
+ use_linear_cross_attn = False,
1054
+ cond_on_text = True,
1055
+ max_text_len = 256,
1056
+ init_dim = None,
1057
+ resnet_groups = 8,
1058
+ init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
1059
+ init_cross_embed = True,
1060
+ init_cross_embed_kernel_sizes = (3, 7, 15),
1061
+ cross_embed_downsample = False,
1062
+ cross_embed_downsample_kernel_sizes = (2, 4),
1063
+ attn_pool_text = True,
1064
+ attn_pool_num_latents = 32,
1065
+ dropout = 0.,
1066
+ memory_efficient = False,
1067
+ init_conv_to_final_conv_residual = False,
1068
+ use_global_context_attn = True,
1069
+ scale_skip_connection = True,
1070
+ final_resnet_block = True,
1071
+ final_conv_kernel_size = 3,
1072
+ cosine_sim_attn = False,
1073
+ self_cond = False,
1074
+ combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
1075
+ pixel_shuffle_upsample = True # may address checkboard artifacts
1076
+ ):
1077
+ super().__init__()
1078
+
1079
+ # guide researchers
1080
+
1081
+ assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
1082
+
1083
+ if dim < 128:
1084
+ print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
1085
+
1086
+ # save locals to take care of some hyperparameters for cascading DDPM
1087
+
1088
+ self._locals = locals()
1089
+ self._locals.pop('self', None)
1090
+ self._locals.pop('__class__', None)
1091
+
1092
+ self.self_cond = self_cond
1093
+
1094
+ # determine dimensions
1095
+
1096
+ self.channels = channels
1097
+ self.channels_out = default(channels_out, channels)
1098
+
1099
+ # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
1100
+ # (2) in self conditioning, one appends the predict x0 (x_start)
1101
+ init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
1102
+ init_dim = default(init_dim, dim)
1103
+
1104
+ # optional image conditioning
1105
+
1106
+ self.has_cond_image = cond_images_channels > 0
1107
+ self.cond_images_channels = cond_images_channels
1108
+
1109
+ init_channels += cond_images_channels
1110
+
1111
+ # initial convolution
1112
+
1113
+ self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
1114
+
1115
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
1116
+ in_out = list(zip(dims[:-1], dims[1:]))
1117
+
1118
+ # time conditioning
1119
+
1120
+ cond_dim = default(cond_dim, dim)
1121
+ time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
1122
+
1123
+ # embedding time for log(snr) noise from continuous version
1124
+
1125
+ sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
1126
+ sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
1127
+
1128
+ self.to_time_hiddens = nn.Sequential(
1129
+ sinu_pos_emb,
1130
+ nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
1131
+ nn.SiLU()
1132
+ )
1133
+
1134
+ self.to_time_cond = nn.Sequential(
1135
+ nn.Linear(time_cond_dim, time_cond_dim)
1136
+ )
1137
+
1138
+ # project to time tokens as well as time hiddens
1139
+
1140
+ self.to_time_tokens = nn.Sequential(
1141
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1142
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1143
+ )
1144
+
1145
+ # low res aug noise conditioning
1146
+
1147
+ self.lowres_cond = lowres_cond
1148
+
1149
+ if lowres_cond:
1150
+ self.to_lowres_time_hiddens = nn.Sequential(
1151
+ LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
1152
+ nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
1153
+ nn.SiLU()
1154
+ )
1155
+
1156
+ self.to_lowres_time_cond = nn.Sequential(
1157
+ nn.Linear(time_cond_dim, time_cond_dim)
1158
+ )
1159
+
1160
+ self.to_lowres_time_tokens = nn.Sequential(
1161
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1162
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1163
+ )
1164
+
1165
+ # normalizations
1166
+
1167
+ self.norm_cond = nn.LayerNorm(cond_dim)
1168
+
1169
+ # text encoding conditioning (optional)
1170
+
1171
+ self.text_to_cond = None
1172
+
1173
+ if cond_on_text:
1174
+ assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
1175
+ self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
1176
+
1177
+ # finer control over whether to condition on text encodings
1178
+
1179
+ self.cond_on_text = cond_on_text
1180
+
1181
+ # attention pooling
1182
+
1183
+ self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
1184
+
1185
+ # for classifier free guidance
1186
+
1187
+ self.max_text_len = max_text_len
1188
+
1189
+ self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
1190
+ self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
1191
+
1192
+ # for non-attention based text conditioning at all points in the network where time is also conditioned
1193
+
1194
+ self.to_text_non_attn_cond = None
1195
+
1196
+ if cond_on_text:
1197
+ self.to_text_non_attn_cond = nn.Sequential(
1198
+ nn.LayerNorm(cond_dim),
1199
+ nn.Linear(cond_dim, time_cond_dim),
1200
+ nn.SiLU(),
1201
+ nn.Linear(time_cond_dim, time_cond_dim)
1202
+ )
1203
+
1204
+ # attention related params
1205
+
1206
+ attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
1207
+
1208
+ num_layers = len(in_out)
1209
+
1210
+ # temporal attention - attention across video frames
1211
+
1212
+ temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
1213
+ temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))
1214
+
1215
+ temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn})))
1216
+
1217
+ # temporal attention relative positional encoding
1218
+
1219
+ self.time_rel_pos_bias = DynamicPositionBias(dim = dim * 2, heads = attn_heads, depth = time_rel_pos_bias_depth)
1220
+
1221
+ # resnet block klass
1222
+
1223
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
1224
+ resnet_groups = cast_tuple(resnet_groups, num_layers)
1225
+
1226
+ resnet_klass = partial(ResnetBlock, **attn_kwargs)
1227
+
1228
+ layer_attns = cast_tuple(layer_attns, num_layers)
1229
+ layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
1230
+ layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
1231
+
1232
+ assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
1233
+
1234
+ # downsample klass
1235
+
1236
+ downsample_klass = Downsample
1237
+
1238
+ if cross_embed_downsample:
1239
+ downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
1240
+
1241
+ # initial resnet block (for memory efficient unet)
1242
+
1243
+ self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
1244
+
1245
+ self.init_temporal_peg = temporal_peg(init_dim)
1246
+ self.init_temporal_attn = temporal_attn(init_dim)
1247
+
1248
+ # scale for resnet skip connections
1249
+
1250
+ self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
1251
+
1252
+ # layers
1253
+
1254
+ self.downs = nn.ModuleList([])
1255
+ self.ups = nn.ModuleList([])
1256
+ num_resolutions = len(in_out)
1257
+
1258
+ layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns]
1259
+ reversed_layer_params = list(map(reversed, layer_params))
1260
+
1261
+ # downsampling layers
1262
+
1263
+ skip_connect_dims = [] # keep track of skip connection dimensions
1264
+
1265
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(in_out, *layer_params)):
1266
+ is_last = ind >= (num_resolutions - 1)
1267
+
1268
+ layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
1269
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1270
+
1271
+ transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
1272
+
1273
+ current_dim = dim_in
1274
+
1275
+ # whether to pre-downsample, from memory efficient unet
1276
+
1277
+ pre_downsample = None
1278
+
1279
+ if memory_efficient:
1280
+ pre_downsample = downsample_klass(dim_in, dim_out)
1281
+ current_dim = dim_out
1282
+
1283
+ skip_connect_dims.append(current_dim)
1284
+
1285
+ # whether to do post-downsample, for non-memory efficient unet
1286
+
1287
+ post_downsample = None
1288
+ if not memory_efficient:
1289
+ post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1))
1290
+
1291
+ self.downs.append(nn.ModuleList([
1292
+ pre_downsample,
1293
+ resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1294
+ nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1295
+ transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1296
+ temporal_peg(current_dim),
1297
+ temporal_attn(current_dim),
1298
+ post_downsample
1299
+ ]))
1300
+
1301
+ # middle layers
1302
+
1303
+ mid_dim = dims[-1]
1304
+
1305
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1306
+ self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
1307
+ self.mid_temporal_peg = temporal_peg(mid_dim)
1308
+ self.mid_temporal_attn = temporal_attn(mid_dim)
1309
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1310
+
1311
+ # upsample klass
1312
+
1313
+ upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1314
+
1315
+ # upsampling layers
1316
+
1317
+ upsample_fmap_dims = []
1318
+
1319
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
1320
+ is_last = ind == (len(in_out) - 1)
1321
+ layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
1322
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1323
+ transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
1324
+
1325
+ skip_connect_dim = skip_connect_dims.pop()
1326
+
1327
+ upsample_fmap_dims.append(dim_out)
1328
+
1329
+ self.ups.append(nn.ModuleList([
1330
+ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1331
+ nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1332
+ transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1333
+ temporal_peg(dim_out),
1334
+ temporal_attn(dim_out),
1335
+ upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
1336
+ ]))
1337
+
1338
+ # whether to combine feature maps from all upsample blocks before final resnet block out
1339
+
1340
+ self.upsample_combiner = UpsampleCombiner(
1341
+ dim = dim,
1342
+ enabled = combine_upsample_fmaps,
1343
+ dim_ins = upsample_fmap_dims,
1344
+ dim_outs = dim
1345
+ )
1346
+
1347
+ # whether to do a final residual from initial conv to the final resnet block out
1348
+
1349
+ self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
1350
+ final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
1351
+
1352
+ # final optional resnet block and convolution out
1353
+
1354
+ self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
1355
+
1356
+ final_conv_dim_in = dim if final_resnet_block else final_conv_dim
1357
+ final_conv_dim_in += (channels if lowres_cond else 0)
1358
+
1359
+ self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
1360
+
1361
+ zero_init_(self.final_conv)
1362
+
1363
+ # if the current settings for the unet are not correct
1364
+ # for cascading DDPM, then reinit the unet with the right settings
1365
+ def cast_model_parameters(
1366
+ self,
1367
+ *,
1368
+ lowres_cond,
1369
+ text_embed_dim,
1370
+ channels,
1371
+ channels_out,
1372
+ cond_on_text
1373
+ ):
1374
+ if lowres_cond == self.lowres_cond and \
1375
+ channels == self.channels and \
1376
+ cond_on_text == self.cond_on_text and \
1377
+ text_embed_dim == self._locals['text_embed_dim'] and \
1378
+ channels_out == self.channels_out:
1379
+ return self
1380
+
1381
+ updated_kwargs = dict(
1382
+ lowres_cond = lowres_cond,
1383
+ text_embed_dim = text_embed_dim,
1384
+ channels = channels,
1385
+ channels_out = channels_out,
1386
+ cond_on_text = cond_on_text
1387
+ )
1388
+
1389
+ return self.__class__(**{**self._locals, **updated_kwargs})
1390
+
1391
+ # methods for returning the full unet config as well as its parameter state
1392
+
1393
+ def to_config_and_state_dict(self):
1394
+ return self._locals, self.state_dict()
1395
+
1396
+ # class method for rehydrating the unet from its config and state dict
1397
+
1398
+ @classmethod
1399
+ def from_config_and_state_dict(klass, config, state_dict):
1400
+ unet = klass(**config)
1401
+ unet.load_state_dict(state_dict)
1402
+ return unet
1403
+
1404
+ # methods for persisting unet to disk
1405
+
1406
+ def persist_to_file(self, path):
1407
+ path = Path(path)
1408
+ path.parents[0].mkdir(exist_ok = True, parents = True)
1409
+
1410
+ config, state_dict = self.to_config_and_state_dict()
1411
+ pkg = dict(config = config, state_dict = state_dict)
1412
+ torch.save(pkg, str(path))
1413
+
1414
+ # class method for rehydrating the unet from file saved with `persist_to_file`
1415
+
1416
+ @classmethod
1417
+ def hydrate_from_file(klass, path):
1418
+ path = Path(path)
1419
+ assert path.exists()
1420
+ pkg = torch.load(str(path))
1421
+
1422
+ assert 'config' in pkg and 'state_dict' in pkg
1423
+ config, state_dict = pkg['config'], pkg['state_dict']
1424
+
1425
+ return Unet.from_config_and_state_dict(config, state_dict)
1426
+
1427
+ # forward with classifier free guidance
1428
+
1429
+ def forward_with_cond_scale(
1430
+ self,
1431
+ *args,
1432
+ cond_scale = 1.,
1433
+ **kwargs
1434
+ ):
1435
+ logits = self.forward(*args, **kwargs)
1436
+
1437
+ if cond_scale == 1:
1438
+ return logits
1439
+
1440
+ null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1441
+ return null_logits + (logits - null_logits) * cond_scale
1442
+
1443
+ def forward(
1444
+ self,
1445
+ x,
1446
+ time,
1447
+ *,
1448
+ lowres_cond_img = None,
1449
+ lowres_noise_times = None,
1450
+ text_embeds = None,
1451
+ text_mask = None,
1452
+ cond_images = None,
1453
+ self_cond = None,
1454
+ cond_drop_prob = 0.
1455
+ ):
1456
+ assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)'
1457
+
1458
+ batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype
1459
+
1460
+ # add self conditioning if needed
1461
+
1462
+ if self.self_cond:
1463
+ self_cond = default(self_cond, lambda: torch.zeros_like(x))
1464
+ x = torch.cat((x, self_cond), dim = 1)
1465
+
1466
+ # add low resolution conditioning, if present
1467
+
1468
+ assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
1469
+ assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
1470
+
1471
+ if exists(lowres_cond_img):
1472
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1473
+
1474
+ # condition on input image
1475
+
1476
+ assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
1477
+
1478
+ if exists(cond_images):
1479
+ assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
1480
+ cond_images = resize_video_to(cond_images, x.shape[-1])
1481
+ x = torch.cat((cond_images, x), dim = 1)
1482
+
1483
+ # get time relative positions
1484
+
1485
+ time_attn_bias = self.time_rel_pos_bias(frames, device = device, dtype = dtype)
1486
+
1487
+ # initial convolution
1488
+
1489
+ x = self.init_conv(x)
1490
+
1491
+ x = self.init_temporal_peg(x)
1492
+ x = self.init_temporal_attn(x, attn_bias = time_attn_bias)
1493
+
1494
+ # init conv residual
1495
+
1496
+ if self.init_conv_to_final_conv_residual:
1497
+ init_conv_residual = x.clone()
1498
+
1499
+ # time conditioning
1500
+
1501
+ time_hiddens = self.to_time_hiddens(time)
1502
+
1503
+ # derive time tokens
1504
+
1505
+ time_tokens = self.to_time_tokens(time_hiddens)
1506
+ t = self.to_time_cond(time_hiddens)
1507
+
1508
+ # add lowres time conditioning to time hiddens
1509
+ # and add lowres time tokens along sequence dimension for attention
1510
+
1511
+ if self.lowres_cond:
1512
+ lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
1513
+ lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
1514
+ lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
1515
+
1516
+ t = t + lowres_t
1517
+ time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
1518
+
1519
+ # text conditioning
1520
+
1521
+ text_tokens = None
1522
+
1523
+ if exists(text_embeds) and self.cond_on_text:
1524
+
1525
+ # conditional dropout
1526
+
1527
+ text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1528
+
1529
+ text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
1530
+ text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
1531
+
1532
+ # calculate text embeds
1533
+
1534
+ text_tokens = self.text_to_cond(text_embeds)
1535
+
1536
+ text_tokens = text_tokens[:, :self.max_text_len]
1537
+
1538
+ if exists(text_mask):
1539
+ text_mask = text_mask[:, :self.max_text_len]
1540
+
1541
+ text_tokens_len = text_tokens.shape[1]
1542
+ remainder = self.max_text_len - text_tokens_len
1543
+
1544
+ if remainder > 0:
1545
+ text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
1546
+
1547
+ if exists(text_mask):
1548
+ if remainder > 0:
1549
+ text_mask = F.pad(text_mask, (0, remainder), value = False)
1550
+
1551
+ text_mask = rearrange(text_mask, 'b n -> b n 1')
1552
+ text_keep_mask_embed = text_mask & text_keep_mask_embed
1553
+
1554
+ null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
1555
+
1556
+ text_tokens = torch.where(
1557
+ text_keep_mask_embed,
1558
+ text_tokens,
1559
+ null_text_embed
1560
+ )
1561
+
1562
+ if exists(self.attn_pool):
1563
+ text_tokens = self.attn_pool(text_tokens)
1564
+
1565
+ # extra non-attention conditioning by projecting and then summing text embeddings to time
1566
+ # termed as text hiddens
1567
+
1568
+ mean_pooled_text_tokens = text_tokens.mean(dim = -2)
1569
+
1570
+ text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
1571
+
1572
+ null_text_hidden = self.null_text_hidden.to(t.dtype)
1573
+
1574
+ text_hiddens = torch.where(
1575
+ text_keep_mask_hidden,
1576
+ text_hiddens,
1577
+ null_text_hidden
1578
+ )
1579
+
1580
+ t = t + text_hiddens
1581
+
1582
+ # main conditioning tokens (c)
1583
+
1584
+ c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
1585
+
1586
+ # normalize conditioning tokens
1587
+
1588
+ c = self.norm_cond(c)
1589
+
1590
+ # initial resnet block (for memory efficient unet)
1591
+
1592
+ if exists(self.init_resnet_block):
1593
+ x = self.init_resnet_block(x, t)
1594
+
1595
+ # go through the layers of the unet, down and up
1596
+
1597
+ hiddens = []
1598
+
1599
+ for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, post_downsample in self.downs:
1600
+ if exists(pre_downsample):
1601
+ x = pre_downsample(x)
1602
+
1603
+ x = init_block(x, t, c)
1604
+
1605
+ for resnet_block in resnet_blocks:
1606
+ x = resnet_block(x, t)
1607
+ hiddens.append(x)
1608
+
1609
+ x = attn_block(x, c)
1610
+ x = temporal_peg(x)
1611
+ x = temporal_attn(x, attn_bias = time_attn_bias)
1612
+
1613
+ hiddens.append(x)
1614
+
1615
+ if exists(post_downsample):
1616
+ x = post_downsample(x)
1617
+
1618
+ x = self.mid_block1(x, t, c)
1619
+
1620
+ if exists(self.mid_attn):
1621
+ x = self.mid_attn(x)
1622
+
1623
+ x = self.mid_temporal_peg(x)
1624
+ x = self.mid_temporal_attn(x, attn_bias = time_attn_bias)
1625
+
1626
+ x = self.mid_block2(x, t, c)
1627
+
1628
+ add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
1629
+
1630
+ up_hiddens = []
1631
+
1632
+ for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, upsample in self.ups:
1633
+ x = add_skip_connection(x)
1634
+ x = init_block(x, t, c)
1635
+
1636
+ for resnet_block in resnet_blocks:
1637
+ x = add_skip_connection(x)
1638
+ x = resnet_block(x, t)
1639
+
1640
+ x = attn_block(x, c)
1641
+ x = temporal_peg(x)
1642
+ x = temporal_attn(x, attn_bias = time_attn_bias)
1643
+
1644
+ up_hiddens.append(x.contiguous())
1645
+ x = upsample(x)
1646
+
1647
+ # whether to combine all feature maps from upsample blocks
1648
+
1649
+ x = self.upsample_combiner(x, up_hiddens)
1650
+
1651
+ # final top-most residual if needed
1652
+
1653
+ if self.init_conv_to_final_conv_residual:
1654
+ x = torch.cat((x, init_conv_residual), dim = 1)
1655
+
1656
+ if exists(self.final_res_block):
1657
+ x = self.final_res_block(x, t)
1658
+
1659
+ if exists(lowres_cond_img):
1660
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1661
+
1662
+ return self.final_conv(x)
imagen_pytorch/joint_imagen.py ADDED
@@ -0,0 +1,1942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import contextmanager, nullcontext
3
+ from functools import partial
4
+ from pathlib import Path
5
+ from random import random
6
+ from typing import List, Union
7
+
8
+ import kornia.augmentation as K
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import check_shape, rearrange_many
15
+ from einops_exts.torch import EinopsToAndFrom
16
+ from torch import nn
17
+ from torch.cuda.amp import autocast
18
+ from torch.nn.parallel import DistributedDataParallel
19
+ from tqdm.auto import tqdm
20
+ from torch.special import expm1
21
+
22
+ from imagen_pytorch.imagen_pytorch import (
23
+ Attention, CrossEmbedLayer, Downsample, GaussianDiffusionContinuousTimes,
24
+ Identity, LearnedSinusoidalPosEmb, LinearAttentionTransformerBlock,
25
+ NullUnet, Parallel, PerceiverResampler, PixelShuffleUpsample,
26
+ Residual, ResnetBlock, TransformerBlock, Upsample, UpsampleCombiner,
27
+ cast_tuple, cast_uint8_images_to_float, default, eval_decorator,
28
+ exists, first, identity, is_float_dtype, maybe, module_device,
29
+ normalize_neg_one_to_one, pad_tuple_to_length, print_once, prob_mask_like,
30
+ resize_image_to, right_pad_dims_to, unnormalize_zero_to_one, zero_init_,
31
+ beta_linear_log_snr, alpha_cosine_log_snr, log, log_snr_to_alpha_sigma)
32
+ from imagen_pytorch.imagen_video.imagen_video import Unet3D, resize_video_to
33
+ from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim, t5_encode_text
34
+
35
+
36
+ def log_1_min_a(a):
37
+ return torch.log(1 - a.exp() + 1e-40)
38
+
39
+
40
+ def log_add_exp(a, b):
41
+ maximum = torch.max(a, b)
42
+ return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
43
+
44
+
45
+ def extract(a, t, x_shape):
46
+ b, *_ = t.shape
47
+ out = a.gather(-1, t)
48
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
49
+
50
+
51
+ def log_categorical(log_x_start, log_prob):
52
+ return (log_x_start.exp() * log_prob).sum(dim=1)
53
+
54
+
55
+ def index_to_log_onehot(x, num_classes):
56
+ assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
57
+ if len(x.size()) == 4 and x.size(1) == 1:
58
+ x = x.squeeze(1)
59
+ x_onehot = F.one_hot(x, num_classes)
60
+ permute_order = (0, -1) + tuple(range(1, len(x.size())))
61
+ x_onehot = x_onehot.permute(permute_order)
62
+ log_x = torch.log(x_onehot.float().clamp(min=1e-30))
63
+ return log_x
64
+
65
+
66
+ def log_onehot_to_index(log_x):
67
+ return log_x.argmax(1, keepdims=True)
68
+
69
+
70
+ def sum_except_batch(x, num_dims=1):
71
+ '''
72
+ Sums all dimensions except the first.
73
+
74
+ Args:
75
+ x: Tensor, shape (batch_size, ...)
76
+ num_dims: int, number of batch dims (default=1)
77
+
78
+ Returns:
79
+ x_sum: Tensor, shape (batch_size,)
80
+ '''
81
+ return x.reshape(*x.shape[:num_dims], -1).sum(-1)
82
+
83
+
84
+ @torch.jit.script
85
+ def alpha_cosine_p_log_snr(t, p: float = 0.8, s: float = 0.008):
86
+ # not sure if this accounts for beta being clipped to 0.999 in discrete version
87
+ return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** (-2 * p)) - 1, eps=1e-5)
88
+
89
+
90
+ class MultinomialDiffusion(nn.Module):
91
+ def __init__(self, num_classes, *, noise_schedule, p=1.0, timesteps=1000):
92
+ super().__init__()
93
+
94
+ if noise_schedule == "linear":
95
+ self.log_snr = beta_linear_log_snr
96
+ elif noise_schedule == "cosine":
97
+ self.log_snr = alpha_cosine_log_snr
98
+ elif noise_schedule == "cosine_p":
99
+ self.log_snr = partial(alpha_cosine_p_log_snr, p=p)
100
+ else:
101
+ raise ValueError(f'invalid noise schedule {noise_schedule}')
102
+
103
+ cumprod_alpha = torch.tensor([log_snr_to_alpha_sigma(self.log_snr(
104
+ torch.tensor(t / timesteps)))[0] ** 2 for t in range(timesteps)])
105
+ alphas = cumprod_alpha / F.pad(cumprod_alpha, (1, 0), value=cumprod_alpha[0])[:-1]
106
+ self.register_buffer('log_alpha', torch.log(alphas))
107
+ self.register_buffer('log_1_min_alpha', log_1_min_a(self.log_alpha))
108
+ self.register_buffer('log_cumprod_alpha', torch.cumsum(self.log_alpha, axis=0))
109
+ self.register_buffer('log_1_min_cumprod_alpha', log_1_min_a(self.log_cumprod_alpha))
110
+ self.num_classes = num_classes
111
+ self.num_timesteps = timesteps
112
+
113
+ def change_times_dtype(self, t):
114
+ if t.dtype != torch.int64:
115
+ if ((t < 1) * (t > 0)).any():
116
+ return torch.floor(t * self.num_timesteps).to(torch.int64)
117
+ else:
118
+ return t.to(torch.int64)
119
+ return t
120
+
121
+ def get_times(self, batch_size, noise_level, *, device):
122
+ raise NotImplementedError
123
+ return torch.full((batch_size,), noise_level, device=device, dtype=torch.float32)
124
+
125
+ def sample_random_times(self, batch_size, max_thres=0.999, *, device):
126
+ raise NotImplementedError
127
+ return torch.zeros((batch_size,), device=device).float().uniform_(0, max_thres)
128
+
129
+ def get_condition(self, times):
130
+ raise NotImplementedError
131
+ return maybe(self.log_snr)(times)
132
+
133
+ def get_sampling_timesteps(self, batch, *, device):
134
+ raise NotImplementedError
135
+ times = torch.linspace(1., 0., self.num_timesteps + 1, device=device)
136
+ times = repeat(times, 't -> b t', b=batch)
137
+ times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
138
+ times = times.unbind(dim=-1)
139
+ return times
140
+
141
+ def q_pred(self, log_x_start, t):
142
+ log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
143
+ log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
144
+ log_probs = log_add_exp(
145
+ log_x_start + log_cumprod_alpha_t,
146
+ log_1_min_cumprod_alpha - math.log(self.num_classes)
147
+ )
148
+ return log_probs
149
+
150
+ def log_sample_categorical(self, logits):
151
+ uniform = torch.rand_like(logits)
152
+ gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
153
+ sample = (gumbel_noise + logits).argmax(dim=1)
154
+ log_sample = index_to_log_onehot(sample, self.num_classes)
155
+ return log_sample
156
+
157
+ def q_sample(self, log_x_start, t):
158
+ t = self.change_times_dtype(t) # caused by continuous timesteps.
159
+ log_EV_qxt_x0 = self.q_pred(log_x_start, t)
160
+ log_sample = self.log_sample_categorical(log_EV_qxt_x0)
161
+ return log_sample
162
+
163
+ def q_pred_one_timestep(self, log_x_t, t):
164
+ log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
165
+ log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
166
+ # alpha_t * E[xt] + (1 - alpha_t) 1 / K
167
+ log_probs = log_add_exp(
168
+ log_x_t + log_alpha_t,
169
+ log_1_min_alpha_t - math.log(self.num_classes)
170
+ )
171
+ return log_probs
172
+
173
+ def q_posterior(self, log_x_start, log_x_t, t):
174
+ t = self.change_times_dtype(t) # caused by continuous timesteps.
175
+ # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
176
+ # where q(xt | xt-1, x0) = q(xt | xt-1).
177
+
178
+ # EV_log_qxt_x0 = self.q_pred(log_x_start, t)
179
+
180
+ # print('sum exp', EV_log_qxt_x0.exp().sum(1).mean())
181
+ # assert False
182
+
183
+ # log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1)
184
+
185
+ t_minus_1 = t - 1
186
+ # Remove negative values, will not be used anyway for final decoder
187
+ t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
188
+ log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
189
+
190
+ num_axes = (1,) * (len(log_x_start.size()) - 1)
191
+ t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
192
+ log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
193
+
194
+ # unnormed_logprobs = log_EV_qxtmin_x0 +
195
+ # log q_pred_one_timestep(x_t, t)
196
+ # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
197
+ # Not very easy to see why this is true. But it is :)
198
+ unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
199
+
200
+ log_EV_xtmin_given_xt_given_xstart = \
201
+ unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)
202
+
203
+ return log_EV_xtmin_given_xt_given_xstart
204
+
205
+ def q_sample_from_to(self, log_x_from, from_t, to_t):
206
+ shape, device, dtype = log_x_from.shape, log_x_from.device, log_x_from.dtype
207
+ batch = shape[0]
208
+
209
+ if isinstance(from_t, float):
210
+ from_t = torch.full((batch,), from_t, device=device, dtype=dtype)
211
+
212
+ if isinstance(to_t, float):
213
+ to_t = torch.full((batch,), to_t, device=device, dtype=dtype)
214
+
215
+ from_t = self.change_times_dtype(from_t) # caused by continuous timesteps.
216
+ to_t = self.change_times_dtype(to_t) # caused by continuous timesteps.
217
+
218
+ log_cumprod_alpha_to_t = extract(self.log_cumprod_alpha, to_t, log_x_from.shape)
219
+ log_cumprod_alpha_from_t = extract(self.log_cumprod_alpha, from_t, log_x_from.shape)
220
+ log_probs = log_add_exp(
221
+ log_x_from + log_cumprod_alpha_to_t - log_cumprod_alpha_from_t,
222
+ log_1_min_a(log_cumprod_alpha_to_t - log_cumprod_alpha_from_t) - math.log(self.num_classes)
223
+ )
224
+
225
+ mask = (to_t == torch.zeros_like(to_t)).float()[:, None, None, None]
226
+ log_sample = index_to_log_onehot(log_probs.argmax(dim=1), self.num_classes) * mask \
227
+ + self.log_sample_categorical(log_probs) * (1. - mask)
228
+
229
+ return log_sample
230
+
231
+ def predict_start_from_noise(self, x_t, t, noise):
232
+ raise NotImplementedError
233
+
234
+ # calculate loss
235
+
236
+ def multinomial_kl(self, log_prob1, log_prob2):
237
+ return (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
238
+
239
+ def kl_prior(self, log_x_start):
240
+ b = log_x_start.size(0)
241
+ device = log_x_start.device
242
+ ones = torch.ones(b, device=device).long()
243
+
244
+ log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
245
+ log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))
246
+
247
+ kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
248
+ return sum_except_batch(kl_prior)
249
+
250
+ def loss_fn(self, target_log_lbl, pred_lbl, t, log_lbl):
251
+ t = self.change_times_dtype(t)
252
+ pt = torch.ones_like(t).float() / self.num_timesteps
253
+
254
+ kl = self.multinomial_kl(target_log_lbl, pred_lbl)
255
+ kl = sum_except_batch(kl)
256
+
257
+ decoder_nll = -log_categorical(log_lbl, pred_lbl)
258
+ decoder_nll = sum_except_batch(decoder_nll)
259
+ mask = (t == torch.zeros_like(t)).float()
260
+ kl = mask * decoder_nll + (1. - mask) * kl
261
+
262
+ kl_prior = self.kl_prior(log_lbl)
263
+ vb_loss = kl / pt + kl_prior
264
+
265
+ loss = vb_loss / (math.log(2) * pred_lbl.shape[1:].numel())
266
+ return loss
267
+
268
+
269
+ class LabelEmbedding(nn.Module):
270
+ def __init__(self, num_classes, channels):
271
+ super().__init__()
272
+ self.emb_layer = nn.Embedding(num_classes, channels)
273
+
274
+ def forward(self, x):
275
+ assert x.dim() == 4, f'x.shape should be (B, 1, H, W) but {x.shape}'
276
+ assert x.size(1) == 1, f'x.shape should be (B, 1, H, W) but {x.shape}'
277
+ x = self.emb_layer(x.long().squeeze(1))
278
+ x = x.permute(0, 3, 1, 2)
279
+ return x
280
+
281
+
282
+ class JointUnet(nn.Module):
283
+ def __init__(
284
+ self,
285
+ *,
286
+ dim,
287
+ num_classes,
288
+ image_embed_dim=1024,
289
+ text_embed_dim=get_encoded_dim(DEFAULT_T5_NAME),
290
+ num_resnet_blocks=1,
291
+ cond_dim=None,
292
+ num_image_tokens=4,
293
+ num_time_tokens=2,
294
+ learned_sinu_pos_emb_dim=16,
295
+ out_dim=None,
296
+ dim_mults=(1, 2, 4, 8),
297
+ cond_images_channels=0,
298
+ channels=3,
299
+ channels_lbl=3,
300
+ channels_out=None,
301
+ attn_dim_head=64,
302
+ attn_heads=8,
303
+ ff_mult=2.,
304
+ lowres_cond=False, # for cascading diffusion - https://cascaded-diffusion.github.io/
305
+ layer_attns=True,
306
+ layer_attns_depth=1,
307
+ # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
308
+ layer_attns_add_text_cond=True,
309
+ # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
310
+ attend_at_middle=True,
311
+ layer_cross_attns=True,
312
+ use_linear_attn=False,
313
+ use_linear_cross_attn=False,
314
+ cond_on_text=True,
315
+ max_text_len=256,
316
+ init_dim=None,
317
+ resnet_groups=8,
318
+ init_conv_kernel_size=7, # kernel size of initial conv, if not using cross embed
319
+ init_cross_embed=True,
320
+ init_cross_embed_kernel_sizes=(3, 7, 15),
321
+ cross_embed_downsample=False,
322
+ cross_embed_downsample_kernel_sizes=(2, 4),
323
+ attn_pool_text=True,
324
+ attn_pool_num_latents=32,
325
+ dropout=0.,
326
+ memory_efficient=False,
327
+ init_conv_to_final_conv_residual=False,
328
+ use_global_context_attn=True,
329
+ scale_skip_connection=True,
330
+ final_resnet_block=True,
331
+ final_conv_kernel_size=3,
332
+ cosine_sim_attn=False,
333
+ self_cond=False,
334
+ combine_upsample_fmaps=False, # combine feature maps from all upsample blocks, used in unet squared successfully
335
+ pixel_shuffle_upsample=True # may address checkboard artifacts
336
+ ):
337
+ super().__init__()
338
+
339
+ # guide researchers
340
+
341
+ assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
342
+
343
+ if dim < 128:
344
+ print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
345
+
346
+ # save locals to take care of some hyperparameters for cascading DDPM
347
+
348
+ self._locals = locals()
349
+ self._locals.pop('self', None)
350
+ self._locals.pop('__class__', None)
351
+
352
+ # determine dimensions
353
+
354
+ self.channels = channels
355
+ self.channels_out = default(channels_out, channels)
356
+
357
+ # label embedding
358
+
359
+ self.num_classes = num_classes
360
+ self.init_emb_seg = LabelEmbedding(self.num_classes, channels_lbl)
361
+ self.init_emb_seg_lowres = LabelEmbedding(self.num_classes, channels_lbl) if lowres_cond else None
362
+
363
+ # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
364
+ # (2) in self conditioning, one appends the predict x0 (x_start)
365
+ # (3) in joint diffusion, label condition appends on image.
366
+ init_channels = (channels + channels_lbl) * (1 + int(lowres_cond) + int(self_cond)) # Joint Imagen
367
+ init_dim = default(init_dim, dim)
368
+
369
+ self.self_cond = self_cond
370
+ if self_cond:
371
+ self.self_cond_lbl_emb = LabelEmbedding(self.num_classes, channels_lbl)
372
+
373
+ # optional image conditioning
374
+
375
+ self.has_cond_image = cond_images_channels > 0
376
+ self.cond_images_channels = cond_images_channels
377
+
378
+ init_channels += cond_images_channels
379
+
380
+ # initial convolution
381
+
382
+ self.init_conv = CrossEmbedLayer(init_channels, dim_out=init_dim, kernel_sizes=init_cross_embed_kernel_sizes, stride=1) if init_cross_embed else nn.Conv2d(
383
+ init_channels, init_dim, init_conv_kernel_size, padding=init_conv_kernel_size // 2)
384
+
385
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
386
+ in_out = list(zip(dims[:-1], dims[1:]))
387
+
388
+ # time conditioning
389
+
390
+ cond_dim = default(cond_dim, dim)
391
+ time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
392
+
393
+ # embedding time for log(snr) noise from continuous version
394
+
395
+ sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
396
+ sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
397
+
398
+ self.to_time_hiddens = nn.Sequential(
399
+ sinu_pos_emb,
400
+ nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
401
+ nn.SiLU()
402
+ )
403
+
404
+ self.to_time_cond = nn.Sequential(
405
+ nn.Linear(time_cond_dim, time_cond_dim)
406
+ )
407
+
408
+ # project to time tokens as well as time hiddens
409
+
410
+ self.to_time_tokens = nn.Sequential(
411
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
412
+ Rearrange('b (r d) -> b r d', r=num_time_tokens)
413
+ )
414
+
415
+ # low res aug noise conditioning
416
+
417
+ self.lowres_cond = lowres_cond
418
+
419
+ if lowres_cond:
420
+ self.to_lowres_time_hiddens = nn.Sequential(
421
+ LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
422
+ nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
423
+ nn.SiLU()
424
+ )
425
+
426
+ self.to_lowres_time_cond = nn.Sequential(
427
+ nn.Linear(time_cond_dim, time_cond_dim)
428
+ )
429
+
430
+ self.to_lowres_time_tokens = nn.Sequential(
431
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
432
+ Rearrange('b (r d) -> b r d', r=num_time_tokens)
433
+ )
434
+
435
+ # normalizations
436
+
437
+ self.norm_cond = nn.LayerNorm(cond_dim)
438
+
439
+ # text encoding conditioning (optional)
440
+
441
+ self.text_to_cond = None
442
+
443
+ if cond_on_text:
444
+ assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
445
+ self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
446
+
447
+ # finer control over whether to condition on text encodings
448
+
449
+ self.cond_on_text = cond_on_text
450
+
451
+ # attention pooling
452
+
453
+ self.attn_pool = PerceiverResampler(dim=cond_dim, depth=2, dim_head=attn_dim_head, heads=attn_heads,
454
+ num_latents=attn_pool_num_latents, cosine_sim_attn=cosine_sim_attn) if attn_pool_text else None
455
+
456
+ # for classifier free guidance
457
+
458
+ self.max_text_len = max_text_len
459
+
460
+ self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
461
+ self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
462
+
463
+ # for non-attention based text conditioning at all points in the network where time is also conditioned
464
+
465
+ self.to_text_non_attn_cond = None
466
+
467
+ if cond_on_text:
468
+ self.to_text_non_attn_cond = nn.Sequential(
469
+ nn.LayerNorm(cond_dim),
470
+ nn.Linear(cond_dim, time_cond_dim),
471
+ nn.SiLU(),
472
+ nn.Linear(time_cond_dim, time_cond_dim)
473
+ )
474
+
475
+ # attention related params
476
+
477
+ attn_kwargs = dict(heads=attn_heads, dim_head=attn_dim_head, cosine_sim_attn=cosine_sim_attn)
478
+
479
+ num_layers = len(in_out)
480
+
481
+ # resnet block klass
482
+
483
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
484
+ resnet_groups = cast_tuple(resnet_groups, num_layers)
485
+
486
+ resnet_klass = partial(ResnetBlock, **attn_kwargs)
487
+
488
+ layer_attns = cast_tuple(layer_attns, num_layers)
489
+ layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
490
+ layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
491
+
492
+ use_linear_attn = cast_tuple(use_linear_attn, num_layers)
493
+ use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
494
+
495
+ assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
496
+
497
+ # downsample klass
498
+
499
+ downsample_klass = Downsample
500
+
501
+ if cross_embed_downsample:
502
+ downsample_klass = partial(CrossEmbedLayer, kernel_sizes=cross_embed_downsample_kernel_sizes)
503
+
504
+ # initial resnet block (for memory efficient unet)
505
+
506
+ self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim=time_cond_dim,
507
+ groups=resnet_groups[0], use_gca=use_global_context_attn) if memory_efficient else None
508
+
509
+ # scale for resnet skip connections
510
+
511
+ self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
512
+
513
+ # layers
514
+
515
+ self.downs = nn.ModuleList([])
516
+ self.ups = nn.ModuleList([])
517
+ num_resolutions = len(in_out)
518
+
519
+ layer_params = [num_resnet_blocks, resnet_groups, layer_attns,
520
+ layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
521
+ reversed_layer_params = list(map(reversed, layer_params))
522
+
523
+ # downsampling layers
524
+
525
+ skip_connect_dims = [] # keep track of skip connection dimensions
526
+
527
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
528
+ is_last = ind >= (num_resolutions - 1)
529
+
530
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
531
+
532
+ if layer_attn:
533
+ transformer_block_klass = TransformerBlock
534
+ elif layer_use_linear_attn:
535
+ transformer_block_klass = LinearAttentionTransformerBlock
536
+ else:
537
+ transformer_block_klass = Identity
538
+
539
+ current_dim = dim_in
540
+
541
+ # whether to pre-downsample, from memory efficient unet
542
+
543
+ pre_downsample = None
544
+
545
+ if memory_efficient:
546
+ pre_downsample = downsample_klass(dim_in, dim_out)
547
+ current_dim = dim_out
548
+
549
+ skip_connect_dims.append(current_dim)
550
+
551
+ # whether to do post-downsample, for non-memory efficient unet
552
+
553
+ post_downsample = None
554
+ if not memory_efficient:
555
+ post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(
556
+ nn.Conv2d(dim_in, dim_out, 3, padding=1), nn.Conv2d(dim_in, dim_out, 1))
557
+
558
+ self.downs.append(nn.ModuleList([
559
+ pre_downsample,
560
+ resnet_klass(current_dim, current_dim, cond_dim=layer_cond_dim,
561
+ linear_attn=layer_use_linear_cross_attn, time_cond_dim=time_cond_dim, groups=groups),
562
+ nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim=time_cond_dim,
563
+ groups=groups, use_gca=use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
564
+ transformer_block_klass(dim=current_dim, depth=layer_attn_depth,
565
+ ff_mult=ff_mult, context_dim=cond_dim, **attn_kwargs),
566
+ post_downsample
567
+ ]))
568
+
569
+ # middle layers
570
+
571
+ mid_dim = dims[-1]
572
+
573
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim=cond_dim,
574
+ time_cond_dim=time_cond_dim, groups=resnet_groups[-1])
575
+ self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(
576
+ Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
577
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim=cond_dim,
578
+ time_cond_dim=time_cond_dim, groups=resnet_groups[-1])
579
+
580
+ # upsample klass
581
+
582
+ upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
583
+
584
+ # upsampling layers
585
+
586
+ upsample_fmap_dims = []
587
+
588
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
589
+ is_last = ind == (len(in_out) - 1)
590
+
591
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
592
+
593
+ if layer_attn:
594
+ transformer_block_klass = TransformerBlock
595
+ elif layer_use_linear_attn:
596
+ transformer_block_klass = LinearAttentionTransformerBlock
597
+ else:
598
+ transformer_block_klass = Identity
599
+
600
+ skip_connect_dim = skip_connect_dims.pop()
601
+
602
+ upsample_fmap_dims.append(dim_out)
603
+
604
+ self.ups.append(nn.ModuleList([
605
+ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim=layer_cond_dim,
606
+ linear_attn=layer_use_linear_cross_attn, time_cond_dim=time_cond_dim, groups=groups),
607
+ nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim=time_cond_dim,
608
+ groups=groups, use_gca=use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
609
+ transformer_block_klass(dim=dim_out, depth=layer_attn_depth, ff_mult=ff_mult,
610
+ context_dim=cond_dim, **attn_kwargs),
611
+ upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity(),
612
+ ]))
613
+
614
+ # whether to combine feature maps from all upsample blocks before final resnet block out
615
+
616
+ self.upsample_combiner = UpsampleCombiner(
617
+ dim=dim,
618
+ enabled=combine_upsample_fmaps,
619
+ dim_ins=upsample_fmap_dims,
620
+ dim_outs=dim
621
+ )
622
+
623
+ # whether to do a final residual from initial conv to the final resnet block out
624
+
625
+ self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
626
+ final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
627
+
628
+ # final optional resnet block and convolution out
629
+
630
+ self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim=time_cond_dim,
631
+ groups=resnet_groups[0], use_gca=True) if final_resnet_block else None
632
+
633
+ final_conv_dim_in = dim if final_resnet_block else final_conv_dim
634
+ final_conv_dim_in += (channels + channels_lbl) if lowres_cond else 0
635
+
636
+ self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out,
637
+ final_conv_kernel_size, padding=final_conv_kernel_size // 2)
638
+ self.final_conv_seg = nn.Conv2d(final_conv_dim_in, self.num_classes,
639
+ final_conv_kernel_size, padding=final_conv_kernel_size // 2)
640
+
641
+ zero_init_(self.final_conv)
642
+ zero_init_(self.final_conv_seg)
643
+
644
+ # if the current settings for the unet are not correct
645
+ # for cascading DDPM, then reinit the unet with the right settings
646
+ def cast_model_parameters(
647
+ self,
648
+ *,
649
+ lowres_cond,
650
+ text_embed_dim,
651
+ channels,
652
+ channels_out,
653
+ cond_on_text
654
+ ):
655
+ if lowres_cond == self.lowres_cond and \
656
+ channels == self.channels and \
657
+ cond_on_text == self.cond_on_text and \
658
+ text_embed_dim == self._locals['text_embed_dim'] and \
659
+ channels_out == self.channels_out:
660
+ return self
661
+
662
+ updated_kwargs = dict(
663
+ lowres_cond=lowres_cond,
664
+ text_embed_dim=text_embed_dim,
665
+ channels=channels,
666
+ channels_out=channels_out,
667
+ cond_on_text=cond_on_text
668
+ )
669
+
670
+ return self.__class__(**{**self._locals, **updated_kwargs})
671
+
672
+ # methods for returning the full unet config as well as its parameter state
673
+
674
+ def to_config_and_state_dict(self):
675
+ return self._locals, self.state_dict()
676
+
677
+ # class method for rehydrating the unet from its config and state dict
678
+
679
+ @classmethod
680
+ def from_config_and_state_dict(klass, config, state_dict):
681
+ unet = klass(**config)
682
+ unet.load_state_dict(state_dict)
683
+ return unet
684
+
685
+ # methods for persisting unet to disk
686
+
687
+ def persist_to_file(self, path):
688
+ path = Path(path)
689
+ path.parents[0].mkdir(exist_ok=True, parents=True)
690
+
691
+ config, state_dict = self.to_config_and_state_dict()
692
+ pkg = dict(config=config, state_dict=state_dict)
693
+ torch.save(pkg, str(path))
694
+
695
+ # class method for rehydrating the unet from file saved with `persist_to_file`
696
+
697
+ @classmethod
698
+ def hydrate_from_file(klass, path):
699
+ path = Path(path)
700
+ assert path.exists()
701
+ pkg = torch.load(str(path))
702
+
703
+ assert 'config' in pkg and 'state_dict' in pkg
704
+ config, state_dict = pkg['config'], pkg['state_dict']
705
+
706
+ return JointUnet.from_config_and_state_dict(config, state_dict)
707
+
708
+ # forward with classifier free guidance
709
+
710
+ def forward_with_cond_scale(
711
+ self,
712
+ *args,
713
+ cond_scale=1.,
714
+ **kwargs
715
+ ):
716
+ logits, logits_seg = self.forward(*args, **kwargs)
717
+
718
+ if cond_scale == 1:
719
+ return logits, logits_seg
720
+
721
+ null_logits, null_logits_seg = self.forward(*args, cond_drop_prob=1., **kwargs)
722
+
723
+ cond_logits = null_logits + (logits - null_logits) * cond_scale
724
+
725
+ # TODO: CFG of categorical is not clear.
726
+ cond_logits_seg = null_logits_seg + (logits_seg - null_logits_seg) * cond_scale
727
+
728
+ return cond_logits, cond_logits_seg
729
+
730
+ def forward(
731
+ self,
732
+ x,
733
+ lbl,
734
+ time,
735
+ *,
736
+ lowres_cond_img=None,
737
+ lowres_cond_lbl=None,
738
+ lowres_noise_times=None,
739
+ text_embeds=None,
740
+ text_mask=None,
741
+ cond_images=None,
742
+ self_cond=None,
743
+ self_cond_lbl=None,
744
+ cond_drop_prob=0.
745
+ ):
746
+ batch_size, device = x.shape[0], x.device
747
+
748
+ # joint imagen
749
+
750
+ lbl = self.init_emb_seg(lbl.long())
751
+ x = torch.cat((x, lbl), dim=1)
752
+
753
+ # condition on self
754
+
755
+ if self.self_cond:
756
+ self_cond = default(self_cond, lambda: torch.zeros_like(x))
757
+ if self_cond_lbl is None:
758
+ self_cond_lbl = torch.zeros_like(lbl)
759
+ else:
760
+ self_cond_lbl = self.self_cond_lbl_emb(self_cond_lbl.long())
761
+ x = torch.cat((x, self_cond, self_cond_lbl), dim=1)
762
+
763
+ # add low resolution conditioning, if present
764
+
765
+ assert not (self.lowres_cond and not exists(lowres_cond_img)
766
+ ), 'low resolution conditioning image must be present'
767
+ assert not (self.lowres_cond and not exists(lowres_noise_times)
768
+ ), 'low resolution conditioning noise time must be present'
769
+
770
+ if exists(lowres_cond_img) and exists(lowres_cond_lbl):
771
+ lowres_cond_lbl = self.init_emb_seg_lowres(lowres_cond_lbl.long())
772
+ x = torch.cat((x, lowres_cond_img, lowres_cond_lbl), dim=1)
773
+
774
+ # condition on input image
775
+
776
+ assert not (self.has_cond_image ^ exists(cond_images)), \
777
+ 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
778
+
779
+ if exists(cond_images):
780
+ assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
781
+ cond_images = resize_image_to(cond_images, x.shape[-1])
782
+ x = torch.cat((cond_images, x), dim=1)
783
+
784
+ # initial convolution
785
+
786
+ x = self.init_conv(x)
787
+
788
+ # init conv residual
789
+
790
+ if self.init_conv_to_final_conv_residual:
791
+ init_conv_residual = x.clone()
792
+
793
+ # time conditioning
794
+
795
+ time_hiddens = self.to_time_hiddens(time)
796
+
797
+ # derive time tokens
798
+
799
+ time_tokens = self.to_time_tokens(time_hiddens)
800
+ t = self.to_time_cond(time_hiddens)
801
+
802
+ # add lowres time conditioning to time hiddens
803
+ # and add lowres time tokens along sequence dimension for attention
804
+
805
+ if self.lowres_cond:
806
+ lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
807
+ lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
808
+ lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
809
+
810
+ t = t + lowres_t
811
+ time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim=-2)
812
+
813
+ # text conditioning
814
+
815
+ text_tokens = None
816
+
817
+ if exists(text_embeds) and self.cond_on_text:
818
+
819
+ # conditional dropout
820
+
821
+ text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
822
+
823
+ text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
824
+ text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
825
+
826
+ # calculate text embeds
827
+
828
+ text_tokens = self.text_to_cond(text_embeds)
829
+
830
+ text_tokens = text_tokens[:, :self.max_text_len]
831
+
832
+ if exists(text_mask):
833
+ text_mask = text_mask[:, :self.max_text_len]
834
+
835
+ text_tokens_len = text_tokens.shape[1]
836
+ remainder = self.max_text_len - text_tokens_len
837
+
838
+ if remainder > 0:
839
+ text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
840
+
841
+ if exists(text_mask):
842
+ if remainder > 0:
843
+ text_mask = F.pad(text_mask, (0, remainder), value=False)
844
+
845
+ text_mask = rearrange(text_mask, 'b n -> b n 1')
846
+ text_keep_mask_embed = text_mask & text_keep_mask_embed
847
+
848
+ null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
849
+
850
+ text_tokens = torch.where(
851
+ text_keep_mask_embed,
852
+ text_tokens,
853
+ null_text_embed
854
+ )
855
+
856
+ if exists(self.attn_pool):
857
+ text_tokens = self.attn_pool(text_tokens)
858
+
859
+ # extra non-attention conditioning by projecting and then summing text embeddings to time
860
+ # termed as text hiddens
861
+
862
+ mean_pooled_text_tokens = text_tokens.mean(dim=-2)
863
+
864
+ text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
865
+
866
+ null_text_hidden = self.null_text_hidden.to(t.dtype)
867
+
868
+ text_hiddens = torch.where(
869
+ text_keep_mask_hidden,
870
+ text_hiddens,
871
+ null_text_hidden
872
+ )
873
+
874
+ t = t + text_hiddens
875
+
876
+ # main conditioning tokens (c)
877
+
878
+ c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim=-2)
879
+
880
+ # normalize conditioning tokens
881
+
882
+ c = self.norm_cond(c)
883
+
884
+ # initial resnet block (for memory efficient unet)
885
+
886
+ if exists(self.init_resnet_block):
887
+ x = self.init_resnet_block(x, t)
888
+
889
+ # go through the layers of the unet, down and up
890
+
891
+ hiddens = []
892
+
893
+ for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
894
+ if exists(pre_downsample):
895
+ x = pre_downsample(x)
896
+
897
+ x = init_block(x, t, c)
898
+
899
+ for resnet_block in resnet_blocks:
900
+ x = resnet_block(x, t)
901
+ hiddens.append(x)
902
+
903
+ x = attn_block(x, c)
904
+ hiddens.append(x)
905
+
906
+ if exists(post_downsample):
907
+ x = post_downsample(x)
908
+
909
+ x = self.mid_block1(x, t, c)
910
+
911
+ if exists(self.mid_attn):
912
+ x = self.mid_attn(x)
913
+
914
+ x = self.mid_block2(x, t, c)
915
+
916
+ def add_skip_connection(x): return torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim=1)
917
+
918
+ up_hiddens = []
919
+
920
+ for init_block, resnet_blocks, attn_block, upsample in self.ups:
921
+ x = add_skip_connection(x)
922
+ x = init_block(x, t, c)
923
+
924
+ for resnet_block in resnet_blocks:
925
+ x = add_skip_connection(x)
926
+ x = resnet_block(x, t)
927
+
928
+ x = attn_block(x, c)
929
+ up_hiddens.append(x.contiguous())
930
+
931
+ x = upsample(x)
932
+
933
+ # whether to combine all feature maps from upsample blocks
934
+
935
+ x = self.upsample_combiner(x, up_hiddens)
936
+
937
+ # final top-most residual if needed
938
+
939
+ if self.init_conv_to_final_conv_residual:
940
+ x = torch.cat((x, init_conv_residual), dim=1)
941
+
942
+ if exists(self.final_res_block):
943
+ x = self.final_res_block(x, t)
944
+
945
+ if exists(lowres_cond_img) and exists(lowres_cond_lbl):
946
+ x = torch.cat((x, lowres_cond_img, lowres_cond_lbl), dim=1)
947
+
948
+ return self.final_conv(x), self.final_conv_seg(x)
949
+
950
+ # predefined unets, with configs lining up with hyperparameters in appendix of paper
951
+
952
+
953
+ class BaseJointUnet(JointUnet):
954
+ def __init__(self, *args, **kwargs):
955
+ default_kwargs = dict(
956
+ dim=128,
957
+ dim_mults=(1, 2, 4, 8),
958
+ num_resnet_blocks=(2, 4, 8, 8),
959
+ layer_attns=(False, False, False, True),
960
+ layer_cross_attns=(False, False, False, True),
961
+ attn_heads=8,
962
+ ff_mult=2.,
963
+ memory_efficient=True
964
+ )
965
+ super().__init__(*args, **{**default_kwargs, **kwargs})
966
+
967
+
968
+ class SRJointUnet(JointUnet):
969
+ def __init__(self, *args, **kwargs):
970
+ default_kwargs = dict(
971
+ dim=128,
972
+ dim_mults=(1, 2, 4, 8),
973
+ num_resnet_blocks=(2, 4, 8, 8),
974
+ layer_attns=False,
975
+ layer_cross_attns=(False, False, False, True),
976
+ attn_heads=8,
977
+ ff_mult=2.,
978
+ memory_efficient=True
979
+ )
980
+ super().__init__(*args, **{**default_kwargs, **kwargs})
981
+
982
+ # main imagen ddpm class, which is a cascading DDPM from Ho et al.
983
+
984
+
985
+ class JointImagen(nn.Module):
986
+ def __init__(
987
+ self,
988
+ unets,
989
+ *,
990
+ image_sizes, # for cascading ddpm, image size at each stage
991
+ num_classes,
992
+ text_encoder_name=DEFAULT_T5_NAME,
993
+ text_embed_dim=None,
994
+ channels=3,
995
+ timesteps=1000,
996
+ sample_timesteps=100,
997
+ cond_drop_prob=0.1,
998
+ loss_type='l2',
999
+ noise_schedules='cosine',
1000
+ noise_schedules_lbl='cosine_p',
1001
+ cosine_p_lbl=1.0,
1002
+ pred_objectives='noise',
1003
+ random_crop_sizes=None,
1004
+ lowres_noise_schedule='linear',
1005
+ # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
1006
+ lowres_sample_noise_level=0.2,
1007
+ # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
1008
+ per_sample_random_aug_noise_level=False,
1009
+ lowres_max_thres=0.999,
1010
+ condition_on_text=True,
1011
+ # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
1012
+ auto_normalize_img=True,
1013
+ # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
1014
+ p2_loss_weight_gamma=0.5,
1015
+ p2_loss_weight_k=1,
1016
+ dynamic_thresholding=True,
1017
+ dynamic_thresholding_percentile=0.95, # unsure what this was based on perusal of paper
1018
+ only_train_unet_number=None,
1019
+ ):
1020
+ super().__init__()
1021
+
1022
+ # joint
1023
+
1024
+ self.num_classes = num_classes
1025
+
1026
+ # loss
1027
+
1028
+ if loss_type == 'l1':
1029
+ loss_fn = F.l1_loss
1030
+ elif loss_type == 'l2':
1031
+ loss_fn = F.mse_loss
1032
+ elif loss_type == 'huber':
1033
+ loss_fn = F.smooth_l1_loss
1034
+ else:
1035
+ raise NotImplementedError()
1036
+
1037
+ self.loss_type = loss_type
1038
+ self.loss_fn = loss_fn
1039
+
1040
+ # conditioning hparams
1041
+
1042
+ self.condition_on_text = condition_on_text
1043
+ self.unconditional = not condition_on_text
1044
+
1045
+ # channels
1046
+
1047
+ self.channels = channels
1048
+
1049
+ # automatically take care of ensuring that first unet is unconditional
1050
+ # while the rest of the unets are conditioned on the low resolution image produced by previous unet
1051
+
1052
+ unets = cast_tuple(unets)
1053
+ num_unets = len(unets)
1054
+
1055
+ # determine noise schedules per unet
1056
+
1057
+ timesteps = cast_tuple(timesteps, num_unets)
1058
+ sample_timesteps = cast_tuple(sample_timesteps, num_unets)
1059
+
1060
+ # make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets
1061
+
1062
+ noise_schedules = cast_tuple(noise_schedules)
1063
+ noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
1064
+ noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
1065
+ noise_schedules_lbl = cast_tuple(noise_schedules_lbl)
1066
+ noise_schedules_lbl = pad_tuple_to_length(noise_schedules_lbl, 2, 'cosine_p')
1067
+ noise_schedules_lbl = pad_tuple_to_length(noise_schedules_lbl, num_unets, 'linear')
1068
+
1069
+ # construct noise schedulers
1070
+
1071
+ noise_scheduler_klass = GaussianDiffusionContinuousTimes
1072
+ noise_scheduler_lbl_klass = MultinomialDiffusion
1073
+ self.noise_schedulers = nn.ModuleList([])
1074
+ self.noise_schedulers_lbl = nn.ModuleList([])
1075
+
1076
+ for timestep, noise_schedule, noise_schedule_lbl in zip(timesteps, noise_schedules, noise_schedules_lbl):
1077
+ noise_scheduler = noise_scheduler_klass(noise_schedule=noise_schedule, timesteps=timestep)
1078
+ self.noise_schedulers.append(noise_scheduler)
1079
+ noise_scheduler_lbl = noise_scheduler_lbl_klass(
1080
+ num_classes, noise_schedule=noise_schedule_lbl, timesteps=timestep, p=cosine_p_lbl)
1081
+ self.noise_schedulers_lbl.append(noise_scheduler_lbl)
1082
+
1083
+ self.noise_schedulers_sample = nn.ModuleList([])
1084
+ self.noise_schedulers_lbl_sample = nn.ModuleList([])
1085
+
1086
+ for sample_timestep, noise_schedule, noise_schedule_lbl in zip(sample_timesteps, noise_schedules, noise_schedules_lbl):
1087
+ noise_scheduler_sample = noise_scheduler_klass(noise_schedule=noise_schedule, timesteps=sample_timestep)
1088
+ self.noise_schedulers_sample.append(noise_scheduler_sample)
1089
+ noise_scheduler_lbl_sample = noise_scheduler_lbl_klass(
1090
+ num_classes, noise_schedule=noise_schedule_lbl, timesteps=sample_timestep, p=cosine_p_lbl)
1091
+ self.noise_schedulers_lbl_sample.append(noise_scheduler_lbl_sample)
1092
+
1093
+ # randomly cropping for upsampler training
1094
+
1095
+ self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
1096
+ assert all(map(lambda x: x is None or (isinstance(x, (tuple, list)) and len(x) == 2), self.random_crop_sizes))
1097
+ assert not exists(first(self.random_crop_sizes)), \
1098
+ 'you should not need to randomly crop image during training for base unet, only for upsamplers '\
1099
+ '- so pass in `random_crop_sizes = (None, 128, 256)` as example'
1100
+
1101
+ # lowres augmentation noise schedule
1102
+
1103
+ self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule=lowres_noise_schedule)
1104
+ self.lowres_noise_schedule_lbl = MultinomialDiffusion(
1105
+ num_classes, noise_schedule=lowres_noise_schedule, p=cosine_p_lbl)
1106
+
1107
+ # ddpm objectives - predicting noise by default
1108
+
1109
+ self.pred_objectives = cast_tuple(pred_objectives, num_unets)
1110
+
1111
+ # get text encoder
1112
+
1113
+ self.text_encoder_name = text_encoder_name
1114
+ self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
1115
+
1116
+ self.encode_text = partial(t5_encode_text, name=text_encoder_name)
1117
+
1118
+ # construct unets
1119
+
1120
+ self.unets = nn.ModuleList([])
1121
+
1122
+ self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
1123
+ self.only_train_unet_number = only_train_unet_number
1124
+
1125
+ for ind, one_unet in enumerate(unets):
1126
+ assert isinstance(one_unet, (JointUnet, Unet3D, NullUnet))
1127
+ is_first = ind == 0
1128
+
1129
+ one_unet = one_unet.cast_model_parameters(
1130
+ lowres_cond=not is_first,
1131
+ cond_on_text=self.condition_on_text,
1132
+ text_embed_dim=self.text_embed_dim if self.condition_on_text else None,
1133
+ channels=self.channels,
1134
+ channels_out=self.channels
1135
+ )
1136
+
1137
+ self.unets.append(one_unet)
1138
+
1139
+ # unet image sizes
1140
+
1141
+ self.image_sizes = cast_tuple(image_sizes)
1142
+ assert all(map(lambda x: isinstance(x, (tuple, list)) and len(x) == 2, self.image_sizes))
1143
+
1144
+ assert num_unets == len(self.image_sizes), \
1145
+ f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {self.image_sizes}'
1146
+
1147
+ self.sample_channels = cast_tuple(self.channels, num_unets)
1148
+
1149
+ # determine whether we are training on images or video
1150
+
1151
+ is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
1152
+ self.is_video = is_video
1153
+
1154
+ self.right_pad_dims_to_datatype = partial(rearrange, pattern=(
1155
+ 'b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
1156
+ self.resize_to = resize_video_to if is_video else resize_image_to
1157
+
1158
+ # cascading ddpm related stuff
1159
+
1160
+ lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
1161
+ assert lowres_conditions == (False, *((True,) * (num_unets - 1))), \
1162
+ 'the first unet must be unconditioned (by low resolution image), ' \
1163
+ 'and the rest of the unets must have `lowres_cond` set to True'
1164
+
1165
+ self.lowres_sample_noise_level = lowres_sample_noise_level
1166
+ self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
1167
+ self.lowres_max_thres = lowres_max_thres
1168
+
1169
+ # classifier free guidance
1170
+
1171
+ self.cond_drop_prob = cond_drop_prob
1172
+ self.can_classifier_guidance = cond_drop_prob > 0.
1173
+
1174
+ # normalize and unnormalize image functions
1175
+
1176
+ self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
1177
+ self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
1178
+ self.input_image_range = (0. if auto_normalize_img else -1., 1.)
1179
+
1180
+ # dynamic thresholding
1181
+
1182
+ self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
1183
+ self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
1184
+
1185
+ # p2 loss weight
1186
+
1187
+ self.p2_loss_weight_k = p2_loss_weight_k
1188
+ self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
1189
+
1190
+ assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), \
1191
+ 'in paper, they noticed any gamma greater than 2 is harmful'
1192
+
1193
+ # one temp parameter for keeping track of device
1194
+
1195
+ self.register_buffer('_temp', torch.tensor([0.]), persistent=False)
1196
+
1197
+ # default to device of unets passed in
1198
+
1199
+ self.to(next(self.unets.parameters()).device)
1200
+
1201
+ def force_unconditional_(self):
1202
+ self.condition_on_text = False
1203
+ self.unconditional = True
1204
+
1205
+ for unet in self.unets:
1206
+ unet.cond_on_text = False
1207
+
1208
+ @property
1209
+ def device(self):
1210
+ return self._temp.device
1211
+
1212
+ def get_unet(self, unet_number):
1213
+ assert 0 < unet_number <= len(self.unets)
1214
+ index = unet_number - 1
1215
+
1216
+ if isinstance(self.unets, nn.ModuleList):
1217
+ unets_list = [unet for unet in self.unets]
1218
+ delattr(self, 'unets')
1219
+ self.unets = unets_list
1220
+
1221
+ if index != self.unet_being_trained_index:
1222
+ for unet_index, unet in enumerate(self.unets):
1223
+ unet.to(self.device if unet_index == index else 'cpu')
1224
+
1225
+ self.unet_being_trained_index = index
1226
+ return self.unets[index]
1227
+
1228
+ def reset_unets_all_one_device(self, device=None):
1229
+ device = default(device, self.device)
1230
+ self.unets = nn.ModuleList([*self.unets])
1231
+ self.unets.to(device)
1232
+
1233
+ self.unet_being_trained_index = -1
1234
+
1235
+ @contextmanager
1236
+ def one_unet_in_gpu(self, unet_number=None, unet=None):
1237
+ assert exists(unet_number) ^ exists(unet)
1238
+
1239
+ if exists(unet_number):
1240
+ unet = self.unets[unet_number - 1]
1241
+
1242
+ devices = [module_device(unet) for unet in self.unets]
1243
+ self.unets.cpu()
1244
+ unet.to(self.device)
1245
+
1246
+ yield
1247
+
1248
+ for unet, device in zip(self.unets, devices):
1249
+ unet.to(device)
1250
+
1251
+ # overriding state dict functions
1252
+
1253
+ def state_dict(self, *args, **kwargs):
1254
+ self.reset_unets_all_one_device()
1255
+ return super().state_dict(*args, **kwargs)
1256
+
1257
+ def load_state_dict(self, *args, **kwargs):
1258
+ self.reset_unets_all_one_device()
1259
+ return super().load_state_dict(*args, **kwargs)
1260
+
1261
+ # gaussian diffusion methods
1262
+
1263
+ def p_mean_variance(
1264
+ self,
1265
+ unet: JointUnet,
1266
+ x,
1267
+ log_lbl,
1268
+ t,
1269
+ *,
1270
+ noise_scheduler: GaussianDiffusionContinuousTimes,
1271
+ noise_scheduler_lbl: MultinomialDiffusion,
1272
+ text_embeds=None,
1273
+ text_mask=None,
1274
+ cond_images=None,
1275
+ lowres_cond_img=None,
1276
+ lowres_cond_lbl=None,
1277
+ self_cond=None,
1278
+ self_cond_lbl=None,
1279
+ lowres_noise_times=None,
1280
+ cond_scale=1.,
1281
+ model_output=None,
1282
+ t_next=None,
1283
+ pred_objective='noise',
1284
+ dynamic_threshold=True,
1285
+ ):
1286
+ assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
1287
+ lbl = log_onehot_to_index(log_lbl)
1288
+ pred, pred_lbl = default(model_output, lambda: unet.forward_with_cond_scale(
1289
+ x, lbl, noise_scheduler.get_condition(t),
1290
+ text_embeds=text_embeds, text_mask=text_mask,
1291
+ cond_images=cond_images, cond_scale=cond_scale,
1292
+ lowres_cond_img=lowres_cond_img, lowres_cond_lbl=lowres_cond_lbl,
1293
+ self_cond=self_cond, self_cond_lbl=self_cond_lbl,
1294
+ lowres_noise_times=self.lowres_noise_schedule.get_condition(lowres_noise_times)))
1295
+ pred_lbl = F.log_softmax(pred_lbl, dim=1)
1296
+ pred_lbl = noise_scheduler_lbl.q_posterior(pred_lbl, log_lbl, t)
1297
+
1298
+ if pred_objective == 'noise':
1299
+ x_start = noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
1300
+ # lbl_start = noise_scheduler_lbl.predict_start_from_noise(log_lbl, t=t, noise=pred_lbl) # TODO ???
1301
+ elif pred_objective == 'x_start':
1302
+ x_start = pred
1303
+ # lbl_start = pred_lbl
1304
+ else:
1305
+ raise ValueError(f'unknown objective {pred_objective}')
1306
+ lbl_start = None
1307
+
1308
+ if dynamic_threshold:
1309
+ # following pseudocode in appendix
1310
+ # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
1311
+ s = torch.quantile(
1312
+ rearrange(x_start, 'b ... -> b (...)').abs(),
1313
+ self.dynamic_thresholding_percentile,
1314
+ dim=-1
1315
+ )
1316
+
1317
+ s.clamp_(min=1.)
1318
+ s = right_pad_dims_to(x_start, s)
1319
+ x_start = x_start.clamp(-s, s) / s
1320
+ else:
1321
+ x_start.clamp_(-1., 1.)
1322
+
1323
+ mean_and_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next)
1324
+ log_lbl = noise_scheduler_lbl.log_sample_categorical(pred_lbl)
1325
+ return mean_and_variance, log_lbl, x_start, lbl_start
1326
+
1327
+ @torch.no_grad()
1328
+ def p_sample(
1329
+ self,
1330
+ unet,
1331
+ x,
1332
+ log_lbl,
1333
+ t,
1334
+ *,
1335
+ noise_scheduler,
1336
+ noise_scheduler_lbl,
1337
+ t_next=None,
1338
+ text_embeds=None,
1339
+ text_mask=None,
1340
+ cond_images=None,
1341
+ cond_scale=1.,
1342
+ self_cond=None,
1343
+ self_cond_lbl=None,
1344
+ lowres_cond_img=None,
1345
+ lowres_cond_lbl=None,
1346
+ lowres_noise_times=None,
1347
+ pred_objective='noise',
1348
+ dynamic_threshold=True,
1349
+ ):
1350
+ b, *_, device = *x.shape, x.device
1351
+ (model_mean, _, model_log_variance), pred_lbl, x_start, lbl_start = self.p_mean_variance(
1352
+ unet, x=x, log_lbl=log_lbl, t=t, t_next=t_next,
1353
+ noise_scheduler=noise_scheduler, noise_scheduler_lbl=noise_scheduler_lbl,
1354
+ text_embeds=text_embeds, text_mask=text_mask,
1355
+ cond_images=cond_images, cond_scale=cond_scale,
1356
+ lowres_cond_img=lowres_cond_img, lowres_cond_lbl=lowres_cond_lbl,
1357
+ self_cond=self_cond, self_cond_lbl=self_cond_lbl,
1358
+ lowres_noise_times=lowres_noise_times,
1359
+ pred_objective=pred_objective, dynamic_threshold=dynamic_threshold)
1360
+ noise = torch.randn_like(x)
1361
+ # no noise when t == 0
1362
+ is_last_sampling_timestep = (t_next == 0) if isinstance(
1363
+ noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
1364
+ nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1365
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1366
+ return pred, pred_lbl, x_start, lbl_start
1367
+
1368
+ @torch.no_grad()
1369
+ def p_sample_loop(
1370
+ self,
1371
+ unet,
1372
+ shape,
1373
+ *,
1374
+ noise_scheduler: GaussianDiffusionContinuousTimes,
1375
+ noise_scheduler_lbl: MultinomialDiffusion,
1376
+ lowres_cond_img=None,
1377
+ lowres_cond_lbl=None,
1378
+ lowres_noise_times=None,
1379
+ text_embeds=None,
1380
+ text_mask=None,
1381
+ cond_images=None,
1382
+ inpaint_images=None,
1383
+ inpaint_labels=None,
1384
+ inpaint_masks=None,
1385
+ inpaint_resample_times=5,
1386
+ init_images=None,
1387
+ init_labels=None,
1388
+ skip_steps=None,
1389
+ cond_scale=1,
1390
+ pred_objective='noise',
1391
+ dynamic_threshold=True,
1392
+ use_tqdm=True
1393
+ ):
1394
+ assert init_labels is None, 'not implemented yet'
1395
+ device = self.device
1396
+
1397
+ batch, _, h, w = shape
1398
+ img = torch.randn(shape, device=device)
1399
+ uniform_logits = torch.zeros((batch, self.num_classes) + (h, w), device=device)
1400
+ log_lbl = noise_scheduler_lbl.log_sample_categorical(uniform_logits)
1401
+
1402
+ # for initialization with an image or video
1403
+
1404
+ if exists(init_images):
1405
+ img += init_images
1406
+ # TODO init_labels
1407
+
1408
+ # keep track of x0, for self conditioning
1409
+
1410
+ x_start = None
1411
+ lbl_start = None
1412
+
1413
+ # prepare inpainting
1414
+
1415
+ has_inpainting = exists(inpaint_images) and exists(inpaint_labels) and exists(inpaint_masks)
1416
+ resample_times = inpaint_resample_times if has_inpainting else 1
1417
+
1418
+ if has_inpainting:
1419
+ assert inpaint_masks.shape[1] == 2, \
1420
+ f'inpaint mask is a tuple of (mask_image, mask_label) but now:\n{inpaint_labels}'
1421
+ inpaint_images = self.normalize_img(inpaint_images)
1422
+ inpaint_images = self.resize_to(inpaint_images, shape[-2:])
1423
+
1424
+ log_inpaint_labels = index_to_log_onehot(inpaint_labels.long(), self.num_classes)
1425
+ log_inpaint_labels = self.resize_to(log_inpaint_labels, shape[-2:])
1426
+
1427
+ inpaint_masks_image = self.resize_to(inpaint_masks[:, [0]], shape[-2:]).bool()
1428
+ inpaint_masks_label = self.resize_to(inpaint_masks[:, [1]], shape[-2:]).bool()
1429
+
1430
+ # time
1431
+
1432
+ timesteps = noise_scheduler.get_sampling_timesteps(batch, device=device)
1433
+ timesteps = [t * (t < 1.) + (1 - 1e-7) * (t >= 1.) for t in timesteps]
1434
+
1435
+ # whether to skip any steps
1436
+
1437
+ skip_steps = default(skip_steps, 0)
1438
+ timesteps = timesteps[skip_steps:]
1439
+
1440
+ for times, times_next in tqdm(timesteps, desc='sampling loop time step', total=len(timesteps), disable=not use_tqdm):
1441
+ is_last_timestep = times_next == 0
1442
+
1443
+ for r in reversed(range(resample_times)):
1444
+ is_last_resample_step = r == 0
1445
+
1446
+ if has_inpainting:
1447
+ noised_inpaint_images, _ = noise_scheduler.q_sample(inpaint_images, t=times)
1448
+ img = img * ~inpaint_masks_image + noised_inpaint_images * inpaint_masks_image
1449
+ log_noised_inpaint_labels = noise_scheduler_lbl.q_sample(log_inpaint_labels, t=times)
1450
+ log_lbl = log_lbl * ~inpaint_masks_label + log_noised_inpaint_labels * inpaint_masks_label
1451
+
1452
+ self_cond = x_start if unet.self_cond else None
1453
+ self_cond_lbl = lbl_start if unet.self_cond else None
1454
+
1455
+ img, log_lbl, x_start, lbl_start = self.p_sample(
1456
+ unet,
1457
+ img,
1458
+ log_lbl,
1459
+ times,
1460
+ t_next=times_next,
1461
+ text_embeds=text_embeds,
1462
+ text_mask=text_mask,
1463
+ cond_images=cond_images,
1464
+ cond_scale=cond_scale,
1465
+ self_cond=self_cond,
1466
+ self_cond_lbl=self_cond_lbl,
1467
+ lowres_cond_img=lowres_cond_img,
1468
+ lowres_cond_lbl=lowres_cond_lbl,
1469
+ lowres_noise_times=lowres_noise_times,
1470
+ noise_scheduler=noise_scheduler,
1471
+ noise_scheduler_lbl=noise_scheduler_lbl,
1472
+ pred_objective=pred_objective,
1473
+ dynamic_threshold=dynamic_threshold,
1474
+ )
1475
+
1476
+ if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
1477
+ renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
1478
+ img = torch.where(
1479
+ self.right_pad_dims_to_datatype(is_last_timestep),
1480
+ img,
1481
+ renoised_img
1482
+ )
1483
+ renoised_log_lbl = noise_scheduler_lbl.q_sample_from_to(log_lbl, times_next, times)
1484
+ log_lbl = torch.where(
1485
+ self.right_pad_dims_to_datatype(is_last_timestep),
1486
+ log_lbl,
1487
+ renoised_log_lbl
1488
+ )
1489
+
1490
+ img.clamp_(-1., 1.)
1491
+
1492
+ # final inpainting
1493
+
1494
+ if has_inpainting:
1495
+ img = img * ~inpaint_masks_image + inpaint_images * inpaint_masks_image
1496
+ log_lbl = log_lbl * ~inpaint_masks_label + log_inpaint_labels * inpaint_masks_label
1497
+
1498
+ unnormalize_img = self.unnormalize_img(img)
1499
+ lbl = log_onehot_to_index(log_lbl)
1500
+ return unnormalize_img, lbl
1501
+
1502
+ @torch.no_grad()
1503
+ @eval_decorator
1504
+ def sample(
1505
+ self,
1506
+ texts: List[str] = None,
1507
+ text_masks=None,
1508
+ text_embeds=None,
1509
+ video_frames=None,
1510
+ cond_images=None,
1511
+ inpaint_images=None,
1512
+ inpaint_labels=None,
1513
+ inpaint_masks=None,
1514
+ inpaint_resample_times=5,
1515
+ init_images=None,
1516
+ init_labels=None,
1517
+ skip_steps=None,
1518
+ batch_size=1,
1519
+ cond_scale=1.,
1520
+ lowres_sample_noise_level=None,
1521
+ start_at_unet_number=1,
1522
+ start_image_or_video=None,
1523
+ start_label_or_video=None,
1524
+ stop_at_unet_number=None,
1525
+ return_all_unet_outputs=False,
1526
+ return_pil_images=False,
1527
+ device=None,
1528
+ use_tqdm=True
1529
+ ):
1530
+ device = default(device, self.device)
1531
+ self.reset_unets_all_one_device(device=device)
1532
+
1533
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
1534
+
1535
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
1536
+ assert all([*map(len, texts)]), 'text cannot be empty'
1537
+
1538
+ with autocast(enabled=False):
1539
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask=True)
1540
+
1541
+ text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
1542
+
1543
+ if not self.unconditional:
1544
+ assert exists(text_embeds), \
1545
+ 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
1546
+
1547
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim=-1))
1548
+ batch_size = text_embeds.shape[0]
1549
+
1550
+ if exists(inpaint_images) and exists(inpaint_labels):
1551
+ if self.unconditional:
1552
+ if batch_size == 1: # assume researcher wants to broadcast along inpainted images
1553
+ batch_size = inpaint_images.shape[0]
1554
+
1555
+ assert inpaint_images.shape[0] == batch_size, \
1556
+ 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
1557
+ assert inpaint_labels.shape[0] == batch_size, \
1558
+ 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
1559
+ assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), \
1560
+ 'number of inpainting images must be equal to the number of text to be conditioned on'
1561
+ assert not (self.condition_on_text and inpaint_labels.shape[0] != text_embeds.shape[0]), \
1562
+ 'number of inpainting images must be equal to the number of text to be conditioned on'
1563
+
1564
+ assert not (self.condition_on_text and not exists(text_embeds)), \
1565
+ 'text or text encodings must be passed into imagen if specified'
1566
+ assert not (not self.condition_on_text and exists(text_embeds)), \
1567
+ 'imagen specified not to be conditioned on text, yet it is presented'
1568
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), \
1569
+ f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
1570
+
1571
+ assert (not (exists(inpaint_images) or exists(inpaint_labels) or exists(inpaint_masks))) \
1572
+ or (exists(inpaint_images) and exists(inpaint_labels) and exists(inpaint_masks)), \
1573
+ 'inpaint images, labels and masks must be both passed in to do inpainting'
1574
+
1575
+ outputs = []
1576
+
1577
+ is_cuda = next(self.parameters()).is_cuda
1578
+ device = next(self.parameters()).device
1579
+
1580
+ lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
1581
+
1582
+ num_unets = len(self.unets)
1583
+
1584
+ # condition scaling
1585
+
1586
+ cond_scale = cast_tuple(cond_scale, num_unets)
1587
+
1588
+ # add frame dimension for video
1589
+
1590
+ assert not (self.is_video and not exists(video_frames)
1591
+ ), 'video_frames must be passed in on sample time if training on video'
1592
+
1593
+ frame_dims = (video_frames,) if self.is_video else tuple()
1594
+
1595
+ # for initial image and skipping steps
1596
+
1597
+ init_images = cast_tuple(init_images, num_unets)
1598
+ init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
1599
+ init_labels = cast_tuple(init_labels, num_unets)
1600
+
1601
+ skip_steps = cast_tuple(skip_steps, num_unets)
1602
+
1603
+ # handle starting at a unet greater than 1, for training only-upscaler training
1604
+
1605
+ if start_at_unet_number > 1:
1606
+ assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
1607
+ assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
1608
+ assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
1609
+ assert exists(start_label_or_video), 'starting image or video must be supplied if only doing upscaling'
1610
+
1611
+ prev_image_size = self.image_sizes[start_at_unet_number - 2]
1612
+ img = self.resize_to(start_image_or_video, prev_image_size)
1613
+ lbl = self.resize_to(start_label_or_video, prev_image_size)
1614
+
1615
+ # go through each unet in cascade
1616
+
1617
+ for unet_number, unet, channel, image_size, noise_scheduler, noise_scheduler_lbl, pred_objective, \
1618
+ dynamic_threshold, unet_cond_scale, unet_init_images, unet_init_labels, unet_skip_steps \
1619
+ in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes,
1620
+ self.noise_schedulers_sample, self.noise_schedulers_lbl_sample, self.pred_objectives,
1621
+ self.dynamic_thresholding, cond_scale, init_images, init_labels, skip_steps),
1622
+ disable=not use_tqdm):
1623
+
1624
+ if unet_number < start_at_unet_number:
1625
+ continue
1626
+
1627
+ assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'
1628
+
1629
+ context = self.one_unet_in_gpu(unet=unet) if is_cuda else nullcontext()
1630
+
1631
+ with context:
1632
+ lowres_cond_img = lowres_cond_lbl = lowres_noise_times = None
1633
+ shape = (batch_size, channel, *frame_dims, *image_size)
1634
+
1635
+ if unet.lowres_cond:
1636
+ lowres_noise_times = self.lowres_noise_schedule.get_times(
1637
+ batch_size, lowres_sample_noise_level, device=device)
1638
+
1639
+ lowres_cond_img = self.resize_to(img, image_size)
1640
+ lowres_cond_lbl = self.resize_to(lbl.float(), image_size)
1641
+
1642
+ lowres_cond_img = self.normalize_img(lowres_cond_img)
1643
+ lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(
1644
+ x_start=lowres_cond_img, t=lowres_noise_times, noise=torch.randn_like(lowres_cond_img))
1645
+ lowres_cond_log_lbl = index_to_log_onehot(lowres_cond_lbl.long(), self.num_classes)
1646
+ lowres_cond_log_lbl_noisy = self.lowres_noise_schedule_lbl.q_sample(
1647
+ lowres_cond_log_lbl, t=lowres_noise_times)
1648
+ lowres_cond_lbl_noisy = log_onehot_to_index(lowres_cond_log_lbl_noisy)
1649
+ lowres_cond_lbl = lowres_cond_lbl_noisy # change just naming
1650
+
1651
+ if exists(unet_init_images) and exists(unet_init_labels):
1652
+ unet_init_images = self.resize_to(unet_init_images, image_size)
1653
+ unet_init_labels = self.resize_to(unet_init_labels, image_size)
1654
+
1655
+ shape = (batch_size, self.channels, *frame_dims, *image_size)
1656
+
1657
+ img, lbl = self.p_sample_loop(
1658
+ unet,
1659
+ shape,
1660
+ text_embeds=text_embeds,
1661
+ text_mask=text_masks,
1662
+ cond_images=cond_images,
1663
+ inpaint_images=inpaint_images,
1664
+ inpaint_labels=inpaint_labels,
1665
+ inpaint_masks=inpaint_masks,
1666
+ inpaint_resample_times=inpaint_resample_times,
1667
+ init_images=unet_init_images,
1668
+ init_labels=unet_init_labels,
1669
+ skip_steps=unet_skip_steps,
1670
+ cond_scale=unet_cond_scale,
1671
+ lowres_cond_img=lowres_cond_img,
1672
+ lowres_cond_lbl=lowres_cond_lbl,
1673
+ lowres_noise_times=lowres_noise_times,
1674
+ noise_scheduler=noise_scheduler,
1675
+ noise_scheduler_lbl=noise_scheduler_lbl,
1676
+ pred_objective=pred_objective,
1677
+ dynamic_threshold=dynamic_threshold,
1678
+ use_tqdm=use_tqdm
1679
+ )
1680
+
1681
+ outputs.append((img.cpu(), lbl.cpu()))
1682
+
1683
+ if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
1684
+ break
1685
+
1686
+ # either return last unet output or all unet outputs
1687
+ output_index = -1 if not return_all_unet_outputs else slice(None)
1688
+
1689
+ if not return_pil_images:
1690
+ return outputs[output_index]
1691
+
1692
+ if not return_all_unet_outputs:
1693
+ outputs = outputs[-1:]
1694
+
1695
+ assert not self.is_video, 'converting sampled video tensor to video file is not supported yet'
1696
+
1697
+ # TODO lbl pil_images
1698
+ pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim=0))), outputs))
1699
+
1700
+ # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
1701
+ return pil_images[output_index]
1702
+
1703
+ def p_losses(
1704
+ self,
1705
+ unet: Union[JointUnet, Unet3D, NullUnet, DistributedDataParallel],
1706
+ x_start,
1707
+ lbl_start,
1708
+ times,
1709
+ *,
1710
+ noise_scheduler: GaussianDiffusionContinuousTimes,
1711
+ noise_scheduler_lbl: MultinomialDiffusion,
1712
+ lowres_cond_img=None,
1713
+ lowres_cond_lbl=None,
1714
+ lowres_aug_times=None,
1715
+ text_embeds=None,
1716
+ text_mask=None,
1717
+ cond_images=None,
1718
+ noise=None,
1719
+ noise_lbl=None,
1720
+ times_next=None,
1721
+ pred_objective='noise',
1722
+ p2_loss_weight_gamma=0.,
1723
+ random_crop_size=None
1724
+ ):
1725
+ is_video = x_start.ndim == 5
1726
+
1727
+ noise = default(noise, lambda: torch.randn_like(x_start))
1728
+ # noise_lbl = default(noise_lbl, lambda: torch.randn_like(x_start)) # TODO
1729
+
1730
+ # normalize to [-1, 1]
1731
+
1732
+ x_start = self.normalize_img(x_start)
1733
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
1734
+
1735
+ # random cropping during training
1736
+ # for upsamplers
1737
+
1738
+ if exists(random_crop_size):
1739
+ if is_video:
1740
+ frames = x_start.shape[2]
1741
+ x_start, lowres_cond_img, noise = rearrange_many(
1742
+ (x_start, lowres_cond_img, noise), 'b c f h w -> (b f) c h w')
1743
+
1744
+ aug = K.RandomCrop(random_crop_size, p=1.)
1745
+
1746
+ # make sure low res conditioner and image both get augmented the same way
1747
+ # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
1748
+ x_start = aug(x_start)
1749
+ lbl_start = aug(lbl_start, params=aug._params)
1750
+ lowres_cond_img = aug(lowres_cond_img, params=aug._params)
1751
+ lowres_cond_lbl = aug(lowres_cond_lbl, params=aug._params)
1752
+ noise = aug(noise, params=aug._params)
1753
+
1754
+ if is_video:
1755
+ x_start, lowres_cond_img, noise = rearrange_many(
1756
+ (x_start, lowres_cond_img, noise), '(b f) c h w -> b c f h w', f=frames)
1757
+
1758
+ # get x_t
1759
+
1760
+ x_noisy, log_snr = noise_scheduler.q_sample(x_start=x_start, t=times, noise=noise)
1761
+ log_lbl_start = index_to_log_onehot(lbl_start.long(), self.num_classes)
1762
+ log_lbl_noisy = noise_scheduler_lbl.q_sample(log_lbl_start, t=times)
1763
+ lbl_noisy = log_onehot_to_index(log_lbl_noisy)
1764
+
1765
+ # also noise the lowres conditioning image
1766
+ # at sample time, they then fix the noise level of 0.1 - 0.3
1767
+
1768
+ lowres_cond_img_noisy = None
1769
+ lowres_cond_lbl_noisy = None
1770
+ if exists(lowres_cond_img) and exists(lowres_cond_lbl):
1771
+ lowres_aug_times = default(lowres_aug_times, times)
1772
+ lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(
1773
+ x_start=lowres_cond_img, t=lowres_aug_times, noise=torch.randn_like(lowres_cond_img))
1774
+ lowres_cond_log_lbl = index_to_log_onehot(lowres_cond_lbl.long(), self.num_classes)
1775
+ lowres_cond_log_lbl_noisy = self.lowres_noise_schedule_lbl.q_sample(
1776
+ lowres_cond_log_lbl, t=lowres_aug_times)
1777
+ lowres_cond_lbl_noisy = log_onehot_to_index(lowres_cond_log_lbl_noisy)
1778
+
1779
+ # time condition
1780
+
1781
+ noise_cond = noise_scheduler.get_condition(times)
1782
+
1783
+ # unet kwargs
1784
+
1785
+ unet_kwargs = dict(
1786
+ text_embeds=text_embeds,
1787
+ text_mask=text_mask,
1788
+ cond_images=cond_images,
1789
+ lowres_noise_times=self.lowres_noise_schedule.get_condition(lowres_aug_times),
1790
+ lowres_cond_img=lowres_cond_img_noisy,
1791
+ lowres_cond_lbl=lowres_cond_lbl_noisy,
1792
+ cond_drop_prob=self.cond_drop_prob,
1793
+ )
1794
+
1795
+ # self condition if needed
1796
+
1797
+ # Because 'unet' can be an instance of DistributedDataParallel coming from the
1798
+ # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
1799
+ # access the member 'module' of the wrapped unet instance.
1800
+ self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
1801
+
1802
+ if self_cond and random() < 0.5:
1803
+ with torch.no_grad():
1804
+ pred, pred_lbl = unet.forward(
1805
+ x_noisy,
1806
+ lbl_noisy,
1807
+ noise_cond,
1808
+ **unet_kwargs
1809
+ ).detach()
1810
+ pred_lbl = F.log_softmax(pred_lbl, dim=1)
1811
+ pred_lbl = noise_scheduler_lbl.q_posterior(pred_lbl, log_lbl_noisy, times)
1812
+
1813
+ x_start = noise_scheduler.predict_start_from_noise(
1814
+ x_noisy, t=times, noise=pred) if pred_objective == 'noise' else pred
1815
+ # lbl_start = noise_scheduler_lbl.predict_start_from_noise(
1816
+ # lbl_noisy, t=times, noise=pred_lbl) if pred_objective == 'noise' else pred_lbl # TODO ???
1817
+ lbl_start = None
1818
+
1819
+ unet_kwargs = {**unet_kwargs, 'self_cond': x_start, 'self_cond_lbl': lbl_start}
1820
+
1821
+ # get prediction
1822
+
1823
+ pred, pred_lbl = unet.forward(
1824
+ x_noisy,
1825
+ lbl_noisy,
1826
+ noise_cond,
1827
+ **unet_kwargs
1828
+ )
1829
+ pred_lbl = F.log_softmax(pred_lbl, dim=1)
1830
+ pred_lbl_post = noise_scheduler_lbl.q_posterior(pred_lbl, log_lbl_noisy, times)
1831
+
1832
+ # prediction objective
1833
+
1834
+ if pred_objective == 'noise':
1835
+ target = noise
1836
+ elif pred_objective == 'x_start':
1837
+ target = x_start
1838
+ else:
1839
+ raise ValueError(f'unknown objective {pred_objective}')
1840
+ target_log_lbl = noise_scheduler_lbl.q_posterior(log_lbl_start, log_lbl_noisy, times)
1841
+
1842
+ # losses
1843
+
1844
+ losses = self.loss_fn(pred, target, reduction='none')
1845
+ losses = reduce(losses, 'b ... -> b', 'mean')
1846
+ losses_lbl = noise_scheduler_lbl.loss_fn(target_log_lbl, pred_lbl_post, times, log_lbl_start)
1847
+
1848
+ # p2 loss reweighting
1849
+
1850
+ if p2_loss_weight_gamma > 0:
1851
+ loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
1852
+ losses = losses * loss_weight
1853
+ losses_lbl = losses_lbl * loss_weight
1854
+
1855
+ return losses.mean(), losses_lbl.mean()
1856
+
1857
+ def forward(
1858
+ self,
1859
+ images,
1860
+ labels,
1861
+ unet: Union[JointUnet, Unet3D, NullUnet, DistributedDataParallel] = None,
1862
+ texts: List[str] = None,
1863
+ text_embeds=None,
1864
+ text_masks=None,
1865
+ unet_number=None,
1866
+ cond_images=None
1867
+ ):
1868
+ # assert images.shape[-1] == images.shape[-2], \
1869
+ # f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
1870
+ assert not (len(self.unets) > 1 and not exists(unet_number)), \
1871
+ f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
1872
+ unet_number = default(unet_number, 1)
1873
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, \
1874
+ 'you can only train on unet #{self.only_train_unet_number}'
1875
+
1876
+ images = cast_uint8_images_to_float(images)
1877
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
1878
+
1879
+ assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
1880
+
1881
+ unet_index = unet_number - 1
1882
+
1883
+ unet = default(unet, lambda: self.get_unet(unet_number))
1884
+
1885
+ assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
1886
+
1887
+ noise_scheduler = self.noise_schedulers[unet_index]
1888
+ noise_scheduler_lbl = self.noise_schedulers_lbl[unet_index]
1889
+ p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
1890
+ pred_objective = self.pred_objectives[unet_index]
1891
+ target_image_size = self.image_sizes[unet_index]
1892
+ random_crop_size = self.random_crop_sizes[unet_index]
1893
+ prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
1894
+
1895
+ b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5
1896
+
1897
+ check_shape(images, 'b c ...', c=self.channels)
1898
+ assert h >= target_image_size[0] and w >= target_image_size[1]
1899
+
1900
+ frames = images.shape[2] if is_video else None
1901
+
1902
+ times = noise_scheduler.sample_random_times(b, device=device)
1903
+
1904
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
1905
+ assert all([*map(len, texts)]), 'text cannot be empty'
1906
+ assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
1907
+
1908
+ with autocast(enabled=False):
1909
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask=True)
1910
+
1911
+ text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
1912
+
1913
+ if not self.unconditional:
1914
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim=-1))
1915
+
1916
+ assert not (self.condition_on_text and not exists(text_embeds)
1917
+ ), 'text or text encodings must be passed into decoder if specified'
1918
+ assert not (not self.condition_on_text and exists(text_embeds)
1919
+ ), 'decoder specified not to be conditioned on text, yet it is presented'
1920
+
1921
+ assert not (exists(
1922
+ text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
1923
+
1924
+ lowres_cond_img = lowres_cond_lbl = lowres_aug_times = None
1925
+ if exists(prev_image_size):
1926
+ lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range=self.input_image_range)
1927
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range=self.input_image_range)
1928
+ lowres_cond_lbl = self.resize_to(labels, prev_image_size, clamp_range=None)
1929
+ lowres_cond_lbl = self.resize_to(lowres_cond_lbl, target_image_size, clamp_range=None)
1930
+
1931
+ if self.per_sample_random_aug_noise_level:
1932
+ lowres_aug_times = self.lowres_noise_schedule.sample_random_times(
1933
+ b, self.lowres_max_thres, device=device)
1934
+ else:
1935
+ lowres_aug_time = self.lowres_noise_schedule.sample_random_times(
1936
+ 1, self.lowres_max_thres, device=device)
1937
+ lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b=b)
1938
+
1939
+ images = self.resize_to(images, target_image_size)
1940
+ labels = self.resize_to(labels, target_image_size)
1941
+
1942
+ return self.p_losses(unet, images, labels, times, text_embeds=text_embeds, text_mask=text_masks, cond_images=cond_images, noise_scheduler=noise_scheduler, noise_scheduler_lbl=noise_scheduler_lbl, lowres_cond_img=lowres_cond_img, lowres_cond_lbl=lowres_cond_lbl, lowres_aug_times=lowres_aug_times, pred_objective=pred_objective, p2_loss_weight_gamma=p2_loss_weight_gamma, random_crop_size=random_crop_size)
imagen_pytorch/t5.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from typing import List
4
+ from transformers import T5Tokenizer, T5EncoderModel, T5Config
5
+ from einops import rearrange
6
+
7
+ transformers.logging.set_verbosity_error()
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def default(val, d):
13
+ if exists(val):
14
+ return val
15
+ return d() if callable(d) else d
16
+
17
+ # config
18
+
19
+ MAX_LENGTH = 256
20
+
21
+ DEFAULT_T5_NAME = 'google/t5-v1_1-base'
22
+
23
+ T5_CONFIGS = {}
24
+
25
+ # singleton globals
26
+
27
+ def get_tokenizer(name):
28
+ tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH, cache_dir="checkpoints")
29
+ return tokenizer
30
+
31
+ def get_model(name):
32
+ model = T5EncoderModel.from_pretrained(name, cache_dir="checkpoints")
33
+ return model
34
+
35
+ def get_model_and_tokenizer(name):
36
+ global T5_CONFIGS
37
+
38
+ if name not in T5_CONFIGS:
39
+ T5_CONFIGS[name] = dict()
40
+ if "model" not in T5_CONFIGS[name]:
41
+ T5_CONFIGS[name]["model"] = get_model(name)
42
+ if "tokenizer" not in T5_CONFIGS[name]:
43
+ T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
44
+
45
+ return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
46
+
47
+ def get_encoded_dim(name):
48
+ if name not in T5_CONFIGS:
49
+ # avoids loading the model if we only want to get the dim
50
+ config = T5Config.from_pretrained(name, cache_dir="checkpoints")
51
+ T5_CONFIGS[name] = dict(config=config)
52
+ elif "config" in T5_CONFIGS[name]:
53
+ config = T5_CONFIGS[name]["config"]
54
+ elif "model" in T5_CONFIGS[name]:
55
+ config = T5_CONFIGS[name]["model"].config
56
+ else:
57
+ assert False
58
+ return config.d_model
59
+
60
+ # encoding text
61
+
62
+ def t5_tokenize(
63
+ texts: List[str],
64
+ name = DEFAULT_T5_NAME
65
+ ):
66
+ t5, tokenizer = get_model_and_tokenizer(name)
67
+
68
+ if torch.cuda.is_available():
69
+ t5 = t5.cuda()
70
+
71
+ device = next(t5.parameters()).device
72
+
73
+ encoded = tokenizer.batch_encode_plus(
74
+ texts,
75
+ return_tensors = "pt",
76
+ padding = 'longest',
77
+ max_length = MAX_LENGTH,
78
+ truncation = True
79
+ )
80
+
81
+ input_ids = encoded.input_ids.to(device)
82
+ attn_mask = encoded.attention_mask.to(device)
83
+ return input_ids, attn_mask
84
+
85
+ def t5_encode_tokenized_text(
86
+ token_ids,
87
+ attn_mask = None,
88
+ pad_id = None,
89
+ name = DEFAULT_T5_NAME
90
+ ):
91
+ assert exists(attn_mask) or exists(pad_id)
92
+ t5, _ = get_model_and_tokenizer(name)
93
+
94
+ attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
95
+
96
+ t5.eval()
97
+
98
+ with torch.no_grad():
99
+ output = t5(input_ids = token_ids, attention_mask = attn_mask)
100
+ encoded_text = output.last_hidden_state.detach()
101
+
102
+ attn_mask = attn_mask.bool()
103
+
104
+ encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0.
105
+ return encoded_text
106
+
107
+ def t5_encode_text(
108
+ texts: List[str],
109
+ name = DEFAULT_T5_NAME,
110
+ return_attn_mask = False
111
+ ):
112
+ token_ids, attn_mask = t5_tokenize(texts, name = name)
113
+ encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
114
+
115
+ if return_attn_mask:
116
+ attn_mask = attn_mask.bool()
117
+ return encoded_text, attn_mask
118
+
119
+ return encoded_text
imagen_pytorch/trainer.py ADDED
@@ -0,0 +1,1782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import os
3
+ import time
4
+ import copy
5
+ from pathlib import Path
6
+ from math import ceil
7
+ from contextlib import contextmanager, nullcontext
8
+ from functools import partial, wraps
9
+ from collections.abc import Iterable
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import random_split, DataLoader
15
+ from torch.optim import Adam
16
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
17
+ from torch.cuda.amp import autocast, GradScaler
18
+
19
+ import pytorch_warmup as warmup
20
+
21
+ from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
22
+ from imagen_pytorch.elucidated_imagen import ElucidatedImagen
23
+ from imagen_pytorch.joint_imagen import JointImagen
24
+ from imagen_pytorch.data import cycle
25
+
26
+ from imagen_pytorch.version import __version__
27
+ from packaging import version
28
+
29
+ import numpy as np
30
+
31
+ from ema_pytorch import EMA
32
+
33
+ from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
34
+
35
+ from fsspec.core import url_to_fs
36
+ from fsspec.implementations.local import LocalFileSystem
37
+
38
+ # helper functions
39
+
40
+
41
+ def exists(val):
42
+ return val is not None
43
+
44
+
45
+ def default(val, d):
46
+ if exists(val):
47
+ return val
48
+ return d() if callable(d) else d
49
+
50
+
51
+ def cast_tuple(val, length=1):
52
+ if isinstance(val, list):
53
+ val = tuple(val)
54
+
55
+ return val if isinstance(val, tuple) else ((val,) * length)
56
+
57
+
58
+ def find_first(fn, arr):
59
+ for ind, el in enumerate(arr):
60
+ if fn(el):
61
+ return ind
62
+ return -1
63
+
64
+
65
+ def pick_and_pop(keys, d):
66
+ values = list(map(lambda key: d.pop(key), keys))
67
+ return dict(zip(keys, values))
68
+
69
+
70
+ def group_dict_by_key(cond, d):
71
+ return_val = [dict(), dict()]
72
+ for key in d.keys():
73
+ match = bool(cond(key))
74
+ ind = int(not match)
75
+ return_val[ind][key] = d[key]
76
+ return (*return_val,)
77
+
78
+
79
+ def string_begins_with(prefix, str):
80
+ return str.startswith(prefix)
81
+
82
+
83
+ def group_by_key_prefix(prefix, d):
84
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
85
+
86
+
87
+ def groupby_prefix_and_trim(prefix, d):
88
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
89
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
90
+ return kwargs_without_prefix, kwargs
91
+
92
+
93
+ def num_to_groups(num, divisor):
94
+ groups = num // divisor
95
+ remainder = num % divisor
96
+ arr = [divisor] * groups
97
+ if remainder > 0:
98
+ arr.append(remainder)
99
+ return arr
100
+
101
+ # url to fs, bucket, path - for checkpointing to cloud
102
+
103
+
104
+ def url_to_bucket(url):
105
+ if '://' not in url:
106
+ return url
107
+
108
+ prefix, suffix = url.split('://')
109
+
110
+ if prefix in {'gs', 's3'}:
111
+ return suffix.split('/')[0]
112
+ else:
113
+ raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
114
+
115
+ # decorators
116
+
117
+
118
+ def eval_decorator(fn):
119
+ def inner(model, *args, **kwargs):
120
+ was_training = model.training
121
+ model.eval()
122
+ out = fn(model, *args, **kwargs)
123
+ model.train(was_training)
124
+ return out
125
+ return inner
126
+
127
+
128
+ def cast_torch_tensor(fn, cast_fp16=False):
129
+ @wraps(fn)
130
+ def inner(model, *args, **kwargs):
131
+ device = kwargs.pop('_device', model.device)
132
+ cast_device = kwargs.pop('_cast_device', True)
133
+
134
+ should_cast_fp16 = cast_fp16 and model.cast_half_at_training
135
+
136
+ kwargs_keys = kwargs.keys()
137
+ all_args = (*args, *kwargs.values())
138
+ split_kwargs_index = len(all_args) - len(kwargs_keys)
139
+ all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
140
+
141
+ if cast_device:
142
+ all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
143
+
144
+ if should_cast_fp16:
145
+ all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(
146
+ t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
147
+
148
+ args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
149
+ kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
150
+
151
+ out = fn(model, *args, **kwargs)
152
+ return out
153
+ return inner
154
+
155
+ # gradient accumulation functions
156
+
157
+
158
+ def split_iterable(it, split_size):
159
+ accum = []
160
+ for ind in range(ceil(len(it) / split_size)):
161
+ start_index = ind * split_size
162
+ accum.append(it[start_index: (start_index + split_size)])
163
+ return accum
164
+
165
+
166
+ def split(t, split_size=None):
167
+ if not exists(split_size):
168
+ return t
169
+
170
+ if isinstance(t, torch.Tensor):
171
+ return t.split(split_size, dim=0)
172
+
173
+ if isinstance(t, Iterable):
174
+ return split_iterable(t, split_size)
175
+
176
+ return TypeError
177
+
178
+
179
+ def find_first(cond, arr):
180
+ for el in arr:
181
+ if cond(el):
182
+ return el
183
+ return None
184
+
185
+
186
+ def split_args_and_kwargs(*args, split_size=None, **kwargs):
187
+ all_args = (*args, *kwargs.values())
188
+ len_all_args = len(all_args)
189
+ first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
190
+ assert exists(first_tensor)
191
+
192
+ batch_size = len(first_tensor)
193
+ split_size = default(split_size, batch_size)
194
+ num_chunks = ceil(batch_size / split_size)
195
+
196
+ dict_len = len(kwargs)
197
+ dict_keys = kwargs.keys()
198
+ split_kwargs_index = len_all_args - dict_len
199
+
200
+ split_all_args = [split(arg, split_size=split_size) if exists(arg) and isinstance(
201
+ arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
202
+ chunk_sizes = tuple(map(len, split_all_args[0]))
203
+
204
+ for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
205
+ chunked_args, chunked_kwargs_values = chunked_all_args[:
206
+ split_kwargs_index], chunked_all_args[split_kwargs_index:]
207
+ chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
208
+ chunk_size_frac = chunk_size / batch_size
209
+ yield chunk_size_frac, (chunked_args, chunked_kwargs)
210
+
211
+ # imagen trainer
212
+
213
+
214
+ def imagen_sample_in_chunks(fn):
215
+ @wraps(fn)
216
+ def inner(self, *args, max_batch_size=None, **kwargs):
217
+ if not exists(max_batch_size):
218
+ return fn(self, *args, **kwargs)
219
+
220
+ if self.imagen.unconditional:
221
+ batch_size = kwargs.get('batch_size')
222
+ batch_sizes = num_to_groups(batch_size, max_batch_size)
223
+ outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
224
+ else:
225
+ outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs)
226
+ in split_args_and_kwargs(*args, split_size=max_batch_size, **kwargs)]
227
+
228
+ if isinstance(outputs[0], torch.Tensor):
229
+ return torch.cat(outputs, dim=0)
230
+
231
+ return list(map(lambda t: torch.cat(t, dim=0), list(zip(*outputs))))
232
+
233
+ return inner
234
+
235
+
236
+ def restore_parts(state_dict_target, state_dict_from):
237
+ for name, param in state_dict_from.items():
238
+
239
+ if name not in state_dict_target:
240
+ continue
241
+
242
+ if param.size() == state_dict_target[name].size():
243
+ state_dict_target[name].copy_(param)
244
+ else:
245
+ print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
246
+
247
+ return state_dict_target
248
+
249
+
250
+ def load_unet_from_trainer(trainer, checkpoint_path, src_unet_idx, tgt_unet_idx, only_model=True):
251
+ assert only_model == True # TODO optimizer, scheduler ...
252
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
253
+ state_dict = OrderedDict()
254
+ for key, val in ckpt['model'].items():
255
+ if key.startswith(f'unets.{src_unet_idx}'):
256
+ state_dict[key[8:]] = val
257
+ trainer.imagen.unets[tgt_unet_idx].load_state_dict(state_dict)
258
+ state_dict = OrderedDict()
259
+ for key, val in ckpt['ema'].items():
260
+ if key.startswith(f'{src_unet_idx}.'):
261
+ state_dict[key[2:]] = val
262
+ trainer.ema_unets[tgt_unet_idx].load_state_dict(state_dict)
263
+ return
264
+
265
+
266
+ class ImagenTrainer(nn.Module):
267
+ locked = False
268
+
269
+ def __init__(
270
+ self,
271
+ imagen=None,
272
+ imagen_checkpoint_path=None,
273
+ use_ema=True,
274
+ lr=1e-4,
275
+ eps=1e-8,
276
+ beta1=0.9,
277
+ beta2=0.99,
278
+ max_grad_norm=None,
279
+ group_wd_params=True,
280
+ warmup_steps=None,
281
+ cosine_decay_max_steps=None,
282
+ only_train_unet_number=None,
283
+ fp16=False,
284
+ precision=None,
285
+ split_batches=True,
286
+ dl_tuple_output_keywords_names=('images', 'texts'),
287
+ verbose=True,
288
+ split_valid_fraction=0.025,
289
+ split_valid_from_train=False,
290
+ split_random_seed=42,
291
+ checkpoint_path=None,
292
+ checkpoint_every=None,
293
+ checkpoint_fs=None,
294
+ fs_kwargs: dict = None,
295
+ max_checkpoints_keep=20,
296
+ **kwargs
297
+ ):
298
+ super().__init__()
299
+ assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
300
+ assert exists(imagen) ^ exists(
301
+ imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
302
+
303
+ # determine filesystem, using fsspec, for saving to local filesystem or cloud
304
+
305
+ self.fs = checkpoint_fs
306
+
307
+ if not exists(self.fs):
308
+ fs_kwargs = default(fs_kwargs, {})
309
+ self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
310
+
311
+ assert isinstance(imagen, (Imagen, ElucidatedImagen))
312
+ ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
313
+
314
+ # elucidated or not
315
+
316
+ self.is_elucidated = isinstance(imagen, ElucidatedImagen)
317
+
318
+ # create accelerator instance
319
+
320
+ accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)
321
+
322
+ assert not (fp16 and exists(precision)
323
+ ), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
324
+ accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')
325
+
326
+ self.accelerator = Accelerator(**{
327
+ 'split_batches': split_batches,
328
+ 'mixed_precision': accelerator_mixed_precision,
329
+ 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerate_kwargs})
330
+
331
+ ImagenTrainer.locked = self.is_distributed
332
+
333
+ # cast data to fp16 at training time if needed
334
+
335
+ self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
336
+
337
+ # grad scaler must be managed outside of accelerator
338
+
339
+ grad_scaler_enabled = fp16
340
+
341
+ # imagen, unets and ema unets
342
+
343
+ self.imagen = imagen
344
+ self.num_unets = len(self.imagen.unets)
345
+
346
+ self.use_ema = use_ema and self.is_main
347
+ self.ema_unets = nn.ModuleList([])
348
+
349
+ # keep track of what unet is being trained on
350
+ # only going to allow 1 unet training at a time
351
+
352
+ self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on
353
+
354
+ # data related functions
355
+
356
+ self.train_dl_iter = None
357
+ self.train_dl = None
358
+
359
+ self.valid_dl_iter = None
360
+ self.valid_dl = None
361
+
362
+ self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
363
+
364
+ # auto splitting validation from training, if dataset is passed in
365
+
366
+ self.split_valid_from_train = split_valid_from_train
367
+
368
+ assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
369
+ self.split_valid_fraction = split_valid_fraction
370
+ self.split_random_seed = split_random_seed
371
+
372
+ # be able to finely customize learning rate, weight decay
373
+ # per unet
374
+
375
+ lr, eps, warmup_steps, cosine_decay_max_steps = map(
376
+ partial(cast_tuple, length=self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))
377
+
378
+ for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
379
+ optimizer = Adam(
380
+ unet.parameters(),
381
+ lr=unet_lr,
382
+ eps=unet_eps,
383
+ betas=(beta1, beta2),
384
+ **kwargs
385
+ )
386
+
387
+ if self.use_ema:
388
+ self.ema_unets.append(EMA(unet, **ema_kwargs))
389
+
390
+ scaler = GradScaler(enabled=grad_scaler_enabled)
391
+
392
+ scheduler = warmup_scheduler = None
393
+
394
+ if exists(unet_cosine_decay_max_steps):
395
+ scheduler = CosineAnnealingLR(optimizer, T_max=unet_cosine_decay_max_steps)
396
+
397
+ if exists(unet_warmup_steps):
398
+ warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=unet_warmup_steps)
399
+
400
+ if not exists(scheduler):
401
+ scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
402
+
403
+ # set on object
404
+
405
+ setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
406
+ setattr(self, f'scaler{ind}', scaler)
407
+ setattr(self, f'scheduler{ind}', scheduler)
408
+ setattr(self, f'warmup{ind}', warmup_scheduler)
409
+
410
+ # gradient clipping if needed
411
+
412
+ self.max_grad_norm = max_grad_norm
413
+
414
+ # step tracker and misc
415
+
416
+ self.register_buffer('steps', torch.tensor([0] * self.num_unets))
417
+
418
+ self.verbose = verbose
419
+
420
+ # automatic set devices based on what accelerator decided
421
+
422
+ self.imagen.to(self.device)
423
+ self.to(self.device)
424
+
425
+ # checkpointing
426
+
427
+ assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
428
+ self.checkpoint_path = checkpoint_path
429
+ self.checkpoint_every = checkpoint_every
430
+ self.max_checkpoints_keep = max_checkpoints_keep
431
+
432
+ self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
433
+
434
+ if exists(checkpoint_path) and self.can_checkpoint:
435
+ bucket = url_to_bucket(checkpoint_path)
436
+
437
+ if not self.fs.exists(bucket):
438
+ self.fs.mkdir(bucket)
439
+
440
+ self.load_from_checkpoint_folder()
441
+
442
+ # only allowing training for unet
443
+
444
+ self.only_train_unet_number = only_train_unet_number
445
+ self.validate_and_set_unet_being_trained(only_train_unet_number)
446
+
447
+ # computed values
448
+
449
+ @property
450
+ def device(self):
451
+ return self.accelerator.device
452
+
453
+ @property
454
+ def is_distributed(self):
455
+ return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
456
+
457
+ @property
458
+ def is_main(self):
459
+ return self.accelerator.is_main_process
460
+
461
+ @property
462
+ def is_local_main(self):
463
+ return self.accelerator.is_local_main_process
464
+
465
+ @property
466
+ def unwrapped_unet(self):
467
+ return self.accelerator.unwrap_model(self.unet_being_trained)
468
+
469
+ # optimizer helper functions
470
+
471
+ def get_lr(self, unet_number):
472
+ self.validate_unet_number(unet_number)
473
+ unet_index = unet_number - 1
474
+
475
+ optim = getattr(self, f'optim{unet_index}')
476
+
477
+ return optim.param_groups[0]['lr']
478
+
479
+ # function for allowing only one unet from being trained at a time
480
+
481
+ def validate_and_set_unet_being_trained(self, unet_number=None):
482
+ if exists(unet_number):
483
+ self.validate_unet_number(unet_number)
484
+
485
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
486
+
487
+ self.only_train_unet_number = unet_number
488
+ self.imagen.only_train_unet_number = unet_number
489
+
490
+ if not exists(unet_number):
491
+ return
492
+
493
+ self.wrap_unet(unet_number)
494
+
495
+ def wrap_unet(self, unet_number):
496
+ if hasattr(self, 'one_unet_wrapped'):
497
+ return
498
+
499
+ unet = self.imagen.get_unet(unet_number)
500
+ self.unet_being_trained = self.accelerator.prepare(unet)
501
+ unet_index = unet_number - 1
502
+
503
+ optimizer = getattr(self, f'optim{unet_index}')
504
+ scheduler = getattr(self, f'scheduler{unet_index}')
505
+
506
+ optimizer = self.accelerator.prepare(optimizer)
507
+
508
+ if exists(scheduler):
509
+ scheduler = self.accelerator.prepare(scheduler)
510
+
511
+ setattr(self, f'optim{unet_index}', optimizer)
512
+ setattr(self, f'scheduler{unet_index}', scheduler)
513
+
514
+ self.one_unet_wrapped = True
515
+
516
+ # hacking accelerator due to not having separate gradscaler per optimizer
517
+
518
+ def set_accelerator_scaler(self, unet_number):
519
+ unet_number = self.validate_unet_number(unet_number)
520
+ scaler = getattr(self, f'scaler{unet_number - 1}')
521
+
522
+ self.accelerator.scaler = scaler
523
+ for optimizer in self.accelerator._optimizers:
524
+ optimizer.scaler = scaler
525
+
526
+ # helper print
527
+
528
+ def print(self, msg):
529
+ if not self.is_main:
530
+ return
531
+
532
+ if not self.verbose:
533
+ return
534
+
535
+ return self.accelerator.print(msg)
536
+
537
+ # validating the unet number
538
+
539
+ def validate_unet_number(self, unet_number=None):
540
+ if self.num_unets == 1:
541
+ unet_number = default(unet_number, 1)
542
+
543
+ assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
544
+ return unet_number
545
+
546
+ # number of training steps taken
547
+
548
+ def num_steps_taken(self, unet_number=None):
549
+ if self.num_unets == 1:
550
+ unet_number = default(unet_number, 1)
551
+
552
+ return self.steps[unet_number - 1].item()
553
+
554
+ def print_untrained_unets(self):
555
+ print_final_error = False
556
+
557
+ for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
558
+ if steps > 0 or isinstance(unet, NullUnet):
559
+ continue
560
+
561
+ self.print(f'unet {ind + 1} has not been trained')
562
+ print_final_error = True
563
+
564
+ if print_final_error:
565
+ self.print(
566
+ 'when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
567
+
568
+ # data related functions
569
+
570
+ def add_train_dataloader(self, dl=None):
571
+ if not exists(dl):
572
+ return
573
+
574
+ assert not exists(self.train_dl), 'training dataloader was already added'
575
+ self.train_dl = self.accelerator.prepare(dl)
576
+
577
+ def add_valid_dataloader(self, dl):
578
+ if not exists(dl):
579
+ return
580
+
581
+ assert not exists(self.valid_dl), 'validation dataloader was already added'
582
+ self.valid_dl = self.accelerator.prepare(dl)
583
+
584
+ def add_train_dataset(self, ds=None, *, batch_size, **dl_kwargs):
585
+ if not exists(ds):
586
+ return
587
+
588
+ assert not exists(self.train_dl), 'training dataloader was already added'
589
+
590
+ valid_ds = None
591
+ if self.split_valid_from_train:
592
+ train_size = int((1 - self.split_valid_fraction) * len(ds))
593
+ valid_size = len(ds) - train_size
594
+
595
+ ds, valid_ds = random_split(ds, [train_size, valid_size],
596
+ generator=torch.Generator().manual_seed(self.split_random_seed))
597
+ self.print(f'training with dataset of {len(ds)} samples '
598
+ f'and validating with randomly splitted {len(valid_ds)} samples')
599
+
600
+ dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
601
+ self.train_dl = self.accelerator.prepare(dl)
602
+
603
+ if not self.split_valid_from_train:
604
+ return
605
+
606
+ self.add_valid_dataset(valid_ds, batch_size=batch_size, **dl_kwargs)
607
+
608
+ def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
609
+ if not exists(ds):
610
+ return
611
+
612
+ assert not exists(self.valid_dl), 'validation dataloader was already added'
613
+
614
+ dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
615
+ self.valid_dl = self.accelerator.prepare(dl)
616
+
617
+ def create_train_iter(self):
618
+ assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
619
+
620
+ if exists(self.train_dl_iter):
621
+ return
622
+
623
+ self.train_dl_iter = cycle(self.train_dl)
624
+
625
+ def create_valid_iter(self):
626
+ assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
627
+
628
+ if exists(self.valid_dl_iter):
629
+ return
630
+
631
+ self.valid_dl_iter = cycle(self.valid_dl)
632
+
633
+ def train_step(self, unet_number=None, **kwargs):
634
+ self.create_train_iter()
635
+ loss = self.step_with_dl_iter(self.train_dl_iter, unet_number=unet_number, **kwargs)
636
+ self.update(unet_number=unet_number)
637
+ return loss
638
+
639
+ @torch.no_grad()
640
+ @eval_decorator
641
+ def valid_step(self, **kwargs):
642
+ self.create_valid_iter()
643
+
644
+ context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
645
+
646
+ with context():
647
+ loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
648
+ return loss
649
+
650
+ def step_with_dl_iter(self, dl_iter, **kwargs):
651
+ dl_tuple_output = cast_tuple(next(dl_iter))
652
+ model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
653
+ loss = self.forward(**{**kwargs, **model_input})
654
+ return loss
655
+
656
+ # checkpointing functions
657
+
658
+ @property
659
+ def all_checkpoints_sorted(self):
660
+ glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
661
+ checkpoints = self.fs.glob(glob_pattern)
662
+ sorted_checkpoints = sorted(checkpoints, key=lambda x: int(str(x).split('.')[-2]), reverse=True)
663
+ return sorted_checkpoints
664
+
665
+ def load_from_checkpoint_folder(self, last_total_steps=-1):
666
+ if last_total_steps != -1:
667
+ filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
668
+ self.load(filepath)
669
+ return
670
+
671
+ sorted_checkpoints = self.all_checkpoints_sorted
672
+
673
+ if len(sorted_checkpoints) == 0:
674
+ self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
675
+ return
676
+
677
+ last_checkpoint = sorted_checkpoints[0]
678
+ self.load(last_checkpoint)
679
+
680
+ def save_to_checkpoint_folder(self):
681
+ self.accelerator.wait_for_everyone()
682
+
683
+ if not self.can_checkpoint:
684
+ return
685
+
686
+ total_steps = int(self.steps.sum().item())
687
+ filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
688
+
689
+ self.save(filepath)
690
+
691
+ if self.max_checkpoints_keep <= 0:
692
+ return
693
+
694
+ sorted_checkpoints = self.all_checkpoints_sorted
695
+ checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
696
+
697
+ for checkpoint in checkpoints_to_discard:
698
+ self.fs.rm(checkpoint)
699
+
700
+ # saving and loading functions
701
+
702
+ def save(
703
+ self,
704
+ path,
705
+ overwrite=True,
706
+ without_optim_and_sched=False,
707
+ **kwargs
708
+ ):
709
+ # self.accelerator.wait_for_everyone()
710
+
711
+ if not self.can_checkpoint:
712
+ return
713
+
714
+ fs = self.fs
715
+
716
+ assert not (fs.exists(path) and not overwrite)
717
+
718
+ self.reset_ema_unets_all_one_device()
719
+
720
+ save_obj = dict(
721
+ model=self.imagen.state_dict(),
722
+ version=__version__,
723
+ steps=self.steps.cpu(),
724
+ **kwargs
725
+ )
726
+
727
+ save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
728
+
729
+ for ind in save_optim_and_sched_iter:
730
+ scaler_key = f'scaler{ind}'
731
+ optimizer_key = f'optim{ind}'
732
+ scheduler_key = f'scheduler{ind}'
733
+ warmup_scheduler_key = f'warmup{ind}'
734
+
735
+ scaler = getattr(self, scaler_key)
736
+ optimizer = getattr(self, optimizer_key)
737
+ scheduler = getattr(self, scheduler_key)
738
+ warmup_scheduler = getattr(self, warmup_scheduler_key)
739
+
740
+ if exists(scheduler):
741
+ save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
742
+
743
+ if exists(warmup_scheduler):
744
+ save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
745
+
746
+ save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
747
+
748
+ if self.use_ema:
749
+ save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
750
+
751
+ # determine if imagen config is available
752
+
753
+ if hasattr(self.imagen, '_config'):
754
+ self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')
755
+
756
+ save_obj = {
757
+ **save_obj,
758
+ 'imagen_type': 'elucidated' if self.is_elucidated else 'original',
759
+ 'imagen_params': self.imagen._config
760
+ }
761
+
762
+ # save to path
763
+
764
+ with fs.open(path, 'wb') as f:
765
+ torch.save(save_obj, f)
766
+
767
+ self.print(f'checkpoint saved to {path}')
768
+
769
+ def load(self, path, only_model=False, strict=True, noop_if_not_exist=False):
770
+ fs = self.fs
771
+
772
+ if noop_if_not_exist and not fs.exists(path):
773
+ self.print(f'trainer checkpoint not found at {str(path)}')
774
+ return
775
+
776
+ assert fs.exists(path), f'{path} does not exist'
777
+
778
+ self.reset_ema_unets_all_one_device()
779
+
780
+ # to avoid extra GPU memory usage in main process when using Accelerate
781
+
782
+ with fs.open(path) as f:
783
+ loaded_obj = torch.load(f, map_location='cpu')
784
+
785
+ if version.parse(__version__) != version.parse(loaded_obj['version']):
786
+ self.print(
787
+ f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
788
+
789
+ try:
790
+ self.imagen.load_state_dict(loaded_obj['model'], strict=strict)
791
+ except RuntimeError:
792
+ print("Failed loading state dict. Trying partial load")
793
+ self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
794
+ loaded_obj['model']))
795
+
796
+ if only_model:
797
+ return loaded_obj
798
+
799
+ self.steps.copy_(loaded_obj['steps'])
800
+
801
+ for ind in range(0, self.num_unets):
802
+ scaler_key = f'scaler{ind}'
803
+ optimizer_key = f'optim{ind}'
804
+ scheduler_key = f'scheduler{ind}'
805
+ warmup_scheduler_key = f'warmup{ind}'
806
+
807
+ scaler = getattr(self, scaler_key)
808
+ optimizer = getattr(self, optimizer_key)
809
+ scheduler = getattr(self, scheduler_key)
810
+ warmup_scheduler = getattr(self, warmup_scheduler_key)
811
+
812
+ if exists(scheduler) and scheduler_key in loaded_obj:
813
+ scheduler.load_state_dict(loaded_obj[scheduler_key])
814
+
815
+ if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
816
+ warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
817
+
818
+ if exists(optimizer):
819
+ try:
820
+ optimizer.load_state_dict(loaded_obj[optimizer_key])
821
+ scaler.load_state_dict(loaded_obj[scaler_key])
822
+ except:
823
+ self.print(
824
+ 'could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')
825
+
826
+ if self.use_ema:
827
+ assert 'ema' in loaded_obj
828
+ try:
829
+ self.ema_unets.load_state_dict(loaded_obj['ema'], strict=strict)
830
+ except RuntimeError:
831
+ print("Failed loading state dict. Trying partial load")
832
+ self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
833
+ loaded_obj['ema']))
834
+
835
+ self.print(f'checkpoint loaded from {path}')
836
+ return loaded_obj
837
+
838
+ # managing ema unets and their devices
839
+
840
+ @property
841
+ def unets(self):
842
+ return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
843
+
844
+ def get_ema_unet(self, unet_number=None):
845
+ if not self.use_ema:
846
+ return
847
+
848
+ unet_number = self.validate_unet_number(unet_number)
849
+ index = unet_number - 1
850
+
851
+ if isinstance(self.unets, nn.ModuleList):
852
+ unets_list = [unet for unet in self.ema_unets]
853
+ delattr(self, 'ema_unets')
854
+ self.ema_unets = unets_list
855
+
856
+ if index != self.ema_unet_being_trained_index:
857
+ for unet_index, unet in enumerate(self.ema_unets):
858
+ unet.to(self.device if unet_index == index else 'cpu')
859
+
860
+ self.ema_unet_being_trained_index = index
861
+ return self.ema_unets[index]
862
+
863
+ def reset_ema_unets_all_one_device(self, device=None):
864
+ if not self.use_ema:
865
+ return
866
+
867
+ device = default(device, self.device)
868
+ self.ema_unets = nn.ModuleList([*self.ema_unets])
869
+ self.ema_unets.to(device)
870
+
871
+ self.ema_unet_being_trained_index = -1
872
+
873
+ @torch.no_grad()
874
+ @contextmanager
875
+ def use_ema_unets(self):
876
+ if not self.use_ema:
877
+ output = yield
878
+ return output
879
+
880
+ self.reset_ema_unets_all_one_device()
881
+ self.imagen.reset_unets_all_one_device()
882
+
883
+ self.unets.eval()
884
+
885
+ trainable_unets = self.imagen.unets
886
+ self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling
887
+
888
+ output = yield
889
+
890
+ self.imagen.unets = trainable_unets # restore original training unets
891
+
892
+ # cast the ema_model unets back to original device
893
+ for ema in self.ema_unets:
894
+ ema.restore_ema_model_device()
895
+
896
+ return output
897
+
898
+ def print_unet_devices(self):
899
+ self.print('unet devices:')
900
+ for i, unet in enumerate(self.imagen.unets):
901
+ device = next(unet.parameters()).device
902
+ self.print(f'\tunet {i}: {device}')
903
+
904
+ if not self.use_ema:
905
+ return
906
+
907
+ self.print('\nema unet devices:')
908
+ for i, ema_unet in enumerate(self.ema_unets):
909
+ device = next(ema_unet.parameters()).device
910
+ self.print(f'\tema unet {i}: {device}')
911
+
912
+ # overriding state dict functions
913
+
914
+ def state_dict(self, *args, **kwargs):
915
+ self.reset_ema_unets_all_one_device()
916
+ return super().state_dict(*args, **kwargs)
917
+
918
+ def load_state_dict(self, *args, **kwargs):
919
+ self.reset_ema_unets_all_one_device()
920
+ return super().load_state_dict(*args, **kwargs)
921
+
922
+ # encoding text functions
923
+
924
+ def encode_text(self, text, **kwargs):
925
+ return self.imagen.encode_text(text, **kwargs)
926
+
927
+ # forwarding functions and gradient step updates
928
+
929
+ def update(self, unet_number=None):
930
+ unet_number = self.validate_unet_number(unet_number)
931
+ self.validate_and_set_unet_being_trained(unet_number)
932
+ self.set_accelerator_scaler(unet_number)
933
+
934
+ index = unet_number - 1
935
+ unet = self.unet_being_trained
936
+
937
+ optimizer = getattr(self, f'optim{index}')
938
+ scaler = getattr(self, f'scaler{index}')
939
+ scheduler = getattr(self, f'scheduler{index}')
940
+ warmup_scheduler = getattr(self, f'warmup{index}')
941
+
942
+ # set the grad scaler on the accelerator, since we are managing one per u-net
943
+
944
+ if exists(self.max_grad_norm):
945
+ self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
946
+
947
+ optimizer.step()
948
+ optimizer.zero_grad()
949
+
950
+ if self.use_ema:
951
+ ema_unet = self.get_ema_unet(unet_number)
952
+ ema_unet.update()
953
+
954
+ # scheduler, if needed
955
+
956
+ maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
957
+
958
+ with maybe_warmup_context:
959
+ if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
960
+ scheduler.step()
961
+
962
+ self.steps += F.one_hot(torch.tensor(unet_number - 1, device=self.steps.device), num_classes=len(self.steps))
963
+
964
+ if not exists(self.checkpoint_path):
965
+ return
966
+
967
+ total_steps = int(self.steps.sum().item())
968
+
969
+ if total_steps % self.checkpoint_every:
970
+ return
971
+
972
+ self.save_to_checkpoint_folder()
973
+
974
+ @torch.no_grad()
975
+ @cast_torch_tensor
976
+ @imagen_sample_in_chunks
977
+ def sample(self, *args, **kwargs):
978
+ context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
979
+
980
+ self.print_untrained_unets()
981
+
982
+ if not self.is_main:
983
+ kwargs['use_tqdm'] = False
984
+
985
+ with context():
986
+ output = self.imagen.sample(*args, device=self.device, **kwargs)
987
+
988
+ return output
989
+
990
+ @partial(cast_torch_tensor, cast_fp16=True)
991
+ def forward(
992
+ self,
993
+ *args,
994
+ unet_number=None,
995
+ max_batch_size=None,
996
+ **kwargs
997
+ ):
998
+ unet_number = self.validate_unet_number(unet_number)
999
+ self.validate_and_set_unet_being_trained(unet_number)
1000
+ self.set_accelerator_scaler(unet_number)
1001
+
1002
+ assert not exists(
1003
+ self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
1004
+
1005
+ total_loss = 0.
1006
+
1007
+ for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size=max_batch_size, **kwargs):
1008
+ with self.accelerator.autocast():
1009
+ loss = self.imagen(*chunked_args, unet=self.unet_being_trained,
1010
+ unet_number=unet_number, **chunked_kwargs)
1011
+ loss = loss * chunk_size_frac
1012
+
1013
+ total_loss += loss.item()
1014
+
1015
+ if self.training:
1016
+ self.accelerator.backward(loss)
1017
+
1018
+ return total_loss
1019
+
1020
+
1021
+ class JointImagenTrainer(nn.Module):
1022
+ locked = False
1023
+
1024
+ def __init__(
1025
+ self,
1026
+ imagen=None,
1027
+ imagen_checkpoint_path=None,
1028
+ use_ema=True,
1029
+ lr=1e-4,
1030
+ eps=1e-8,
1031
+ beta1=0.9,
1032
+ beta2=0.99,
1033
+ max_grad_norm=None,
1034
+ group_wd_params=True,
1035
+ warmup_steps=None,
1036
+ cosine_decay_max_steps=None,
1037
+ only_train_unet_number=None,
1038
+ fp16=False,
1039
+ precision=None,
1040
+ split_batches=True,
1041
+ dl_tuple_output_keywords_names=('images', 'labels', 'texts'),
1042
+ verbose=True,
1043
+ split_valid_fraction=0.025,
1044
+ split_valid_from_train=False,
1045
+ split_random_seed=42,
1046
+ checkpoint_path=None,
1047
+ checkpoint_every=None,
1048
+ checkpoint_fs=None,
1049
+ fs_kwargs: dict = None,
1050
+ max_checkpoints_keep=20,
1051
+ lambdas=(1., 1.), # lambdas for image / label losses
1052
+ **kwargs
1053
+ ):
1054
+ super().__init__()
1055
+ assert not JointImagenTrainer.locked, 'JointImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
1056
+ assert exists(imagen) ^ exists(
1057
+ imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
1058
+
1059
+ # save lambdas for backward
1060
+
1061
+ self.lambdas = lambdas
1062
+
1063
+ # determine filesystem, using fsspec, for saving to local filesystem or cloud
1064
+
1065
+ self.fs = checkpoint_fs
1066
+
1067
+ if not exists(self.fs):
1068
+ fs_kwargs = default(fs_kwargs, {})
1069
+ self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
1070
+
1071
+ assert isinstance(imagen, (JointImagen, )) # ElucidatedImagen is not implemented yet
1072
+ ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
1073
+
1074
+ # elucidated or not
1075
+
1076
+ self.is_elucidated = isinstance(imagen, ElucidatedImagen)
1077
+
1078
+ # create accelerator instance
1079
+
1080
+ accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)
1081
+
1082
+ assert not (fp16 and exists(precision)
1083
+ ), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
1084
+ accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')
1085
+
1086
+ self.accelerator = Accelerator(**{
1087
+ 'split_batches': split_batches,
1088
+ 'mixed_precision': accelerator_mixed_precision,
1089
+ 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerate_kwargs})
1090
+
1091
+ JointImagenTrainer.locked = self.is_distributed
1092
+
1093
+ # cast data to fp16 at training time if needed
1094
+
1095
+ self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
1096
+
1097
+ # grad scaler must be managed outside of accelerator
1098
+
1099
+ grad_scaler_enabled = fp16
1100
+
1101
+ # imagen, unets and ema unets
1102
+
1103
+ self.imagen = imagen
1104
+ self.num_unets = len(self.imagen.unets)
1105
+
1106
+ self.use_ema = use_ema and self.is_main
1107
+ self.ema_unets = nn.ModuleList([])
1108
+
1109
+ # keep track of what unet is being trained on
1110
+ # only going to allow 1 unet training at a time
1111
+
1112
+ self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on
1113
+
1114
+ # data related functions
1115
+
1116
+ self.train_dl_iter = None
1117
+ self.train_dl = None
1118
+
1119
+ self.valid_dl_iter = None
1120
+ self.valid_dl = None
1121
+
1122
+ self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
1123
+
1124
+ # auto splitting validation from training, if dataset is passed in
1125
+
1126
+ self.split_valid_from_train = split_valid_from_train
1127
+
1128
+ assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
1129
+ self.split_valid_fraction = split_valid_fraction
1130
+ self.split_random_seed = split_random_seed
1131
+
1132
+ # be able to finely customize learning rate, weight decay
1133
+ # per unet
1134
+
1135
+ lr, eps, warmup_steps, cosine_decay_max_steps = map(
1136
+ partial(cast_tuple, length=self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))
1137
+
1138
+ for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
1139
+ optimizer = Adam(
1140
+ unet.parameters(),
1141
+ lr=unet_lr,
1142
+ eps=unet_eps,
1143
+ betas=(beta1, beta2),
1144
+ **kwargs
1145
+ )
1146
+
1147
+ if self.use_ema:
1148
+ self.ema_unets.append(EMA(unet, **ema_kwargs))
1149
+
1150
+ scaler = GradScaler(enabled=grad_scaler_enabled)
1151
+
1152
+ scheduler = warmup_scheduler = None
1153
+
1154
+ if exists(unet_cosine_decay_max_steps):
1155
+ scheduler = CosineAnnealingLR(optimizer, T_max=unet_cosine_decay_max_steps)
1156
+
1157
+ if exists(unet_warmup_steps):
1158
+ warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=unet_warmup_steps)
1159
+
1160
+ if not exists(scheduler):
1161
+ scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
1162
+
1163
+ # set on object
1164
+
1165
+ setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
1166
+ setattr(self, f'scaler{ind}', scaler)
1167
+ setattr(self, f'scheduler{ind}', scheduler)
1168
+ setattr(self, f'warmup{ind}', warmup_scheduler)
1169
+
1170
+ # gradient clipping if needed
1171
+
1172
+ self.max_grad_norm = max_grad_norm
1173
+
1174
+ # step tracker and misc
1175
+
1176
+ self.register_buffer('steps', torch.tensor([0] * self.num_unets))
1177
+
1178
+ self.verbose = verbose
1179
+
1180
+ # automatic set devices based on what accelerator decided
1181
+
1182
+ self.imagen.to(self.device)
1183
+ self.to(self.device)
1184
+
1185
+ # checkpointing
1186
+
1187
+ assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
1188
+ self.checkpoint_path = checkpoint_path
1189
+ self.checkpoint_every = checkpoint_every
1190
+ self.max_checkpoints_keep = max_checkpoints_keep
1191
+
1192
+ self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
1193
+
1194
+ if exists(checkpoint_path) and self.can_checkpoint:
1195
+ bucket = url_to_bucket(checkpoint_path)
1196
+
1197
+ if not self.fs.exists(bucket):
1198
+ self.fs.mkdir(bucket)
1199
+
1200
+ self.load_from_checkpoint_folder()
1201
+
1202
+ # only allowing training for unet
1203
+
1204
+ self.only_train_unet_number = only_train_unet_number
1205
+ self.validate_and_set_unet_being_trained(only_train_unet_number)
1206
+
1207
+ # computed values
1208
+
1209
+ @property
1210
+ def device(self):
1211
+ return self.accelerator.device
1212
+
1213
+ @property
1214
+ def is_distributed(self):
1215
+ return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
1216
+
1217
+ @property
1218
+ def is_main(self):
1219
+ return self.accelerator.is_main_process
1220
+
1221
+ @property
1222
+ def is_local_main(self):
1223
+ return self.accelerator.is_local_main_process
1224
+
1225
+ @property
1226
+ def unwrapped_unet(self):
1227
+ return self.accelerator.unwrap_model(self.unet_being_trained)
1228
+
1229
+ # optimizer helper functions
1230
+
1231
+ def get_lr(self, unet_number):
1232
+ self.validate_unet_number(unet_number)
1233
+ unet_index = unet_number - 1
1234
+
1235
+ optim = getattr(self, f'optim{unet_index}')
1236
+
1237
+ return optim.param_groups[0]['lr']
1238
+
1239
+ # function for allowing only one unet from being trained at a time
1240
+
1241
+ def validate_and_set_unet_being_trained(self, unet_number=None):
1242
+ if exists(unet_number):
1243
+ self.validate_unet_number(unet_number)
1244
+
1245
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
1246
+
1247
+ self.only_train_unet_number = unet_number
1248
+ self.imagen.only_train_unet_number = unet_number
1249
+
1250
+ if not exists(unet_number):
1251
+ return
1252
+
1253
+ self.wrap_unet(unet_number)
1254
+
1255
+ def wrap_unet(self, unet_number):
1256
+ if hasattr(self, 'one_unet_wrapped'):
1257
+ return
1258
+
1259
+ unet = self.imagen.get_unet(unet_number)
1260
+ self.unet_being_trained = self.accelerator.prepare(unet)
1261
+ unet_index = unet_number - 1
1262
+
1263
+ optimizer = getattr(self, f'optim{unet_index}')
1264
+ scheduler = getattr(self, f'scheduler{unet_index}')
1265
+
1266
+ optimizer = self.accelerator.prepare(optimizer)
1267
+
1268
+ if exists(scheduler):
1269
+ scheduler = self.accelerator.prepare(scheduler)
1270
+
1271
+ setattr(self, f'optim{unet_index}', optimizer)
1272
+ setattr(self, f'scheduler{unet_index}', scheduler)
1273
+
1274
+ self.one_unet_wrapped = True
1275
+
1276
+ # hacking accelerator due to not having separate gradscaler per optimizer
1277
+
1278
+ def set_accelerator_scaler(self, unet_number):
1279
+ unet_number = self.validate_unet_number(unet_number)
1280
+ scaler = getattr(self, f'scaler{unet_number - 1}')
1281
+
1282
+ self.accelerator.scaler = scaler
1283
+ for optimizer in self.accelerator._optimizers:
1284
+ optimizer.scaler = scaler
1285
+
1286
+ # helper print
1287
+
1288
+ def print(self, msg):
1289
+ if not self.is_main:
1290
+ return
1291
+
1292
+ if not self.verbose:
1293
+ return
1294
+
1295
+ return self.accelerator.print(msg)
1296
+
1297
+ # validating the unet number
1298
+
1299
+ def validate_unet_number(self, unet_number=None):
1300
+ if self.num_unets == 1:
1301
+ unet_number = default(unet_number, 1)
1302
+
1303
+ assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
1304
+ return unet_number
1305
+
1306
+ # number of training steps taken
1307
+
1308
+ def num_steps_taken(self, unet_number=None):
1309
+ if self.num_unets == 1:
1310
+ unet_number = default(unet_number, 1)
1311
+
1312
+ return self.steps[unet_number - 1].item()
1313
+
1314
+ def print_untrained_unets(self):
1315
+ print_final_error = False
1316
+
1317
+ for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
1318
+ if steps > 0 or isinstance(unet, NullUnet):
1319
+ continue
1320
+
1321
+ self.print(f'unet {ind + 1} has not been trained')
1322
+ print_final_error = True
1323
+
1324
+ if print_final_error:
1325
+ self.print(
1326
+ 'when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
1327
+
1328
+ # data related functions
1329
+
1330
+ def add_train_dataloader(self, dl=None):
1331
+ if not exists(dl):
1332
+ return
1333
+
1334
+ assert not exists(self.train_dl), 'training dataloader was already added'
1335
+ self.train_dl = self.accelerator.prepare(dl)
1336
+
1337
+ def add_valid_dataloader(self, dl):
1338
+ if not exists(dl):
1339
+ return
1340
+
1341
+ assert not exists(self.valid_dl), 'validation dataloader was already added'
1342
+ self.valid_dl = self.accelerator.prepare(dl)
1343
+
1344
+ def add_train_dataset(self, ds=None, *, batch_size, **dl_kwargs):
1345
+ if not exists(ds):
1346
+ return
1347
+
1348
+ assert not exists(self.train_dl), 'training dataloader was already added'
1349
+
1350
+ valid_ds = None
1351
+ if self.split_valid_from_train:
1352
+ train_size = int((1 - self.split_valid_fraction) * len(ds))
1353
+ valid_size = len(ds) - train_size
1354
+
1355
+ ds, valid_ds = random_split(ds, [train_size, valid_size],
1356
+ generator=torch.Generator().manual_seed(self.split_random_seed))
1357
+ self.print(
1358
+ f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')
1359
+
1360
+ dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
1361
+ self.train_dl = self.accelerator.prepare(dl)
1362
+
1363
+ if not self.split_valid_from_train:
1364
+ return
1365
+
1366
+ self.add_valid_dataset(valid_ds, batch_size=batch_size, **dl_kwargs)
1367
+
1368
+ def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
1369
+ if not exists(ds):
1370
+ return
1371
+
1372
+ assert not exists(self.valid_dl), 'validation dataloader was already added'
1373
+
1374
+ dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
1375
+ self.valid_dl = self.accelerator.prepare(dl)
1376
+
1377
+ def create_train_iter(self):
1378
+ assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
1379
+
1380
+ if exists(self.train_dl_iter):
1381
+ return
1382
+
1383
+ self.train_dl_iter = cycle(self.train_dl)
1384
+
1385
+ def create_valid_iter(self):
1386
+ assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
1387
+
1388
+ if exists(self.valid_dl_iter):
1389
+ return
1390
+
1391
+ self.valid_dl_iter = cycle(self.valid_dl)
1392
+
1393
+ def train_step(self, unet_number=None, **kwargs):
1394
+ self.create_train_iter()
1395
+ loss = self.step_with_dl_iter(self.train_dl_iter, unet_number=unet_number, **kwargs)
1396
+ self.update(unet_number=unet_number)
1397
+ return loss
1398
+
1399
+ @torch.no_grad()
1400
+ @eval_decorator
1401
+ def valid_step(self, **kwargs):
1402
+ self.create_valid_iter()
1403
+
1404
+ context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
1405
+
1406
+ with context():
1407
+ loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
1408
+ return loss
1409
+
1410
+ def step_with_dl_iter(self, dl_iter, **kwargs):
1411
+ dl_tuple_output = cast_tuple(next(dl_iter))
1412
+ model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
1413
+ loss = self.forward(**{**kwargs, **model_input})
1414
+ return loss
1415
+
1416
+ # checkpointing functions
1417
+
1418
+ @property
1419
+ def all_checkpoints_sorted(self):
1420
+ glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
1421
+ checkpoints = self.fs.glob(glob_pattern)
1422
+ sorted_checkpoints = sorted(checkpoints, key=lambda x: int(str(x).split('.')[-2]), reverse=True)
1423
+ return sorted_checkpoints
1424
+
1425
+ def load_from_checkpoint_folder(self, last_total_steps=-1):
1426
+ if last_total_steps != -1:
1427
+ filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
1428
+ self.load(filepath)
1429
+ return
1430
+
1431
+ sorted_checkpoints = self.all_checkpoints_sorted
1432
+
1433
+ if len(sorted_checkpoints) == 0:
1434
+ self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
1435
+ return
1436
+
1437
+ last_checkpoint = sorted_checkpoints[0]
1438
+ self.load(last_checkpoint)
1439
+
1440
+ def save_to_checkpoint_folder(self):
1441
+ self.accelerator.wait_for_everyone()
1442
+
1443
+ if not self.can_checkpoint:
1444
+ return
1445
+
1446
+ total_steps = int(self.steps.sum().item())
1447
+ filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
1448
+
1449
+ self.save(filepath)
1450
+
1451
+ if self.max_checkpoints_keep <= 0:
1452
+ return
1453
+
1454
+ sorted_checkpoints = self.all_checkpoints_sorted
1455
+ checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
1456
+
1457
+ for checkpoint in checkpoints_to_discard:
1458
+ self.fs.rm(checkpoint)
1459
+
1460
+ # saving and loading functions
1461
+
1462
+ def save(
1463
+ self,
1464
+ path,
1465
+ overwrite=True,
1466
+ without_optim_and_sched=False,
1467
+ **kwargs
1468
+ ):
1469
+ # self.accelerator.wait_for_everyone()
1470
+
1471
+ if not self.can_checkpoint:
1472
+ return
1473
+
1474
+ fs = self.fs
1475
+
1476
+ assert not (fs.exists(path) and not overwrite)
1477
+
1478
+ self.reset_ema_unets_all_one_device()
1479
+
1480
+ save_obj = dict(
1481
+ model=self.imagen.state_dict(),
1482
+ version=__version__,
1483
+ steps=self.steps.cpu(),
1484
+ **kwargs
1485
+ )
1486
+
1487
+ save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
1488
+
1489
+ for ind in save_optim_and_sched_iter:
1490
+ scaler_key = f'scaler{ind}'
1491
+ optimizer_key = f'optim{ind}'
1492
+ scheduler_key = f'scheduler{ind}'
1493
+ warmup_scheduler_key = f'warmup{ind}'
1494
+
1495
+ scaler = getattr(self, scaler_key)
1496
+ optimizer = getattr(self, optimizer_key)
1497
+ scheduler = getattr(self, scheduler_key)
1498
+ warmup_scheduler = getattr(self, warmup_scheduler_key)
1499
+
1500
+ if exists(scheduler):
1501
+ save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
1502
+
1503
+ if exists(warmup_scheduler):
1504
+ save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
1505
+
1506
+ save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
1507
+
1508
+ if self.use_ema:
1509
+ save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
1510
+
1511
+ # determine if imagen config is available
1512
+
1513
+ if hasattr(self.imagen, '_config'):
1514
+ self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')
1515
+
1516
+ save_obj = {
1517
+ **save_obj,
1518
+ 'imagen_type': 'elucidated' if self.is_elucidated else 'original',
1519
+ 'imagen_params': self.imagen._config
1520
+ }
1521
+
1522
+ # save to path
1523
+
1524
+ with fs.open(path, 'wb') as f:
1525
+ torch.save(save_obj, f)
1526
+
1527
+ self.print(f'checkpoint saved to {path}')
1528
+
1529
+ def load(self, path, only_model=False, strict=True, noop_if_not_exist=False):
1530
+ fs = self.fs
1531
+
1532
+ if noop_if_not_exist and not fs.exists(path):
1533
+ self.print(f'trainer checkpoint not found at {str(path)}')
1534
+ return
1535
+
1536
+ assert fs.exists(path), f'{path} does not exist'
1537
+
1538
+ self.reset_ema_unets_all_one_device()
1539
+
1540
+ # to avoid extra GPU memory usage in main process when using Accelerate
1541
+
1542
+ with fs.open(path) as f:
1543
+ loaded_obj = torch.load(f, map_location='cpu')
1544
+
1545
+ if version.parse(__version__) != version.parse(loaded_obj['version']):
1546
+ self.print(f'loading saved imagen at version {loaded_obj["version"]}, '
1547
+ f'but current package version is {__version__}')
1548
+
1549
+ try:
1550
+ self.imagen.load_state_dict(loaded_obj['model'], strict=strict)
1551
+ except RuntimeError:
1552
+ print("Failed loading state dict. Trying partial load")
1553
+ self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
1554
+ loaded_obj['model']))
1555
+
1556
+ if only_model:
1557
+ return loaded_obj
1558
+
1559
+ self.steps.copy_(loaded_obj['steps'])
1560
+
1561
+ for ind in range(0, self.num_unets):
1562
+ scaler_key = f'scaler{ind}'
1563
+ optimizer_key = f'optim{ind}'
1564
+ scheduler_key = f'scheduler{ind}'
1565
+ warmup_scheduler_key = f'warmup{ind}'
1566
+
1567
+ scaler = getattr(self, scaler_key)
1568
+ optimizer = getattr(self, optimizer_key)
1569
+ scheduler = getattr(self, scheduler_key)
1570
+ warmup_scheduler = getattr(self, warmup_scheduler_key)
1571
+
1572
+ if exists(scheduler) and scheduler_key in loaded_obj:
1573
+ scheduler.load_state_dict(loaded_obj[scheduler_key])
1574
+
1575
+ if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
1576
+ warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
1577
+
1578
+ if exists(optimizer):
1579
+ try:
1580
+ optimizer.load_state_dict(loaded_obj[optimizer_key])
1581
+ scaler.load_state_dict(loaded_obj[scaler_key])
1582
+ except:
1583
+ self.print('could not load optimizer and scaler, '
1584
+ 'possibly because you have turned on mixed precision training since the last run. '
1585
+ 'resuming with new optimizer and scalers')
1586
+
1587
+ if self.use_ema:
1588
+ assert 'ema' in loaded_obj
1589
+ try:
1590
+ self.ema_unets.load_state_dict(loaded_obj['ema'], strict=strict)
1591
+ except RuntimeError:
1592
+ print("Failed loading state dict. Trying partial load")
1593
+ self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
1594
+ loaded_obj['ema']))
1595
+
1596
+ self.print(f'checkpoint loaded from {path}')
1597
+ return loaded_obj
1598
+
1599
+ # managing ema unets and their devices
1600
+
1601
+ @property
1602
+ def unets(self):
1603
+ return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
1604
+
1605
+ def get_ema_unet(self, unet_number=None):
1606
+ if not self.use_ema:
1607
+ return
1608
+
1609
+ unet_number = self.validate_unet_number(unet_number)
1610
+ index = unet_number - 1
1611
+
1612
+ if isinstance(self.unets, nn.ModuleList):
1613
+ unets_list = [unet for unet in self.ema_unets]
1614
+ delattr(self, 'ema_unets')
1615
+ self.ema_unets = unets_list
1616
+
1617
+ if index != self.ema_unet_being_trained_index:
1618
+ for unet_index, unet in enumerate(self.ema_unets):
1619
+ unet.to(self.device if unet_index == index else 'cpu')
1620
+
1621
+ self.ema_unet_being_trained_index = index
1622
+ return self.ema_unets[index]
1623
+
1624
+ def reset_ema_unets_all_one_device(self, device=None):
1625
+ if not self.use_ema:
1626
+ return
1627
+
1628
+ device = default(device, self.device)
1629
+ self.ema_unets = nn.ModuleList([*self.ema_unets])
1630
+ self.ema_unets.to(device)
1631
+
1632
+ self.ema_unet_being_trained_index = -1
1633
+
1634
+ @torch.no_grad()
1635
+ @contextmanager
1636
+ def use_ema_unets(self):
1637
+ if not self.use_ema:
1638
+ output = yield
1639
+ return output
1640
+
1641
+ self.reset_ema_unets_all_one_device()
1642
+ self.imagen.reset_unets_all_one_device()
1643
+
1644
+ self.unets.eval()
1645
+
1646
+ trainable_unets = self.imagen.unets
1647
+ self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling
1648
+
1649
+ output = yield
1650
+
1651
+ self.imagen.unets = trainable_unets # restore original training unets
1652
+
1653
+ # cast the ema_model unets back to original device
1654
+ for ema in self.ema_unets:
1655
+ ema.restore_ema_model_device()
1656
+
1657
+ return output
1658
+
1659
+ def print_unet_devices(self):
1660
+ self.print('unet devices:')
1661
+ for i, unet in enumerate(self.imagen.unets):
1662
+ device = next(unet.parameters()).device
1663
+ self.print(f'\tunet {i}: {device}')
1664
+
1665
+ if not self.use_ema:
1666
+ return
1667
+
1668
+ self.print('\nema unet devices:')
1669
+ for i, ema_unet in enumerate(self.ema_unets):
1670
+ device = next(ema_unet.parameters()).device
1671
+ self.print(f'\tema unet {i}: {device}')
1672
+
1673
+ # overriding state dict functions
1674
+
1675
+ def state_dict(self, *args, **kwargs):
1676
+ self.reset_ema_unets_all_one_device()
1677
+ return super().state_dict(*args, **kwargs)
1678
+
1679
+ def load_state_dict(self, *args, **kwargs):
1680
+ self.reset_ema_unets_all_one_device()
1681
+ return super().load_state_dict(*args, **kwargs)
1682
+
1683
+ # encoding text functions
1684
+
1685
+ def encode_text(self, text, **kwargs):
1686
+ return self.imagen.encode_text(text, **kwargs)
1687
+
1688
+ # forwarding functions and gradient step updates
1689
+
1690
+ def update(self, unet_number=None):
1691
+ unet_number = self.validate_unet_number(unet_number)
1692
+ self.validate_and_set_unet_being_trained(unet_number)
1693
+ self.set_accelerator_scaler(unet_number)
1694
+
1695
+ index = unet_number - 1
1696
+ unet = self.unet_being_trained
1697
+
1698
+ optimizer = getattr(self, f'optim{index}')
1699
+ scaler = getattr(self, f'scaler{index}')
1700
+ scheduler = getattr(self, f'scheduler{index}')
1701
+ warmup_scheduler = getattr(self, f'warmup{index}')
1702
+
1703
+ # set the grad scaler on the accelerator, since we are managing one per u-net
1704
+
1705
+ if exists(self.max_grad_norm):
1706
+ self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
1707
+
1708
+ optimizer.step()
1709
+ optimizer.zero_grad()
1710
+
1711
+ if self.use_ema:
1712
+ ema_unet = self.get_ema_unet(unet_number)
1713
+ ema_unet.update()
1714
+
1715
+ # scheduler, if needed
1716
+
1717
+ maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
1718
+
1719
+ with maybe_warmup_context:
1720
+ if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
1721
+ scheduler.step()
1722
+
1723
+ self.steps += F.one_hot(torch.tensor(unet_number - 1, device=self.steps.device), num_classes=len(self.steps))
1724
+
1725
+ if not exists(self.checkpoint_path):
1726
+ return
1727
+
1728
+ total_steps = int(self.steps.sum().item())
1729
+
1730
+ if total_steps % self.checkpoint_every:
1731
+ return
1732
+
1733
+ self.save_to_checkpoint_folder()
1734
+
1735
+ @torch.no_grad()
1736
+ @cast_torch_tensor
1737
+ @imagen_sample_in_chunks
1738
+ def sample(self, *args, **kwargs):
1739
+ context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
1740
+
1741
+ self.print_untrained_unets()
1742
+
1743
+ if not self.is_main:
1744
+ kwargs['use_tqdm'] = False
1745
+
1746
+ with context():
1747
+ output = self.imagen.sample(*args, device=self.device, **kwargs)
1748
+
1749
+ return output
1750
+
1751
+ @partial(cast_torch_tensor, cast_fp16=True)
1752
+ def forward(
1753
+ self,
1754
+ *args,
1755
+ unet_number=None,
1756
+ max_batch_size=None,
1757
+ **kwargs
1758
+ ):
1759
+ unet_number = self.validate_unet_number(unet_number)
1760
+ self.validate_and_set_unet_being_trained(unet_number)
1761
+ self.set_accelerator_scaler(unet_number)
1762
+
1763
+ assert not exists(
1764
+ self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
1765
+
1766
+ total_loss = 0.
1767
+ total_loss_seg = 0.
1768
+
1769
+ for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size=max_batch_size, **kwargs):
1770
+ with self.accelerator.autocast():
1771
+ loss, loss_seg = self.imagen(*chunked_args, unet=self.unet_being_trained,
1772
+ unet_number=unet_number, **chunked_kwargs)
1773
+ loss = loss * chunk_size_frac
1774
+ loss_seg = loss_seg * chunk_size_frac
1775
+
1776
+ total_loss += loss.item()
1777
+ total_loss_seg += loss_seg.item()
1778
+
1779
+ if self.training:
1780
+ self.accelerator.backward(loss * self.lambdas[0] + loss_seg * self.lambdas[1])
1781
+
1782
+ return total_loss, total_loss_seg
imagen_pytorch/utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import reduce
4
+ from pathlib import Path
5
+
6
+ from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
7
+ from ema_pytorch import EMA
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def safeget(dictionary, keys, default = None):
13
+ return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
14
+
15
+ def load_imagen_from_checkpoint(
16
+ checkpoint_path,
17
+ load_weights = True,
18
+ load_ema_if_available = False
19
+ ):
20
+ model_path = Path(checkpoint_path)
21
+ full_model_path = str(model_path.resolve())
22
+ assert model_path.exists(), f'checkpoint not found at {full_model_path}'
23
+ loaded = torch.load(str(model_path), map_location='cpu')
24
+
25
+ imagen_params = safeget(loaded, 'imagen_params')
26
+ imagen_type = safeget(loaded, 'imagen_type')
27
+
28
+ if imagen_type == 'original':
29
+ imagen_klass = ImagenConfig
30
+ elif imagen_type == 'elucidated':
31
+ imagen_klass = ElucidatedImagenConfig
32
+ else:
33
+ raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
34
+
35
+ assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
36
+
37
+ imagen = imagen_klass(**imagen_params).create()
38
+
39
+ if not load_weights:
40
+ return imagen
41
+
42
+ has_ema = 'ema' in loaded
43
+ should_load_ema = has_ema and load_ema_if_available
44
+
45
+ imagen.load_state_dict(loaded['model'])
46
+
47
+ if not should_load_ema:
48
+ print('loading non-EMA version of unets')
49
+ return imagen
50
+
51
+ ema_unets = nn.ModuleList([])
52
+ for unet in imagen.unets:
53
+ ema_unets.append(EMA(unet))
54
+
55
+ ema_unets.load_state_dict(loaded['ema'])
56
+
57
+ for unet, ema_unet in zip(imagen.unets, ema_unets):
58
+ unet.load_state_dict(ema_unet.ema_model.state_dict())
59
+
60
+ print('loaded EMA version of unets')
61
+ return imagen
imagen_pytorch/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '1.11.14'
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [tool.autopep8]
2
+ max_line_length = 120
3
+ ignore = ["E402"]
repaint/LICENSES/LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+
6
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+
8
+ The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
repaint/LICENSES/LICENSE_guided_diffusion ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
repaint/LICENSES/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # License and Acknowledgement
2
+
3
+ A big thanks to following contributes that open sourced their code and therefore helped us a lot in developing RePaint!
4
+
5
+ This repository was forked from:
6
+ https://github.com/openai/guided-diffusion
7
+
8
+ It contains code from:
9
+ https://github.com/hojonathanho/diffusion
10
+
11
+ If we missed a contribution, please contact us.
repaint/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RePaint
2
+ **Inpainting using Denoising Diffusion Probabilistic Models**
3
+
4
+
5
+ CVPR 2022 [[Paper]](https://bit.ly/3b1ABEb)
6
+
7
+ [![Denoising_Diffusion_Inpainting_Animation](https://user-images.githubusercontent.com/11280511/150849757-5cd762cb-07a3-46aa-a906-0fe4606eba3b.gif)](#)
8
+
9
+ ## Setup
10
+
11
+ ### 1. Code
12
+
13
+ ```bash
14
+ git clone https://github.com/andreas128/RePaint.git
15
+ ```
16
+
17
+ ### 2. Environment
18
+ ```bash
19
+ pip install numpy torch blobfile tqdm pyYaml pillow # e.g. torch 1.7.1+cu110.
20
+ ```
21
+
22
+ ### 3. Download models and data
23
+
24
+ ```bash
25
+ pip install --upgrade gdown && bash ./download.sh
26
+ ```
27
+
28
+ That downloads the models for ImageNet, CelebA-HQ, and Places2, as well as the face example and example masks.
29
+
30
+
31
+ ### 4. Run example
32
+ ```bash
33
+ python test.py --conf_path confs/face_example.yml
34
+ ```
35
+ Find the output in `./log/face_example/inpainted`
36
+
37
+ *Note: After refactoring the code, we did not reevaluate all experiments.*
38
+
39
+ <br>
40
+
41
+ # RePaint fills a missing image part using diffusion models
42
+
43
+ <table border="0" cellspacing="0" cellpadding="0">
44
+ <tr>
45
+ <td><img alt="RePaint Inpainting using Denoising Diffusion Probabilistic Models Demo 1" src="https://user-images.githubusercontent.com/11280511/150766080-9f3d7bc9-99f2-472e-9e5d-b6ed456340d1.gif"></td>
46
+ <td><img alt="RePaint Inpainting using Denoising Diffusion Probabilistic Models Demo 2" src="https://user-images.githubusercontent.com/11280511/150766125-adf5a3cb-17f2-432c-a8f6-ce0b97122819.gif"></td>
47
+ </tr>
48
+ </table>
49
+
50
+ **What are the blue parts?** <br>
51
+ Those parts are missing and therefore have to be filled by RePaint. <br> RePaint generates the missing parts inspired by the known parts.
52
+
53
+ **How does it work?** <br>
54
+ RePaint starts from pure noise. Then the image is denoised step-by-step. <br> It uses the known part to fill the unknown part in each step.
55
+
56
+ **Why does the noise level fluctuate during generation?** <br>
57
+ Our noise schedule improves the harmony between the generated and <br> the known part [[4.2 Resampling]](https://bit.ly/3b1ABEb).
58
+
59
+ <br>
60
+
61
+ ## Details on data
62
+
63
+ **Which datasets and masks have a ready-to-use config file?**
64
+
65
+ We provide config files for ImageNet (inet256), CelebA-HQ (c256) and Places2 (p256) for the masks "thin", "thick", "every second line", "super-resolution", "expand" and "half" in [`./confs`](https://github.com/andreas128/RePaint/tree/main/confs). You can use them as shown in the example above.
66
+
67
+ **How to prepare the test data?**
68
+
69
+ We use [LaMa](https://github.com/saic-mdal/lama) for validation and testing. Follow their instructions and add the images as specified in the config files. When you download the data using `download.sh`, you can see examples of masks we used.
70
+
71
+ **How to apply it to other images?**
72
+
73
+ Copy the config file for the dataset that matches your data best (for faces aligned like CelebA-HQ `_c256`, for diverse images `_inet256`). Then set the [`gt_path`](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L70) and [`mask_path`](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L71) to where your input is. The masks have the value 255 for known regions and 0 for unknown areas (the ones that get generated).
74
+
75
+ **How to apply it for other datasets?**
76
+
77
+ If you work with other data than faces, places or general images, train a model using the [guided-diffusion](https://github.com/openai/guided-diffusion) repository. Note that RePaint is an inference scheme. We do not train or finetune the diffusion model but condition pre-trained models.
78
+
79
+ ## Adapt the code
80
+
81
+ **How to design a new schedule?**
82
+
83
+ Fill in your own parameters in this [line](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/guided_diffusion/scheduler.py#L180) to visualize the schedule using `python guided_diffusion/scheduler.py`. Then copy a config file, set your parameters in these [lines](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L61-L65) and run the inference using `python test.py --conf_path confs/my_schedule.yml`.
84
+
85
+ **How to speed up the inference?**
86
+
87
+ The following settings are in the [schedule_jump_params](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L61) key in the config files. You can visualize them as described above.
88
+
89
+ - Reduce `t_T`, the total number of steps (without resampling). The lower it is, the more noise gets removed per step.
90
+ - Reduce `jump_n_sample` to resample fewer times.
91
+ - Apply resampling not from the beginning but only after a specific time by setting `start_resampling`.
92
+
93
+ ## Code overview
94
+
95
+ - **Schedule:** The list of diffusion times t which will be traversed are obtained in this [line](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L503). e.g. times = [249, 248, 249, 248, 247, 248, 247, 248, 247, 246, ...]
96
+ - **Denoise:** Reverse diffusion steps from x<sub>t</sub> (more noise) to a x<sub>t-1</sub> (less noisy) are done below this [line](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L515).
97
+ - **Predict:** The model is called [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L237) and obtains x<sub>t</sub> and the time t to predict a tensor with 6 channels containing information about the mean and variance of x<sub>t-1</sub>. Then the value range of the variance is adjusted [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L252). The mean of x<sub>t-1</sub> is obtained by the weighted sum of the estimated [x<sub>0</sub>](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L270) and x<sub>t</sub> [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L189). The obtained mean and variance is used [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L402) to sample x<sub>t-1</sub>. (This is the original reverse step from [guided-diffusion](https://github.com/openai/guided-diffusion.git). )
98
+ - **Condition:** The known part of the input image needs to have the same amount of noise as the part that the diffusion model generates to join them. The required amount of noise is calculated [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L368) and added to the known part [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L371). The generated and sampled parts get joined using a maks [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L373).
99
+ - **Undo:** The forward diffusion steps from x<sub>t-1</sub> to x<sub>t</sub> is done after this [line](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L536). The noise gets added to x<sub>t-1</sub> [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L176).
100
+
101
+ ## Issues
102
+
103
+ **Do you have further questions?**
104
+
105
+ Please open an [issue](https://github.com/andreas128/RePaint/issues), and we will try to help you.
106
+
107
+ **Did you find a mistake?**
108
+
109
+ Please create a pull request. For examply by clicking the pencil button on the top right on the github page.
110
+
111
+ <br>
112
+
113
+ # RePaint on diverse content and shapes of missing regions
114
+
115
+ The blue region is unknown and filled by RePaint:
116
+
117
+ ![Denoising Diffusion Probabilistic Models Inpainting](https://user-images.githubusercontent.com/11280511/150803812-a4729ef8-6ad4-46aa-ae99-8c27fbb2ea2e.png)
118
+
119
+
120
+ **Note: RePaint creates many meaningful fillings.** <br>
121
+ 1) **Face:** Expressions and features like an earring or a mole. <br>
122
+ 2) **Computer:** The computer screen shows different images, text, and even a logo. <br>
123
+ 3) **Greens:** RePaint makes sense of the tiny known part and incorporates it in a beetle, spaghetti, and plants. <br>
124
+ 4) **Garden:** From simple filling like a curtain to complex filling like a human. <br>
125
+
126
+
127
+ <br>
128
+
129
+ # Extreme Case 1: Generate every second line
130
+ ![Denoising_Diffusion_Probabilistic_Models_Inpainting_Every_Second_Line](https://user-images.githubusercontent.com/11280511/150818064-29789cbe-73c7-45de-a955-9fad5fb24c0e.png)
131
+
132
+ - Every Second line of the input image is unknown.
133
+ - Most inpainting methods fail on such masks.
134
+
135
+
136
+ <br>
137
+
138
+ # Extreme Case 2: Upscale an image
139
+ ![Denoising_Diffusion_Probabilistic_Models_Inpainting_Super_Resolution](https://user-images.githubusercontent.com/11280511/150818741-5ed19a0b-1cf8-4f28-9e57-2e4c12303c3e.png)
140
+
141
+ - The inpainting only knows pixels with a stridden access of 2.
142
+ - A ratio of 3/4 of the image has to be filled.
143
+ - This is equivalent to Super-Resolution with the Nearest Neighbor kernel.
144
+
145
+ <br>
146
+
147
+ # RePaint conditions the diffusion model on the known part
148
+
149
+ - RePaint uses unconditionally trained Denoising Diffusion Probabilistic Models.
150
+ - We condition during inference on the given image content.
151
+
152
+ ![Denoising Diffusion Probabilistic Models Inpainting Method](https://user-images.githubusercontent.com/11280511/180631151-59b6674b-bf2c-4501-8307-03c9f5f593ae.gif)
153
+
154
+ **Intuition of one conditioned denoising step:**
155
+ 1) **Sample the known part:** Add gaussian noise to the known regions of the image. <br> We obtain a noisy image that follows the denoising process exactly.
156
+ 2) **Denoise one step:** Denoise the previous image for one step. This generates <br> content for the unknown region conditioned on the known region.
157
+ 3) **Join:** Merge the images from both steps.
158
+
159
+ Details are in Algorithm 1 on Page 5. [[Paper]](https://bit.ly/3b1ABEb)
160
+
161
+
162
+ <br>
163
+
164
+ # How to harmonize the generated with the known part?
165
+
166
+ - **Fail:** When using only the algorithm above, the filling is not well harmonized with the known part (n=1).
167
+ - **Fix:** When applying the [[4.2 Resampling]](https://bit.ly/3b1ABEb) technique, the images are better harmonized (n>1).
168
+
169
+ <img width="1577" alt="Diffusion Model Resampling" src="https://user-images.githubusercontent.com/11280511/150822917-737c00b0-b6bb-439d-a5bf-e73238d30990.png">
170
+
171
+ <br>
172
+
173
+ # RePaint Fails
174
+ - The ImageNet model is biased towards inpainting dogs.
175
+ - This is due to the high ratio of dog images in ImageNet.
176
+
177
+ <img width="1653" alt="RePaint Fails" src="https://user-images.githubusercontent.com/11280511/150853163-b965f59c-5ad4-485b-816e-4391e77b5199.png">
178
+
179
+ <br>
180
+
181
+ # User Study State-of-the-Art Comparison
182
+
183
+ - Outperforms autoregression-based and GAN-based SOTA methods, <br> with 95% significance for all masks except for two inconclusive cases.
184
+ - The user study was done for six different masks on three datasets.
185
+ - RePaint outperformed SOTA methods in 42 of 44 cases. [[Paper]](https://bit.ly/3b1ABEb)
186
+
187
+ <br>
188
+
189
+ # Explore the Visual Examples
190
+ - Datasets: CelebA-HQ, ImageNet, Places2
191
+ - Masks: Random strokes, half image, huge, sparse
192
+ - Explore more examples like this in the [[Appendix]](https://bit.ly/3b1ABEb).
193
+
194
+
195
+ <img width="1556" alt="Denosing Diffusion Inpainting Examples" src="https://user-images.githubusercontent.com/11280511/150864677-0eb482ae-c114-4b0b-b1e0-9be9574da307.png">
196
+
197
+
198
+ <br>
199
+
200
+
201
+ # Acknowledgement
202
+
203
+ This work was supported by the ETH Zürich Fund (OK), a Huawei Technologies Oy (Finland) project, and an Nvidia GPU grant.
204
+
205
+ This repository is based on [guided-diffuion](https://github.com/openai/guided-diffusion.git) from OpenAI.
repaint/conf_mgt/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+
18
+ from conf_mgt.conf_base import Default_Conf
repaint/conf_mgt/conf_base.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ from functools import lru_cache
18
+ import os
19
+ import torch
20
+ from utils import imwrite
21
+
22
+ from collections import defaultdict
23
+ from os.path import isfile, expanduser
24
+
25
+ def to_file_ext(img_names, ext):
26
+ img_names_out = []
27
+ for img_name in img_names:
28
+ splits = img_name.split('.')
29
+ if not len(splits) == 2:
30
+ raise RuntimeError("File name needs exactly one '.':", img_name)
31
+ img_names_out.append(splits[0] + '.' + ext)
32
+
33
+ return img_names_out
34
+
35
+ def write_images(imgs, img_names, dir_path):
36
+ os.makedirs(dir_path, exist_ok=True)
37
+
38
+ for image_name, image in zip(img_names, imgs):
39
+ out_path = os.path.join(dir_path, image_name)
40
+ imwrite(img=image, path=out_path)
41
+
42
+
43
+
44
+ class NoneDict(defaultdict):
45
+ def __init__(self):
46
+ super().__init__(self.return_None)
47
+
48
+ @staticmethod
49
+ def return_None():
50
+ return None
51
+
52
+ def __getattr__(self, attr):
53
+ return self.get(attr)
54
+
55
+
56
+ class Default_Conf(NoneDict):
57
+ def __init__(self):
58
+ pass
59
+
60
+ def get_dataloader(self, dset='train', dsName=None, batch_size=None, return_dataset=False):
61
+
62
+ if batch_size is None:
63
+ batch_size = self.batch_size
64
+
65
+ candidates = self['data'][dset]
66
+ ds_conf = candidates[dsName].copy()
67
+
68
+ if ds_conf.get('mask_loader', False):
69
+ from guided_diffusion.image_datasets import load_data_inpa
70
+ return load_data_inpa(**ds_conf, conf=self)
71
+ else:
72
+ raise NotImplementedError()
73
+
74
+ def get_debug_variance_path(self):
75
+ return os.path.expanduser(os.path.join(self.get_default_eval_conf()['paths']['root'], 'debug/debug_variance'))
76
+
77
+ @ staticmethod
78
+ def device():
79
+ return 'cuda' if torch.cuda.is_available() else 'cpu'
80
+
81
+ def eval_imswrite(self, srs=None, img_names=None, dset=None, name=None, ext='png', lrs=None, gts=None, gt_keep_masks=None, verify_same=True):
82
+ img_names = to_file_ext(img_names, ext)
83
+
84
+ if dset is None:
85
+ dset = self.get_default_eval_name()
86
+
87
+ max_len = self['data'][dset][name].get('max_len')
88
+
89
+ if srs is not None:
90
+ sr_dir_path = expanduser(self['data'][dset][name]['paths']['srs'])
91
+ write_images(srs, img_names, sr_dir_path)
92
+
93
+ if gt_keep_masks is not None:
94
+ mask_dir_path = expanduser(
95
+ self['data'][dset][name]['paths']['gt_keep_masks'])
96
+ write_images(gt_keep_masks, img_names, mask_dir_path)
97
+
98
+ gts_path = self['data'][dset][name]['paths'].get('gts')
99
+ if gts is not None and gts_path:
100
+ gt_dir_path = expanduser(gts_path)
101
+ write_images(gts, img_names, gt_dir_path)
102
+
103
+ if lrs is not None:
104
+ lrs_dir_path = expanduser(
105
+ self['data'][dset][name]['paths']['lrs'])
106
+ write_images(lrs, img_names, lrs_dir_path)
107
+
108
+ def get_default_eval_name(self):
109
+ candidates = self['data']['eval'].keys()
110
+ if len(candidates) != 1:
111
+ raise RuntimeError(
112
+ f"Need exactly one candidate for {self.name}: {candidates}")
113
+ return list(candidates)[0]
114
+
115
+ def pget(self, name, default=None):
116
+ if '.' in name:
117
+ names = name.split('.')
118
+ else:
119
+ names = [name]
120
+
121
+ sub_dict = self
122
+ for name in names:
123
+ sub_dict = sub_dict.get(name, default)
124
+
125
+ if sub_dict == None:
126
+ return default
127
+
128
+ return sub_dict
repaint/confs/face_example.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ latex_name: RePaint
53
+ method_name: Repaint
54
+ image_size: 256
55
+ model_path: ./data/pretrained/celeba256_250000.pt
56
+ name: face_example
57
+ inpa_inj_sched_prev: true
58
+ n_jobs: 1
59
+ print_estimated_vars: true
60
+ inpa_inj_sched_prev_cumnoise: false
61
+ schedule_jump_params:
62
+ t_T: 250
63
+ n_sample: 5 # for GCDP, 1 for image2layout and 5 for layout2image
64
+ jump_length: 10
65
+ jump_n_sample: 10
66
+ data:
67
+ eval:
68
+ paper_face_mask:
69
+ mask_loader: true
70
+ gt_path: ./data/datasets/gts/gcdp
71
+ mask_path: ./data/datasets/gt_keep_masks/gcdp
72
+ image_size: 256
73
+ class_cond: false
74
+ deterministic: true
75
+ random_crop: false
76
+ random_flip: false
77
+ return_dict: true
78
+ drop_last: false
79
+ batch_size: 1
80
+ return_dataloader: true
81
+ offset: 0
82
+ max_len: 8
83
+ paths:
84
+ srs: ./log/face_example/inpainted
85
+ lrs: ./log/face_example/gt_masked
86
+ gts: ./log/face_example/gt
87
+ gt_keep_masks: ./log/face_example/gt_keep_mask
repaint/confs/test_c256_ev2li.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/celeba256_250000.pt
54
+ name: test_c256_ev2li
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_c256_ev2li_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/c256
69
+ mask_path: ./data/datasets/gt_keep_masks/ev2li
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: fix_ev2li_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_c256_ev2li/inpainted
84
+ lrs: ./log/test_c256_ev2li/gt_masked
85
+ gts: ./log/test_c256_ev2li/gt
86
+ gt_keep_masks: ./log/test_c256_ev2li/gt_keep_mask
repaint/confs/test_c256_ex64.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
54
+ name: test_c256_ex64
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_c256_ex64_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/c256
69
+ mask_path: ./data/datasets/gt_keep_masks/ex64
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: fix_ex64_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_c256_ex64/inpainted
84
+ lrs: ./log/test_c256_ex64/gt_masked
85
+ gts: ./log/test_c256_ex64/gt
86
+ gt_keep_masks: ./log/test_c256_ex64/gt_keep_mask
repaint/confs/test_c256_genhalf.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
54
+ name: test_c256_genhalf
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_c256_genhalf_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/c256
69
+ mask_path: ./data/datasets/gt_keep_masks/genhalf
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: fix_genhalf_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_c256_genhalf/inpainted
84
+ lrs: ./log/test_c256_genhalf/gt_masked
85
+ gts: ./log/test_c256_genhalf/gt
86
+ gt_keep_masks: ./log/test_c256_genhalf/gt_keep_mask
repaint/confs/test_c256_nn2.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
54
+ name: test_c256_nn2
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_c256_nn2_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/c256
69
+ mask_path: ./data/datasets/gt_keep_masks/nn2
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: fix_nn2_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_c256_nn2/inpainted
84
+ lrs: ./log/test_c256_nn2/gt_masked
85
+ gts: ./log/test_c256_nn2/gt
86
+ gt_keep_masks: ./log/test_c256_nn2/gt_keep_mask
repaint/confs/test_c256_thick.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
54
+ name: test_c256_thick
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_c256_thick_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/c256
69
+ mask_path: ./data/datasets/gt_keep_masks/thick
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_thick_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_c256_thick/inpainted
84
+ lrs: ./log/test_c256_thick/gt_masked
85
+ gts: ./log/test_c256_thick/gt
86
+ gt_keep_masks: ./log/test_c256_thick/gt_keep_mask
repaint/confs/test_c256_thin.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
54
+ name: test_c256_thin
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_c256_thin_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/c256
69
+ mask_path: ./data/datasets/gt_keep_masks/thin
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_thin_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_c256_thin/inpainted
84
+ lrs: ./log/test_c256_thin/gt_masked
85
+ gts: ./log/test_c256_thin/gt
86
+ gt_keep_masks: ./log/test_c256_thin/gt_keep_mask
repaint/confs/test_inet256_ev2li.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: true
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: true
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 1.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ classifier_path: ./data/pretrained/256x256_classifier.pt
54
+ model_path: ./data/pretrained/256x256_diffusion.pt
55
+ name: test_inet256_ev2li
56
+ inpa_inj_sched_prev: true
57
+ n_jobs: 25
58
+ print_estimated_vars: true
59
+ inpa_inj_sched_prev_cumnoise: false
60
+ schedule_jump_params:
61
+ t_T: 250
62
+ n_sample: 1
63
+ jump_length: 10
64
+ jump_n_sample: 10
65
+ data:
66
+ eval:
67
+ lama_inet256_ev2li_n100_test:
68
+ mask_loader: true
69
+ gt_path: ./data/datasets/gts/inet256
70
+ mask_path: ./data/datasets/gt_keep_masks/ev2li
71
+ image_size: 256
72
+ class_cond: false
73
+ deterministic: true
74
+ random_crop: false
75
+ random_flip: false
76
+ return_dict: true
77
+ drop_last: false
78
+ batch_size: 4
79
+ return_dataloader: true
80
+ ds_conf:
81
+ name: random_ev2li_256
82
+ max_len: 100
83
+ paths:
84
+ srs: ./log/test_inet256_ev2li/inpainted
85
+ lrs: ./log/test_inet256_ev2li/gt_masked
86
+ gts: ./log/test_inet256_ev2li/gt
87
+ gt_keep_masks: ./log/test_inet256_ev2li/gt_keep_mask
repaint/confs/test_inet256_ex64.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: true
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: true
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 1.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ classifier_path: ./data/pretrained/256x256_classifier.pt
54
+ model_path: ./data/pretrained/256x256_diffusion.pt
55
+ name: test_inet256_ex64
56
+ inpa_inj_sched_prev: true
57
+ n_jobs: 25
58
+ print_estimated_vars: true
59
+ inpa_inj_sched_prev_cumnoise: false
60
+ schedule_jump_params:
61
+ t_T: 250
62
+ n_sample: 1
63
+ jump_length: 10
64
+ jump_n_sample: 10
65
+ data:
66
+ eval:
67
+ lama_inet256_ex64_n100_test:
68
+ mask_loader: true
69
+ gt_path: ./data/datasets/gts/inet256
70
+ mask_path: ./data/datasets/gt_keep_masks/ex64
71
+ image_size: 256
72
+ class_cond: false
73
+ deterministic: true
74
+ random_crop: false
75
+ random_flip: false
76
+ return_dict: true
77
+ drop_last: false
78
+ batch_size: 4
79
+ return_dataloader: true
80
+ ds_conf:
81
+ name: random_ex64_256
82
+ max_len: 100
83
+ paths:
84
+ srs: ./log/test_inet256_ex64/inpainted
85
+ lrs: ./log/test_inet256_ex64/gt_masked
86
+ gts: ./log/test_inet256_ex64/gt
87
+ gt_keep_masks: ./log/test_inet256_ex64/gt_keep_mask
repaint/confs/test_inet256_genhalf.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: true
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: true
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 1.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ classifier_path: ./data/pretrained/256x256_classifier.pt
54
+ model_path: ./data/pretrained/256x256_diffusion.pt
55
+ name: test_inet256_genhalf
56
+ inpa_inj_sched_prev: true
57
+ n_jobs: 25
58
+ print_estimated_vars: true
59
+ inpa_inj_sched_prev_cumnoise: false
60
+ schedule_jump_params:
61
+ t_T: 250
62
+ n_sample: 1
63
+ jump_length: 10
64
+ jump_n_sample: 10
65
+ data:
66
+ eval:
67
+ lama_inet256_genhalf_n100_test:
68
+ mask_loader: true
69
+ gt_path: ./data/datasets/gts/inet256
70
+ mask_path: ./data/datasets/gt_keep_masks/genhalf
71
+ image_size: 256
72
+ class_cond: false
73
+ deterministic: true
74
+ random_crop: false
75
+ random_flip: false
76
+ return_dict: true
77
+ drop_last: false
78
+ batch_size: 4
79
+ return_dataloader: true
80
+ ds_conf:
81
+ name: random_genhalf_256
82
+ max_len: 100
83
+ paths:
84
+ srs: ./log/test_inet256_genhalf/inpainted
85
+ lrs: ./log/test_inet256_genhalf/gt_masked
86
+ gts: ./log/test_inet256_genhalf/gt
87
+ gt_keep_masks: ./log/test_inet256_genhalf/gt_keep_mask
repaint/confs/test_inet256_nn2.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: true
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: true
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 1.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ classifier_path: ./data/pretrained/256x256_classifier.pt
54
+ model_path: ./data/pretrained/256x256_diffusion.pt
55
+ name: test_inet256_nn2
56
+ inpa_inj_sched_prev: true
57
+ n_jobs: 25
58
+ print_estimated_vars: true
59
+ inpa_inj_sched_prev_cumnoise: false
60
+ schedule_jump_params:
61
+ t_T: 250
62
+ n_sample: 1
63
+ jump_length: 10
64
+ jump_n_sample: 10
65
+ data:
66
+ eval:
67
+ lama_inet256_nn2_n100_test:
68
+ mask_loader: true
69
+ gt_path: ./data/datasets/gts/inet256
70
+ mask_path: ./data/datasets/gt_keep_masks/nn2
71
+ image_size: 256
72
+ class_cond: false
73
+ deterministic: true
74
+ random_crop: false
75
+ random_flip: false
76
+ return_dict: true
77
+ drop_last: false
78
+ batch_size: 4
79
+ return_dataloader: true
80
+ ds_conf:
81
+ name: random_nn2_256
82
+ max_len: 100
83
+ paths:
84
+ srs: ./log/test_inet256_nn2/inpainted
85
+ lrs: ./log/test_inet256_nn2/gt_masked
86
+ gts: ./log/test_inet256_nn2/gt
87
+ gt_keep_masks: ./log/test_inet256_nn2/gt_keep_mask
repaint/confs/test_inet256_thick.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: true
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: true
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 1.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ classifier_path: ./data/pretrained/256x256_classifier.pt
54
+ model_path: ./data/pretrained/256x256_diffusion.pt
55
+ name: test_inet256_thick
56
+ inpa_inj_sched_prev: true
57
+ n_jobs: 25
58
+ print_estimated_vars: true
59
+ inpa_inj_sched_prev_cumnoise: false
60
+ schedule_jump_params:
61
+ t_T: 250
62
+ n_sample: 1
63
+ jump_length: 10
64
+ jump_n_sample: 10
65
+ data:
66
+ eval:
67
+ lama_inet256_thick_n100_test:
68
+ mask_loader: true
69
+ gt_path: ./data/datasets/gts/inet256
70
+ mask_path: ./data/datasets/gt_keep_masks/thick
71
+ image_size: 256
72
+ class_cond: false
73
+ deterministic: true
74
+ random_crop: false
75
+ random_flip: false
76
+ return_dict: true
77
+ drop_last: false
78
+ batch_size: 4
79
+ return_dataloader: true
80
+ ds_conf:
81
+ name: random_thick_256
82
+ max_len: 100
83
+ paths:
84
+ srs: ./log/test_inet256_thick/inpainted
85
+ lrs: ./log/test_inet256_thick/gt_masked
86
+ gts: ./log/test_inet256_thick/gt
87
+ gt_keep_masks: ./log/test_inet256_thick/gt_keep_mask
repaint/confs/test_inet256_thin.yml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: true
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: true
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 1.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ classifier_path: ./data/pretrained/256x256_classifier.pt
54
+ model_path: ./data/pretrained/256x256_diffusion.pt
55
+ name: test_inet256_thin
56
+ inpa_inj_sched_prev: true
57
+ n_jobs: 25
58
+ print_estimated_vars: true
59
+ inpa_inj_sched_prev_cumnoise: false
60
+ schedule_jump_params:
61
+ t_T: 250
62
+ n_sample: 1
63
+ jump_length: 10
64
+ jump_n_sample: 10
65
+ data:
66
+ eval:
67
+ lama_inet256_thin_n100_test:
68
+ mask_loader: true
69
+ gt_path: ./data/datasets/gts/inet256
70
+ mask_path: ./data/datasets/gt_keep_masks/thin
71
+ image_size: 256
72
+ class_cond: false
73
+ deterministic: true
74
+ random_crop: false
75
+ random_flip: false
76
+ return_dict: true
77
+ drop_last: false
78
+ batch_size: 4
79
+ return_dataloader: true
80
+ ds_conf:
81
+ name: random_thin_256
82
+ max_len: 100
83
+ paths:
84
+ srs: ./log/test_inet256_thin/inpainted
85
+ lrs: ./log/test_inet256_thin/gt_masked
86
+ gts: ./log/test_inet256_thin/gt
87
+ gt_keep_masks: ./log/test_inet256_thin/gt_keep_mask
repaint/confs/test_p256_ev2li.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/places256_300000.pt
54
+ name: test_p256_ev2li
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_p256_ev2li_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/p256
69
+ mask_path: ./data/datasets/gt_keep_masks/ev2li
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_ev2li_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_p256_ev2li/inpainted
84
+ lrs: ./log/test_p256_ev2li/gt_masked
85
+ gts: ./log/test_p256_ev2li/gt
86
+ gt_keep_masks: ./log/test_p256_ev2li/gt_keep_mask
repaint/confs/test_p256_ex64.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/places256_300000.pt
54
+ name: test_p256_ex64
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_p256_ex64_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/p256
69
+ mask_path: ./data/datasets/gt_keep_masks/ex64
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_ex64_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_p256_ex64/inpainted
84
+ lrs: ./log/test_p256_ex64/gt_masked
85
+ gts: ./log/test_p256_ex64/gt
86
+ gt_keep_masks: ./log/test_p256_ex64/gt_keep_mask
repaint/confs/test_p256_genhalf.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/places256_300000.pt
54
+ name: test_p256_genhalf
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_p256_genhalf_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/p256
69
+ mask_path: ./data/datasets/gt_keep_masks/genhalf
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_genhalf_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_p256_genhalf/inpainted
84
+ lrs: ./log/test_p256_genhalf/gt_masked
85
+ gts: ./log/test_p256_genhalf/gt
86
+ gt_keep_masks: ./log/test_p256_genhalf/gt_keep_mask
repaint/confs/test_p256_nn2.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/places256_300000.pt
54
+ name: test_p256_nn2
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_p256_nn2_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/p256
69
+ mask_path: ./data/datasets/gt_keep_masks/nn2
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_nn2_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_p256_nn2/inpainted
84
+ lrs: ./log/test_p256_nn2/gt_masked
85
+ gts: ./log/test_p256_nn2/gt
86
+ gt_keep_masks: ./log/test_p256_nn2/gt_keep_mask
repaint/confs/test_p256_thick.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/places256_300000.pt
54
+ name: test_p256_thick
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_p256_thick_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/p256
69
+ mask_path: ./data/datasets/gt_keep_masks/thick
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_thick_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_p256_thick/inpainted
84
+ lrs: ./log/test_p256_thick/gt_masked
85
+ gts: ./log/test_p256_thick/gt
86
+ gt_keep_masks: ./log/test_p256_thick/gt_keep_mask
repaint/confs/test_p256_thin.yml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ attention_resolutions: 32,16,8
18
+ class_cond: false
19
+ diffusion_steps: 1000
20
+ learn_sigma: true
21
+ noise_schedule: linear
22
+ num_channels: 256
23
+ num_head_channels: 64
24
+ num_heads: 4
25
+ num_res_blocks: 2
26
+ resblock_updown: true
27
+ use_fp16: false
28
+ use_scale_shift_norm: true
29
+ classifier_scale: 4.0
30
+ lr_kernel_n_std: 2
31
+ num_samples: 100
32
+ show_progress: true
33
+ timestep_respacing: '250'
34
+ use_kl: false
35
+ predict_xstart: false
36
+ rescale_timesteps: false
37
+ rescale_learned_sigmas: false
38
+ classifier_use_fp16: false
39
+ classifier_width: 128
40
+ classifier_depth: 2
41
+ classifier_attention_resolutions: 32,16,8
42
+ classifier_use_scale_shift_norm: true
43
+ classifier_resblock_updown: true
44
+ classifier_pool: attention
45
+ num_heads_upsample: -1
46
+ channel_mult: ''
47
+ dropout: 0.0
48
+ use_checkpoint: false
49
+ use_new_attention_order: false
50
+ clip_denoised: true
51
+ use_ddim: false
52
+ image_size: 256
53
+ model_path: ./data/pretrained/places256_300000.pt
54
+ name: test_p256_thin
55
+ inpa_inj_sched_prev: true
56
+ n_jobs: 25
57
+ print_estimated_vars: true
58
+ inpa_inj_sched_prev_cumnoise: false
59
+ schedule_jump_params:
60
+ t_T: 250
61
+ n_sample: 1
62
+ jump_length: 10
63
+ jump_n_sample: 10
64
+ data:
65
+ eval:
66
+ lama_p256_thin_n100_test:
67
+ mask_loader: true
68
+ gt_path: ./data/datasets/gts/p256
69
+ mask_path: ./data/datasets/gt_keep_masks/thin
70
+ image_size: 256
71
+ class_cond: false
72
+ deterministic: true
73
+ random_crop: false
74
+ random_flip: false
75
+ return_dict: true
76
+ drop_last: false
77
+ batch_size: 4
78
+ return_dataloader: true
79
+ ds_conf:
80
+ name: random_thin_256
81
+ max_len: 100
82
+ paths:
83
+ srs: ./log/test_p256_thin/inpainted
84
+ lrs: ./log/test_p256_thin/gt_masked
85
+ gts: ./log/test_p256_thin/gt
86
+ gt_keep_masks: ./log/test_p256_thin/gt_keep_mask
repaint/download.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ (
4
+ mkdir -p data/pretrained
5
+ cd data/pretrained
6
+
7
+ wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_classifier.pt # Trained by OpenAI
8
+ wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt # Trained by OpenAI
9
+
10
+ gdown https://drive.google.com/uc?id=1norNWWGYP3EZ_o05DmoW1ryKuKMmhlCX
11
+ gdown https://drive.google.com/uc?id=1QEl-btGbzQz6IwkXiFGd49uQNTUtTHsk
12
+ )
13
+
14
+ # data
15
+ (
16
+ gdown https://drive.google.com/uc?id=1Q_dxuyI41AAmSv9ti3780BwaJQqwvwMv
17
+ unzip data.zip
18
+ rm data.zip
19
+ )
repaint/guided_diffusion/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ """
18
+ Based on "Improved Denoising Diffusion Probabilistic Models".
19
+ """
repaint/guided_diffusion/dist_util.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ """
18
+ Helpers for distributed training.
19
+ """
20
+
21
+ import io
22
+
23
+ import blobfile as bf
24
+ import torch as th
25
+
26
+
27
+ def dev(device):
28
+ """
29
+ Get the device to use for torch.distributed.
30
+ """
31
+ if device is None:
32
+ if th.cuda.is_available():
33
+ return th.device(f"cuda")
34
+ return th.device("cpu")
35
+ return th.device(device)
36
+
37
+
38
+ def load_state_dict(path, backend=None, **kwargs):
39
+ with bf.BlobFile(path, "rb") as f:
40
+ data = f.read()
41
+ return th.load(io.BytesIO(data), **kwargs)
42
+
43
+