Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +29 -7
- README.md +16 -24
- datasets/celeba.py +209 -0
- datasets/cityscapes.py +303 -0
- docs/gcdp.png +3 -0
- environment.yaml +66 -0
- example.py +25 -0
- imagen_pytorch/__init__.py +26 -0
- imagen_pytorch/cli.py +52 -0
- imagen_pytorch/configs.py +181 -0
- imagen_pytorch/data.py +73 -0
- imagen_pytorch/elucidated_imagen.py +846 -0
- imagen_pytorch/imagen_pytorch.py +2515 -0
- imagen_pytorch/imagen_video/__init__.py +1 -0
- imagen_pytorch/imagen_video/imagen_video.py +1662 -0
- imagen_pytorch/joint_imagen.py +1942 -0
- imagen_pytorch/t5.py +119 -0
- imagen_pytorch/trainer.py +1782 -0
- imagen_pytorch/utils.py +61 -0
- imagen_pytorch/version.py +1 -0
- pyproject.toml +3 -0
- repaint/LICENSES/LICENSE +13 -0
- repaint/LICENSES/LICENSE_guided_diffusion +21 -0
- repaint/LICENSES/README.md +11 -0
- repaint/README.md +205 -0
- repaint/conf_mgt/__init__.py +18 -0
- repaint/conf_mgt/conf_base.py +128 -0
- repaint/confs/face_example.yml +87 -0
- repaint/confs/test_c256_ev2li.yml +86 -0
- repaint/confs/test_c256_ex64.yml +86 -0
- repaint/confs/test_c256_genhalf.yml +86 -0
- repaint/confs/test_c256_nn2.yml +86 -0
- repaint/confs/test_c256_thick.yml +86 -0
- repaint/confs/test_c256_thin.yml +86 -0
- repaint/confs/test_inet256_ev2li.yml +87 -0
- repaint/confs/test_inet256_ex64.yml +87 -0
- repaint/confs/test_inet256_genhalf.yml +87 -0
- repaint/confs/test_inet256_nn2.yml +87 -0
- repaint/confs/test_inet256_thick.yml +87 -0
- repaint/confs/test_inet256_thin.yml +87 -0
- repaint/confs/test_p256_ev2li.yml +86 -0
- repaint/confs/test_p256_ex64.yml +86 -0
- repaint/confs/test_p256_genhalf.yml +86 -0
- repaint/confs/test_p256_nn2.yml +86 -0
- repaint/confs/test_p256_thick.yml +86 -0
- repaint/confs/test_p256_thin.yml +86 -0
- repaint/download.sh +19 -0
- repaint/guided_diffusion/__init__.py +19 -0
- repaint/guided_diffusion/dist_util.py +43 -0
.gitattributes
CHANGED
|
@@ -47,3 +47,4 @@ tedigan/ext/experiment/inference_coupled/input_label.png filter=lfs diff=lfs mer
|
|
| 47 |
tedigan/ext/experiment/inference_results/input_label.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
uniteandconquer/utils/faces.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
uniteandconquer/utils/natural.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 47 |
tedigan/ext/experiment/inference_results/input_label.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
uniteandconquer/utils/faces.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
uniteandconquer/utils/natural.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
docs/gcdp.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -1,3 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Byte-compiled / optimized / DLL files
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
|
@@ -106,10 +132,8 @@ ipython_config.py
|
|
| 106 |
#pdm.lock
|
| 107 |
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
# in version control.
|
| 109 |
-
# https://pdm.fming.dev/
|
| 110 |
.pdm.toml
|
| 111 |
-
.pdm-python
|
| 112 |
-
.pdm-build/
|
| 113 |
|
| 114 |
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
__pypackages__/
|
|
@@ -161,7 +185,5 @@ cython_debug/
|
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
#.idea/
|
| 163 |
|
| 164 |
-
#
|
| 165 |
-
|
| 166 |
-
*.pth
|
| 167 |
-
*.ckpt
|
|
|
|
| 1 |
+
logs
|
| 2 |
+
debug
|
| 3 |
+
wandb_dir
|
| 4 |
+
checkpoints
|
| 5 |
+
squeue.txt
|
| 6 |
+
results
|
| 7 |
+
|
| 8 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,linux
|
| 9 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux
|
| 10 |
+
|
| 11 |
+
### Linux ###
|
| 12 |
+
*~
|
| 13 |
+
|
| 14 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
| 15 |
+
.fuse_hidden*
|
| 16 |
+
|
| 17 |
+
# KDE directory preferences
|
| 18 |
+
.directory
|
| 19 |
+
|
| 20 |
+
# Linux trash folder which might appear on any partition or disk
|
| 21 |
+
.Trash-*
|
| 22 |
+
|
| 23 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
| 24 |
+
.nfs*
|
| 25 |
+
|
| 26 |
+
### Python ###
|
| 27 |
# Byte-compiled / optimized / DLL files
|
| 28 |
__pycache__/
|
| 29 |
*.py[cod]
|
|
|
|
| 132 |
#pdm.lock
|
| 133 |
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 134 |
# in version control.
|
| 135 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 136 |
.pdm.toml
|
|
|
|
|
|
|
| 137 |
|
| 138 |
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 139 |
__pypackages__/
|
|
|
|
| 185 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 186 |
#.idea/
|
| 187 |
|
| 188 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,linux
|
| 189 |
+
_baselines/datasetGAN/StyleGAN.pytorch/outputs/cityscapes/2023-03-05/log.txt
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,24 +1,16 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
- [x] [TediGAN (CVPR 2021)](https://github.com/IIGROUP/TediGAN)
|
| 19 |
-
- [x] [UniteandConquer (CVPR 2023)](https://github.com/Nithin-GK/UniteandConquer)
|
| 20 |
-
- [x] [Collaborative-Diffusion (CVPR 2023)](https://github.com/ziqihuangg/Collaborative-Diffusion)
|
| 21 |
-
- [x] [GCDP (ICCV 2023)](https://github.com/pmh9960/GCDP) (Text2Image)
|
| 22 |
-
- [x] [PixelFace+ (MM 2023)](https://github.com/qazwsx671713/PixelFace-Plus)
|
| 23 |
-
- [ ] [Diffusion-driven GAN Inversion (CVPR 2024)](https://github.com/1211sh/Diffusion-driven_GAN-Inversion/)
|
| 24 |
-
- [ ] [MM2Latent (ECCVW 2024)](https://github.com/Open-Debin/MM2Latent)
|
|
|
|
| 1 |
+
# Evaluation
|
| 2 |
+
|
| 3 |
+
```bash
|
| 4 |
+
CUDA_VISIBLE_DEVICES=4 python test.py --model_type=base_128x128 \
|
| 5 |
+
--checkpoint_path checkpoints/celeba/base_128x128_flip_100/checkpoint.500000.pt \
|
| 6 |
+
--end_sample_idx=1 \
|
| 7 |
+
--test_batch_size=1 \
|
| 8 |
+
--dataset celeba \
|
| 9 |
+
--num_classes 19 \
|
| 10 |
+
--save_path=results/celeba/base.png \
|
| 11 |
+
--test_captions "The woman wears earrings. She has wavy hair. She is attractive."
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
CUDA_VISIBLE_DEVICES=4 python test.py --conf_path confs/face_example.yml
|
| 16 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/celeba.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import random
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from torchvision.transforms import Compose, InterpolationMode, RandomCrop, RandomHorizontalFlip, Resize, ToTensor
|
| 14 |
+
|
| 15 |
+
CelebaClass = namedtuple('CelebaClass', ['name', 'id', 'color'])
|
| 16 |
+
# autopep8: off
|
| 17 |
+
classes = [
|
| 18 |
+
CelebaClass('background', 0, ( 0, 0, 0)),
|
| 19 |
+
CelebaClass('skin', 1, (204, 0, 0)),
|
| 20 |
+
CelebaClass('nose', 2, ( 76, 153, 0)),
|
| 21 |
+
CelebaClass('eye_g', 3, (204, 204, 0)),
|
| 22 |
+
CelebaClass('l_eye', 4, ( 51, 51, 255)),
|
| 23 |
+
CelebaClass('r_eye', 5, (204, 0, 204)),
|
| 24 |
+
CelebaClass('l_brow', 6, ( 0, 255, 255)),
|
| 25 |
+
CelebaClass('r_brow', 7, (255, 204, 204)),
|
| 26 |
+
CelebaClass('l_ear', 8, (102, 51, 0)),
|
| 27 |
+
CelebaClass('r_ear', 9, (255, 0, 0)),
|
| 28 |
+
CelebaClass('mouth', 10, (102, 204, 0)),
|
| 29 |
+
CelebaClass('u_lip', 11, (255, 255, 0)),
|
| 30 |
+
CelebaClass('l_lip', 12, ( 0, 0, 153)),
|
| 31 |
+
CelebaClass('hair', 13, ( 0, 0, 204)),
|
| 32 |
+
CelebaClass('hat', 14, (255, 51, 153)),
|
| 33 |
+
CelebaClass('ear_r', 15, ( 0, 204, 204)),
|
| 34 |
+
CelebaClass('neck_l', 16, ( 0, 51, 0)),
|
| 35 |
+
CelebaClass('neck', 17, (255, 153, 51)),
|
| 36 |
+
CelebaClass('cloth', 18, ( 0, 204, 0)),
|
| 37 |
+
]
|
| 38 |
+
# autopep8: on
|
| 39 |
+
num_classes = 19
|
| 40 |
+
mapping_id = torch.tensor([x.id for x in classes])
|
| 41 |
+
colors = torch.tensor([cls.color for cls in classes])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def normalize_to_neg_one_to_one(img):
|
| 45 |
+
return img * 2 - 1
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def unnormalize_to_zero_to_one(img):
|
| 49 |
+
return (img + 1) * 0.5
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def unnormalize_and_clamp_to_zero_to_one(img):
|
| 53 |
+
return torch.clamp(unnormalize_to_zero_to_one(img.cpu()), 0, 1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def exists(val):
|
| 57 |
+
return val is not None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def default(val, d):
|
| 61 |
+
if exists(val):
|
| 62 |
+
return val
|
| 63 |
+
return d() if callable(d) else d
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ToTensorNoNorm():
|
| 67 |
+
def __call__(self, X_i):
|
| 68 |
+
X_i = np.array(X_i)
|
| 69 |
+
|
| 70 |
+
if len(X_i.shape) == 2:
|
| 71 |
+
# Add channel dim.
|
| 72 |
+
X_i = X_i[:, :, None]
|
| 73 |
+
|
| 74 |
+
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def interpolate_3d(x, *args, **kwargs):
|
| 78 |
+
return F.interpolate(x.unsqueeze(0), *args, **kwargs).squeeze(0)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RandomResize(nn.Module):
|
| 82 |
+
def __init__(self, scale=(0.5, 2.0), mode='nearest'):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.scale = scale
|
| 85 |
+
self.mode = mode
|
| 86 |
+
|
| 87 |
+
def get_random_scale(self):
|
| 88 |
+
return random.uniform(*self.scale)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
random_scale = self.get_random_scale()
|
| 92 |
+
x = interpolate_3d(x, scale_factor=random_scale, mode=self.mode)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def read_jsonl(jsonl_path):
|
| 97 |
+
import jsonlines
|
| 98 |
+
lines = []
|
| 99 |
+
with jsonlines.open(jsonl_path, 'r') as f:
|
| 100 |
+
for line in f.iter():
|
| 101 |
+
lines.append(line)
|
| 102 |
+
return lines
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CelebaDataset(Dataset):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
root="",
|
| 109 |
+
split='train',
|
| 110 |
+
side_x=128,
|
| 111 |
+
side_y=128,
|
| 112 |
+
caption_list_dir='',
|
| 113 |
+
augmentation_type='flip',
|
| 114 |
+
):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.root = Path(root)
|
| 117 |
+
self.image_dir = osp.join(self.root, 'CelebA-HQ-img')
|
| 118 |
+
self.label_dir = osp.join(self.root, 'CelebAMask-HQ-mask-anno', 'preprocessed')
|
| 119 |
+
self.split = split
|
| 120 |
+
self.side_x = side_x
|
| 121 |
+
self.side_y = side_y
|
| 122 |
+
|
| 123 |
+
self.caption_list_dir = caption_list_dir
|
| 124 |
+
captions_jsonl = read_jsonl(osp.join(self.caption_list_dir, f'{split}_captions.jsonl'))
|
| 125 |
+
self.caption_dict = {}
|
| 126 |
+
for caption_jsonl in captions_jsonl:
|
| 127 |
+
self.caption_dict[osp.splitext(caption_jsonl['file_name'])[0]] = caption_jsonl['text']
|
| 128 |
+
|
| 129 |
+
if augmentation_type == 'none':
|
| 130 |
+
self.augmentation = Compose([
|
| 131 |
+
Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
|
| 132 |
+
# ToTensor(),
|
| 133 |
+
])
|
| 134 |
+
elif augmentation_type == 'flip':
|
| 135 |
+
self.augmentation = Compose([
|
| 136 |
+
Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
|
| 137 |
+
RandomHorizontalFlip(p=0.5),
|
| 138 |
+
# ToTensor(),
|
| 139 |
+
])
|
| 140 |
+
elif 'resizedCrop' in augmentation_type:
|
| 141 |
+
scale = [float(s) for s in augmentation_type.split('_')[1:]]
|
| 142 |
+
assert len(scale) == 2, scale
|
| 143 |
+
self.augmentation = Compose([
|
| 144 |
+
RandomResize(scale=scale, mode='nearest'),
|
| 145 |
+
RandomCrop((1024, 1024)),
|
| 146 |
+
Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
|
| 147 |
+
RandomHorizontalFlip(p=0.5),
|
| 148 |
+
# ToTensor(),
|
| 149 |
+
])
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError(augmentation_type)
|
| 152 |
+
|
| 153 |
+
# verification
|
| 154 |
+
self.images = sorted([osp.join(self.image_dir, file) for file in os.listdir(self.image_dir)
|
| 155 |
+
if osp.splitext(file)[0] in self.caption_dict.keys()])
|
| 156 |
+
self.labels = sorted([osp.join(self.label_dir, file) for file in os.listdir(self.label_dir)
|
| 157 |
+
if osp.splitext(file)[0] in self.caption_dict.keys()])
|
| 158 |
+
|
| 159 |
+
assert len(self.images) == len(self.labels), f'{len(self.images)} != {len(self.labels)}'
|
| 160 |
+
for img, lbl in zip(self.images, self.labels):
|
| 161 |
+
assert osp.splitext(osp.basename(img))[0] == osp.splitext(osp.basename(lbl))[0]
|
| 162 |
+
|
| 163 |
+
def __len__(self):
|
| 164 |
+
return len(self.images)
|
| 165 |
+
|
| 166 |
+
def random_sample(self):
|
| 167 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
| 168 |
+
|
| 169 |
+
def sequential_sample(self, ind):
|
| 170 |
+
if ind >= self.__len__() - 1:
|
| 171 |
+
return self.__getitem__(0)
|
| 172 |
+
return self.__getitem__(ind + 1)
|
| 173 |
+
|
| 174 |
+
def skip_sample(self, ind):
|
| 175 |
+
return self.sequential_sample(ind=ind)
|
| 176 |
+
|
| 177 |
+
def get_caption_list_objects(self, idx):
|
| 178 |
+
filename = osp.splitext(osp.basename(self.images[idx]))[0]
|
| 179 |
+
caption = random.choice(self.caption_dict[filename])
|
| 180 |
+
return caption
|
| 181 |
+
|
| 182 |
+
def __getitem__(self, idx):
|
| 183 |
+
# load image label
|
| 184 |
+
try:
|
| 185 |
+
original_pil_image = Image.open(self.images[idx]).convert("RGB")
|
| 186 |
+
original_pil_target = Image.open(self.labels[idx])
|
| 187 |
+
except (OSError, ValueError) as e:
|
| 188 |
+
print(f"An exception occurred trying to load file {self.images[idx]}.")
|
| 189 |
+
print(f"Skipping index {idx}")
|
| 190 |
+
return self.skip_sample(idx)
|
| 191 |
+
|
| 192 |
+
# Transforms
|
| 193 |
+
image = Resize((1024, 1024), InterpolationMode.NEAREST)(ToTensor()(original_pil_image))
|
| 194 |
+
label = Resize((1024, 1024), InterpolationMode.NEAREST)(ToTensorNoNorm()(original_pil_target).float())
|
| 195 |
+
img_lbl = self.augmentation(torch.cat([image, label]))
|
| 196 |
+
|
| 197 |
+
caption = self.get_caption_list_objects(idx)
|
| 198 |
+
|
| 199 |
+
return img_lbl[:3], img_lbl[3:], caption
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def transform_lbl(lbl: torch.Tensor, *args, **kwargs):
|
| 203 |
+
lbl = lbl.long()
|
| 204 |
+
if lbl.size(1) == 1:
|
| 205 |
+
# Remove single channel axis.
|
| 206 |
+
lbl = lbl[:, 0]
|
| 207 |
+
rgbs = colors[lbl]
|
| 208 |
+
rgbs = rgbs.permute(0, 3, 1, 2)
|
| 209 |
+
return rgbs / 255.
|
datasets/cityscapes.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from cityscapesscripts.helpers.labels import trainId2label
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
from torchvision.transforms import Compose, InterpolationMode, RandomCrop, RandomHorizontalFlip, Resize, ToTensor
|
| 15 |
+
|
| 16 |
+
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
|
| 17 |
+
'has_instances', 'ignore_in_eval', 'color'])
|
| 18 |
+
# autopep8: off
|
| 19 |
+
classes = [
|
| 20 |
+
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, ( 0, 0, 0)),
|
| 21 |
+
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, ( 0, 0, 0)),
|
| 22 |
+
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, ( 0, 0, 0)),
|
| 23 |
+
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, ( 0, 0, 0)),
|
| 24 |
+
CityscapesClass('static', 4, 255, 'void', 0, False, True, ( 0, 0, 0)),
|
| 25 |
+
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
|
| 26 |
+
CityscapesClass('ground', 6, 255, 'void', 0, False, True, ( 81, 0, 81)),
|
| 27 |
+
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
|
| 28 |
+
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
|
| 29 |
+
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
|
| 30 |
+
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
|
| 31 |
+
CityscapesClass('building', 11, 2, 'construction', 2, False, False, ( 70, 70, 70)),
|
| 32 |
+
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
|
| 33 |
+
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
|
| 34 |
+
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
|
| 35 |
+
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
|
| 36 |
+
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
|
| 37 |
+
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
|
| 38 |
+
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
|
| 39 |
+
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
|
| 40 |
+
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
|
| 41 |
+
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
|
| 42 |
+
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
|
| 43 |
+
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, ( 70, 130, 180)),
|
| 44 |
+
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
|
| 45 |
+
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
|
| 46 |
+
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, ( 0, 0, 142)),
|
| 47 |
+
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, ( 0, 0, 70)),
|
| 48 |
+
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, ( 0, 60, 100)),
|
| 49 |
+
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, ( 0, 0, 90)),
|
| 50 |
+
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, ( 0, 0, 110)),
|
| 51 |
+
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, ( 0, 80, 100)),
|
| 52 |
+
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, ( 0, 0, 230)),
|
| 53 |
+
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
|
| 54 |
+
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, ( 0, 0, 142)),
|
| 55 |
+
]
|
| 56 |
+
# autopep8: on
|
| 57 |
+
|
| 58 |
+
map_id_to_id = torch.tensor([x.id for x in classes])
|
| 59 |
+
map_id_to_category_id = torch.tensor([x.category_id for x in classes])
|
| 60 |
+
map_id_to_train_id = torch.tensor([x.train_id for x in classes])
|
| 61 |
+
id_type_to_classes = dict(
|
| 62 |
+
id=dict(num_classes=34,
|
| 63 |
+
map_fn=torch.tensor([x if x not in (-1, ) else 0 for x in map_id_to_id]),
|
| 64 |
+
names=[cls.name for cls in classes][:-1]),
|
| 65 |
+
category_id=dict(num_classes=8,
|
| 66 |
+
map_fn=map_id_to_category_id,
|
| 67 |
+
names=[cls.name for cls in classes][:-1]), # TODO it is wrong
|
| 68 |
+
train_id=dict(num_classes=20,
|
| 69 |
+
map_fn=torch.tensor([x if x not in (-1, 255) else 19 for x in map_id_to_train_id]),
|
| 70 |
+
names=[i.name for i in classes if i.train_id != 255][:-1] + ['unlabeled']),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def normalize_to_neg_one_to_one(img):
|
| 75 |
+
return img * 2 - 1
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def unnormalize_to_zero_to_one(img):
|
| 79 |
+
return (img + 1) * 0.5
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def unnormalize_and_clamp_to_zero_to_one(img):
|
| 83 |
+
return torch.clamp(unnormalize_to_zero_to_one(img.cpu()), 0, 1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def exists(val):
|
| 87 |
+
return val is not None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def default(val, d):
|
| 91 |
+
if exists(val):
|
| 92 |
+
return val
|
| 93 |
+
return d() if callable(d) else d
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ToTensorNoNorm():
|
| 97 |
+
def __call__(self, X_i):
|
| 98 |
+
X_i = np.array(X_i)
|
| 99 |
+
|
| 100 |
+
if len(X_i.shape) == 2:
|
| 101 |
+
# Add channel dim.
|
| 102 |
+
X_i = X_i[:, :, None]
|
| 103 |
+
|
| 104 |
+
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def interpolate_3d(x, *args, **kwargs):
|
| 108 |
+
return F.interpolate(x.unsqueeze(0), *args, **kwargs).squeeze(0)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class RandomResize(nn.Module):
|
| 112 |
+
def __init__(self, scale=(0.5, 2.0), mode='nearest'):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.scale = scale
|
| 115 |
+
self.mode = mode
|
| 116 |
+
|
| 117 |
+
def get_random_scale(self):
|
| 118 |
+
return random.uniform(*self.scale)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
random_scale = self.get_random_scale()
|
| 122 |
+
x = interpolate_3d(x, scale_factor=random_scale, mode=self.mode)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def read_jsonl(jsonl_path):
|
| 127 |
+
import jsonlines
|
| 128 |
+
lines = []
|
| 129 |
+
with jsonlines.open(jsonl_path, 'r') as f:
|
| 130 |
+
for line in f.iter():
|
| 131 |
+
lines.append(line)
|
| 132 |
+
return lines
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class CityscapesDataset(Dataset):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
root="",
|
| 139 |
+
split='train',
|
| 140 |
+
side_x=64,
|
| 141 |
+
side_y=64,
|
| 142 |
+
shuffle=False,
|
| 143 |
+
caption_list_dir='',
|
| 144 |
+
id_type='train_id',
|
| 145 |
+
augmentation_type='flip',
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.root = Path(root)
|
| 149 |
+
self.image_dir = os.path.join(self.root, 'leftImg8bit')
|
| 150 |
+
self.label_dir = os.path.join(self.root, 'gtFine')
|
| 151 |
+
self.split = split
|
| 152 |
+
self.metadata = read_jsonl(os.path.join(caption_list_dir, f'{split}_captions.jsonl'))
|
| 153 |
+
self.metadata = sorted(self.metadata, key=lambda line: line['file_name'])
|
| 154 |
+
|
| 155 |
+
assert id_type == 'train_id'
|
| 156 |
+
self.map_fn = id_type_to_classes[id_type]['map_fn']
|
| 157 |
+
self.class_names = id_type_to_classes[id_type]['names']
|
| 158 |
+
self.num_classes = id_type_to_classes[id_type]['num_classes']
|
| 159 |
+
|
| 160 |
+
# self.text_ctx_len = text_ctx_len
|
| 161 |
+
self.shuffle = shuffle
|
| 162 |
+
self.side_x = side_x
|
| 163 |
+
self.side_y = side_y
|
| 164 |
+
|
| 165 |
+
if augmentation_type == 'none':
|
| 166 |
+
self.augmentation = Compose([
|
| 167 |
+
Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
|
| 168 |
+
# ToTensor(),
|
| 169 |
+
])
|
| 170 |
+
elif augmentation_type == 'flip':
|
| 171 |
+
self.augmentation = Compose([
|
| 172 |
+
Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
|
| 173 |
+
RandomHorizontalFlip(p=0.5),
|
| 174 |
+
# ToTensor(),
|
| 175 |
+
])
|
| 176 |
+
elif 'resizedCrop' in augmentation_type:
|
| 177 |
+
scale = [float(s) for s in augmentation_type.split('_')[1:]]
|
| 178 |
+
assert len(scale) == 2, scale
|
| 179 |
+
self.augmentation = Compose([
|
| 180 |
+
RandomResize(scale=scale, mode='nearest'),
|
| 181 |
+
RandomCrop((1024, 2048)),
|
| 182 |
+
Resize((side_x, side_y), interpolation=InterpolationMode.NEAREST),
|
| 183 |
+
RandomHorizontalFlip(p=0.5),
|
| 184 |
+
# ToTensor(),
|
| 185 |
+
])
|
| 186 |
+
else:
|
| 187 |
+
raise NotImplementedError(augmentation_type)
|
| 188 |
+
|
| 189 |
+
# filenames of images and labels
|
| 190 |
+
self.images = []
|
| 191 |
+
self.labels = []
|
| 192 |
+
for line in self.metadata:
|
| 193 |
+
cityname = line['file_name'].split('_')[0]
|
| 194 |
+
split = 'val' if cityname in ['frankfurt', 'lindau', 'munster'] else 'train'
|
| 195 |
+
img_dir = os.path.join(self.image_dir, split, cityname, line['file_name'])
|
| 196 |
+
lbl_dir = os.path.join(self.label_dir, split, cityname,
|
| 197 |
+
line['file_name'].replace('leftImg8bit.png', 'gtFine_labelIds.png'))
|
| 198 |
+
assert os.path.isfile(img_dir), img_dir
|
| 199 |
+
assert os.path.isfile(lbl_dir), lbl_dir
|
| 200 |
+
self.images.append(img_dir)
|
| 201 |
+
self.labels.append(lbl_dir)
|
| 202 |
+
|
| 203 |
+
def __len__(self):
|
| 204 |
+
return len(self.images)
|
| 205 |
+
|
| 206 |
+
def random_sample(self):
|
| 207 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
| 208 |
+
|
| 209 |
+
def sequential_sample(self, ind):
|
| 210 |
+
if ind >= self.__len__() - 1:
|
| 211 |
+
return self.__getitem__(0)
|
| 212 |
+
return self.__getitem__(ind + 1)
|
| 213 |
+
|
| 214 |
+
def skip_sample(self, ind):
|
| 215 |
+
if self.shuffle:
|
| 216 |
+
return self.random_sample()
|
| 217 |
+
return self.sequential_sample(ind=ind)
|
| 218 |
+
|
| 219 |
+
def get_caption_list_objects(self, idx):
|
| 220 |
+
caption = random.choice(self.metadata[idx]['text'])
|
| 221 |
+
return caption
|
| 222 |
+
|
| 223 |
+
def _load_json(self, path):
|
| 224 |
+
with open(path, 'r') as file:
|
| 225 |
+
data = json.load(file)
|
| 226 |
+
return data
|
| 227 |
+
|
| 228 |
+
def __getitem__(self, idx):
|
| 229 |
+
# load image
|
| 230 |
+
try:
|
| 231 |
+
original_pil_image = Image.open(self.images[idx]).convert("RGB")
|
| 232 |
+
original_pil_target = Image.open(self.labels[idx])
|
| 233 |
+
except (OSError, ValueError) as e:
|
| 234 |
+
print(f"An exception occurred trying to load file {self.images[idx]}.")
|
| 235 |
+
print(f"Skipping index {idx}")
|
| 236 |
+
return self.skip_sample(idx)
|
| 237 |
+
|
| 238 |
+
# Transforms
|
| 239 |
+
image = ToTensor()(original_pil_image)
|
| 240 |
+
label = ToTensorNoNorm()(original_pil_target)
|
| 241 |
+
label = self.map_fn[label.long()]
|
| 242 |
+
img_lbl = self.augmentation(torch.cat([image, label]))
|
| 243 |
+
|
| 244 |
+
caption = self.get_caption_list_objects(idx)
|
| 245 |
+
|
| 246 |
+
return img_lbl[:3], img_lbl[3:], caption
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def indices_segmentation_to_img(indices, colors):
|
| 250 |
+
if indices.size(1) == 1:
|
| 251 |
+
# Remove single channel axis.
|
| 252 |
+
indices = indices[:, 0]
|
| 253 |
+
# for train_id
|
| 254 |
+
indices = indices * (indices != 255) + torch.ones_like(indices) * 19 * (indices == 255)
|
| 255 |
+
rgbs = colors[indices]
|
| 256 |
+
rgbs = rgbs.permute(0, 3, 1, 2)
|
| 257 |
+
return rgbs / 255.
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_colors_from_id_type(id_type):
|
| 261 |
+
num_classes = len(id_type_to_classes[id_type]['map_fn'].unique())
|
| 262 |
+
colors = torch.zeros((num_classes, 3))
|
| 263 |
+
exist_ids = []
|
| 264 |
+
for idx, cls in enumerate(id_type_to_classes[id_type]['map_fn']):
|
| 265 |
+
if cls == 255:
|
| 266 |
+
cls = 19
|
| 267 |
+
if cls not in exist_ids:
|
| 268 |
+
colors[cls] = torch.tensor(classes[idx].color)
|
| 269 |
+
exist_ids.append(cls)
|
| 270 |
+
return colors
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def transform_lbl(lbl, id_type='id'):
|
| 274 |
+
colors = get_colors_from_id_type(id_type)
|
| 275 |
+
return indices_segmentation_to_img(lbl, colors)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def transform_img_lbl(x, id_type='id', unnorm=True):
|
| 279 |
+
colors = get_colors_from_id_type(id_type)
|
| 280 |
+
|
| 281 |
+
x = x.detach().cpu()
|
| 282 |
+
x = x.unsqueeze(0) if x.dim() == 3 else x
|
| 283 |
+
|
| 284 |
+
# b, _, h, w = x.shape
|
| 285 |
+
img = x[:, :3]
|
| 286 |
+
lbl = x[:, 3:].long()
|
| 287 |
+
img = unnormalize_to_zero_to_one(img) if unnorm else img
|
| 288 |
+
saved_img = torch.cat([img, indices_segmentation_to_img(lbl, colors)]) # b * 2, 3, h ,w
|
| 289 |
+
return saved_img
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def trainId2label_fn(train_id_map):
|
| 293 |
+
saved_label_id = torch.zeros_like(train_id_map)
|
| 294 |
+
for t_id, label in trainId2label.items():
|
| 295 |
+
if label.ignoreInEval:
|
| 296 |
+
continue
|
| 297 |
+
saved_label_id[train_id_map == t_id] = label.id
|
| 298 |
+
return saved_label_id
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def change_19_to_255(id_map):
|
| 302 |
+
id_map[id_map == 19] = 255
|
| 303 |
+
return id_map
|
docs/gcdp.png
ADDED
|
Git LFS Details
|
environment.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
channels:
|
| 2 |
+
- pytorch
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- cudatoolkit=11.3.1
|
| 7 |
+
- python=3.9.17
|
| 8 |
+
- pytorch=1.10.1
|
| 9 |
+
- torchvision=0.11.2
|
| 10 |
+
- pip:
|
| 11 |
+
- accelerate==0.21.0
|
| 12 |
+
- annotated-types==0.5.0
|
| 13 |
+
- appdirs==1.4.4
|
| 14 |
+
- attrs==23.1.0
|
| 15 |
+
- autopep8==2.0.2
|
| 16 |
+
- certifi==2023.7.22
|
| 17 |
+
- charset-normalizer==3.2.0
|
| 18 |
+
- cityscapesscripts==2.2.2
|
| 19 |
+
- click==8.1.6
|
| 20 |
+
- coloredlogs==15.0.1
|
| 21 |
+
- contourpy==1.1.0
|
| 22 |
+
- cycler==0.11.0
|
| 23 |
+
- docker-pycreds==0.4.0
|
| 24 |
+
- einops==0.6.1
|
| 25 |
+
- einops-exts==0.0.4
|
| 26 |
+
- ema-pytorch==0.2.3
|
| 27 |
+
- filelock==3.12.2
|
| 28 |
+
- fonttools==4.41.1
|
| 29 |
+
- fsspec==2023.6.0
|
| 30 |
+
- gitdb==4.0.10
|
| 31 |
+
- gitpython==3.1.32
|
| 32 |
+
- huggingface-hub==0.16.4
|
| 33 |
+
- humanfriendly==10.0
|
| 34 |
+
- idna==3.4
|
| 35 |
+
- importlib-resources==6.0.0
|
| 36 |
+
- jsonlines==3.1.0
|
| 37 |
+
- kiwisolver==1.4.4
|
| 38 |
+
- kornia==0.6.12
|
| 39 |
+
- matplotlib==3.7.2
|
| 40 |
+
- packaging==23.1
|
| 41 |
+
- pathtools==0.1.2
|
| 42 |
+
- protobuf==4.23.4
|
| 43 |
+
- psutil==5.9.5
|
| 44 |
+
- pycodestyle==2.11.0
|
| 45 |
+
- pydantic==2.1.1
|
| 46 |
+
- pydantic-core==2.4.0
|
| 47 |
+
- pyparsing==3.0.9
|
| 48 |
+
- pyquaternion==0.9.9
|
| 49 |
+
- python-dateutil==2.8.2
|
| 50 |
+
- pytorch-warmup==0.1.1
|
| 51 |
+
- pyyaml==6.0.1
|
| 52 |
+
- regex==2023.6.3
|
| 53 |
+
- requests==2.31.0
|
| 54 |
+
- safetensors==0.3.1
|
| 55 |
+
- sentencepiece==0.1.99
|
| 56 |
+
- sentry-sdk==1.28.1
|
| 57 |
+
- setproctitle==1.3.2
|
| 58 |
+
- smmap==5.0.0
|
| 59 |
+
- tokenizers==0.13.3
|
| 60 |
+
- tomli==2.0.1
|
| 61 |
+
- tqdm==4.65.0
|
| 62 |
+
- transformers==4.31.0
|
| 63 |
+
- typing==3.7.4.3
|
| 64 |
+
- urllib3==2.0.4
|
| 65 |
+
- wandb==0.15.7
|
| 66 |
+
- zipp==3.16.2
|
example.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
# Load the image
|
| 5 |
+
image_path = '/home/zouxuechao/mmface/gcdp/repaint/data/datasets/gt_keep_masks/gcdp/base_0_0.png'
|
| 6 |
+
image = Image.open(image_path)
|
| 7 |
+
|
| 8 |
+
# Convert the image to a numpy array
|
| 9 |
+
image_array = np.array(image)
|
| 10 |
+
|
| 11 |
+
# Check if the image is 256x128 with 3 channels (RGB)
|
| 12 |
+
if image_array.shape == (128, 256, 3):
|
| 13 |
+
# Modify the left 128x128 section to be 0 (black) for all 3 channels
|
| 14 |
+
image_array[:, :128, :] = 0
|
| 15 |
+
|
| 16 |
+
# Modify the right 128x128 section to be 255 (white) for all 3 channels
|
| 17 |
+
image_array[:, 128:, :] = 255
|
| 18 |
+
|
| 19 |
+
# Convert back to an image
|
| 20 |
+
modified_image = Image.fromarray(image_array)
|
| 21 |
+
|
| 22 |
+
# Save the modified image
|
| 23 |
+
modified_image.save('modified_image.png')
|
| 24 |
+
else:
|
| 25 |
+
print("The image does not have the required size (256x128 with 3 channels).")
|
imagen_pytorch/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from imagen_pytorch.imagen_pytorch import Imagen, Unet
|
| 2 |
+
from imagen_pytorch.imagen_pytorch import NullUnet
|
| 3 |
+
from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
|
| 4 |
+
from imagen_pytorch.trainer import ImagenTrainer
|
| 5 |
+
from imagen_pytorch.version import __version__
|
| 6 |
+
|
| 7 |
+
# imagen using the elucidated ddpm from Tero Karras' new paper
|
| 8 |
+
|
| 9 |
+
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
|
| 10 |
+
|
| 11 |
+
# config driven creation of imagen instances
|
| 12 |
+
|
| 13 |
+
from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig
|
| 14 |
+
|
| 15 |
+
# utils
|
| 16 |
+
|
| 17 |
+
from imagen_pytorch.utils import load_imagen_from_checkpoint
|
| 18 |
+
|
| 19 |
+
# video
|
| 20 |
+
|
| 21 |
+
from imagen_pytorch.imagen_video import Unet3D
|
| 22 |
+
|
| 23 |
+
# joint
|
| 24 |
+
|
| 25 |
+
from imagen_pytorch.joint_imagen import BaseJointUnet, JointImagen, SRJointUnet
|
| 26 |
+
from imagen_pytorch.trainer import ImagenTrainer, JointImagenTrainer
|
imagen_pytorch/cli.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from imagen_pytorch import load_imagen_from_checkpoint
|
| 6 |
+
from imagen_pytorch.version import __version__
|
| 7 |
+
from imagen_pytorch.utils import safeget
|
| 8 |
+
|
| 9 |
+
def exists(val):
|
| 10 |
+
return val is not None
|
| 11 |
+
|
| 12 |
+
def simple_slugify(text, max_length = 255):
|
| 13 |
+
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
@click.command()
|
| 19 |
+
@click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
|
| 20 |
+
@click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
|
| 21 |
+
@click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
|
| 22 |
+
@click.argument('text')
|
| 23 |
+
def imagen(
|
| 24 |
+
model,
|
| 25 |
+
cond_scale,
|
| 26 |
+
load_ema,
|
| 27 |
+
text
|
| 28 |
+
):
|
| 29 |
+
model_path = Path(model)
|
| 30 |
+
full_model_path = str(model_path.resolve())
|
| 31 |
+
assert model_path.exists(), f'model not found at {full_model_path}'
|
| 32 |
+
loaded = torch.load(str(model_path))
|
| 33 |
+
|
| 34 |
+
# get version
|
| 35 |
+
|
| 36 |
+
version = safeget(loaded, 'version')
|
| 37 |
+
print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
|
| 38 |
+
|
| 39 |
+
# get imagen parameters and type
|
| 40 |
+
|
| 41 |
+
imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
|
| 42 |
+
imagen.cuda()
|
| 43 |
+
|
| 44 |
+
# generate image
|
| 45 |
+
|
| 46 |
+
pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True)
|
| 47 |
+
|
| 48 |
+
image_path = f'./{simple_slugify(text)}.png'
|
| 49 |
+
pil_image[0].save(image_path)
|
| 50 |
+
|
| 51 |
+
print(f'image saved to {str(image_path)}')
|
| 52 |
+
return
|
imagen_pytorch/configs.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pydantic import BaseModel, validator, root_validator
|
| 3 |
+
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
|
| 7 |
+
from imagen_pytorch.trainer import ImagenTrainer
|
| 8 |
+
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
|
| 9 |
+
from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim
|
| 10 |
+
|
| 11 |
+
# helper functions
|
| 12 |
+
|
| 13 |
+
def exists(val):
|
| 14 |
+
return val is not None
|
| 15 |
+
|
| 16 |
+
def default(val, d):
|
| 17 |
+
return val if exists(val) else d
|
| 18 |
+
|
| 19 |
+
def ListOrTuple(inner_type):
|
| 20 |
+
return Union[List[inner_type], Tuple[inner_type]]
|
| 21 |
+
|
| 22 |
+
def SingleOrList(inner_type):
|
| 23 |
+
return Union[inner_type, ListOrTuple(inner_type)]
|
| 24 |
+
|
| 25 |
+
# noise schedule
|
| 26 |
+
|
| 27 |
+
class NoiseSchedule(Enum):
|
| 28 |
+
cosine = 'cosine'
|
| 29 |
+
linear = 'linear'
|
| 30 |
+
|
| 31 |
+
class AllowExtraBaseModel(BaseModel):
|
| 32 |
+
class Config:
|
| 33 |
+
extra = "allow"
|
| 34 |
+
use_enum_values = True
|
| 35 |
+
|
| 36 |
+
# imagen pydantic classes
|
| 37 |
+
|
| 38 |
+
class NullUnetConfig(BaseModel):
|
| 39 |
+
is_null: bool
|
| 40 |
+
|
| 41 |
+
def create(self):
|
| 42 |
+
return NullUnet()
|
| 43 |
+
|
| 44 |
+
class UnetConfig(AllowExtraBaseModel):
|
| 45 |
+
dim: int
|
| 46 |
+
dim_mults: ListOrTuple(int)
|
| 47 |
+
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
|
| 48 |
+
cond_dim: int = None
|
| 49 |
+
channels: int = 3
|
| 50 |
+
attn_dim_head: int = 32
|
| 51 |
+
attn_heads: int = 16
|
| 52 |
+
|
| 53 |
+
def create(self):
|
| 54 |
+
return Unet(**self.dict())
|
| 55 |
+
|
| 56 |
+
class Unet3DConfig(AllowExtraBaseModel):
|
| 57 |
+
dim: int
|
| 58 |
+
dim_mults: ListOrTuple(int)
|
| 59 |
+
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
|
| 60 |
+
cond_dim: int = None
|
| 61 |
+
channels: int = 3
|
| 62 |
+
attn_dim_head: int = 32
|
| 63 |
+
attn_heads: int = 16
|
| 64 |
+
|
| 65 |
+
def create(self):
|
| 66 |
+
return Unet3D(**self.dict())
|
| 67 |
+
|
| 68 |
+
class ImagenConfig(AllowExtraBaseModel):
|
| 69 |
+
unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
|
| 70 |
+
image_sizes: ListOrTuple(int)
|
| 71 |
+
video: bool = False
|
| 72 |
+
timesteps: SingleOrList(int) = 1000
|
| 73 |
+
noise_schedules: SingleOrList(NoiseSchedule) = 'cosine'
|
| 74 |
+
text_encoder_name: str = DEFAULT_T5_NAME
|
| 75 |
+
channels: int = 3
|
| 76 |
+
loss_type: str = 'l2'
|
| 77 |
+
cond_drop_prob: float = 0.5
|
| 78 |
+
|
| 79 |
+
@validator('image_sizes')
|
| 80 |
+
def check_image_sizes(cls, image_sizes, values):
|
| 81 |
+
unets = values.get('unets')
|
| 82 |
+
if len(image_sizes) != len(unets):
|
| 83 |
+
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
|
| 84 |
+
return image_sizes
|
| 85 |
+
|
| 86 |
+
def create(self):
|
| 87 |
+
decoder_kwargs = self.dict()
|
| 88 |
+
unets_kwargs = decoder_kwargs.pop('unets')
|
| 89 |
+
is_video = decoder_kwargs.pop('video', False)
|
| 90 |
+
|
| 91 |
+
unets = []
|
| 92 |
+
|
| 93 |
+
for unet, unet_kwargs in zip(self.unets, unets_kwargs):
|
| 94 |
+
if isinstance(unet, NullUnetConfig):
|
| 95 |
+
unet_klass = NullUnet
|
| 96 |
+
elif is_video:
|
| 97 |
+
unet_klass = Unet3D
|
| 98 |
+
else:
|
| 99 |
+
unet_klass = Unet
|
| 100 |
+
|
| 101 |
+
unets.append(unet_klass(**unet_kwargs))
|
| 102 |
+
|
| 103 |
+
imagen = Imagen(unets, **decoder_kwargs)
|
| 104 |
+
|
| 105 |
+
imagen._config = self.dict().copy()
|
| 106 |
+
return imagen
|
| 107 |
+
|
| 108 |
+
class ElucidatedImagenConfig(AllowExtraBaseModel):
|
| 109 |
+
unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
|
| 110 |
+
image_sizes: ListOrTuple(int)
|
| 111 |
+
video: bool = False
|
| 112 |
+
text_encoder_name: str = DEFAULT_T5_NAME
|
| 113 |
+
channels: int = 3
|
| 114 |
+
cond_drop_prob: float = 0.5
|
| 115 |
+
num_sample_steps: SingleOrList(int) = 32
|
| 116 |
+
sigma_min: SingleOrList(float) = 0.002
|
| 117 |
+
sigma_max: SingleOrList(int) = 80
|
| 118 |
+
sigma_data: SingleOrList(float) = 0.5
|
| 119 |
+
rho: SingleOrList(int) = 7
|
| 120 |
+
P_mean: SingleOrList(float) = -1.2
|
| 121 |
+
P_std: SingleOrList(float) = 1.2
|
| 122 |
+
S_churn: SingleOrList(int) = 80
|
| 123 |
+
S_tmin: SingleOrList(float) = 0.05
|
| 124 |
+
S_tmax: SingleOrList(int) = 50
|
| 125 |
+
S_noise: SingleOrList(float) = 1.003
|
| 126 |
+
|
| 127 |
+
@validator('image_sizes')
|
| 128 |
+
def check_image_sizes(cls, image_sizes, values):
|
| 129 |
+
unets = values.get('unets')
|
| 130 |
+
if len(image_sizes) != len(unets):
|
| 131 |
+
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
|
| 132 |
+
return image_sizes
|
| 133 |
+
|
| 134 |
+
def create(self):
|
| 135 |
+
decoder_kwargs = self.dict()
|
| 136 |
+
unets_kwargs = decoder_kwargs.pop('unets')
|
| 137 |
+
is_video = decoder_kwargs.pop('video', False)
|
| 138 |
+
|
| 139 |
+
unet_klass = Unet3D if is_video else Unet
|
| 140 |
+
|
| 141 |
+
unets = []
|
| 142 |
+
|
| 143 |
+
for unet, unet_kwargs in zip(self.unets, unets_kwargs):
|
| 144 |
+
if isinstance(unet, NullUnetConfig):
|
| 145 |
+
unet_klass = NullUnet
|
| 146 |
+
elif is_video:
|
| 147 |
+
unet_klass = Unet3D
|
| 148 |
+
else:
|
| 149 |
+
unet_klass = Unet
|
| 150 |
+
|
| 151 |
+
unets.append(unet_klass(**unet_kwargs))
|
| 152 |
+
|
| 153 |
+
imagen = ElucidatedImagen(unets, **decoder_kwargs)
|
| 154 |
+
|
| 155 |
+
imagen._config = self.dict().copy()
|
| 156 |
+
return imagen
|
| 157 |
+
|
| 158 |
+
class ImagenTrainerConfig(AllowExtraBaseModel):
|
| 159 |
+
imagen: dict
|
| 160 |
+
elucidated: bool = False
|
| 161 |
+
video: bool = False
|
| 162 |
+
use_ema: bool = True
|
| 163 |
+
lr: SingleOrList(float) = 1e-4
|
| 164 |
+
eps: SingleOrList(float) = 1e-8
|
| 165 |
+
beta1: float = 0.9
|
| 166 |
+
beta2: float = 0.99
|
| 167 |
+
max_grad_norm: Optional[float] = None
|
| 168 |
+
group_wd_params: bool = True
|
| 169 |
+
warmup_steps: SingleOrList(Optional[int]) = None
|
| 170 |
+
cosine_decay_max_steps: SingleOrList(Optional[int]) = None
|
| 171 |
+
|
| 172 |
+
def create(self):
|
| 173 |
+
trainer_kwargs = self.dict()
|
| 174 |
+
|
| 175 |
+
imagen_config = trainer_kwargs.pop('imagen')
|
| 176 |
+
elucidated = trainer_kwargs.pop('elucidated')
|
| 177 |
+
|
| 178 |
+
imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
|
| 179 |
+
imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()
|
| 180 |
+
|
| 181 |
+
return ImagenTrainer(imagen, **trainer_kwargs)
|
imagen_pytorch/data.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from torchvision import transforms as T, utils
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
# helpers functions
|
| 12 |
+
|
| 13 |
+
def exists(val):
|
| 14 |
+
return val is not None
|
| 15 |
+
|
| 16 |
+
def cycle(dl):
|
| 17 |
+
while True:
|
| 18 |
+
for data in dl:
|
| 19 |
+
yield data
|
| 20 |
+
|
| 21 |
+
def convert_image_to(img_type, image):
|
| 22 |
+
if image.mode != img_type:
|
| 23 |
+
return image.convert(img_type)
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
# dataset and dataloader
|
| 27 |
+
|
| 28 |
+
class Dataset(Dataset):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
folder,
|
| 32 |
+
image_size,
|
| 33 |
+
exts = ['jpg', 'jpeg', 'png', 'tiff'],
|
| 34 |
+
convert_image_to_type = None
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.folder = folder
|
| 38 |
+
self.image_size = image_size
|
| 39 |
+
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
| 40 |
+
|
| 41 |
+
convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
|
| 42 |
+
|
| 43 |
+
self.transform = T.Compose([
|
| 44 |
+
T.Lambda(convert_fn),
|
| 45 |
+
T.Resize(image_size),
|
| 46 |
+
T.RandomHorizontalFlip(),
|
| 47 |
+
T.CenterCrop(image_size),
|
| 48 |
+
T.ToTensor()
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.paths)
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, index):
|
| 55 |
+
path = self.paths[index]
|
| 56 |
+
img = Image.open(path)
|
| 57 |
+
return self.transform(img)
|
| 58 |
+
|
| 59 |
+
def get_images_dataloader(
|
| 60 |
+
folder,
|
| 61 |
+
*,
|
| 62 |
+
batch_size,
|
| 63 |
+
image_size,
|
| 64 |
+
shuffle = True,
|
| 65 |
+
cycle_dl = False,
|
| 66 |
+
pin_memory = True
|
| 67 |
+
):
|
| 68 |
+
ds = Dataset(folder, image_size)
|
| 69 |
+
dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
| 70 |
+
|
| 71 |
+
if cycle_dl:
|
| 72 |
+
dl = cycle(dl)
|
| 73 |
+
return dl
|
imagen_pytorch/elucidated_imagen.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
+
from random import random
|
| 3 |
+
from functools import partial
|
| 4 |
+
from contextlib import contextmanager, nullcontext
|
| 5 |
+
from typing import List, Union
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import nn, einsum
|
| 12 |
+
from torch.cuda.amp import autocast
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 14 |
+
import torchvision.transforms as T
|
| 15 |
+
|
| 16 |
+
import kornia.augmentation as K
|
| 17 |
+
|
| 18 |
+
from einops import rearrange, repeat, reduce
|
| 19 |
+
from einops_exts import rearrange_many
|
| 20 |
+
|
| 21 |
+
from imagen_pytorch.imagen_pytorch import (
|
| 22 |
+
GaussianDiffusionContinuousTimes,
|
| 23 |
+
Unet,
|
| 24 |
+
NullUnet,
|
| 25 |
+
first,
|
| 26 |
+
exists,
|
| 27 |
+
identity,
|
| 28 |
+
maybe,
|
| 29 |
+
default,
|
| 30 |
+
cast_tuple,
|
| 31 |
+
cast_uint8_images_to_float,
|
| 32 |
+
is_float_dtype,
|
| 33 |
+
eval_decorator,
|
| 34 |
+
check_shape,
|
| 35 |
+
pad_tuple_to_length,
|
| 36 |
+
resize_image_to,
|
| 37 |
+
right_pad_dims_to,
|
| 38 |
+
module_device,
|
| 39 |
+
normalize_neg_one_to_one,
|
| 40 |
+
unnormalize_zero_to_one,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
from imagen_pytorch.imagen_video.imagen_video import (
|
| 44 |
+
Unet3D,
|
| 45 |
+
resize_video_to
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
|
| 49 |
+
|
| 50 |
+
# constants
|
| 51 |
+
|
| 52 |
+
Hparams_fields = [
|
| 53 |
+
'num_sample_steps',
|
| 54 |
+
'sigma_min',
|
| 55 |
+
'sigma_max',
|
| 56 |
+
'sigma_data',
|
| 57 |
+
'rho',
|
| 58 |
+
'P_mean',
|
| 59 |
+
'P_std',
|
| 60 |
+
'S_churn',
|
| 61 |
+
'S_tmin',
|
| 62 |
+
'S_tmax',
|
| 63 |
+
'S_noise'
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
Hparams = namedtuple('Hparams', Hparams_fields)
|
| 67 |
+
|
| 68 |
+
# helper functions
|
| 69 |
+
|
| 70 |
+
def log(t, eps = 1e-20):
|
| 71 |
+
return torch.log(t.clamp(min = eps))
|
| 72 |
+
|
| 73 |
+
# main class
|
| 74 |
+
|
| 75 |
+
class ElucidatedImagen(nn.Module):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
unets,
|
| 79 |
+
*,
|
| 80 |
+
image_sizes, # for cascading ddpm, image size at each stage
|
| 81 |
+
text_encoder_name = DEFAULT_T5_NAME,
|
| 82 |
+
text_embed_dim = None,
|
| 83 |
+
channels = 3,
|
| 84 |
+
cond_drop_prob = 0.1,
|
| 85 |
+
random_crop_sizes = None,
|
| 86 |
+
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
|
| 87 |
+
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
|
| 88 |
+
condition_on_text = True,
|
| 89 |
+
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
| 90 |
+
dynamic_thresholding = True,
|
| 91 |
+
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
|
| 92 |
+
only_train_unet_number = None,
|
| 93 |
+
lowres_noise_schedule = 'linear',
|
| 94 |
+
num_sample_steps = 32, # number of sampling steps
|
| 95 |
+
sigma_min = 0.002, # min noise level
|
| 96 |
+
sigma_max = 80, # max noise level
|
| 97 |
+
sigma_data = 0.5, # standard deviation of data distribution
|
| 98 |
+
rho = 7, # controls the sampling schedule
|
| 99 |
+
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
|
| 100 |
+
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
|
| 101 |
+
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
|
| 102 |
+
S_tmin = 0.05,
|
| 103 |
+
S_tmax = 50,
|
| 104 |
+
S_noise = 1.003,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
self.only_train_unet_number = only_train_unet_number
|
| 109 |
+
|
| 110 |
+
# conditioning hparams
|
| 111 |
+
|
| 112 |
+
self.condition_on_text = condition_on_text
|
| 113 |
+
self.unconditional = not condition_on_text
|
| 114 |
+
|
| 115 |
+
# channels
|
| 116 |
+
|
| 117 |
+
self.channels = channels
|
| 118 |
+
|
| 119 |
+
# automatically take care of ensuring that first unet is unconditional
|
| 120 |
+
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
| 121 |
+
|
| 122 |
+
unets = cast_tuple(unets)
|
| 123 |
+
num_unets = len(unets)
|
| 124 |
+
|
| 125 |
+
# randomly cropping for upsampler training
|
| 126 |
+
|
| 127 |
+
self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
|
| 128 |
+
assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
|
| 129 |
+
|
| 130 |
+
# lowres augmentation noise schedule
|
| 131 |
+
|
| 132 |
+
self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
|
| 133 |
+
|
| 134 |
+
# get text encoder
|
| 135 |
+
|
| 136 |
+
self.text_encoder_name = text_encoder_name
|
| 137 |
+
self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
|
| 138 |
+
|
| 139 |
+
self.encode_text = partial(t5_encode_text, name = text_encoder_name)
|
| 140 |
+
|
| 141 |
+
# construct unets
|
| 142 |
+
|
| 143 |
+
self.unets = nn.ModuleList([])
|
| 144 |
+
self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
|
| 145 |
+
|
| 146 |
+
for ind, one_unet in enumerate(unets):
|
| 147 |
+
assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
|
| 148 |
+
is_first = ind == 0
|
| 149 |
+
|
| 150 |
+
one_unet = one_unet.cast_model_parameters(
|
| 151 |
+
lowres_cond = not is_first,
|
| 152 |
+
cond_on_text = self.condition_on_text,
|
| 153 |
+
text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
|
| 154 |
+
channels = self.channels,
|
| 155 |
+
channels_out = self.channels
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.unets.append(one_unet)
|
| 159 |
+
|
| 160 |
+
# determine whether we are training on images or video
|
| 161 |
+
|
| 162 |
+
is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
|
| 163 |
+
self.is_video = is_video
|
| 164 |
+
|
| 165 |
+
self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
|
| 166 |
+
self.resize_to = resize_video_to if is_video else resize_image_to
|
| 167 |
+
|
| 168 |
+
# unet image sizes
|
| 169 |
+
|
| 170 |
+
self.image_sizes = cast_tuple(self.image_sizes)
|
| 171 |
+
assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}'
|
| 172 |
+
|
| 173 |
+
self.sample_channels = cast_tuple(self.channels, num_unets)
|
| 174 |
+
|
| 175 |
+
# cascading ddpm related stuff
|
| 176 |
+
|
| 177 |
+
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
| 178 |
+
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
| 179 |
+
|
| 180 |
+
self.lowres_sample_noise_level = lowres_sample_noise_level
|
| 181 |
+
self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
|
| 182 |
+
|
| 183 |
+
# classifier free guidance
|
| 184 |
+
|
| 185 |
+
self.cond_drop_prob = cond_drop_prob
|
| 186 |
+
self.can_classifier_guidance = cond_drop_prob > 0.
|
| 187 |
+
|
| 188 |
+
# normalize and unnormalize image functions
|
| 189 |
+
|
| 190 |
+
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
| 191 |
+
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
| 192 |
+
self.input_image_range = (0. if auto_normalize_img else -1., 1.)
|
| 193 |
+
|
| 194 |
+
# dynamic thresholding
|
| 195 |
+
|
| 196 |
+
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
|
| 197 |
+
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
|
| 198 |
+
|
| 199 |
+
# elucidating parameters
|
| 200 |
+
|
| 201 |
+
hparams = [
|
| 202 |
+
num_sample_steps,
|
| 203 |
+
sigma_min,
|
| 204 |
+
sigma_max,
|
| 205 |
+
sigma_data,
|
| 206 |
+
rho,
|
| 207 |
+
P_mean,
|
| 208 |
+
P_std,
|
| 209 |
+
S_churn,
|
| 210 |
+
S_tmin,
|
| 211 |
+
S_tmax,
|
| 212 |
+
S_noise,
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
hparams = [cast_tuple(hp, num_unets) for hp in hparams]
|
| 216 |
+
self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)]
|
| 217 |
+
|
| 218 |
+
# one temp parameter for keeping track of device
|
| 219 |
+
|
| 220 |
+
self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
|
| 221 |
+
|
| 222 |
+
# default to device of unets passed in
|
| 223 |
+
|
| 224 |
+
self.to(next(self.unets.parameters()).device)
|
| 225 |
+
|
| 226 |
+
def force_unconditional_(self):
|
| 227 |
+
self.condition_on_text = False
|
| 228 |
+
self.unconditional = True
|
| 229 |
+
|
| 230 |
+
for unet in self.unets:
|
| 231 |
+
unet.cond_on_text = False
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def device(self):
|
| 235 |
+
return self._temp.device
|
| 236 |
+
|
| 237 |
+
def get_unet(self, unet_number):
|
| 238 |
+
assert 0 < unet_number <= len(self.unets)
|
| 239 |
+
index = unet_number - 1
|
| 240 |
+
|
| 241 |
+
if isinstance(self.unets, nn.ModuleList):
|
| 242 |
+
unets_list = [unet for unet in self.unets]
|
| 243 |
+
delattr(self, 'unets')
|
| 244 |
+
self.unets = unets_list
|
| 245 |
+
|
| 246 |
+
if index != self.unet_being_trained_index:
|
| 247 |
+
for unet_index, unet in enumerate(self.unets):
|
| 248 |
+
unet.to(self.device if unet_index == index else 'cpu')
|
| 249 |
+
|
| 250 |
+
self.unet_being_trained_index = index
|
| 251 |
+
return self.unets[index]
|
| 252 |
+
|
| 253 |
+
def reset_unets_all_one_device(self, device = None):
|
| 254 |
+
device = default(device, self.device)
|
| 255 |
+
self.unets = nn.ModuleList([*self.unets])
|
| 256 |
+
self.unets.to(device)
|
| 257 |
+
|
| 258 |
+
self.unet_being_trained_index = -1
|
| 259 |
+
|
| 260 |
+
@contextmanager
|
| 261 |
+
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
| 262 |
+
assert exists(unet_number) ^ exists(unet)
|
| 263 |
+
|
| 264 |
+
if exists(unet_number):
|
| 265 |
+
unet = self.unets[unet_number - 1]
|
| 266 |
+
|
| 267 |
+
devices = [module_device(unet) for unet in self.unets]
|
| 268 |
+
self.unets.cpu()
|
| 269 |
+
unet.to(self.device)
|
| 270 |
+
|
| 271 |
+
yield
|
| 272 |
+
|
| 273 |
+
for unet, device in zip(self.unets, devices):
|
| 274 |
+
unet.to(device)
|
| 275 |
+
|
| 276 |
+
# overriding state dict functions
|
| 277 |
+
|
| 278 |
+
def state_dict(self, *args, **kwargs):
|
| 279 |
+
self.reset_unets_all_one_device()
|
| 280 |
+
return super().state_dict(*args, **kwargs)
|
| 281 |
+
|
| 282 |
+
def load_state_dict(self, *args, **kwargs):
|
| 283 |
+
self.reset_unets_all_one_device()
|
| 284 |
+
return super().load_state_dict(*args, **kwargs)
|
| 285 |
+
|
| 286 |
+
# dynamic thresholding
|
| 287 |
+
|
| 288 |
+
def threshold_x_start(self, x_start, dynamic_threshold = True):
|
| 289 |
+
if not dynamic_threshold:
|
| 290 |
+
return x_start.clamp(-1., 1.)
|
| 291 |
+
|
| 292 |
+
s = torch.quantile(
|
| 293 |
+
rearrange(x_start, 'b ... -> b (...)').abs(),
|
| 294 |
+
self.dynamic_thresholding_percentile,
|
| 295 |
+
dim = -1
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
s.clamp_(min = 1.)
|
| 299 |
+
s = right_pad_dims_to(x_start, s)
|
| 300 |
+
return x_start.clamp(-s, s) / s
|
| 301 |
+
|
| 302 |
+
# derived preconditioning params - Table 1
|
| 303 |
+
|
| 304 |
+
def c_skip(self, sigma_data, sigma):
|
| 305 |
+
return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
|
| 306 |
+
|
| 307 |
+
def c_out(self, sigma_data, sigma):
|
| 308 |
+
return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
|
| 309 |
+
|
| 310 |
+
def c_in(self, sigma_data, sigma):
|
| 311 |
+
return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
|
| 312 |
+
|
| 313 |
+
def c_noise(self, sigma):
|
| 314 |
+
return log(sigma) * 0.25
|
| 315 |
+
|
| 316 |
+
# preconditioned network output
|
| 317 |
+
# equation (7) in the paper
|
| 318 |
+
|
| 319 |
+
def preconditioned_network_forward(
|
| 320 |
+
self,
|
| 321 |
+
unet_forward,
|
| 322 |
+
noised_images,
|
| 323 |
+
sigma,
|
| 324 |
+
*,
|
| 325 |
+
sigma_data,
|
| 326 |
+
clamp = False,
|
| 327 |
+
dynamic_threshold = True,
|
| 328 |
+
**kwargs
|
| 329 |
+
):
|
| 330 |
+
batch, device = noised_images.shape[0], noised_images.device
|
| 331 |
+
|
| 332 |
+
if isinstance(sigma, float):
|
| 333 |
+
sigma = torch.full((batch,), sigma, device = device)
|
| 334 |
+
|
| 335 |
+
padded_sigma = self.right_pad_dims_to_datatype(sigma)
|
| 336 |
+
|
| 337 |
+
net_out = unet_forward(
|
| 338 |
+
self.c_in(sigma_data, padded_sigma) * noised_images,
|
| 339 |
+
self.c_noise(sigma),
|
| 340 |
+
**kwargs
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out
|
| 344 |
+
|
| 345 |
+
if not clamp:
|
| 346 |
+
return out
|
| 347 |
+
|
| 348 |
+
return self.threshold_x_start(out, dynamic_threshold)
|
| 349 |
+
|
| 350 |
+
# sampling
|
| 351 |
+
|
| 352 |
+
# sample schedule
|
| 353 |
+
# equation (5) in the paper
|
| 354 |
+
|
| 355 |
+
def sample_schedule(
|
| 356 |
+
self,
|
| 357 |
+
num_sample_steps,
|
| 358 |
+
rho,
|
| 359 |
+
sigma_min,
|
| 360 |
+
sigma_max
|
| 361 |
+
):
|
| 362 |
+
N = num_sample_steps
|
| 363 |
+
inv_rho = 1 / rho
|
| 364 |
+
|
| 365 |
+
steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
|
| 366 |
+
sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
|
| 367 |
+
|
| 368 |
+
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
|
| 369 |
+
return sigmas
|
| 370 |
+
|
| 371 |
+
@torch.no_grad()
|
| 372 |
+
def one_unet_sample(
|
| 373 |
+
self,
|
| 374 |
+
unet,
|
| 375 |
+
shape,
|
| 376 |
+
*,
|
| 377 |
+
unet_number,
|
| 378 |
+
clamp = True,
|
| 379 |
+
dynamic_threshold = True,
|
| 380 |
+
cond_scale = 1.,
|
| 381 |
+
use_tqdm = True,
|
| 382 |
+
inpaint_images = None,
|
| 383 |
+
inpaint_masks = None,
|
| 384 |
+
inpaint_resample_times = 5,
|
| 385 |
+
init_images = None,
|
| 386 |
+
skip_steps = None,
|
| 387 |
+
sigma_min = None,
|
| 388 |
+
sigma_max = None,
|
| 389 |
+
**kwargs
|
| 390 |
+
):
|
| 391 |
+
# get specific sampling hyperparameters for unet
|
| 392 |
+
|
| 393 |
+
hp = self.hparams[unet_number - 1]
|
| 394 |
+
|
| 395 |
+
sigma_min = default(sigma_min, hp.sigma_min)
|
| 396 |
+
sigma_max = default(sigma_max, hp.sigma_max)
|
| 397 |
+
|
| 398 |
+
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
|
| 399 |
+
|
| 400 |
+
sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max)
|
| 401 |
+
|
| 402 |
+
gammas = torch.where(
|
| 403 |
+
(sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax),
|
| 404 |
+
min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1),
|
| 405 |
+
0.
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
|
| 409 |
+
|
| 410 |
+
# images is noise at the beginning
|
| 411 |
+
|
| 412 |
+
init_sigma = sigmas[0]
|
| 413 |
+
|
| 414 |
+
images = init_sigma * torch.randn(shape, device = self.device)
|
| 415 |
+
|
| 416 |
+
# initializing with an image
|
| 417 |
+
|
| 418 |
+
if exists(init_images):
|
| 419 |
+
images += init_images
|
| 420 |
+
|
| 421 |
+
# keeping track of x0, for self conditioning if needed
|
| 422 |
+
|
| 423 |
+
x_start = None
|
| 424 |
+
|
| 425 |
+
# prepare inpainting images and mask
|
| 426 |
+
|
| 427 |
+
has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
|
| 428 |
+
resample_times = inpaint_resample_times if has_inpainting else 1
|
| 429 |
+
|
| 430 |
+
if has_inpainting:
|
| 431 |
+
inpaint_images = self.normalize_img(inpaint_images)
|
| 432 |
+
inpaint_images = self.resize_to(inpaint_images, shape[-1])
|
| 433 |
+
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()
|
| 434 |
+
|
| 435 |
+
# unet kwargs
|
| 436 |
+
|
| 437 |
+
unet_kwargs = dict(
|
| 438 |
+
sigma_data = hp.sigma_data,
|
| 439 |
+
clamp = clamp,
|
| 440 |
+
dynamic_threshold = dynamic_threshold,
|
| 441 |
+
cond_scale = cond_scale,
|
| 442 |
+
**kwargs
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# gradually denoise
|
| 446 |
+
|
| 447 |
+
initial_step = default(skip_steps, 0)
|
| 448 |
+
sigmas_and_gammas = sigmas_and_gammas[initial_step:]
|
| 449 |
+
|
| 450 |
+
total_steps = len(sigmas_and_gammas)
|
| 451 |
+
|
| 452 |
+
for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm):
|
| 453 |
+
is_last_timestep = ind == (total_steps - 1)
|
| 454 |
+
|
| 455 |
+
sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
|
| 456 |
+
|
| 457 |
+
for r in reversed(range(resample_times)):
|
| 458 |
+
is_last_resample_step = r == 0
|
| 459 |
+
|
| 460 |
+
eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
|
| 461 |
+
|
| 462 |
+
sigma_hat = sigma + gamma * sigma
|
| 463 |
+
added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps
|
| 464 |
+
|
| 465 |
+
images_hat = images + added_noise
|
| 466 |
+
|
| 467 |
+
self_cond = x_start if unet.self_cond else None
|
| 468 |
+
|
| 469 |
+
if has_inpainting:
|
| 470 |
+
images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks
|
| 471 |
+
|
| 472 |
+
model_output = self.preconditioned_network_forward(
|
| 473 |
+
unet.forward_with_cond_scale,
|
| 474 |
+
images_hat,
|
| 475 |
+
sigma_hat,
|
| 476 |
+
self_cond = self_cond,
|
| 477 |
+
**unet_kwargs
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
denoised_over_sigma = (images_hat - model_output) / sigma_hat
|
| 481 |
+
|
| 482 |
+
images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
|
| 483 |
+
|
| 484 |
+
# second order correction, if not the last timestep
|
| 485 |
+
|
| 486 |
+
if sigma_next != 0:
|
| 487 |
+
self_cond = model_output if unet.self_cond else None
|
| 488 |
+
|
| 489 |
+
model_output_next = self.preconditioned_network_forward(
|
| 490 |
+
unet.forward_with_cond_scale,
|
| 491 |
+
images_next,
|
| 492 |
+
sigma_next,
|
| 493 |
+
self_cond = self_cond,
|
| 494 |
+
**unet_kwargs
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
|
| 498 |
+
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
|
| 499 |
+
|
| 500 |
+
images = images_next
|
| 501 |
+
|
| 502 |
+
if has_inpainting and not (is_last_resample_step or is_last_timestep):
|
| 503 |
+
# renoise in repaint and then resample
|
| 504 |
+
repaint_noise = torch.randn(shape, device = self.device)
|
| 505 |
+
images = images + (sigma - sigma_next) * repaint_noise
|
| 506 |
+
|
| 507 |
+
x_start = model_output # save model output for self conditioning
|
| 508 |
+
|
| 509 |
+
images = images.clamp(-1., 1.)
|
| 510 |
+
|
| 511 |
+
if has_inpainting:
|
| 512 |
+
images = images * ~inpaint_masks + inpaint_images * inpaint_masks
|
| 513 |
+
|
| 514 |
+
return self.unnormalize_img(images)
|
| 515 |
+
|
| 516 |
+
@torch.no_grad()
|
| 517 |
+
@eval_decorator
|
| 518 |
+
def sample(
|
| 519 |
+
self,
|
| 520 |
+
texts: List[str] = None,
|
| 521 |
+
text_masks = None,
|
| 522 |
+
text_embeds = None,
|
| 523 |
+
cond_images = None,
|
| 524 |
+
inpaint_images = None,
|
| 525 |
+
inpaint_masks = None,
|
| 526 |
+
inpaint_resample_times = 5,
|
| 527 |
+
init_images = None,
|
| 528 |
+
skip_steps = None,
|
| 529 |
+
sigma_min = None,
|
| 530 |
+
sigma_max = None,
|
| 531 |
+
video_frames = None,
|
| 532 |
+
batch_size = 1,
|
| 533 |
+
cond_scale = 1.,
|
| 534 |
+
lowres_sample_noise_level = None,
|
| 535 |
+
start_at_unet_number = 1,
|
| 536 |
+
start_image_or_video = None,
|
| 537 |
+
stop_at_unet_number = None,
|
| 538 |
+
return_all_unet_outputs = False,
|
| 539 |
+
return_pil_images = False,
|
| 540 |
+
use_tqdm = True,
|
| 541 |
+
device = None,
|
| 542 |
+
):
|
| 543 |
+
device = default(device, self.device)
|
| 544 |
+
self.reset_unets_all_one_device(device = device)
|
| 545 |
+
|
| 546 |
+
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
|
| 547 |
+
|
| 548 |
+
if exists(texts) and not exists(text_embeds) and not self.unconditional:
|
| 549 |
+
assert all([*map(len, texts)]), 'text cannot be empty'
|
| 550 |
+
|
| 551 |
+
with autocast(enabled = False):
|
| 552 |
+
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
|
| 553 |
+
|
| 554 |
+
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
|
| 555 |
+
|
| 556 |
+
if not self.unconditional:
|
| 557 |
+
assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
|
| 558 |
+
|
| 559 |
+
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
|
| 560 |
+
batch_size = text_embeds.shape[0]
|
| 561 |
+
|
| 562 |
+
if exists(inpaint_images):
|
| 563 |
+
if self.unconditional:
|
| 564 |
+
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
|
| 565 |
+
batch_size = inpaint_images.shape[0]
|
| 566 |
+
|
| 567 |
+
assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
|
| 568 |
+
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
|
| 569 |
+
|
| 570 |
+
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
|
| 571 |
+
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
|
| 572 |
+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
|
| 573 |
+
|
| 574 |
+
assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
|
| 575 |
+
|
| 576 |
+
outputs = []
|
| 577 |
+
|
| 578 |
+
is_cuda = next(self.parameters()).is_cuda
|
| 579 |
+
device = next(self.parameters()).device
|
| 580 |
+
|
| 581 |
+
lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
|
| 582 |
+
|
| 583 |
+
num_unets = len(self.unets)
|
| 584 |
+
cond_scale = cast_tuple(cond_scale, num_unets)
|
| 585 |
+
|
| 586 |
+
# handle video and frame dimension
|
| 587 |
+
|
| 588 |
+
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
|
| 589 |
+
|
| 590 |
+
frame_dims = (video_frames,) if self.is_video else tuple()
|
| 591 |
+
|
| 592 |
+
# initializing with an image or video
|
| 593 |
+
|
| 594 |
+
init_images = cast_tuple(init_images, num_unets)
|
| 595 |
+
init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
|
| 596 |
+
|
| 597 |
+
skip_steps = cast_tuple(skip_steps, num_unets)
|
| 598 |
+
|
| 599 |
+
sigma_min = cast_tuple(sigma_min, num_unets)
|
| 600 |
+
sigma_max = cast_tuple(sigma_max, num_unets)
|
| 601 |
+
|
| 602 |
+
# handle starting at a unet greater than 1, for training only-upscaler training
|
| 603 |
+
|
| 604 |
+
if start_at_unet_number > 1:
|
| 605 |
+
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
|
| 606 |
+
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
|
| 607 |
+
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
|
| 608 |
+
|
| 609 |
+
prev_image_size = self.image_sizes[start_at_unet_number - 2]
|
| 610 |
+
img = self.resize_to(start_image_or_video, prev_image_size)
|
| 611 |
+
|
| 612 |
+
# go through each unet in cascade
|
| 613 |
+
|
| 614 |
+
for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
|
| 615 |
+
if unet_number < start_at_unet_number:
|
| 616 |
+
continue
|
| 617 |
+
|
| 618 |
+
assert not isinstance(unet, NullUnet), 'cannot sample from null unet'
|
| 619 |
+
|
| 620 |
+
context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
|
| 621 |
+
|
| 622 |
+
with context:
|
| 623 |
+
lowres_cond_img = lowres_noise_times = None
|
| 624 |
+
|
| 625 |
+
shape = (batch_size, channel, *frame_dims, image_size, image_size)
|
| 626 |
+
|
| 627 |
+
if unet.lowres_cond:
|
| 628 |
+
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
|
| 629 |
+
|
| 630 |
+
lowres_cond_img = self.resize_to(img, image_size)
|
| 631 |
+
lowres_cond_img = self.normalize_img(lowres_cond_img)
|
| 632 |
+
|
| 633 |
+
lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
|
| 634 |
+
|
| 635 |
+
if exists(unet_init_images):
|
| 636 |
+
unet_init_images = self.resize_to(unet_init_images, image_size)
|
| 637 |
+
|
| 638 |
+
shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
|
| 639 |
+
|
| 640 |
+
img = self.one_unet_sample(
|
| 641 |
+
unet,
|
| 642 |
+
shape,
|
| 643 |
+
unet_number = unet_number,
|
| 644 |
+
text_embeds = text_embeds,
|
| 645 |
+
text_mask = text_masks,
|
| 646 |
+
cond_images = cond_images,
|
| 647 |
+
inpaint_images = inpaint_images,
|
| 648 |
+
inpaint_masks = inpaint_masks,
|
| 649 |
+
inpaint_resample_times = inpaint_resample_times,
|
| 650 |
+
init_images = unet_init_images,
|
| 651 |
+
skip_steps = unet_skip_steps,
|
| 652 |
+
sigma_min = unet_sigma_min,
|
| 653 |
+
sigma_max = unet_sigma_max,
|
| 654 |
+
cond_scale = unet_cond_scale,
|
| 655 |
+
lowres_cond_img = lowres_cond_img,
|
| 656 |
+
lowres_noise_times = lowres_noise_times,
|
| 657 |
+
dynamic_threshold = dynamic_threshold,
|
| 658 |
+
use_tqdm = use_tqdm
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
outputs.append(img)
|
| 662 |
+
|
| 663 |
+
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
| 664 |
+
break
|
| 665 |
+
|
| 666 |
+
output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
|
| 667 |
+
|
| 668 |
+
if not return_pil_images:
|
| 669 |
+
return outputs[output_index]
|
| 670 |
+
|
| 671 |
+
if not return_all_unet_outputs:
|
| 672 |
+
outputs = outputs[-1:]
|
| 673 |
+
|
| 674 |
+
assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet'
|
| 675 |
+
|
| 676 |
+
pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
|
| 677 |
+
|
| 678 |
+
return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
|
| 679 |
+
|
| 680 |
+
# training
|
| 681 |
+
|
| 682 |
+
def loss_weight(self, sigma_data, sigma):
|
| 683 |
+
return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
|
| 684 |
+
|
| 685 |
+
def noise_distribution(self, P_mean, P_std, batch_size):
|
| 686 |
+
return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
|
| 687 |
+
|
| 688 |
+
def forward(
|
| 689 |
+
self,
|
| 690 |
+
images,
|
| 691 |
+
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
|
| 692 |
+
texts: List[str] = None,
|
| 693 |
+
text_embeds = None,
|
| 694 |
+
text_masks = None,
|
| 695 |
+
unet_number = None,
|
| 696 |
+
cond_images = None
|
| 697 |
+
):
|
| 698 |
+
assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
|
| 699 |
+
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
| 700 |
+
unet_number = default(unet_number, 1)
|
| 701 |
+
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
|
| 702 |
+
|
| 703 |
+
images = cast_uint8_images_to_float(images)
|
| 704 |
+
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
|
| 705 |
+
|
| 706 |
+
assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
|
| 707 |
+
|
| 708 |
+
unet_index = unet_number - 1
|
| 709 |
+
|
| 710 |
+
unet = default(unet, lambda: self.get_unet(unet_number))
|
| 711 |
+
|
| 712 |
+
assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
|
| 713 |
+
|
| 714 |
+
target_image_size = self.image_sizes[unet_index]
|
| 715 |
+
random_crop_size = self.random_crop_sizes[unet_index]
|
| 716 |
+
prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
|
| 717 |
+
hp = self.hparams[unet_index]
|
| 718 |
+
|
| 719 |
+
batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)
|
| 720 |
+
|
| 721 |
+
frames = images.shape[2] if is_video else None
|
| 722 |
+
|
| 723 |
+
check_shape(images, 'b c ...', c = self.channels)
|
| 724 |
+
|
| 725 |
+
assert h >= target_image_size and w >= target_image_size
|
| 726 |
+
|
| 727 |
+
if exists(texts) and not exists(text_embeds) and not self.unconditional:
|
| 728 |
+
assert all([*map(len, texts)]), 'text cannot be empty'
|
| 729 |
+
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
|
| 730 |
+
|
| 731 |
+
with autocast(enabled = False):
|
| 732 |
+
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
|
| 733 |
+
|
| 734 |
+
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
|
| 735 |
+
|
| 736 |
+
if not self.unconditional:
|
| 737 |
+
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
|
| 738 |
+
|
| 739 |
+
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
|
| 740 |
+
assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
|
| 741 |
+
|
| 742 |
+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
|
| 743 |
+
|
| 744 |
+
lowres_cond_img = lowres_aug_times = None
|
| 745 |
+
if exists(prev_image_size):
|
| 746 |
+
lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
|
| 747 |
+
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
|
| 748 |
+
|
| 749 |
+
if self.per_sample_random_aug_noise_level:
|
| 750 |
+
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
|
| 751 |
+
else:
|
| 752 |
+
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
|
| 753 |
+
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
|
| 754 |
+
|
| 755 |
+
images = self.resize_to(images, target_image_size)
|
| 756 |
+
|
| 757 |
+
# normalize to [-1, 1]
|
| 758 |
+
|
| 759 |
+
images = self.normalize_img(images)
|
| 760 |
+
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
| 761 |
+
|
| 762 |
+
# random cropping during training
|
| 763 |
+
# for upsamplers
|
| 764 |
+
|
| 765 |
+
if exists(random_crop_size):
|
| 766 |
+
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
|
| 767 |
+
|
| 768 |
+
if is_video:
|
| 769 |
+
images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h w -> (b f) c h w')
|
| 770 |
+
|
| 771 |
+
# make sure low res conditioner and image both get augmented the same way
|
| 772 |
+
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
|
| 773 |
+
images = aug(images)
|
| 774 |
+
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
| 775 |
+
|
| 776 |
+
if is_video:
|
| 777 |
+
images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h w -> b c f h w', f = frames)
|
| 778 |
+
|
| 779 |
+
# noise the lowres conditioning image
|
| 780 |
+
# at sample time, they then fix the noise level of 0.1 - 0.3
|
| 781 |
+
|
| 782 |
+
lowres_cond_img_noisy = None
|
| 783 |
+
if exists(lowres_cond_img):
|
| 784 |
+
lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
|
| 785 |
+
|
| 786 |
+
# get the sigmas
|
| 787 |
+
|
| 788 |
+
sigmas = self.noise_distribution(hp.P_mean, hp.P_std, batch_size)
|
| 789 |
+
padded_sigmas = self.right_pad_dims_to_datatype(sigmas)
|
| 790 |
+
|
| 791 |
+
# noise
|
| 792 |
+
|
| 793 |
+
noise = torch.randn_like(images)
|
| 794 |
+
noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
|
| 795 |
+
|
| 796 |
+
# unet kwargs
|
| 797 |
+
|
| 798 |
+
unet_kwargs = dict(
|
| 799 |
+
sigma_data = hp.sigma_data,
|
| 800 |
+
text_embeds = text_embeds,
|
| 801 |
+
text_mask = text_masks,
|
| 802 |
+
cond_images = cond_images,
|
| 803 |
+
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
|
| 804 |
+
lowres_cond_img = lowres_cond_img_noisy,
|
| 805 |
+
cond_drop_prob = self.cond_drop_prob,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
# self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
|
| 809 |
+
|
| 810 |
+
# Because 'unet' can be an instance of DistributedDataParallel coming from the
|
| 811 |
+
# ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
|
| 812 |
+
# access the member 'module' of the wrapped unet instance.
|
| 813 |
+
self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet
|
| 814 |
+
|
| 815 |
+
if self_cond and random() < 0.5:
|
| 816 |
+
with torch.no_grad():
|
| 817 |
+
pred_x0 = self.preconditioned_network_forward(
|
| 818 |
+
unet.forward,
|
| 819 |
+
noised_images,
|
| 820 |
+
sigmas,
|
| 821 |
+
**unet_kwargs
|
| 822 |
+
).detach()
|
| 823 |
+
|
| 824 |
+
unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0}
|
| 825 |
+
|
| 826 |
+
# get prediction
|
| 827 |
+
|
| 828 |
+
denoised_images = self.preconditioned_network_forward(
|
| 829 |
+
unet.forward,
|
| 830 |
+
noised_images,
|
| 831 |
+
sigmas,
|
| 832 |
+
**unet_kwargs
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
# losses
|
| 836 |
+
|
| 837 |
+
losses = F.mse_loss(denoised_images, images, reduction = 'none')
|
| 838 |
+
losses = reduce(losses, 'b ... -> b', 'mean')
|
| 839 |
+
|
| 840 |
+
# loss weighting
|
| 841 |
+
|
| 842 |
+
losses = losses * self.loss_weight(hp.sigma_data, sigmas)
|
| 843 |
+
|
| 844 |
+
# return average loss
|
| 845 |
+
|
| 846 |
+
return losses.mean()
|
imagen_pytorch/imagen_pytorch.py
ADDED
|
@@ -0,0 +1,2515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import copy
|
| 3 |
+
from random import random
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from functools import partial, wraps
|
| 7 |
+
from contextlib import contextmanager, nullcontext
|
| 8 |
+
from collections import namedtuple
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 14 |
+
from torch import nn, einsum
|
| 15 |
+
from torch.cuda.amp import autocast
|
| 16 |
+
from torch.special import expm1
|
| 17 |
+
import torchvision.transforms as T
|
| 18 |
+
|
| 19 |
+
import kornia.augmentation as K
|
| 20 |
+
|
| 21 |
+
from einops import rearrange, repeat, reduce
|
| 22 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 23 |
+
from einops_exts import rearrange_many, repeat_many, check_shape
|
| 24 |
+
from einops_exts.torch import EinopsToAndFrom
|
| 25 |
+
|
| 26 |
+
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
|
| 27 |
+
|
| 28 |
+
from imagen_pytorch.imagen_video.imagen_video import Unet3D, resize_video_to
|
| 29 |
+
|
| 30 |
+
# helper functions
|
| 31 |
+
|
| 32 |
+
def exists(val):
|
| 33 |
+
return val is not None
|
| 34 |
+
|
| 35 |
+
def identity(t, *args, **kwargs):
|
| 36 |
+
return t
|
| 37 |
+
|
| 38 |
+
def first(arr, d = None):
|
| 39 |
+
if len(arr) == 0:
|
| 40 |
+
return d
|
| 41 |
+
return arr[0]
|
| 42 |
+
|
| 43 |
+
def maybe(fn):
|
| 44 |
+
@wraps(fn)
|
| 45 |
+
def inner(x):
|
| 46 |
+
if not exists(x):
|
| 47 |
+
return x
|
| 48 |
+
return fn(x)
|
| 49 |
+
return inner
|
| 50 |
+
|
| 51 |
+
def once(fn):
|
| 52 |
+
called = False
|
| 53 |
+
@wraps(fn)
|
| 54 |
+
def inner(x):
|
| 55 |
+
nonlocal called
|
| 56 |
+
if called:
|
| 57 |
+
return
|
| 58 |
+
called = True
|
| 59 |
+
return fn(x)
|
| 60 |
+
return inner
|
| 61 |
+
|
| 62 |
+
print_once = once(print)
|
| 63 |
+
|
| 64 |
+
def default(val, d):
|
| 65 |
+
if exists(val):
|
| 66 |
+
return val
|
| 67 |
+
return d() if callable(d) else d
|
| 68 |
+
|
| 69 |
+
def cast_tuple(val, length = None):
|
| 70 |
+
if isinstance(val, list):
|
| 71 |
+
val = tuple(val)
|
| 72 |
+
|
| 73 |
+
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
| 74 |
+
|
| 75 |
+
if exists(length):
|
| 76 |
+
assert len(output) == length
|
| 77 |
+
|
| 78 |
+
return output
|
| 79 |
+
|
| 80 |
+
def is_float_dtype(dtype):
|
| 81 |
+
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
| 82 |
+
|
| 83 |
+
def cast_uint8_images_to_float(images):
|
| 84 |
+
if not images.dtype == torch.uint8:
|
| 85 |
+
return images
|
| 86 |
+
return images / 255
|
| 87 |
+
|
| 88 |
+
def module_device(module):
|
| 89 |
+
return next(module.parameters()).device
|
| 90 |
+
|
| 91 |
+
def zero_init_(m):
|
| 92 |
+
nn.init.zeros_(m.weight)
|
| 93 |
+
if exists(m.bias):
|
| 94 |
+
nn.init.zeros_(m.bias)
|
| 95 |
+
|
| 96 |
+
def eval_decorator(fn):
|
| 97 |
+
def inner(model, *args, **kwargs):
|
| 98 |
+
was_training = model.training
|
| 99 |
+
model.eval()
|
| 100 |
+
out = fn(model, *args, **kwargs)
|
| 101 |
+
model.train(was_training)
|
| 102 |
+
return out
|
| 103 |
+
return inner
|
| 104 |
+
|
| 105 |
+
def pad_tuple_to_length(t, length, fillvalue = None):
|
| 106 |
+
remain_length = length - len(t)
|
| 107 |
+
if remain_length <= 0:
|
| 108 |
+
return t
|
| 109 |
+
return (*t, *((fillvalue,) * remain_length))
|
| 110 |
+
|
| 111 |
+
# helper classes
|
| 112 |
+
|
| 113 |
+
class Identity(nn.Module):
|
| 114 |
+
def __init__(self, *args, **kwargs):
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
def forward(self, x, *args, **kwargs):
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
# tensor helpers
|
| 121 |
+
|
| 122 |
+
def log(t, eps: float = 1e-12):
|
| 123 |
+
return torch.log(t.clamp(min = eps))
|
| 124 |
+
|
| 125 |
+
def l2norm(t):
|
| 126 |
+
return F.normalize(t, dim = -1)
|
| 127 |
+
|
| 128 |
+
def right_pad_dims_to(x, t):
|
| 129 |
+
padding_dims = x.ndim - t.ndim
|
| 130 |
+
if padding_dims <= 0:
|
| 131 |
+
return t
|
| 132 |
+
return t.view(*t.shape, *((1,) * padding_dims))
|
| 133 |
+
|
| 134 |
+
def masked_mean(t, *, dim, mask = None):
|
| 135 |
+
if not exists(mask):
|
| 136 |
+
return t.mean(dim = dim)
|
| 137 |
+
|
| 138 |
+
denom = mask.sum(dim = dim, keepdim = True)
|
| 139 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 140 |
+
masked_t = t.masked_fill(~mask, 0.)
|
| 141 |
+
|
| 142 |
+
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
|
| 143 |
+
|
| 144 |
+
def resize_image_to(
|
| 145 |
+
image,
|
| 146 |
+
target_image_size,
|
| 147 |
+
clamp_range = None
|
| 148 |
+
):
|
| 149 |
+
orig_image_size = image.shape[-1]
|
| 150 |
+
|
| 151 |
+
if orig_image_size == target_image_size:
|
| 152 |
+
return image
|
| 153 |
+
|
| 154 |
+
out = F.interpolate(image, target_image_size, mode = 'nearest')
|
| 155 |
+
|
| 156 |
+
if exists(clamp_range):
|
| 157 |
+
out = out.clamp(*clamp_range)
|
| 158 |
+
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
# image normalization functions
|
| 162 |
+
# ddpms expect images to be in the range of -1 to 1
|
| 163 |
+
|
| 164 |
+
def normalize_neg_one_to_one(img):
|
| 165 |
+
return img * 2 - 1
|
| 166 |
+
|
| 167 |
+
def unnormalize_zero_to_one(normed_img):
|
| 168 |
+
return (normed_img + 1) * 0.5
|
| 169 |
+
|
| 170 |
+
# classifier free guidance functions
|
| 171 |
+
|
| 172 |
+
def prob_mask_like(shape, prob, device):
|
| 173 |
+
if prob == 1:
|
| 174 |
+
return torch.ones(shape, device = device, dtype = torch.bool)
|
| 175 |
+
elif prob == 0:
|
| 176 |
+
return torch.zeros(shape, device = device, dtype = torch.bool)
|
| 177 |
+
else:
|
| 178 |
+
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
| 179 |
+
|
| 180 |
+
# gaussian diffusion with continuous time helper functions and classes
|
| 181 |
+
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
|
| 182 |
+
|
| 183 |
+
@torch.jit.script
|
| 184 |
+
def beta_linear_log_snr(t):
|
| 185 |
+
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
|
| 186 |
+
|
| 187 |
+
@torch.jit.script
|
| 188 |
+
def alpha_cosine_log_snr(t, s: float = 0.008):
|
| 189 |
+
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
|
| 190 |
+
|
| 191 |
+
def log_snr_to_alpha_sigma(log_snr):
|
| 192 |
+
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
|
| 193 |
+
|
| 194 |
+
class GaussianDiffusionContinuousTimes(nn.Module):
|
| 195 |
+
def __init__(self, *, noise_schedule, timesteps = 1000):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
if noise_schedule == "linear":
|
| 199 |
+
self.log_snr = beta_linear_log_snr
|
| 200 |
+
elif noise_schedule == "cosine":
|
| 201 |
+
self.log_snr = alpha_cosine_log_snr
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f'invalid noise schedule {noise_schedule}')
|
| 204 |
+
|
| 205 |
+
self.num_timesteps = timesteps
|
| 206 |
+
|
| 207 |
+
def get_times(self, batch_size, noise_level, *, device):
|
| 208 |
+
return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
|
| 209 |
+
|
| 210 |
+
def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
|
| 211 |
+
return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)
|
| 212 |
+
|
| 213 |
+
def get_condition(self, times):
|
| 214 |
+
return maybe(self.log_snr)(times)
|
| 215 |
+
|
| 216 |
+
def get_sampling_timesteps(self, batch, *, device):
|
| 217 |
+
times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
|
| 218 |
+
times = repeat(times, 't -> b t', b = batch)
|
| 219 |
+
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
|
| 220 |
+
times = times.unbind(dim = -1)
|
| 221 |
+
return times
|
| 222 |
+
|
| 223 |
+
def q_posterior(self, x_start, x_t, t, *, t_next = None):
|
| 224 |
+
t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
|
| 225 |
+
|
| 226 |
+
""" https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
|
| 227 |
+
log_snr = self.log_snr(t)
|
| 228 |
+
log_snr_next = self.log_snr(t_next)
|
| 229 |
+
log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
|
| 230 |
+
|
| 231 |
+
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
|
| 232 |
+
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
|
| 233 |
+
|
| 234 |
+
# c - as defined near eq 33
|
| 235 |
+
c = -expm1(log_snr - log_snr_next)
|
| 236 |
+
posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
|
| 237 |
+
|
| 238 |
+
# following (eq. 33)
|
| 239 |
+
posterior_variance = (sigma_next ** 2) * c
|
| 240 |
+
posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
|
| 241 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 242 |
+
|
| 243 |
+
def q_sample(self, x_start, t, noise = None):
|
| 244 |
+
dtype = x_start.dtype
|
| 245 |
+
|
| 246 |
+
if isinstance(t, float):
|
| 247 |
+
batch = x_start.shape[0]
|
| 248 |
+
t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
|
| 249 |
+
|
| 250 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 251 |
+
log_snr = self.log_snr(t).type(dtype)
|
| 252 |
+
log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
|
| 253 |
+
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
|
| 254 |
+
|
| 255 |
+
return alpha * x_start + sigma * noise, log_snr
|
| 256 |
+
|
| 257 |
+
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
| 258 |
+
shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
|
| 259 |
+
batch = shape[0]
|
| 260 |
+
|
| 261 |
+
if isinstance(from_t, float):
|
| 262 |
+
from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
|
| 263 |
+
|
| 264 |
+
if isinstance(to_t, float):
|
| 265 |
+
to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
|
| 266 |
+
|
| 267 |
+
noise = default(noise, lambda: torch.randn_like(x_from))
|
| 268 |
+
|
| 269 |
+
log_snr = self.log_snr(from_t)
|
| 270 |
+
log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
|
| 271 |
+
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
|
| 272 |
+
|
| 273 |
+
log_snr_to = self.log_snr(to_t)
|
| 274 |
+
log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
|
| 275 |
+
alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
|
| 276 |
+
|
| 277 |
+
return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
|
| 278 |
+
|
| 279 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 280 |
+
log_snr = self.log_snr(t)
|
| 281 |
+
log_snr = right_pad_dims_to(x_t, log_snr)
|
| 282 |
+
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
|
| 283 |
+
return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
|
| 284 |
+
|
| 285 |
+
# norms and residuals
|
| 286 |
+
|
| 287 |
+
class LayerNorm(nn.Module):
|
| 288 |
+
def __init__(self, feats, stable = False, dim = -1):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.stable = stable
|
| 291 |
+
self.dim = dim
|
| 292 |
+
|
| 293 |
+
self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))
|
| 294 |
+
|
| 295 |
+
def forward(self, x):
|
| 296 |
+
dtype, dim = x.dtype, self.dim
|
| 297 |
+
|
| 298 |
+
if self.stable:
|
| 299 |
+
x = x / x.amax(dim = dim, keepdim = True).detach()
|
| 300 |
+
|
| 301 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
| 302 |
+
var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
|
| 303 |
+
mean = torch.mean(x, dim = dim, keepdim = True)
|
| 304 |
+
|
| 305 |
+
return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)
|
| 306 |
+
|
| 307 |
+
ChanLayerNorm = partial(LayerNorm, dim = -3)
|
| 308 |
+
|
| 309 |
+
class Always():
|
| 310 |
+
def __init__(self, val):
|
| 311 |
+
self.val = val
|
| 312 |
+
|
| 313 |
+
def __call__(self, *args, **kwargs):
|
| 314 |
+
return self.val
|
| 315 |
+
|
| 316 |
+
class Residual(nn.Module):
|
| 317 |
+
def __init__(self, fn):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.fn = fn
|
| 320 |
+
|
| 321 |
+
def forward(self, x, **kwargs):
|
| 322 |
+
return self.fn(x, **kwargs) + x
|
| 323 |
+
|
| 324 |
+
class Parallel(nn.Module):
|
| 325 |
+
def __init__(self, *fns):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.fns = nn.ModuleList(fns)
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
outputs = [fn(x) for fn in self.fns]
|
| 331 |
+
return sum(outputs)
|
| 332 |
+
|
| 333 |
+
# attention pooling
|
| 334 |
+
|
| 335 |
+
class PerceiverAttention(nn.Module):
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
*,
|
| 339 |
+
dim,
|
| 340 |
+
dim_head = 64,
|
| 341 |
+
heads = 8,
|
| 342 |
+
cosine_sim_attn = False
|
| 343 |
+
):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
|
| 346 |
+
self.cosine_sim_attn = cosine_sim_attn
|
| 347 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
| 348 |
+
|
| 349 |
+
self.heads = heads
|
| 350 |
+
inner_dim = dim_head * heads
|
| 351 |
+
|
| 352 |
+
self.norm = nn.LayerNorm(dim)
|
| 353 |
+
self.norm_latents = nn.LayerNorm(dim)
|
| 354 |
+
|
| 355 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
| 356 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
| 357 |
+
|
| 358 |
+
self.to_out = nn.Sequential(
|
| 359 |
+
nn.Linear(inner_dim, dim, bias = False),
|
| 360 |
+
nn.LayerNorm(dim)
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def forward(self, x, latents, mask = None):
|
| 364 |
+
x = self.norm(x)
|
| 365 |
+
latents = self.norm_latents(latents)
|
| 366 |
+
|
| 367 |
+
b, h = x.shape[0], self.heads
|
| 368 |
+
|
| 369 |
+
q = self.to_q(latents)
|
| 370 |
+
|
| 371 |
+
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
|
| 372 |
+
kv_input = torch.cat((x, latents), dim = -2)
|
| 373 |
+
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
|
| 374 |
+
|
| 375 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
| 376 |
+
|
| 377 |
+
q = q * self.scale
|
| 378 |
+
|
| 379 |
+
# cosine sim attention
|
| 380 |
+
|
| 381 |
+
if self.cosine_sim_attn:
|
| 382 |
+
q, k = map(l2norm, (q, k))
|
| 383 |
+
|
| 384 |
+
# similarities and masking
|
| 385 |
+
|
| 386 |
+
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
|
| 387 |
+
|
| 388 |
+
if exists(mask):
|
| 389 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 390 |
+
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
|
| 391 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
| 392 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 393 |
+
|
| 394 |
+
# attention
|
| 395 |
+
|
| 396 |
+
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
| 397 |
+
attn = attn.to(sim.dtype)
|
| 398 |
+
|
| 399 |
+
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
| 400 |
+
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
|
| 401 |
+
return self.to_out(out)
|
| 402 |
+
|
| 403 |
+
class PerceiverResampler(nn.Module):
|
| 404 |
+
def __init__(
|
| 405 |
+
self,
|
| 406 |
+
*,
|
| 407 |
+
dim,
|
| 408 |
+
depth,
|
| 409 |
+
dim_head = 64,
|
| 410 |
+
heads = 8,
|
| 411 |
+
num_latents = 64,
|
| 412 |
+
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
|
| 413 |
+
max_seq_len = 512,
|
| 414 |
+
ff_mult = 4,
|
| 415 |
+
cosine_sim_attn = False
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.pos_emb = nn.Embedding(max_seq_len, dim)
|
| 419 |
+
|
| 420 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 421 |
+
|
| 422 |
+
self.to_latents_from_mean_pooled_seq = None
|
| 423 |
+
|
| 424 |
+
if num_latents_mean_pooled > 0:
|
| 425 |
+
self.to_latents_from_mean_pooled_seq = nn.Sequential(
|
| 426 |
+
LayerNorm(dim),
|
| 427 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
| 428 |
+
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
self.layers = nn.ModuleList([])
|
| 432 |
+
for _ in range(depth):
|
| 433 |
+
self.layers.append(nn.ModuleList([
|
| 434 |
+
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
|
| 435 |
+
FeedForward(dim = dim, mult = ff_mult)
|
| 436 |
+
]))
|
| 437 |
+
|
| 438 |
+
def forward(self, x, mask = None):
|
| 439 |
+
n, device = x.shape[1], x.device
|
| 440 |
+
pos_emb = self.pos_emb(torch.arange(n, device = device))
|
| 441 |
+
|
| 442 |
+
x_with_pos = x + pos_emb
|
| 443 |
+
|
| 444 |
+
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
|
| 445 |
+
|
| 446 |
+
if exists(self.to_latents_from_mean_pooled_seq):
|
| 447 |
+
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
|
| 448 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
| 449 |
+
latents = torch.cat((meanpooled_latents, latents), dim = -2)
|
| 450 |
+
|
| 451 |
+
for attn, ff in self.layers:
|
| 452 |
+
latents = attn(x_with_pos, latents, mask = mask) + latents
|
| 453 |
+
latents = ff(latents) + latents
|
| 454 |
+
|
| 455 |
+
return latents
|
| 456 |
+
|
| 457 |
+
# attention
|
| 458 |
+
|
| 459 |
+
class Attention(nn.Module):
|
| 460 |
+
def __init__(
|
| 461 |
+
self,
|
| 462 |
+
dim,
|
| 463 |
+
*,
|
| 464 |
+
dim_head = 64,
|
| 465 |
+
heads = 8,
|
| 466 |
+
context_dim = None,
|
| 467 |
+
cosine_sim_attn = False
|
| 468 |
+
):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
|
| 471 |
+
self.cosine_sim_attn = cosine_sim_attn
|
| 472 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
| 473 |
+
|
| 474 |
+
self.heads = heads
|
| 475 |
+
inner_dim = dim_head * heads
|
| 476 |
+
|
| 477 |
+
self.norm = LayerNorm(dim)
|
| 478 |
+
|
| 479 |
+
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
| 480 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
| 481 |
+
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
| 482 |
+
|
| 483 |
+
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
|
| 484 |
+
|
| 485 |
+
self.to_out = nn.Sequential(
|
| 486 |
+
nn.Linear(inner_dim, dim, bias = False),
|
| 487 |
+
LayerNorm(dim)
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
def forward(self, x, context = None, mask = None, attn_bias = None):
|
| 491 |
+
b, n, device = *x.shape[:2], x.device
|
| 492 |
+
|
| 493 |
+
x = self.norm(x)
|
| 494 |
+
|
| 495 |
+
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
| 496 |
+
|
| 497 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
| 498 |
+
q = q * self.scale
|
| 499 |
+
|
| 500 |
+
# add null key / value for classifier free guidance in prior net
|
| 501 |
+
|
| 502 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
|
| 503 |
+
k = torch.cat((nk, k), dim = -2)
|
| 504 |
+
v = torch.cat((nv, v), dim = -2)
|
| 505 |
+
|
| 506 |
+
# add text conditioning, if present
|
| 507 |
+
|
| 508 |
+
if exists(context):
|
| 509 |
+
assert exists(self.to_context)
|
| 510 |
+
ck, cv = self.to_context(context).chunk(2, dim = -1)
|
| 511 |
+
k = torch.cat((ck, k), dim = -2)
|
| 512 |
+
v = torch.cat((cv, v), dim = -2)
|
| 513 |
+
|
| 514 |
+
# cosine sim attention
|
| 515 |
+
|
| 516 |
+
if self.cosine_sim_attn:
|
| 517 |
+
q, k = map(l2norm, (q, k))
|
| 518 |
+
|
| 519 |
+
# calculate query / key similarities
|
| 520 |
+
|
| 521 |
+
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
|
| 522 |
+
|
| 523 |
+
# relative positional encoding (T5 style)
|
| 524 |
+
|
| 525 |
+
if exists(attn_bias):
|
| 526 |
+
sim = sim + attn_bias
|
| 527 |
+
|
| 528 |
+
# masking
|
| 529 |
+
|
| 530 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 531 |
+
|
| 532 |
+
if exists(mask):
|
| 533 |
+
mask = F.pad(mask, (1, 0), value = True)
|
| 534 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
| 535 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 536 |
+
|
| 537 |
+
# attention
|
| 538 |
+
|
| 539 |
+
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
| 540 |
+
attn = attn.to(sim.dtype)
|
| 541 |
+
|
| 542 |
+
# aggregate values
|
| 543 |
+
|
| 544 |
+
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
| 545 |
+
|
| 546 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 547 |
+
return self.to_out(out)
|
| 548 |
+
|
| 549 |
+
# decoder
|
| 550 |
+
|
| 551 |
+
def Upsample(dim, dim_out = None):
|
| 552 |
+
dim_out = default(dim_out, dim)
|
| 553 |
+
|
| 554 |
+
return nn.Sequential(
|
| 555 |
+
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
| 556 |
+
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
class PixelShuffleUpsample(nn.Module):
|
| 560 |
+
"""
|
| 561 |
+
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
| 562 |
+
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
| 563 |
+
"""
|
| 564 |
+
def __init__(self, dim, dim_out = None):
|
| 565 |
+
super().__init__()
|
| 566 |
+
dim_out = default(dim_out, dim)
|
| 567 |
+
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
| 568 |
+
|
| 569 |
+
self.net = nn.Sequential(
|
| 570 |
+
conv,
|
| 571 |
+
nn.SiLU(),
|
| 572 |
+
nn.PixelShuffle(2)
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
self.init_conv_(conv)
|
| 576 |
+
|
| 577 |
+
def init_conv_(self, conv):
|
| 578 |
+
o, i, h, w = conv.weight.shape
|
| 579 |
+
conv_weight = torch.empty(o // 4, i, h, w)
|
| 580 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 581 |
+
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
| 582 |
+
|
| 583 |
+
conv.weight.data.copy_(conv_weight)
|
| 584 |
+
nn.init.zeros_(conv.bias.data)
|
| 585 |
+
|
| 586 |
+
def forward(self, x):
|
| 587 |
+
return self.net(x)
|
| 588 |
+
|
| 589 |
+
def Downsample(dim, dim_out = None):
|
| 590 |
+
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
|
| 591 |
+
# named SP-conv in the paper, but basically a pixel unshuffle
|
| 592 |
+
dim_out = default(dim_out, dim)
|
| 593 |
+
return nn.Sequential(
|
| 594 |
+
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
|
| 595 |
+
nn.Conv2d(dim * 4, dim_out, 1)
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
class SinusoidalPosEmb(nn.Module):
|
| 599 |
+
def __init__(self, dim):
|
| 600 |
+
super().__init__()
|
| 601 |
+
self.dim = dim
|
| 602 |
+
|
| 603 |
+
def forward(self, x):
|
| 604 |
+
half_dim = self.dim // 2
|
| 605 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 606 |
+
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
| 607 |
+
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
| 608 |
+
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
| 609 |
+
|
| 610 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
| 611 |
+
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
|
| 612 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
| 613 |
+
|
| 614 |
+
def __init__(self, dim):
|
| 615 |
+
super().__init__()
|
| 616 |
+
assert (dim % 2) == 0
|
| 617 |
+
half_dim = dim // 2
|
| 618 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
| 619 |
+
|
| 620 |
+
def forward(self, x):
|
| 621 |
+
x = rearrange(x, 'b -> b 1')
|
| 622 |
+
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
|
| 623 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
|
| 624 |
+
fouriered = torch.cat((x, fouriered), dim = -1)
|
| 625 |
+
return fouriered
|
| 626 |
+
|
| 627 |
+
class Block(nn.Module):
|
| 628 |
+
def __init__(
|
| 629 |
+
self,
|
| 630 |
+
dim,
|
| 631 |
+
dim_out,
|
| 632 |
+
groups = 8,
|
| 633 |
+
norm = True
|
| 634 |
+
):
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
|
| 637 |
+
self.activation = nn.SiLU()
|
| 638 |
+
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
| 639 |
+
|
| 640 |
+
def forward(self, x, scale_shift = None):
|
| 641 |
+
x = self.groupnorm(x)
|
| 642 |
+
|
| 643 |
+
if exists(scale_shift):
|
| 644 |
+
scale, shift = scale_shift
|
| 645 |
+
x = x * (scale + 1) + shift
|
| 646 |
+
|
| 647 |
+
x = self.activation(x)
|
| 648 |
+
return self.project(x)
|
| 649 |
+
|
| 650 |
+
class ResnetBlock(nn.Module):
|
| 651 |
+
def __init__(
|
| 652 |
+
self,
|
| 653 |
+
dim,
|
| 654 |
+
dim_out,
|
| 655 |
+
*,
|
| 656 |
+
cond_dim = None,
|
| 657 |
+
time_cond_dim = None,
|
| 658 |
+
groups = 8,
|
| 659 |
+
linear_attn = False,
|
| 660 |
+
use_gca = False,
|
| 661 |
+
squeeze_excite = False,
|
| 662 |
+
**attn_kwargs
|
| 663 |
+
):
|
| 664 |
+
super().__init__()
|
| 665 |
+
|
| 666 |
+
self.time_mlp = None
|
| 667 |
+
|
| 668 |
+
if exists(time_cond_dim):
|
| 669 |
+
self.time_mlp = nn.Sequential(
|
| 670 |
+
nn.SiLU(),
|
| 671 |
+
nn.Linear(time_cond_dim, dim_out * 2)
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
self.cross_attn = None
|
| 675 |
+
|
| 676 |
+
if exists(cond_dim):
|
| 677 |
+
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
|
| 678 |
+
|
| 679 |
+
self.cross_attn = EinopsToAndFrom(
|
| 680 |
+
'b c h w',
|
| 681 |
+
'b (h w) c',
|
| 682 |
+
attn_klass(
|
| 683 |
+
dim = dim_out,
|
| 684 |
+
context_dim = cond_dim,
|
| 685 |
+
**attn_kwargs
|
| 686 |
+
)
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
self.block1 = Block(dim, dim_out, groups = groups)
|
| 690 |
+
self.block2 = Block(dim_out, dim_out, groups = groups)
|
| 691 |
+
|
| 692 |
+
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
|
| 693 |
+
|
| 694 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def forward(self, x, time_emb = None, cond = None):
|
| 698 |
+
|
| 699 |
+
scale_shift = None
|
| 700 |
+
if exists(self.time_mlp) and exists(time_emb):
|
| 701 |
+
time_emb = self.time_mlp(time_emb)
|
| 702 |
+
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
| 703 |
+
scale_shift = time_emb.chunk(2, dim = 1)
|
| 704 |
+
|
| 705 |
+
h = self.block1(x)
|
| 706 |
+
|
| 707 |
+
if exists(self.cross_attn):
|
| 708 |
+
assert exists(cond)
|
| 709 |
+
h = self.cross_attn(h, context = cond) + h
|
| 710 |
+
|
| 711 |
+
h = self.block2(h, scale_shift = scale_shift)
|
| 712 |
+
|
| 713 |
+
h = h * self.gca(h)
|
| 714 |
+
|
| 715 |
+
return h + self.res_conv(x)
|
| 716 |
+
|
| 717 |
+
class CrossAttention(nn.Module):
|
| 718 |
+
def __init__(
|
| 719 |
+
self,
|
| 720 |
+
dim,
|
| 721 |
+
*,
|
| 722 |
+
context_dim = None,
|
| 723 |
+
dim_head = 64,
|
| 724 |
+
heads = 8,
|
| 725 |
+
norm_context = False,
|
| 726 |
+
cosine_sim_attn = False
|
| 727 |
+
):
|
| 728 |
+
super().__init__()
|
| 729 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
|
| 730 |
+
self.cosine_sim_attn = cosine_sim_attn
|
| 731 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
| 732 |
+
|
| 733 |
+
self.heads = heads
|
| 734 |
+
inner_dim = dim_head * heads
|
| 735 |
+
|
| 736 |
+
context_dim = default(context_dim, dim)
|
| 737 |
+
|
| 738 |
+
self.norm = LayerNorm(dim)
|
| 739 |
+
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
|
| 740 |
+
|
| 741 |
+
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
| 742 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
| 743 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
| 744 |
+
|
| 745 |
+
self.to_out = nn.Sequential(
|
| 746 |
+
nn.Linear(inner_dim, dim, bias = False),
|
| 747 |
+
LayerNorm(dim)
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
def forward(self, x, context, mask = None):
|
| 751 |
+
b, n, device = *x.shape[:2], x.device
|
| 752 |
+
|
| 753 |
+
x = self.norm(x)
|
| 754 |
+
context = self.norm_context(context)
|
| 755 |
+
|
| 756 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
| 757 |
+
|
| 758 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
| 759 |
+
|
| 760 |
+
# add null key / value for classifier free guidance in prior net
|
| 761 |
+
|
| 762 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
| 763 |
+
|
| 764 |
+
k = torch.cat((nk, k), dim = -2)
|
| 765 |
+
v = torch.cat((nv, v), dim = -2)
|
| 766 |
+
|
| 767 |
+
q = q * self.scale
|
| 768 |
+
|
| 769 |
+
# cosine sim attention
|
| 770 |
+
|
| 771 |
+
if self.cosine_sim_attn:
|
| 772 |
+
q, k = map(l2norm, (q, k))
|
| 773 |
+
|
| 774 |
+
# similarities
|
| 775 |
+
|
| 776 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
|
| 777 |
+
|
| 778 |
+
# masking
|
| 779 |
+
|
| 780 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 781 |
+
|
| 782 |
+
if exists(mask):
|
| 783 |
+
mask = F.pad(mask, (1, 0), value = True)
|
| 784 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
| 785 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 786 |
+
|
| 787 |
+
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
| 788 |
+
attn = attn.to(sim.dtype)
|
| 789 |
+
|
| 790 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
| 791 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 792 |
+
return self.to_out(out)
|
| 793 |
+
|
| 794 |
+
class LinearCrossAttention(CrossAttention):
|
| 795 |
+
def forward(self, x, context, mask = None):
|
| 796 |
+
b, n, device = *x.shape[:2], x.device
|
| 797 |
+
|
| 798 |
+
x = self.norm(x)
|
| 799 |
+
context = self.norm_context(context)
|
| 800 |
+
|
| 801 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
| 802 |
+
|
| 803 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
|
| 804 |
+
|
| 805 |
+
# add null key / value for classifier free guidance in prior net
|
| 806 |
+
|
| 807 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b)
|
| 808 |
+
|
| 809 |
+
k = torch.cat((nk, k), dim = -2)
|
| 810 |
+
v = torch.cat((nv, v), dim = -2)
|
| 811 |
+
|
| 812 |
+
# masking
|
| 813 |
+
|
| 814 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 815 |
+
|
| 816 |
+
if exists(mask):
|
| 817 |
+
mask = F.pad(mask, (1, 0), value = True)
|
| 818 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 819 |
+
k = k.masked_fill(~mask, max_neg_value)
|
| 820 |
+
v = v.masked_fill(~mask, 0.)
|
| 821 |
+
|
| 822 |
+
# linear attention
|
| 823 |
+
|
| 824 |
+
q = q.softmax(dim = -1)
|
| 825 |
+
k = k.softmax(dim = -2)
|
| 826 |
+
|
| 827 |
+
q = q * self.scale
|
| 828 |
+
|
| 829 |
+
context = einsum('b n d, b n e -> b d e', k, v)
|
| 830 |
+
out = einsum('b n d, b d e -> b n e', q, context)
|
| 831 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
|
| 832 |
+
return self.to_out(out)
|
| 833 |
+
|
| 834 |
+
class LinearAttention(nn.Module):
|
| 835 |
+
def __init__(
|
| 836 |
+
self,
|
| 837 |
+
dim,
|
| 838 |
+
dim_head = 32,
|
| 839 |
+
heads = 8,
|
| 840 |
+
dropout = 0.05,
|
| 841 |
+
context_dim = None,
|
| 842 |
+
**kwargs
|
| 843 |
+
):
|
| 844 |
+
super().__init__()
|
| 845 |
+
self.scale = dim_head ** -0.5
|
| 846 |
+
self.heads = heads
|
| 847 |
+
inner_dim = dim_head * heads
|
| 848 |
+
self.norm = ChanLayerNorm(dim)
|
| 849 |
+
|
| 850 |
+
self.nonlin = nn.SiLU()
|
| 851 |
+
|
| 852 |
+
self.to_q = nn.Sequential(
|
| 853 |
+
nn.Dropout(dropout),
|
| 854 |
+
nn.Conv2d(dim, inner_dim, 1, bias = False),
|
| 855 |
+
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
self.to_k = nn.Sequential(
|
| 859 |
+
nn.Dropout(dropout),
|
| 860 |
+
nn.Conv2d(dim, inner_dim, 1, bias = False),
|
| 861 |
+
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
self.to_v = nn.Sequential(
|
| 865 |
+
nn.Dropout(dropout),
|
| 866 |
+
nn.Conv2d(dim, inner_dim, 1, bias = False),
|
| 867 |
+
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
|
| 871 |
+
|
| 872 |
+
self.to_out = nn.Sequential(
|
| 873 |
+
nn.Conv2d(inner_dim, dim, 1, bias = False),
|
| 874 |
+
ChanLayerNorm(dim)
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
def forward(self, fmap, context = None):
|
| 878 |
+
h, x, y = self.heads, *fmap.shape[-2:]
|
| 879 |
+
|
| 880 |
+
fmap = self.norm(fmap)
|
| 881 |
+
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
|
| 882 |
+
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
| 883 |
+
|
| 884 |
+
if exists(context):
|
| 885 |
+
assert exists(self.to_context)
|
| 886 |
+
ck, cv = self.to_context(context).chunk(2, dim = -1)
|
| 887 |
+
ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
|
| 888 |
+
k = torch.cat((k, ck), dim = -2)
|
| 889 |
+
v = torch.cat((v, cv), dim = -2)
|
| 890 |
+
|
| 891 |
+
q = q.softmax(dim = -1)
|
| 892 |
+
k = k.softmax(dim = -2)
|
| 893 |
+
|
| 894 |
+
q = q * self.scale
|
| 895 |
+
|
| 896 |
+
context = einsum('b n d, b n e -> b d e', k, v)
|
| 897 |
+
out = einsum('b n d, b d e -> b n e', q, context)
|
| 898 |
+
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
|
| 899 |
+
|
| 900 |
+
out = self.nonlin(out)
|
| 901 |
+
return self.to_out(out)
|
| 902 |
+
|
| 903 |
+
class GlobalContext(nn.Module):
|
| 904 |
+
""" basically a superior form of squeeze-excitation that is attention-esque """
|
| 905 |
+
|
| 906 |
+
def __init__(
|
| 907 |
+
self,
|
| 908 |
+
*,
|
| 909 |
+
dim_in,
|
| 910 |
+
dim_out
|
| 911 |
+
):
|
| 912 |
+
super().__init__()
|
| 913 |
+
self.to_k = nn.Conv2d(dim_in, 1, 1)
|
| 914 |
+
hidden_dim = max(3, dim_out // 2)
|
| 915 |
+
|
| 916 |
+
self.net = nn.Sequential(
|
| 917 |
+
nn.Conv2d(dim_in, hidden_dim, 1),
|
| 918 |
+
nn.SiLU(),
|
| 919 |
+
nn.Conv2d(hidden_dim, dim_out, 1),
|
| 920 |
+
nn.Sigmoid()
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
def forward(self, x):
|
| 924 |
+
context = self.to_k(x)
|
| 925 |
+
x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
|
| 926 |
+
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
|
| 927 |
+
out = rearrange(out, '... -> ... 1')
|
| 928 |
+
return self.net(out)
|
| 929 |
+
|
| 930 |
+
def FeedForward(dim, mult = 2):
|
| 931 |
+
hidden_dim = int(dim * mult)
|
| 932 |
+
return nn.Sequential(
|
| 933 |
+
LayerNorm(dim),
|
| 934 |
+
nn.Linear(dim, hidden_dim, bias = False),
|
| 935 |
+
nn.GELU(),
|
| 936 |
+
LayerNorm(hidden_dim),
|
| 937 |
+
nn.Linear(hidden_dim, dim, bias = False)
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
|
| 941 |
+
hidden_dim = int(dim * mult)
|
| 942 |
+
return nn.Sequential(
|
| 943 |
+
ChanLayerNorm(dim),
|
| 944 |
+
nn.Conv2d(dim, hidden_dim, 1, bias = False),
|
| 945 |
+
nn.GELU(),
|
| 946 |
+
ChanLayerNorm(hidden_dim),
|
| 947 |
+
nn.Conv2d(hidden_dim, dim, 1, bias = False)
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
class TransformerBlock(nn.Module):
|
| 951 |
+
def __init__(
|
| 952 |
+
self,
|
| 953 |
+
dim,
|
| 954 |
+
*,
|
| 955 |
+
depth = 1,
|
| 956 |
+
heads = 8,
|
| 957 |
+
dim_head = 32,
|
| 958 |
+
ff_mult = 2,
|
| 959 |
+
context_dim = None,
|
| 960 |
+
cosine_sim_attn = False
|
| 961 |
+
):
|
| 962 |
+
super().__init__()
|
| 963 |
+
self.layers = nn.ModuleList([])
|
| 964 |
+
|
| 965 |
+
for _ in range(depth):
|
| 966 |
+
self.layers.append(nn.ModuleList([
|
| 967 |
+
EinopsToAndFrom('b c h w', 'b (h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
|
| 968 |
+
ChanFeedForward(dim = dim, mult = ff_mult)
|
| 969 |
+
]))
|
| 970 |
+
|
| 971 |
+
def forward(self, x, context = None):
|
| 972 |
+
for attn, ff in self.layers:
|
| 973 |
+
x = attn(x, context = context) + x
|
| 974 |
+
x = ff(x) + x
|
| 975 |
+
return x
|
| 976 |
+
|
| 977 |
+
class LinearAttentionTransformerBlock(nn.Module):
|
| 978 |
+
def __init__(
|
| 979 |
+
self,
|
| 980 |
+
dim,
|
| 981 |
+
*,
|
| 982 |
+
depth = 1,
|
| 983 |
+
heads = 8,
|
| 984 |
+
dim_head = 32,
|
| 985 |
+
ff_mult = 2,
|
| 986 |
+
context_dim = None,
|
| 987 |
+
**kwargs
|
| 988 |
+
):
|
| 989 |
+
super().__init__()
|
| 990 |
+
self.layers = nn.ModuleList([])
|
| 991 |
+
|
| 992 |
+
for _ in range(depth):
|
| 993 |
+
self.layers.append(nn.ModuleList([
|
| 994 |
+
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
|
| 995 |
+
ChanFeedForward(dim = dim, mult = ff_mult)
|
| 996 |
+
]))
|
| 997 |
+
|
| 998 |
+
def forward(self, x, context = None):
|
| 999 |
+
for attn, ff in self.layers:
|
| 1000 |
+
x = attn(x, context = context) + x
|
| 1001 |
+
x = ff(x) + x
|
| 1002 |
+
return x
|
| 1003 |
+
|
| 1004 |
+
class CrossEmbedLayer(nn.Module):
|
| 1005 |
+
def __init__(
|
| 1006 |
+
self,
|
| 1007 |
+
dim_in,
|
| 1008 |
+
kernel_sizes,
|
| 1009 |
+
dim_out = None,
|
| 1010 |
+
stride = 2
|
| 1011 |
+
):
|
| 1012 |
+
super().__init__()
|
| 1013 |
+
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
| 1014 |
+
dim_out = default(dim_out, dim_in)
|
| 1015 |
+
|
| 1016 |
+
kernel_sizes = sorted(kernel_sizes)
|
| 1017 |
+
num_scales = len(kernel_sizes)
|
| 1018 |
+
|
| 1019 |
+
# calculate the dimension at each scale
|
| 1020 |
+
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
| 1021 |
+
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
| 1022 |
+
|
| 1023 |
+
self.convs = nn.ModuleList([])
|
| 1024 |
+
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
| 1025 |
+
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
| 1026 |
+
|
| 1027 |
+
def forward(self, x):
|
| 1028 |
+
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
| 1029 |
+
return torch.cat(fmaps, dim = 1)
|
| 1030 |
+
|
| 1031 |
+
class UpsampleCombiner(nn.Module):
|
| 1032 |
+
def __init__(
|
| 1033 |
+
self,
|
| 1034 |
+
dim,
|
| 1035 |
+
*,
|
| 1036 |
+
enabled = False,
|
| 1037 |
+
dim_ins = tuple(),
|
| 1038 |
+
dim_outs = tuple()
|
| 1039 |
+
):
|
| 1040 |
+
super().__init__()
|
| 1041 |
+
dim_outs = cast_tuple(dim_outs, len(dim_ins))
|
| 1042 |
+
assert len(dim_ins) == len(dim_outs)
|
| 1043 |
+
|
| 1044 |
+
self.enabled = enabled
|
| 1045 |
+
|
| 1046 |
+
if not self.enabled:
|
| 1047 |
+
self.dim_out = dim
|
| 1048 |
+
return
|
| 1049 |
+
|
| 1050 |
+
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
|
| 1051 |
+
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
|
| 1052 |
+
|
| 1053 |
+
def forward(self, x, fmaps = None):
|
| 1054 |
+
target_size = x.shape[-1]
|
| 1055 |
+
|
| 1056 |
+
fmaps = default(fmaps, tuple())
|
| 1057 |
+
|
| 1058 |
+
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
|
| 1059 |
+
return x
|
| 1060 |
+
|
| 1061 |
+
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
|
| 1062 |
+
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
|
| 1063 |
+
return torch.cat((x, *outs), dim = 1)
|
| 1064 |
+
|
| 1065 |
+
class Unet(nn.Module):
|
| 1066 |
+
def __init__(
|
| 1067 |
+
self,
|
| 1068 |
+
*,
|
| 1069 |
+
dim,
|
| 1070 |
+
image_embed_dim = 1024,
|
| 1071 |
+
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
|
| 1072 |
+
num_resnet_blocks = 1,
|
| 1073 |
+
cond_dim = None,
|
| 1074 |
+
num_image_tokens = 4,
|
| 1075 |
+
num_time_tokens = 2,
|
| 1076 |
+
learned_sinu_pos_emb_dim = 16,
|
| 1077 |
+
out_dim = None,
|
| 1078 |
+
dim_mults=(1, 2, 4, 8),
|
| 1079 |
+
cond_images_channels = 0,
|
| 1080 |
+
channels = 3,
|
| 1081 |
+
channels_out = None,
|
| 1082 |
+
attn_dim_head = 64,
|
| 1083 |
+
attn_heads = 8,
|
| 1084 |
+
ff_mult = 2.,
|
| 1085 |
+
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
| 1086 |
+
layer_attns = True,
|
| 1087 |
+
layer_attns_depth = 1,
|
| 1088 |
+
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
|
| 1089 |
+
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
| 1090 |
+
layer_cross_attns = True,
|
| 1091 |
+
use_linear_attn = False,
|
| 1092 |
+
use_linear_cross_attn = False,
|
| 1093 |
+
cond_on_text = True,
|
| 1094 |
+
max_text_len = 256,
|
| 1095 |
+
init_dim = None,
|
| 1096 |
+
resnet_groups = 8,
|
| 1097 |
+
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
|
| 1098 |
+
init_cross_embed = True,
|
| 1099 |
+
init_cross_embed_kernel_sizes = (3, 7, 15),
|
| 1100 |
+
cross_embed_downsample = False,
|
| 1101 |
+
cross_embed_downsample_kernel_sizes = (2, 4),
|
| 1102 |
+
attn_pool_text = True,
|
| 1103 |
+
attn_pool_num_latents = 32,
|
| 1104 |
+
dropout = 0.,
|
| 1105 |
+
memory_efficient = False,
|
| 1106 |
+
init_conv_to_final_conv_residual = False,
|
| 1107 |
+
use_global_context_attn = True,
|
| 1108 |
+
scale_skip_connection = True,
|
| 1109 |
+
final_resnet_block = True,
|
| 1110 |
+
final_conv_kernel_size = 3,
|
| 1111 |
+
cosine_sim_attn = False,
|
| 1112 |
+
self_cond = False,
|
| 1113 |
+
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
|
| 1114 |
+
pixel_shuffle_upsample = True # may address checkboard artifacts
|
| 1115 |
+
):
|
| 1116 |
+
super().__init__()
|
| 1117 |
+
|
| 1118 |
+
# guide researchers
|
| 1119 |
+
|
| 1120 |
+
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
|
| 1121 |
+
|
| 1122 |
+
if dim < 128:
|
| 1123 |
+
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
|
| 1124 |
+
|
| 1125 |
+
# save locals to take care of some hyperparameters for cascading DDPM
|
| 1126 |
+
|
| 1127 |
+
self._locals = locals()
|
| 1128 |
+
self._locals.pop('self', None)
|
| 1129 |
+
self._locals.pop('__class__', None)
|
| 1130 |
+
|
| 1131 |
+
# determine dimensions
|
| 1132 |
+
|
| 1133 |
+
self.channels = channels
|
| 1134 |
+
self.channels_out = default(channels_out, channels)
|
| 1135 |
+
|
| 1136 |
+
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
| 1137 |
+
# (2) in self conditioning, one appends the predict x0 (x_start)
|
| 1138 |
+
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
|
| 1139 |
+
init_dim = default(init_dim, dim)
|
| 1140 |
+
|
| 1141 |
+
self.self_cond = self_cond
|
| 1142 |
+
|
| 1143 |
+
# optional image conditioning
|
| 1144 |
+
|
| 1145 |
+
self.has_cond_image = cond_images_channels > 0
|
| 1146 |
+
self.cond_images_channels = cond_images_channels
|
| 1147 |
+
|
| 1148 |
+
init_channels += cond_images_channels
|
| 1149 |
+
|
| 1150 |
+
# initial convolution
|
| 1151 |
+
|
| 1152 |
+
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
| 1153 |
+
|
| 1154 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
| 1155 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
| 1156 |
+
|
| 1157 |
+
# time conditioning
|
| 1158 |
+
|
| 1159 |
+
cond_dim = default(cond_dim, dim)
|
| 1160 |
+
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
|
| 1161 |
+
|
| 1162 |
+
# embedding time for log(snr) noise from continuous version
|
| 1163 |
+
|
| 1164 |
+
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
|
| 1165 |
+
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
|
| 1166 |
+
|
| 1167 |
+
self.to_time_hiddens = nn.Sequential(
|
| 1168 |
+
sinu_pos_emb,
|
| 1169 |
+
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
|
| 1170 |
+
nn.SiLU()
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
self.to_time_cond = nn.Sequential(
|
| 1174 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
# project to time tokens as well as time hiddens
|
| 1178 |
+
|
| 1179 |
+
self.to_time_tokens = nn.Sequential(
|
| 1180 |
+
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
| 1181 |
+
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
# low res aug noise conditioning
|
| 1185 |
+
|
| 1186 |
+
self.lowres_cond = lowres_cond
|
| 1187 |
+
|
| 1188 |
+
if lowres_cond:
|
| 1189 |
+
self.to_lowres_time_hiddens = nn.Sequential(
|
| 1190 |
+
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
|
| 1191 |
+
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
|
| 1192 |
+
nn.SiLU()
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
self.to_lowres_time_cond = nn.Sequential(
|
| 1196 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
self.to_lowres_time_tokens = nn.Sequential(
|
| 1200 |
+
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
| 1201 |
+
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
# normalizations
|
| 1205 |
+
|
| 1206 |
+
self.norm_cond = nn.LayerNorm(cond_dim)
|
| 1207 |
+
|
| 1208 |
+
# text encoding conditioning (optional)
|
| 1209 |
+
|
| 1210 |
+
self.text_to_cond = None
|
| 1211 |
+
|
| 1212 |
+
if cond_on_text:
|
| 1213 |
+
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
|
| 1214 |
+
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
| 1215 |
+
|
| 1216 |
+
# finer control over whether to condition on text encodings
|
| 1217 |
+
|
| 1218 |
+
self.cond_on_text = cond_on_text
|
| 1219 |
+
|
| 1220 |
+
# attention pooling
|
| 1221 |
+
|
| 1222 |
+
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
|
| 1223 |
+
|
| 1224 |
+
# for classifier free guidance
|
| 1225 |
+
|
| 1226 |
+
self.max_text_len = max_text_len
|
| 1227 |
+
|
| 1228 |
+
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
| 1229 |
+
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
|
| 1230 |
+
|
| 1231 |
+
# for non-attention based text conditioning at all points in the network where time is also conditioned
|
| 1232 |
+
|
| 1233 |
+
self.to_text_non_attn_cond = None
|
| 1234 |
+
|
| 1235 |
+
if cond_on_text:
|
| 1236 |
+
self.to_text_non_attn_cond = nn.Sequential(
|
| 1237 |
+
nn.LayerNorm(cond_dim),
|
| 1238 |
+
nn.Linear(cond_dim, time_cond_dim),
|
| 1239 |
+
nn.SiLU(),
|
| 1240 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
# attention related params
|
| 1244 |
+
|
| 1245 |
+
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
|
| 1246 |
+
|
| 1247 |
+
num_layers = len(in_out)
|
| 1248 |
+
|
| 1249 |
+
# resnet block klass
|
| 1250 |
+
|
| 1251 |
+
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
|
| 1252 |
+
resnet_groups = cast_tuple(resnet_groups, num_layers)
|
| 1253 |
+
|
| 1254 |
+
resnet_klass = partial(ResnetBlock, **attn_kwargs)
|
| 1255 |
+
|
| 1256 |
+
layer_attns = cast_tuple(layer_attns, num_layers)
|
| 1257 |
+
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
|
| 1258 |
+
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
|
| 1259 |
+
|
| 1260 |
+
use_linear_attn = cast_tuple(use_linear_attn, num_layers)
|
| 1261 |
+
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
|
| 1262 |
+
|
| 1263 |
+
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
|
| 1264 |
+
|
| 1265 |
+
# downsample klass
|
| 1266 |
+
|
| 1267 |
+
downsample_klass = Downsample
|
| 1268 |
+
|
| 1269 |
+
if cross_embed_downsample:
|
| 1270 |
+
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
| 1271 |
+
|
| 1272 |
+
# initial resnet block (for memory efficient unet)
|
| 1273 |
+
|
| 1274 |
+
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
|
| 1275 |
+
|
| 1276 |
+
# scale for resnet skip connections
|
| 1277 |
+
|
| 1278 |
+
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
| 1279 |
+
|
| 1280 |
+
# layers
|
| 1281 |
+
|
| 1282 |
+
self.downs = nn.ModuleList([])
|
| 1283 |
+
self.ups = nn.ModuleList([])
|
| 1284 |
+
num_resolutions = len(in_out)
|
| 1285 |
+
|
| 1286 |
+
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
|
| 1287 |
+
reversed_layer_params = list(map(reversed, layer_params))
|
| 1288 |
+
|
| 1289 |
+
# downsampling layers
|
| 1290 |
+
|
| 1291 |
+
skip_connect_dims = [] # keep track of skip connection dimensions
|
| 1292 |
+
|
| 1293 |
+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
|
| 1294 |
+
is_last = ind >= (num_resolutions - 1)
|
| 1295 |
+
|
| 1296 |
+
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
|
| 1297 |
+
|
| 1298 |
+
if layer_attn:
|
| 1299 |
+
transformer_block_klass = TransformerBlock
|
| 1300 |
+
elif layer_use_linear_attn:
|
| 1301 |
+
transformer_block_klass = LinearAttentionTransformerBlock
|
| 1302 |
+
else:
|
| 1303 |
+
transformer_block_klass = Identity
|
| 1304 |
+
|
| 1305 |
+
current_dim = dim_in
|
| 1306 |
+
|
| 1307 |
+
# whether to pre-downsample, from memory efficient unet
|
| 1308 |
+
|
| 1309 |
+
pre_downsample = None
|
| 1310 |
+
|
| 1311 |
+
if memory_efficient:
|
| 1312 |
+
pre_downsample = downsample_klass(dim_in, dim_out)
|
| 1313 |
+
current_dim = dim_out
|
| 1314 |
+
|
| 1315 |
+
skip_connect_dims.append(current_dim)
|
| 1316 |
+
|
| 1317 |
+
# whether to do post-downsample, for non-memory efficient unet
|
| 1318 |
+
|
| 1319 |
+
post_downsample = None
|
| 1320 |
+
if not memory_efficient:
|
| 1321 |
+
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv2d(dim_in, dim_out, 3, padding = 1), nn.Conv2d(dim_in, dim_out, 1))
|
| 1322 |
+
|
| 1323 |
+
self.downs.append(nn.ModuleList([
|
| 1324 |
+
pre_downsample,
|
| 1325 |
+
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
|
| 1326 |
+
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
|
| 1327 |
+
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
|
| 1328 |
+
post_downsample
|
| 1329 |
+
]))
|
| 1330 |
+
|
| 1331 |
+
# middle layers
|
| 1332 |
+
|
| 1333 |
+
mid_dim = dims[-1]
|
| 1334 |
+
|
| 1335 |
+
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
| 1336 |
+
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
| 1337 |
+
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
| 1338 |
+
|
| 1339 |
+
# upsample klass
|
| 1340 |
+
|
| 1341 |
+
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
| 1342 |
+
|
| 1343 |
+
# upsampling layers
|
| 1344 |
+
|
| 1345 |
+
upsample_fmap_dims = []
|
| 1346 |
+
|
| 1347 |
+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
|
| 1348 |
+
is_last = ind == (len(in_out) - 1)
|
| 1349 |
+
|
| 1350 |
+
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
|
| 1351 |
+
|
| 1352 |
+
if layer_attn:
|
| 1353 |
+
transformer_block_klass = TransformerBlock
|
| 1354 |
+
elif layer_use_linear_attn:
|
| 1355 |
+
transformer_block_klass = LinearAttentionTransformerBlock
|
| 1356 |
+
else:
|
| 1357 |
+
transformer_block_klass = Identity
|
| 1358 |
+
|
| 1359 |
+
skip_connect_dim = skip_connect_dims.pop()
|
| 1360 |
+
|
| 1361 |
+
upsample_fmap_dims.append(dim_out)
|
| 1362 |
+
|
| 1363 |
+
self.ups.append(nn.ModuleList([
|
| 1364 |
+
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
|
| 1365 |
+
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
|
| 1366 |
+
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
|
| 1367 |
+
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
|
| 1368 |
+
]))
|
| 1369 |
+
|
| 1370 |
+
# whether to combine feature maps from all upsample blocks before final resnet block out
|
| 1371 |
+
|
| 1372 |
+
self.upsample_combiner = UpsampleCombiner(
|
| 1373 |
+
dim = dim,
|
| 1374 |
+
enabled = combine_upsample_fmaps,
|
| 1375 |
+
dim_ins = upsample_fmap_dims,
|
| 1376 |
+
dim_outs = dim
|
| 1377 |
+
)
|
| 1378 |
+
|
| 1379 |
+
# whether to do a final residual from initial conv to the final resnet block out
|
| 1380 |
+
|
| 1381 |
+
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
|
| 1382 |
+
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
|
| 1383 |
+
|
| 1384 |
+
# final optional resnet block and convolution out
|
| 1385 |
+
|
| 1386 |
+
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
|
| 1387 |
+
|
| 1388 |
+
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
|
| 1389 |
+
final_conv_dim_in += (channels if lowres_cond else 0)
|
| 1390 |
+
|
| 1391 |
+
self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
| 1392 |
+
|
| 1393 |
+
zero_init_(self.final_conv)
|
| 1394 |
+
|
| 1395 |
+
# if the current settings for the unet are not correct
|
| 1396 |
+
# for cascading DDPM, then reinit the unet with the right settings
|
| 1397 |
+
def cast_model_parameters(
|
| 1398 |
+
self,
|
| 1399 |
+
*,
|
| 1400 |
+
lowres_cond,
|
| 1401 |
+
text_embed_dim,
|
| 1402 |
+
channels,
|
| 1403 |
+
channels_out,
|
| 1404 |
+
cond_on_text
|
| 1405 |
+
):
|
| 1406 |
+
if lowres_cond == self.lowres_cond and \
|
| 1407 |
+
channels == self.channels and \
|
| 1408 |
+
cond_on_text == self.cond_on_text and \
|
| 1409 |
+
text_embed_dim == self._locals['text_embed_dim'] and \
|
| 1410 |
+
channels_out == self.channels_out:
|
| 1411 |
+
return self
|
| 1412 |
+
|
| 1413 |
+
updated_kwargs = dict(
|
| 1414 |
+
lowres_cond = lowres_cond,
|
| 1415 |
+
text_embed_dim = text_embed_dim,
|
| 1416 |
+
channels = channels,
|
| 1417 |
+
channels_out = channels_out,
|
| 1418 |
+
cond_on_text = cond_on_text
|
| 1419 |
+
)
|
| 1420 |
+
|
| 1421 |
+
return self.__class__(**{**self._locals, **updated_kwargs})
|
| 1422 |
+
|
| 1423 |
+
# methods for returning the full unet config as well as its parameter state
|
| 1424 |
+
|
| 1425 |
+
def to_config_and_state_dict(self):
|
| 1426 |
+
return self._locals, self.state_dict()
|
| 1427 |
+
|
| 1428 |
+
# class method for rehydrating the unet from its config and state dict
|
| 1429 |
+
|
| 1430 |
+
@classmethod
|
| 1431 |
+
def from_config_and_state_dict(klass, config, state_dict):
|
| 1432 |
+
unet = klass(**config)
|
| 1433 |
+
unet.load_state_dict(state_dict)
|
| 1434 |
+
return unet
|
| 1435 |
+
|
| 1436 |
+
# methods for persisting unet to disk
|
| 1437 |
+
|
| 1438 |
+
def persist_to_file(self, path):
|
| 1439 |
+
path = Path(path)
|
| 1440 |
+
path.parents[0].mkdir(exist_ok = True, parents = True)
|
| 1441 |
+
|
| 1442 |
+
config, state_dict = self.to_config_and_state_dict()
|
| 1443 |
+
pkg = dict(config = config, state_dict = state_dict)
|
| 1444 |
+
torch.save(pkg, str(path))
|
| 1445 |
+
|
| 1446 |
+
# class method for rehydrating the unet from file saved with `persist_to_file`
|
| 1447 |
+
|
| 1448 |
+
@classmethod
|
| 1449 |
+
def hydrate_from_file(klass, path):
|
| 1450 |
+
path = Path(path)
|
| 1451 |
+
assert path.exists()
|
| 1452 |
+
pkg = torch.load(str(path))
|
| 1453 |
+
|
| 1454 |
+
assert 'config' in pkg and 'state_dict' in pkg
|
| 1455 |
+
config, state_dict = pkg['config'], pkg['state_dict']
|
| 1456 |
+
|
| 1457 |
+
return Unet.from_config_and_state_dict(config, state_dict)
|
| 1458 |
+
|
| 1459 |
+
# forward with classifier free guidance
|
| 1460 |
+
|
| 1461 |
+
def forward_with_cond_scale(
|
| 1462 |
+
self,
|
| 1463 |
+
*args,
|
| 1464 |
+
cond_scale = 1.,
|
| 1465 |
+
**kwargs
|
| 1466 |
+
):
|
| 1467 |
+
logits = self.forward(*args, **kwargs)
|
| 1468 |
+
|
| 1469 |
+
if cond_scale == 1:
|
| 1470 |
+
return logits
|
| 1471 |
+
|
| 1472 |
+
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
| 1473 |
+
return null_logits + (logits - null_logits) * cond_scale
|
| 1474 |
+
|
| 1475 |
+
def forward(
|
| 1476 |
+
self,
|
| 1477 |
+
x,
|
| 1478 |
+
time,
|
| 1479 |
+
*,
|
| 1480 |
+
lowres_cond_img = None,
|
| 1481 |
+
lowres_noise_times = None,
|
| 1482 |
+
text_embeds = None,
|
| 1483 |
+
text_mask = None,
|
| 1484 |
+
cond_images = None,
|
| 1485 |
+
self_cond = None,
|
| 1486 |
+
cond_drop_prob = 0.
|
| 1487 |
+
):
|
| 1488 |
+
batch_size, device = x.shape[0], x.device
|
| 1489 |
+
|
| 1490 |
+
# condition on self
|
| 1491 |
+
|
| 1492 |
+
if self.self_cond:
|
| 1493 |
+
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
| 1494 |
+
x = torch.cat((x, self_cond), dim = 1)
|
| 1495 |
+
|
| 1496 |
+
# add low resolution conditioning, if present
|
| 1497 |
+
|
| 1498 |
+
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
| 1499 |
+
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
|
| 1500 |
+
|
| 1501 |
+
if exists(lowres_cond_img):
|
| 1502 |
+
x = torch.cat((x, lowres_cond_img), dim = 1)
|
| 1503 |
+
|
| 1504 |
+
# condition on input image
|
| 1505 |
+
|
| 1506 |
+
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
|
| 1507 |
+
|
| 1508 |
+
if exists(cond_images):
|
| 1509 |
+
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
|
| 1510 |
+
cond_images = resize_image_to(cond_images, x.shape[-1])
|
| 1511 |
+
x = torch.cat((cond_images, x), dim = 1)
|
| 1512 |
+
|
| 1513 |
+
# initial convolution
|
| 1514 |
+
|
| 1515 |
+
x = self.init_conv(x)
|
| 1516 |
+
|
| 1517 |
+
# init conv residual
|
| 1518 |
+
|
| 1519 |
+
if self.init_conv_to_final_conv_residual:
|
| 1520 |
+
init_conv_residual = x.clone()
|
| 1521 |
+
|
| 1522 |
+
# time conditioning
|
| 1523 |
+
|
| 1524 |
+
time_hiddens = self.to_time_hiddens(time)
|
| 1525 |
+
|
| 1526 |
+
# derive time tokens
|
| 1527 |
+
|
| 1528 |
+
time_tokens = self.to_time_tokens(time_hiddens)
|
| 1529 |
+
t = self.to_time_cond(time_hiddens)
|
| 1530 |
+
|
| 1531 |
+
# add lowres time conditioning to time hiddens
|
| 1532 |
+
# and add lowres time tokens along sequence dimension for attention
|
| 1533 |
+
|
| 1534 |
+
if self.lowres_cond:
|
| 1535 |
+
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
|
| 1536 |
+
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
|
| 1537 |
+
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
|
| 1538 |
+
|
| 1539 |
+
t = t + lowres_t
|
| 1540 |
+
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
|
| 1541 |
+
|
| 1542 |
+
# text conditioning
|
| 1543 |
+
|
| 1544 |
+
text_tokens = None
|
| 1545 |
+
|
| 1546 |
+
if exists(text_embeds) and self.cond_on_text:
|
| 1547 |
+
|
| 1548 |
+
# conditional dropout
|
| 1549 |
+
|
| 1550 |
+
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
|
| 1551 |
+
|
| 1552 |
+
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
|
| 1553 |
+
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
|
| 1554 |
+
|
| 1555 |
+
# calculate text embeds
|
| 1556 |
+
|
| 1557 |
+
text_tokens = self.text_to_cond(text_embeds)
|
| 1558 |
+
|
| 1559 |
+
text_tokens = text_tokens[:, :self.max_text_len]
|
| 1560 |
+
|
| 1561 |
+
if exists(text_mask):
|
| 1562 |
+
text_mask = text_mask[:, :self.max_text_len]
|
| 1563 |
+
|
| 1564 |
+
text_tokens_len = text_tokens.shape[1]
|
| 1565 |
+
remainder = self.max_text_len - text_tokens_len
|
| 1566 |
+
|
| 1567 |
+
if remainder > 0:
|
| 1568 |
+
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
| 1569 |
+
|
| 1570 |
+
if exists(text_mask):
|
| 1571 |
+
if remainder > 0:
|
| 1572 |
+
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
| 1573 |
+
|
| 1574 |
+
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
| 1575 |
+
text_keep_mask_embed = text_mask & text_keep_mask_embed
|
| 1576 |
+
|
| 1577 |
+
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
| 1578 |
+
|
| 1579 |
+
text_tokens = torch.where(
|
| 1580 |
+
text_keep_mask_embed,
|
| 1581 |
+
text_tokens,
|
| 1582 |
+
null_text_embed
|
| 1583 |
+
)
|
| 1584 |
+
|
| 1585 |
+
if exists(self.attn_pool):
|
| 1586 |
+
text_tokens = self.attn_pool(text_tokens)
|
| 1587 |
+
|
| 1588 |
+
# extra non-attention conditioning by projecting and then summing text embeddings to time
|
| 1589 |
+
# termed as text hiddens
|
| 1590 |
+
|
| 1591 |
+
mean_pooled_text_tokens = text_tokens.mean(dim = -2)
|
| 1592 |
+
|
| 1593 |
+
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
|
| 1594 |
+
|
| 1595 |
+
null_text_hidden = self.null_text_hidden.to(t.dtype)
|
| 1596 |
+
|
| 1597 |
+
text_hiddens = torch.where(
|
| 1598 |
+
text_keep_mask_hidden,
|
| 1599 |
+
text_hiddens,
|
| 1600 |
+
null_text_hidden
|
| 1601 |
+
)
|
| 1602 |
+
|
| 1603 |
+
t = t + text_hiddens
|
| 1604 |
+
|
| 1605 |
+
# main conditioning tokens (c)
|
| 1606 |
+
|
| 1607 |
+
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
|
| 1608 |
+
|
| 1609 |
+
# normalize conditioning tokens
|
| 1610 |
+
|
| 1611 |
+
c = self.norm_cond(c)
|
| 1612 |
+
|
| 1613 |
+
# initial resnet block (for memory efficient unet)
|
| 1614 |
+
|
| 1615 |
+
if exists(self.init_resnet_block):
|
| 1616 |
+
x = self.init_resnet_block(x, t)
|
| 1617 |
+
|
| 1618 |
+
# go through the layers of the unet, down and up
|
| 1619 |
+
|
| 1620 |
+
hiddens = []
|
| 1621 |
+
|
| 1622 |
+
for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
|
| 1623 |
+
if exists(pre_downsample):
|
| 1624 |
+
x = pre_downsample(x)
|
| 1625 |
+
|
| 1626 |
+
x = init_block(x, t, c)
|
| 1627 |
+
|
| 1628 |
+
for resnet_block in resnet_blocks:
|
| 1629 |
+
x = resnet_block(x, t)
|
| 1630 |
+
hiddens.append(x)
|
| 1631 |
+
|
| 1632 |
+
x = attn_block(x, c)
|
| 1633 |
+
hiddens.append(x)
|
| 1634 |
+
|
| 1635 |
+
if exists(post_downsample):
|
| 1636 |
+
x = post_downsample(x)
|
| 1637 |
+
|
| 1638 |
+
x = self.mid_block1(x, t, c)
|
| 1639 |
+
|
| 1640 |
+
if exists(self.mid_attn):
|
| 1641 |
+
x = self.mid_attn(x)
|
| 1642 |
+
|
| 1643 |
+
x = self.mid_block2(x, t, c)
|
| 1644 |
+
|
| 1645 |
+
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
| 1646 |
+
|
| 1647 |
+
up_hiddens = []
|
| 1648 |
+
|
| 1649 |
+
for init_block, resnet_blocks, attn_block, upsample in self.ups:
|
| 1650 |
+
x = add_skip_connection(x)
|
| 1651 |
+
x = init_block(x, t, c)
|
| 1652 |
+
|
| 1653 |
+
for resnet_block in resnet_blocks:
|
| 1654 |
+
x = add_skip_connection(x)
|
| 1655 |
+
x = resnet_block(x, t)
|
| 1656 |
+
|
| 1657 |
+
x = attn_block(x, c)
|
| 1658 |
+
up_hiddens.append(x.contiguous())
|
| 1659 |
+
x = upsample(x)
|
| 1660 |
+
|
| 1661 |
+
# whether to combine all feature maps from upsample blocks
|
| 1662 |
+
|
| 1663 |
+
x = self.upsample_combiner(x, up_hiddens)
|
| 1664 |
+
|
| 1665 |
+
# final top-most residual if needed
|
| 1666 |
+
|
| 1667 |
+
if self.init_conv_to_final_conv_residual:
|
| 1668 |
+
x = torch.cat((x, init_conv_residual), dim = 1)
|
| 1669 |
+
|
| 1670 |
+
if exists(self.final_res_block):
|
| 1671 |
+
x = self.final_res_block(x, t)
|
| 1672 |
+
|
| 1673 |
+
if exists(lowres_cond_img):
|
| 1674 |
+
x = torch.cat((x, lowres_cond_img), dim = 1)
|
| 1675 |
+
|
| 1676 |
+
return self.final_conv(x)
|
| 1677 |
+
|
| 1678 |
+
# null unet
|
| 1679 |
+
|
| 1680 |
+
class NullUnet(nn.Module):
|
| 1681 |
+
def __init__(self, *args, **kwargs):
|
| 1682 |
+
super().__init__()
|
| 1683 |
+
self.lowres_cond = False
|
| 1684 |
+
self.dummy_parameter = nn.Parameter(torch.tensor([0.]))
|
| 1685 |
+
|
| 1686 |
+
def cast_model_parameters(self, *args, **kwargs):
|
| 1687 |
+
return self
|
| 1688 |
+
|
| 1689 |
+
def forward(self, x, *args, **kwargs):
|
| 1690 |
+
return x
|
| 1691 |
+
|
| 1692 |
+
# predefined unets, with configs lining up with hyperparameters in appendix of paper
|
| 1693 |
+
|
| 1694 |
+
class BaseUnet64(Unet):
|
| 1695 |
+
def __init__(self, *args, **kwargs):
|
| 1696 |
+
default_kwargs = dict(
|
| 1697 |
+
dim = 512,
|
| 1698 |
+
dim_mults = (1, 2, 3, 4),
|
| 1699 |
+
num_resnet_blocks = 3,
|
| 1700 |
+
layer_attns = (False, True, True, True),
|
| 1701 |
+
layer_cross_attns = (False, True, True, True),
|
| 1702 |
+
attn_heads = 8,
|
| 1703 |
+
ff_mult = 2.,
|
| 1704 |
+
memory_efficient = False
|
| 1705 |
+
)
|
| 1706 |
+
super().__init__(*args, **{**default_kwargs, **kwargs})
|
| 1707 |
+
|
| 1708 |
+
class SRUnet256(Unet):
|
| 1709 |
+
def __init__(self, *args, **kwargs):
|
| 1710 |
+
default_kwargs = dict(
|
| 1711 |
+
dim = 128,
|
| 1712 |
+
dim_mults = (1, 2, 4, 8),
|
| 1713 |
+
num_resnet_blocks = (2, 4, 8, 8),
|
| 1714 |
+
layer_attns = (False, False, False, True),
|
| 1715 |
+
layer_cross_attns = (False, False, False, True),
|
| 1716 |
+
attn_heads = 8,
|
| 1717 |
+
ff_mult = 2.,
|
| 1718 |
+
memory_efficient = True
|
| 1719 |
+
)
|
| 1720 |
+
super().__init__(*args, **{**default_kwargs, **kwargs})
|
| 1721 |
+
|
| 1722 |
+
class SRUnet1024(Unet):
|
| 1723 |
+
def __init__(self, *args, **kwargs):
|
| 1724 |
+
default_kwargs = dict(
|
| 1725 |
+
dim = 128,
|
| 1726 |
+
dim_mults = (1, 2, 4, 8),
|
| 1727 |
+
num_resnet_blocks = (2, 4, 8, 8),
|
| 1728 |
+
layer_attns = False,
|
| 1729 |
+
layer_cross_attns = (False, False, False, True),
|
| 1730 |
+
attn_heads = 8,
|
| 1731 |
+
ff_mult = 2.,
|
| 1732 |
+
memory_efficient = True
|
| 1733 |
+
)
|
| 1734 |
+
super().__init__(*args, **{**default_kwargs, **kwargs})
|
| 1735 |
+
|
| 1736 |
+
# main imagen ddpm class, which is a cascading DDPM from Ho et al.
|
| 1737 |
+
|
| 1738 |
+
class Imagen(nn.Module):
|
| 1739 |
+
def __init__(
|
| 1740 |
+
self,
|
| 1741 |
+
unets,
|
| 1742 |
+
*,
|
| 1743 |
+
image_sizes, # for cascading ddpm, image size at each stage
|
| 1744 |
+
text_encoder_name = DEFAULT_T5_NAME,
|
| 1745 |
+
text_embed_dim = None,
|
| 1746 |
+
channels = 3,
|
| 1747 |
+
timesteps = 1000,
|
| 1748 |
+
sample_timesteps=100,
|
| 1749 |
+
cond_drop_prob = 0.1,
|
| 1750 |
+
loss_type = 'l2',
|
| 1751 |
+
noise_schedules = 'cosine',
|
| 1752 |
+
pred_objectives = 'noise',
|
| 1753 |
+
random_crop_sizes = None,
|
| 1754 |
+
lowres_noise_schedule = 'linear',
|
| 1755 |
+
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
|
| 1756 |
+
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
|
| 1757 |
+
condition_on_text = True,
|
| 1758 |
+
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
| 1759 |
+
p2_loss_weight_gamma = 0.5, # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
|
| 1760 |
+
p2_loss_weight_k = 1,
|
| 1761 |
+
dynamic_thresholding = True,
|
| 1762 |
+
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
|
| 1763 |
+
only_train_unet_number = None
|
| 1764 |
+
):
|
| 1765 |
+
super().__init__()
|
| 1766 |
+
|
| 1767 |
+
# loss
|
| 1768 |
+
|
| 1769 |
+
if loss_type == 'l1':
|
| 1770 |
+
loss_fn = F.l1_loss
|
| 1771 |
+
elif loss_type == 'l2':
|
| 1772 |
+
loss_fn = F.mse_loss
|
| 1773 |
+
elif loss_type == 'huber':
|
| 1774 |
+
loss_fn = F.smooth_l1_loss
|
| 1775 |
+
else:
|
| 1776 |
+
raise NotImplementedError()
|
| 1777 |
+
|
| 1778 |
+
self.loss_type = loss_type
|
| 1779 |
+
self.loss_fn = loss_fn
|
| 1780 |
+
|
| 1781 |
+
# conditioning hparams
|
| 1782 |
+
|
| 1783 |
+
self.condition_on_text = condition_on_text
|
| 1784 |
+
self.unconditional = not condition_on_text
|
| 1785 |
+
|
| 1786 |
+
# channels
|
| 1787 |
+
|
| 1788 |
+
self.channels = channels
|
| 1789 |
+
|
| 1790 |
+
# automatically take care of ensuring that first unet is unconditional
|
| 1791 |
+
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
| 1792 |
+
|
| 1793 |
+
unets = cast_tuple(unets)
|
| 1794 |
+
num_unets = len(unets)
|
| 1795 |
+
|
| 1796 |
+
# determine noise schedules per unet
|
| 1797 |
+
|
| 1798 |
+
timesteps = cast_tuple(timesteps, num_unets)
|
| 1799 |
+
sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
| 1800 |
+
|
| 1801 |
+
# make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets
|
| 1802 |
+
|
| 1803 |
+
noise_schedules = cast_tuple(noise_schedules)
|
| 1804 |
+
noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
|
| 1805 |
+
noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
|
| 1806 |
+
|
| 1807 |
+
# construct noise schedulers
|
| 1808 |
+
|
| 1809 |
+
noise_scheduler_klass = GaussianDiffusionContinuousTimes
|
| 1810 |
+
self.noise_schedulers = nn.ModuleList([])
|
| 1811 |
+
|
| 1812 |
+
for timestep, noise_schedule in zip(timesteps, noise_schedules):
|
| 1813 |
+
noise_scheduler = noise_scheduler_klass(noise_schedule = noise_schedule, timesteps = timestep)
|
| 1814 |
+
self.noise_schedulers.append(noise_scheduler)
|
| 1815 |
+
|
| 1816 |
+
self.noise_schedulers_sample = nn.ModuleList([])
|
| 1817 |
+
|
| 1818 |
+
for sample_timestep, noise_schedule in zip(sample_timesteps, noise_schedules):
|
| 1819 |
+
noise_scheduler_sample = noise_scheduler_klass(noise_schedule=noise_schedule, timesteps=sample_timestep)
|
| 1820 |
+
self.noise_schedulers_sample.append(noise_scheduler_sample)
|
| 1821 |
+
|
| 1822 |
+
# randomly cropping for upsampler training
|
| 1823 |
+
|
| 1824 |
+
self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
|
| 1825 |
+
assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
|
| 1826 |
+
|
| 1827 |
+
# lowres augmentation noise schedule
|
| 1828 |
+
|
| 1829 |
+
self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
|
| 1830 |
+
|
| 1831 |
+
# ddpm objectives - predicting noise by default
|
| 1832 |
+
|
| 1833 |
+
self.pred_objectives = cast_tuple(pred_objectives, num_unets)
|
| 1834 |
+
|
| 1835 |
+
# get text encoder
|
| 1836 |
+
|
| 1837 |
+
self.text_encoder_name = text_encoder_name
|
| 1838 |
+
self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
|
| 1839 |
+
|
| 1840 |
+
self.encode_text = partial(t5_encode_text, name = text_encoder_name)
|
| 1841 |
+
|
| 1842 |
+
# construct unets
|
| 1843 |
+
|
| 1844 |
+
self.unets = nn.ModuleList([])
|
| 1845 |
+
|
| 1846 |
+
self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
|
| 1847 |
+
self.only_train_unet_number = only_train_unet_number
|
| 1848 |
+
|
| 1849 |
+
for ind, one_unet in enumerate(unets):
|
| 1850 |
+
assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
|
| 1851 |
+
is_first = ind == 0
|
| 1852 |
+
|
| 1853 |
+
one_unet = one_unet.cast_model_parameters(
|
| 1854 |
+
lowres_cond = not is_first,
|
| 1855 |
+
cond_on_text = self.condition_on_text,
|
| 1856 |
+
text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
|
| 1857 |
+
channels = self.channels,
|
| 1858 |
+
channels_out = self.channels
|
| 1859 |
+
)
|
| 1860 |
+
|
| 1861 |
+
self.unets.append(one_unet)
|
| 1862 |
+
|
| 1863 |
+
# unet image sizes
|
| 1864 |
+
|
| 1865 |
+
image_sizes = cast_tuple(image_sizes)
|
| 1866 |
+
self.image_sizes = image_sizes
|
| 1867 |
+
|
| 1868 |
+
assert num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {image_sizes}'
|
| 1869 |
+
|
| 1870 |
+
self.sample_channels = cast_tuple(self.channels, num_unets)
|
| 1871 |
+
|
| 1872 |
+
# determine whether we are training on images or video
|
| 1873 |
+
|
| 1874 |
+
is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
|
| 1875 |
+
self.is_video = is_video
|
| 1876 |
+
|
| 1877 |
+
self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
|
| 1878 |
+
self.resize_to = resize_video_to if is_video else resize_image_to
|
| 1879 |
+
|
| 1880 |
+
# cascading ddpm related stuff
|
| 1881 |
+
|
| 1882 |
+
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
| 1883 |
+
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
| 1884 |
+
|
| 1885 |
+
self.lowres_sample_noise_level = lowres_sample_noise_level
|
| 1886 |
+
self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
|
| 1887 |
+
|
| 1888 |
+
# classifier free guidance
|
| 1889 |
+
|
| 1890 |
+
self.cond_drop_prob = cond_drop_prob
|
| 1891 |
+
self.can_classifier_guidance = cond_drop_prob > 0.
|
| 1892 |
+
|
| 1893 |
+
# normalize and unnormalize image functions
|
| 1894 |
+
|
| 1895 |
+
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
| 1896 |
+
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
| 1897 |
+
self.input_image_range = (0. if auto_normalize_img else -1., 1.)
|
| 1898 |
+
|
| 1899 |
+
# dynamic thresholding
|
| 1900 |
+
|
| 1901 |
+
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
|
| 1902 |
+
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
|
| 1903 |
+
|
| 1904 |
+
# p2 loss weight
|
| 1905 |
+
|
| 1906 |
+
self.p2_loss_weight_k = p2_loss_weight_k
|
| 1907 |
+
self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
|
| 1908 |
+
|
| 1909 |
+
assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), 'in paper, they noticed any gamma greater than 2 is harmful'
|
| 1910 |
+
|
| 1911 |
+
# one temp parameter for keeping track of device
|
| 1912 |
+
|
| 1913 |
+
self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
|
| 1914 |
+
|
| 1915 |
+
# default to device of unets passed in
|
| 1916 |
+
|
| 1917 |
+
self.to(next(self.unets.parameters()).device)
|
| 1918 |
+
|
| 1919 |
+
def force_unconditional_(self):
|
| 1920 |
+
self.condition_on_text = False
|
| 1921 |
+
self.unconditional = True
|
| 1922 |
+
|
| 1923 |
+
for unet in self.unets:
|
| 1924 |
+
unet.cond_on_text = False
|
| 1925 |
+
|
| 1926 |
+
@property
|
| 1927 |
+
def device(self):
|
| 1928 |
+
return self._temp.device
|
| 1929 |
+
|
| 1930 |
+
def get_unet(self, unet_number):
|
| 1931 |
+
assert 0 < unet_number <= len(self.unets)
|
| 1932 |
+
index = unet_number - 1
|
| 1933 |
+
|
| 1934 |
+
if isinstance(self.unets, nn.ModuleList):
|
| 1935 |
+
unets_list = [unet for unet in self.unets]
|
| 1936 |
+
delattr(self, 'unets')
|
| 1937 |
+
self.unets = unets_list
|
| 1938 |
+
|
| 1939 |
+
if index != self.unet_being_trained_index:
|
| 1940 |
+
for unet_index, unet in enumerate(self.unets):
|
| 1941 |
+
unet.to(self.device if unet_index == index else 'cpu')
|
| 1942 |
+
|
| 1943 |
+
self.unet_being_trained_index = index
|
| 1944 |
+
return self.unets[index]
|
| 1945 |
+
|
| 1946 |
+
def reset_unets_all_one_device(self, device = None):
|
| 1947 |
+
device = default(device, self.device)
|
| 1948 |
+
self.unets = nn.ModuleList([*self.unets])
|
| 1949 |
+
self.unets.to(device)
|
| 1950 |
+
|
| 1951 |
+
self.unet_being_trained_index = -1
|
| 1952 |
+
|
| 1953 |
+
@contextmanager
|
| 1954 |
+
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
| 1955 |
+
assert exists(unet_number) ^ exists(unet)
|
| 1956 |
+
|
| 1957 |
+
if exists(unet_number):
|
| 1958 |
+
unet = self.unets[unet_number - 1]
|
| 1959 |
+
|
| 1960 |
+
devices = [module_device(unet) for unet in self.unets]
|
| 1961 |
+
self.unets.cpu()
|
| 1962 |
+
unet.to(self.device)
|
| 1963 |
+
|
| 1964 |
+
yield
|
| 1965 |
+
|
| 1966 |
+
for unet, device in zip(self.unets, devices):
|
| 1967 |
+
unet.to(device)
|
| 1968 |
+
|
| 1969 |
+
# overriding state dict functions
|
| 1970 |
+
|
| 1971 |
+
def state_dict(self, *args, **kwargs):
|
| 1972 |
+
self.reset_unets_all_one_device()
|
| 1973 |
+
return super().state_dict(*args, **kwargs)
|
| 1974 |
+
|
| 1975 |
+
def load_state_dict(self, *args, **kwargs):
|
| 1976 |
+
self.reset_unets_all_one_device()
|
| 1977 |
+
return super().load_state_dict(*args, **kwargs)
|
| 1978 |
+
|
| 1979 |
+
# gaussian diffusion methods
|
| 1980 |
+
|
| 1981 |
+
def p_mean_variance(
|
| 1982 |
+
self,
|
| 1983 |
+
unet,
|
| 1984 |
+
x,
|
| 1985 |
+
t,
|
| 1986 |
+
*,
|
| 1987 |
+
noise_scheduler,
|
| 1988 |
+
text_embeds = None,
|
| 1989 |
+
text_mask = None,
|
| 1990 |
+
cond_images = None,
|
| 1991 |
+
lowres_cond_img = None,
|
| 1992 |
+
self_cond = None,
|
| 1993 |
+
lowres_noise_times = None,
|
| 1994 |
+
cond_scale = 1.,
|
| 1995 |
+
model_output = None,
|
| 1996 |
+
t_next = None,
|
| 1997 |
+
pred_objective = 'noise',
|
| 1998 |
+
dynamic_threshold = True
|
| 1999 |
+
):
|
| 2000 |
+
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
| 2001 |
+
|
| 2002 |
+
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times)))
|
| 2003 |
+
|
| 2004 |
+
if pred_objective == 'noise':
|
| 2005 |
+
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
| 2006 |
+
elif pred_objective == 'x_start':
|
| 2007 |
+
x_start = pred
|
| 2008 |
+
else:
|
| 2009 |
+
raise ValueError(f'unknown objective {pred_objective}')
|
| 2010 |
+
|
| 2011 |
+
if dynamic_threshold:
|
| 2012 |
+
# following pseudocode in appendix
|
| 2013 |
+
# s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
|
| 2014 |
+
s = torch.quantile(
|
| 2015 |
+
rearrange(x_start, 'b ... -> b (...)').abs(),
|
| 2016 |
+
self.dynamic_thresholding_percentile,
|
| 2017 |
+
dim = -1
|
| 2018 |
+
)
|
| 2019 |
+
|
| 2020 |
+
s.clamp_(min = 1.)
|
| 2021 |
+
s = right_pad_dims_to(x_start, s)
|
| 2022 |
+
x_start = x_start.clamp(-s, s) / s
|
| 2023 |
+
else:
|
| 2024 |
+
x_start.clamp_(-1., 1.)
|
| 2025 |
+
|
| 2026 |
+
mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
|
| 2027 |
+
return mean_and_variance, x_start
|
| 2028 |
+
|
| 2029 |
+
@torch.no_grad()
|
| 2030 |
+
def p_sample(
|
| 2031 |
+
self,
|
| 2032 |
+
unet,
|
| 2033 |
+
x,
|
| 2034 |
+
t,
|
| 2035 |
+
*,
|
| 2036 |
+
noise_scheduler,
|
| 2037 |
+
t_next = None,
|
| 2038 |
+
text_embeds = None,
|
| 2039 |
+
text_mask = None,
|
| 2040 |
+
cond_images = None,
|
| 2041 |
+
cond_scale = 1.,
|
| 2042 |
+
self_cond = None,
|
| 2043 |
+
lowres_cond_img = None,
|
| 2044 |
+
lowres_noise_times = None,
|
| 2045 |
+
pred_objective = 'noise',
|
| 2046 |
+
dynamic_threshold = True
|
| 2047 |
+
):
|
| 2048 |
+
b, *_, device = *x.shape, x.device
|
| 2049 |
+
(model_mean, _, model_log_variance), x_start = self.p_mean_variance(unet, x = x, t = t, t_next = t_next, noise_scheduler = noise_scheduler, text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = lowres_noise_times, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)
|
| 2050 |
+
noise = torch.randn_like(x)
|
| 2051 |
+
# no noise when t == 0
|
| 2052 |
+
is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
|
| 2053 |
+
nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 2054 |
+
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 2055 |
+
return pred, x_start
|
| 2056 |
+
|
| 2057 |
+
@torch.no_grad()
|
| 2058 |
+
def p_sample_loop(
|
| 2059 |
+
self,
|
| 2060 |
+
unet,
|
| 2061 |
+
shape,
|
| 2062 |
+
*,
|
| 2063 |
+
noise_scheduler,
|
| 2064 |
+
lowres_cond_img = None,
|
| 2065 |
+
lowres_noise_times = None,
|
| 2066 |
+
text_embeds = None,
|
| 2067 |
+
text_mask = None,
|
| 2068 |
+
cond_images = None,
|
| 2069 |
+
inpaint_images = None,
|
| 2070 |
+
inpaint_masks = None,
|
| 2071 |
+
inpaint_resample_times = 5,
|
| 2072 |
+
init_images = None,
|
| 2073 |
+
skip_steps = None,
|
| 2074 |
+
cond_scale = 1,
|
| 2075 |
+
pred_objective = 'noise',
|
| 2076 |
+
dynamic_threshold = True,
|
| 2077 |
+
use_tqdm = True
|
| 2078 |
+
):
|
| 2079 |
+
device = self.device
|
| 2080 |
+
|
| 2081 |
+
batch = shape[0]
|
| 2082 |
+
img = torch.randn(shape, device = device)
|
| 2083 |
+
|
| 2084 |
+
# for initialization with an image or video
|
| 2085 |
+
|
| 2086 |
+
if exists(init_images):
|
| 2087 |
+
img += init_images
|
| 2088 |
+
|
| 2089 |
+
# keep track of x0, for self conditioning
|
| 2090 |
+
|
| 2091 |
+
x_start = None
|
| 2092 |
+
|
| 2093 |
+
# prepare inpainting
|
| 2094 |
+
|
| 2095 |
+
has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
|
| 2096 |
+
resample_times = inpaint_resample_times if has_inpainting else 1
|
| 2097 |
+
|
| 2098 |
+
if has_inpainting:
|
| 2099 |
+
inpaint_images = self.normalize_img(inpaint_images)
|
| 2100 |
+
inpaint_images = self.resize_to(inpaint_images, shape[-1])
|
| 2101 |
+
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()
|
| 2102 |
+
|
| 2103 |
+
# time
|
| 2104 |
+
|
| 2105 |
+
timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)
|
| 2106 |
+
|
| 2107 |
+
# whether to skip any steps
|
| 2108 |
+
|
| 2109 |
+
skip_steps = default(skip_steps, 0)
|
| 2110 |
+
timesteps = timesteps[skip_steps:]
|
| 2111 |
+
|
| 2112 |
+
for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
|
| 2113 |
+
is_last_timestep = times_next == 0
|
| 2114 |
+
|
| 2115 |
+
for r in reversed(range(resample_times)):
|
| 2116 |
+
is_last_resample_step = r == 0
|
| 2117 |
+
|
| 2118 |
+
if has_inpainting:
|
| 2119 |
+
noised_inpaint_images, _ = noise_scheduler.q_sample(inpaint_images, t = times)
|
| 2120 |
+
img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks
|
| 2121 |
+
|
| 2122 |
+
self_cond = x_start if unet.self_cond else None
|
| 2123 |
+
|
| 2124 |
+
img, x_start = self.p_sample(
|
| 2125 |
+
unet,
|
| 2126 |
+
img,
|
| 2127 |
+
times,
|
| 2128 |
+
t_next = times_next,
|
| 2129 |
+
text_embeds = text_embeds,
|
| 2130 |
+
text_mask = text_mask,
|
| 2131 |
+
cond_images = cond_images,
|
| 2132 |
+
cond_scale = cond_scale,
|
| 2133 |
+
self_cond = self_cond,
|
| 2134 |
+
lowres_cond_img = lowres_cond_img,
|
| 2135 |
+
lowres_noise_times = lowres_noise_times,
|
| 2136 |
+
noise_scheduler = noise_scheduler,
|
| 2137 |
+
pred_objective = pred_objective,
|
| 2138 |
+
dynamic_threshold = dynamic_threshold
|
| 2139 |
+
)
|
| 2140 |
+
|
| 2141 |
+
if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
|
| 2142 |
+
renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
|
| 2143 |
+
|
| 2144 |
+
img = torch.where(
|
| 2145 |
+
self.right_pad_dims_to_datatype(is_last_timestep),
|
| 2146 |
+
img,
|
| 2147 |
+
renoised_img
|
| 2148 |
+
)
|
| 2149 |
+
|
| 2150 |
+
img.clamp_(-1., 1.)
|
| 2151 |
+
|
| 2152 |
+
# final inpainting
|
| 2153 |
+
|
| 2154 |
+
if has_inpainting:
|
| 2155 |
+
img = img * ~inpaint_masks + inpaint_images * inpaint_masks
|
| 2156 |
+
|
| 2157 |
+
unnormalize_img = self.unnormalize_img(img)
|
| 2158 |
+
return unnormalize_img
|
| 2159 |
+
|
| 2160 |
+
@torch.no_grad()
|
| 2161 |
+
@eval_decorator
|
| 2162 |
+
def sample(
|
| 2163 |
+
self,
|
| 2164 |
+
texts: List[str] = None,
|
| 2165 |
+
text_masks = None,
|
| 2166 |
+
text_embeds = None,
|
| 2167 |
+
video_frames = None,
|
| 2168 |
+
cond_images = None,
|
| 2169 |
+
inpaint_images = None,
|
| 2170 |
+
inpaint_masks = None,
|
| 2171 |
+
inpaint_resample_times = 5,
|
| 2172 |
+
init_images = None,
|
| 2173 |
+
skip_steps = None,
|
| 2174 |
+
batch_size = 1,
|
| 2175 |
+
cond_scale = 1.,
|
| 2176 |
+
lowres_sample_noise_level = None,
|
| 2177 |
+
start_at_unet_number = 1,
|
| 2178 |
+
start_image_or_video = None,
|
| 2179 |
+
stop_at_unet_number = None,
|
| 2180 |
+
return_all_unet_outputs = False,
|
| 2181 |
+
return_pil_images = False,
|
| 2182 |
+
device = None,
|
| 2183 |
+
use_tqdm = True
|
| 2184 |
+
):
|
| 2185 |
+
device = default(device, self.device)
|
| 2186 |
+
self.reset_unets_all_one_device(device = device)
|
| 2187 |
+
|
| 2188 |
+
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
|
| 2189 |
+
|
| 2190 |
+
if exists(texts) and not exists(text_embeds) and not self.unconditional:
|
| 2191 |
+
assert all([*map(len, texts)]), 'text cannot be empty'
|
| 2192 |
+
|
| 2193 |
+
with autocast(enabled = False):
|
| 2194 |
+
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
|
| 2195 |
+
|
| 2196 |
+
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
|
| 2197 |
+
|
| 2198 |
+
if not self.unconditional:
|
| 2199 |
+
assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
|
| 2200 |
+
|
| 2201 |
+
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
|
| 2202 |
+
batch_size = text_embeds.shape[0]
|
| 2203 |
+
|
| 2204 |
+
if exists(inpaint_images):
|
| 2205 |
+
if self.unconditional:
|
| 2206 |
+
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
|
| 2207 |
+
batch_size = inpaint_images.shape[0]
|
| 2208 |
+
|
| 2209 |
+
assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
|
| 2210 |
+
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
|
| 2211 |
+
|
| 2212 |
+
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
|
| 2213 |
+
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
|
| 2214 |
+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
|
| 2215 |
+
|
| 2216 |
+
assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
|
| 2217 |
+
|
| 2218 |
+
outputs = []
|
| 2219 |
+
|
| 2220 |
+
is_cuda = next(self.parameters()).is_cuda
|
| 2221 |
+
device = next(self.parameters()).device
|
| 2222 |
+
|
| 2223 |
+
lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
|
| 2224 |
+
|
| 2225 |
+
num_unets = len(self.unets)
|
| 2226 |
+
|
| 2227 |
+
# condition scaling
|
| 2228 |
+
|
| 2229 |
+
cond_scale = cast_tuple(cond_scale, num_unets)
|
| 2230 |
+
|
| 2231 |
+
# add frame dimension for video
|
| 2232 |
+
|
| 2233 |
+
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
|
| 2234 |
+
|
| 2235 |
+
frame_dims = (video_frames,) if self.is_video else tuple()
|
| 2236 |
+
|
| 2237 |
+
# for initial image and skipping steps
|
| 2238 |
+
|
| 2239 |
+
init_images = cast_tuple(init_images, num_unets)
|
| 2240 |
+
init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
|
| 2241 |
+
|
| 2242 |
+
skip_steps = cast_tuple(skip_steps, num_unets)
|
| 2243 |
+
|
| 2244 |
+
# handle starting at a unet greater than 1, for training only-upscaler training
|
| 2245 |
+
|
| 2246 |
+
if start_at_unet_number > 1:
|
| 2247 |
+
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
|
| 2248 |
+
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
|
| 2249 |
+
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
|
| 2250 |
+
|
| 2251 |
+
prev_image_size = self.image_sizes[start_at_unet_number - 2]
|
| 2252 |
+
img = self.resize_to(start_image_or_video, prev_image_size)
|
| 2253 |
+
|
| 2254 |
+
# go through each unet in cascade
|
| 2255 |
+
|
| 2256 |
+
for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers_sample, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
|
| 2257 |
+
|
| 2258 |
+
if unet_number < start_at_unet_number:
|
| 2259 |
+
continue
|
| 2260 |
+
|
| 2261 |
+
assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'
|
| 2262 |
+
|
| 2263 |
+
context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
|
| 2264 |
+
|
| 2265 |
+
with context:
|
| 2266 |
+
lowres_cond_img = lowres_noise_times = None
|
| 2267 |
+
shape = (batch_size, channel, *frame_dims, *image_size)
|
| 2268 |
+
|
| 2269 |
+
if unet.lowres_cond:
|
| 2270 |
+
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
|
| 2271 |
+
|
| 2272 |
+
lowres_cond_img = self.resize_to(img, image_size)
|
| 2273 |
+
|
| 2274 |
+
lowres_cond_img = self.normalize_img(lowres_cond_img)
|
| 2275 |
+
lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
|
| 2276 |
+
|
| 2277 |
+
if exists(unet_init_images):
|
| 2278 |
+
unet_init_images = self.resize_to(unet_init_images, image_size)
|
| 2279 |
+
|
| 2280 |
+
shape = (batch_size, self.channels, *frame_dims, *image_size)
|
| 2281 |
+
|
| 2282 |
+
img = self.p_sample_loop(
|
| 2283 |
+
unet,
|
| 2284 |
+
shape,
|
| 2285 |
+
text_embeds = text_embeds,
|
| 2286 |
+
text_mask = text_masks,
|
| 2287 |
+
cond_images = cond_images,
|
| 2288 |
+
inpaint_images = inpaint_images,
|
| 2289 |
+
inpaint_masks = inpaint_masks,
|
| 2290 |
+
inpaint_resample_times = inpaint_resample_times,
|
| 2291 |
+
init_images = unet_init_images,
|
| 2292 |
+
skip_steps = unet_skip_steps,
|
| 2293 |
+
cond_scale = unet_cond_scale,
|
| 2294 |
+
lowres_cond_img = lowres_cond_img,
|
| 2295 |
+
lowres_noise_times = lowres_noise_times,
|
| 2296 |
+
noise_scheduler = noise_scheduler,
|
| 2297 |
+
pred_objective = pred_objective,
|
| 2298 |
+
dynamic_threshold = dynamic_threshold,
|
| 2299 |
+
use_tqdm = use_tqdm
|
| 2300 |
+
)
|
| 2301 |
+
|
| 2302 |
+
outputs.append(img)
|
| 2303 |
+
|
| 2304 |
+
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
| 2305 |
+
break
|
| 2306 |
+
|
| 2307 |
+
output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
|
| 2308 |
+
|
| 2309 |
+
if not return_pil_images:
|
| 2310 |
+
return outputs[output_index]
|
| 2311 |
+
|
| 2312 |
+
if not return_all_unet_outputs:
|
| 2313 |
+
outputs = outputs[-1:]
|
| 2314 |
+
|
| 2315 |
+
assert not self.is_video, 'converting sampled video tensor to video file is not supported yet'
|
| 2316 |
+
|
| 2317 |
+
pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
|
| 2318 |
+
|
| 2319 |
+
return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
|
| 2320 |
+
|
| 2321 |
+
def p_losses(
|
| 2322 |
+
self,
|
| 2323 |
+
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],
|
| 2324 |
+
x_start,
|
| 2325 |
+
times,
|
| 2326 |
+
*,
|
| 2327 |
+
noise_scheduler,
|
| 2328 |
+
lowres_cond_img = None,
|
| 2329 |
+
lowres_aug_times = None,
|
| 2330 |
+
text_embeds = None,
|
| 2331 |
+
text_mask = None,
|
| 2332 |
+
cond_images = None,
|
| 2333 |
+
noise = None,
|
| 2334 |
+
times_next = None,
|
| 2335 |
+
pred_objective = 'noise',
|
| 2336 |
+
p2_loss_weight_gamma = 0.,
|
| 2337 |
+
random_crop_size = None
|
| 2338 |
+
):
|
| 2339 |
+
is_video = x_start.ndim == 5
|
| 2340 |
+
|
| 2341 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 2342 |
+
|
| 2343 |
+
# normalize to [-1, 1]
|
| 2344 |
+
|
| 2345 |
+
x_start = self.normalize_img(x_start)
|
| 2346 |
+
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
| 2347 |
+
|
| 2348 |
+
# random cropping during training
|
| 2349 |
+
# for upsamplers
|
| 2350 |
+
|
| 2351 |
+
if exists(random_crop_size):
|
| 2352 |
+
if is_video:
|
| 2353 |
+
frames = x_start.shape[2]
|
| 2354 |
+
x_start, lowres_cond_img, noise = rearrange_many((x_start, lowres_cond_img, noise), 'b c f h w -> (b f) c h w')
|
| 2355 |
+
|
| 2356 |
+
aug = K.RandomCrop(random_crop_size, p = 1.)
|
| 2357 |
+
|
| 2358 |
+
# make sure low res conditioner and image both get augmented the same way
|
| 2359 |
+
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
|
| 2360 |
+
x_start = aug(x_start)
|
| 2361 |
+
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
| 2362 |
+
noise = aug(noise, params = aug._params)
|
| 2363 |
+
|
| 2364 |
+
if is_video:
|
| 2365 |
+
x_start, lowres_cond_img, noise = rearrange_many((x_start, lowres_cond_img, noise), '(b f) c h w -> b c f h w', f = frames)
|
| 2366 |
+
|
| 2367 |
+
# get x_t
|
| 2368 |
+
|
| 2369 |
+
x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
| 2370 |
+
|
| 2371 |
+
# also noise the lowres conditioning image
|
| 2372 |
+
# at sample time, they then fix the noise level of 0.1 - 0.3
|
| 2373 |
+
|
| 2374 |
+
lowres_cond_img_noisy = None
|
| 2375 |
+
if exists(lowres_cond_img):
|
| 2376 |
+
lowres_aug_times = default(lowres_aug_times, times)
|
| 2377 |
+
lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
|
| 2378 |
+
|
| 2379 |
+
# time condition
|
| 2380 |
+
|
| 2381 |
+
noise_cond = noise_scheduler.get_condition(times)
|
| 2382 |
+
|
| 2383 |
+
# unet kwargs
|
| 2384 |
+
|
| 2385 |
+
unet_kwargs = dict(
|
| 2386 |
+
text_embeds = text_embeds,
|
| 2387 |
+
text_mask = text_mask,
|
| 2388 |
+
cond_images = cond_images,
|
| 2389 |
+
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
|
| 2390 |
+
lowres_cond_img = lowres_cond_img_noisy,
|
| 2391 |
+
cond_drop_prob = self.cond_drop_prob,
|
| 2392 |
+
)
|
| 2393 |
+
|
| 2394 |
+
# self condition if needed
|
| 2395 |
+
|
| 2396 |
+
# Because 'unet' can be an instance of DistributedDataParallel coming from the
|
| 2397 |
+
# ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
|
| 2398 |
+
# access the member 'module' of the wrapped unet instance.
|
| 2399 |
+
self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
|
| 2400 |
+
|
| 2401 |
+
if self_cond and random() < 0.5:
|
| 2402 |
+
with torch.no_grad():
|
| 2403 |
+
pred = unet.forward(
|
| 2404 |
+
x_noisy,
|
| 2405 |
+
noise_cond,
|
| 2406 |
+
**unet_kwargs
|
| 2407 |
+
).detach()
|
| 2408 |
+
|
| 2409 |
+
x_start = noise_scheduler.predict_start_from_noise(x_noisy, t = times, noise = pred) if pred_objective == 'noise' else pred
|
| 2410 |
+
|
| 2411 |
+
unet_kwargs = {**unet_kwargs, 'self_cond': x_start}
|
| 2412 |
+
|
| 2413 |
+
# get prediction
|
| 2414 |
+
|
| 2415 |
+
pred = unet.forward(
|
| 2416 |
+
x_noisy,
|
| 2417 |
+
noise_cond,
|
| 2418 |
+
**unet_kwargs
|
| 2419 |
+
)
|
| 2420 |
+
|
| 2421 |
+
# prediction objective
|
| 2422 |
+
|
| 2423 |
+
if pred_objective == 'noise':
|
| 2424 |
+
target = noise
|
| 2425 |
+
elif pred_objective == 'x_start':
|
| 2426 |
+
target = x_start
|
| 2427 |
+
else:
|
| 2428 |
+
raise ValueError(f'unknown objective {pred_objective}')
|
| 2429 |
+
|
| 2430 |
+
# losses
|
| 2431 |
+
|
| 2432 |
+
losses = self.loss_fn(pred, target, reduction = 'none')
|
| 2433 |
+
losses = reduce(losses, 'b ... -> b', 'mean')
|
| 2434 |
+
|
| 2435 |
+
# p2 loss reweighting
|
| 2436 |
+
|
| 2437 |
+
if p2_loss_weight_gamma > 0:
|
| 2438 |
+
loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
|
| 2439 |
+
losses = losses * loss_weight
|
| 2440 |
+
|
| 2441 |
+
return losses.mean()
|
| 2442 |
+
|
| 2443 |
+
def forward(
|
| 2444 |
+
self,
|
| 2445 |
+
images,
|
| 2446 |
+
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
|
| 2447 |
+
texts: List[str] = None,
|
| 2448 |
+
text_embeds = None,
|
| 2449 |
+
text_masks = None,
|
| 2450 |
+
unet_number = None,
|
| 2451 |
+
cond_images = None
|
| 2452 |
+
):
|
| 2453 |
+
# assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
|
| 2454 |
+
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
| 2455 |
+
unet_number = default(unet_number, 1)
|
| 2456 |
+
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
|
| 2457 |
+
|
| 2458 |
+
images = cast_uint8_images_to_float(images)
|
| 2459 |
+
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
|
| 2460 |
+
|
| 2461 |
+
assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
|
| 2462 |
+
|
| 2463 |
+
unet_index = unet_number - 1
|
| 2464 |
+
|
| 2465 |
+
unet = default(unet, lambda: self.get_unet(unet_number))
|
| 2466 |
+
|
| 2467 |
+
assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
|
| 2468 |
+
|
| 2469 |
+
noise_scheduler = self.noise_schedulers[unet_index]
|
| 2470 |
+
p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
|
| 2471 |
+
pred_objective = self.pred_objectives[unet_index]
|
| 2472 |
+
target_image_size = self.image_sizes[unet_index]
|
| 2473 |
+
random_crop_size = self.random_crop_sizes[unet_index]
|
| 2474 |
+
prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
|
| 2475 |
+
|
| 2476 |
+
b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5
|
| 2477 |
+
|
| 2478 |
+
check_shape(images, 'b c ...', c = self.channels)
|
| 2479 |
+
assert h >= target_image_size[0] and w >= target_image_size[1]
|
| 2480 |
+
|
| 2481 |
+
frames = images.shape[2] if is_video else None
|
| 2482 |
+
|
| 2483 |
+
times = noise_scheduler.sample_random_times(b, device = device)
|
| 2484 |
+
|
| 2485 |
+
if exists(texts) and not exists(text_embeds) and not self.unconditional:
|
| 2486 |
+
assert all([*map(len, texts)]), 'text cannot be empty'
|
| 2487 |
+
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
|
| 2488 |
+
|
| 2489 |
+
with autocast(enabled = False):
|
| 2490 |
+
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
|
| 2491 |
+
|
| 2492 |
+
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
|
| 2493 |
+
|
| 2494 |
+
if not self.unconditional:
|
| 2495 |
+
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
|
| 2496 |
+
|
| 2497 |
+
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
|
| 2498 |
+
assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
|
| 2499 |
+
|
| 2500 |
+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
|
| 2501 |
+
|
| 2502 |
+
lowres_cond_img = lowres_aug_times = None
|
| 2503 |
+
if exists(prev_image_size):
|
| 2504 |
+
lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
|
| 2505 |
+
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
|
| 2506 |
+
|
| 2507 |
+
if self.per_sample_random_aug_noise_level:
|
| 2508 |
+
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device)
|
| 2509 |
+
else:
|
| 2510 |
+
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
|
| 2511 |
+
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b)
|
| 2512 |
+
|
| 2513 |
+
images = self.resize_to(images, target_image_size)
|
| 2514 |
+
|
| 2515 |
+
return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size)
|
imagen_pytorch/imagen_video/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from imagen_pytorch.imagen_video.imagen_video import Unet3D
|
imagen_pytorch/imagen_video/imagen_video.py
ADDED
|
@@ -0,0 +1,1662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import copy
|
| 3 |
+
from typing import List
|
| 4 |
+
from tqdm.auto import tqdm
|
| 5 |
+
from functools import partial, wraps
|
| 6 |
+
from contextlib import contextmanager, nullcontext
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn, einsum
|
| 13 |
+
|
| 14 |
+
from einops import rearrange, repeat, reduce
|
| 15 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 16 |
+
from einops_exts import rearrange_many, repeat_many, check_shape
|
| 17 |
+
from einops_exts.torch import EinopsToAndFrom
|
| 18 |
+
|
| 19 |
+
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
|
| 20 |
+
|
| 21 |
+
# helper functions
|
| 22 |
+
|
| 23 |
+
def exists(val):
|
| 24 |
+
return val is not None
|
| 25 |
+
|
| 26 |
+
def identity(t, *args, **kwargs):
|
| 27 |
+
return t
|
| 28 |
+
|
| 29 |
+
def first(arr, d = None):
|
| 30 |
+
if len(arr) == 0:
|
| 31 |
+
return d
|
| 32 |
+
return arr[0]
|
| 33 |
+
|
| 34 |
+
def maybe(fn):
|
| 35 |
+
@wraps(fn)
|
| 36 |
+
def inner(x):
|
| 37 |
+
if not exists(x):
|
| 38 |
+
return x
|
| 39 |
+
return fn(x)
|
| 40 |
+
return inner
|
| 41 |
+
|
| 42 |
+
def once(fn):
|
| 43 |
+
called = False
|
| 44 |
+
@wraps(fn)
|
| 45 |
+
def inner(x):
|
| 46 |
+
nonlocal called
|
| 47 |
+
if called:
|
| 48 |
+
return
|
| 49 |
+
called = True
|
| 50 |
+
return fn(x)
|
| 51 |
+
return inner
|
| 52 |
+
|
| 53 |
+
print_once = once(print)
|
| 54 |
+
|
| 55 |
+
def default(val, d):
|
| 56 |
+
if exists(val):
|
| 57 |
+
return val
|
| 58 |
+
return d() if callable(d) else d
|
| 59 |
+
|
| 60 |
+
def cast_tuple(val, length = None):
|
| 61 |
+
if isinstance(val, list):
|
| 62 |
+
val = tuple(val)
|
| 63 |
+
|
| 64 |
+
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
| 65 |
+
|
| 66 |
+
if exists(length):
|
| 67 |
+
assert len(output) == length
|
| 68 |
+
|
| 69 |
+
return output
|
| 70 |
+
|
| 71 |
+
def cast_uint8_images_to_float(images):
|
| 72 |
+
if not images.dtype == torch.uint8:
|
| 73 |
+
return images
|
| 74 |
+
return images / 255
|
| 75 |
+
|
| 76 |
+
def module_device(module):
|
| 77 |
+
return next(module.parameters()).device
|
| 78 |
+
|
| 79 |
+
def zero_init_(m):
|
| 80 |
+
nn.init.zeros_(m.weight)
|
| 81 |
+
if exists(m.bias):
|
| 82 |
+
nn.init.zeros_(m.bias)
|
| 83 |
+
|
| 84 |
+
def eval_decorator(fn):
|
| 85 |
+
def inner(model, *args, **kwargs):
|
| 86 |
+
was_training = model.training
|
| 87 |
+
model.eval()
|
| 88 |
+
out = fn(model, *args, **kwargs)
|
| 89 |
+
model.train(was_training)
|
| 90 |
+
return out
|
| 91 |
+
return inner
|
| 92 |
+
|
| 93 |
+
def pad_tuple_to_length(t, length, fillvalue = None):
|
| 94 |
+
remain_length = length - len(t)
|
| 95 |
+
if remain_length <= 0:
|
| 96 |
+
return t
|
| 97 |
+
return (*t, *((fillvalue,) * remain_length))
|
| 98 |
+
|
| 99 |
+
# helper classes
|
| 100 |
+
|
| 101 |
+
class Identity(nn.Module):
|
| 102 |
+
def __init__(self, *args, **kwargs):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
def forward(self, x, *args, **kwargs):
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
# tensor helpers
|
| 109 |
+
|
| 110 |
+
def log(t, eps: float = 1e-12):
|
| 111 |
+
return torch.log(t.clamp(min = eps))
|
| 112 |
+
|
| 113 |
+
def l2norm(t):
|
| 114 |
+
return F.normalize(t, dim = -1)
|
| 115 |
+
|
| 116 |
+
def right_pad_dims_to(x, t):
|
| 117 |
+
padding_dims = x.ndim - t.ndim
|
| 118 |
+
if padding_dims <= 0:
|
| 119 |
+
return t
|
| 120 |
+
return t.view(*t.shape, *((1,) * padding_dims))
|
| 121 |
+
|
| 122 |
+
def masked_mean(t, *, dim, mask = None):
|
| 123 |
+
if not exists(mask):
|
| 124 |
+
return t.mean(dim = dim)
|
| 125 |
+
|
| 126 |
+
denom = mask.sum(dim = dim, keepdim = True)
|
| 127 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 128 |
+
masked_t = t.masked_fill(~mask, 0.)
|
| 129 |
+
|
| 130 |
+
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
|
| 131 |
+
|
| 132 |
+
def resize_video_to(
|
| 133 |
+
video,
|
| 134 |
+
target_image_size,
|
| 135 |
+
clamp_range = None
|
| 136 |
+
):
|
| 137 |
+
orig_video_size = video.shape[-1]
|
| 138 |
+
|
| 139 |
+
if orig_video_size == target_image_size:
|
| 140 |
+
return video
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
frames = video.shape[2]
|
| 144 |
+
video = rearrange(video, 'b c f h w -> (b f) c h w')
|
| 145 |
+
|
| 146 |
+
out = F.interpolate(video, target_image_size, mode = 'nearest')
|
| 147 |
+
|
| 148 |
+
if exists(clamp_range):
|
| 149 |
+
out = out.clamp(*clamp_range)
|
| 150 |
+
|
| 151 |
+
out = rearrange(out, '(b f) c h w -> b c f h w', f = frames)
|
| 152 |
+
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
# classifier free guidance functions
|
| 156 |
+
|
| 157 |
+
def prob_mask_like(shape, prob, device):
|
| 158 |
+
if prob == 1:
|
| 159 |
+
return torch.ones(shape, device = device, dtype = torch.bool)
|
| 160 |
+
elif prob == 0:
|
| 161 |
+
return torch.zeros(shape, device = device, dtype = torch.bool)
|
| 162 |
+
else:
|
| 163 |
+
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
| 164 |
+
|
| 165 |
+
# norms and residuals
|
| 166 |
+
|
| 167 |
+
class LayerNorm(nn.Module):
|
| 168 |
+
def __init__(self, dim, stable = False):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.stable = stable
|
| 171 |
+
self.g = nn.Parameter(torch.ones(dim))
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
if self.stable:
|
| 175 |
+
x = x / x.amax(dim = -1, keepdim = True).detach()
|
| 176 |
+
|
| 177 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
| 178 |
+
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
| 179 |
+
mean = torch.mean(x, dim = -1, keepdim = True)
|
| 180 |
+
return (x - mean) * (var + eps).rsqrt() * self.g
|
| 181 |
+
|
| 182 |
+
class ChanLayerNorm(nn.Module):
|
| 183 |
+
def __init__(self, dim, stable = False):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.stable = stable
|
| 186 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
if self.stable:
|
| 190 |
+
x = x / x.amax(dim = 1, keepdim = True).detach()
|
| 191 |
+
|
| 192 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
| 193 |
+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
| 194 |
+
mean = torch.mean(x, dim = 1, keepdim = True)
|
| 195 |
+
return (x - mean) * (var + eps).rsqrt() * self.g
|
| 196 |
+
|
| 197 |
+
class Always():
|
| 198 |
+
def __init__(self, val):
|
| 199 |
+
self.val = val
|
| 200 |
+
|
| 201 |
+
def __call__(self, *args, **kwargs):
|
| 202 |
+
return self.val
|
| 203 |
+
|
| 204 |
+
class Residual(nn.Module):
|
| 205 |
+
def __init__(self, fn):
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.fn = fn
|
| 208 |
+
|
| 209 |
+
def forward(self, x, **kwargs):
|
| 210 |
+
return self.fn(x, **kwargs) + x
|
| 211 |
+
|
| 212 |
+
class Parallel(nn.Module):
|
| 213 |
+
def __init__(self, *fns):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.fns = nn.ModuleList(fns)
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
outputs = [fn(x) for fn in self.fns]
|
| 219 |
+
return sum(outputs)
|
| 220 |
+
|
| 221 |
+
# attention pooling
|
| 222 |
+
|
| 223 |
+
class PerceiverAttention(nn.Module):
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
*,
|
| 227 |
+
dim,
|
| 228 |
+
dim_head = 64,
|
| 229 |
+
heads = 8,
|
| 230 |
+
cosine_sim_attn = False
|
| 231 |
+
):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
|
| 234 |
+
self.cosine_sim_attn = cosine_sim_attn
|
| 235 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
| 236 |
+
|
| 237 |
+
self.heads = heads
|
| 238 |
+
inner_dim = dim_head * heads
|
| 239 |
+
|
| 240 |
+
self.norm = nn.LayerNorm(dim)
|
| 241 |
+
self.norm_latents = nn.LayerNorm(dim)
|
| 242 |
+
|
| 243 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
| 244 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
| 245 |
+
|
| 246 |
+
self.to_out = nn.Sequential(
|
| 247 |
+
nn.Linear(inner_dim, dim, bias = False),
|
| 248 |
+
nn.LayerNorm(dim)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def forward(self, x, latents, mask = None):
|
| 252 |
+
x = self.norm(x)
|
| 253 |
+
latents = self.norm_latents(latents)
|
| 254 |
+
|
| 255 |
+
b, h = x.shape[0], self.heads
|
| 256 |
+
|
| 257 |
+
q = self.to_q(latents)
|
| 258 |
+
|
| 259 |
+
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
|
| 260 |
+
kv_input = torch.cat((x, latents), dim = -2)
|
| 261 |
+
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
|
| 262 |
+
|
| 263 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
| 264 |
+
|
| 265 |
+
q = q * self.scale
|
| 266 |
+
|
| 267 |
+
# cosine sim attention
|
| 268 |
+
|
| 269 |
+
if self.cosine_sim_attn:
|
| 270 |
+
q, k = map(l2norm, (q, k))
|
| 271 |
+
|
| 272 |
+
# similarities and masking
|
| 273 |
+
|
| 274 |
+
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
|
| 275 |
+
|
| 276 |
+
if exists(mask):
|
| 277 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 278 |
+
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
|
| 279 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
| 280 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 281 |
+
|
| 282 |
+
# attention
|
| 283 |
+
|
| 284 |
+
attn = sim.softmax(dim = -1)
|
| 285 |
+
|
| 286 |
+
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
| 287 |
+
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
|
| 288 |
+
return self.to_out(out)
|
| 289 |
+
|
| 290 |
+
class PerceiverResampler(nn.Module):
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
*,
|
| 294 |
+
dim,
|
| 295 |
+
depth,
|
| 296 |
+
dim_head = 64,
|
| 297 |
+
heads = 8,
|
| 298 |
+
num_latents = 64,
|
| 299 |
+
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
|
| 300 |
+
max_seq_len = 512,
|
| 301 |
+
ff_mult = 4,
|
| 302 |
+
cosine_sim_attn = False
|
| 303 |
+
):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.pos_emb = nn.Embedding(max_seq_len, dim)
|
| 306 |
+
|
| 307 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 308 |
+
|
| 309 |
+
self.to_latents_from_mean_pooled_seq = None
|
| 310 |
+
|
| 311 |
+
if num_latents_mean_pooled > 0:
|
| 312 |
+
self.to_latents_from_mean_pooled_seq = nn.Sequential(
|
| 313 |
+
LayerNorm(dim),
|
| 314 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
| 315 |
+
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.layers = nn.ModuleList([])
|
| 319 |
+
for _ in range(depth):
|
| 320 |
+
self.layers.append(nn.ModuleList([
|
| 321 |
+
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
|
| 322 |
+
FeedForward(dim = dim, mult = ff_mult)
|
| 323 |
+
]))
|
| 324 |
+
|
| 325 |
+
def forward(self, x, mask = None):
|
| 326 |
+
n, device = x.shape[1], x.device
|
| 327 |
+
pos_emb = self.pos_emb(torch.arange(n, device = device))
|
| 328 |
+
|
| 329 |
+
x_with_pos = x + pos_emb
|
| 330 |
+
|
| 331 |
+
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
|
| 332 |
+
|
| 333 |
+
if exists(self.to_latents_from_mean_pooled_seq):
|
| 334 |
+
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
|
| 335 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
| 336 |
+
latents = torch.cat((meanpooled_latents, latents), dim = -2)
|
| 337 |
+
|
| 338 |
+
for attn, ff in self.layers:
|
| 339 |
+
latents = attn(x_with_pos, latents, mask = mask) + latents
|
| 340 |
+
latents = ff(latents) + latents
|
| 341 |
+
|
| 342 |
+
return latents
|
| 343 |
+
|
| 344 |
+
# attention
|
| 345 |
+
|
| 346 |
+
class Attention(nn.Module):
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
dim,
|
| 350 |
+
*,
|
| 351 |
+
dim_head = 64,
|
| 352 |
+
heads = 8,
|
| 353 |
+
causal = False,
|
| 354 |
+
context_dim = None,
|
| 355 |
+
cosine_sim_attn = False
|
| 356 |
+
):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
|
| 359 |
+
self.causal = causal
|
| 360 |
+
|
| 361 |
+
self.cosine_sim_attn = cosine_sim_attn
|
| 362 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
| 363 |
+
|
| 364 |
+
self.heads = heads
|
| 365 |
+
inner_dim = dim_head * heads
|
| 366 |
+
|
| 367 |
+
self.norm = LayerNorm(dim)
|
| 368 |
+
|
| 369 |
+
self.null_attn_bias = nn.Parameter(torch.randn(heads))
|
| 370 |
+
|
| 371 |
+
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
| 372 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
| 373 |
+
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
| 374 |
+
|
| 375 |
+
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
|
| 376 |
+
|
| 377 |
+
self.to_out = nn.Sequential(
|
| 378 |
+
nn.Linear(inner_dim, dim, bias = False),
|
| 379 |
+
LayerNorm(dim)
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def forward(self, x, context = None, mask = None, attn_bias = None):
|
| 383 |
+
b, n, device = *x.shape[:2], x.device
|
| 384 |
+
|
| 385 |
+
x = self.norm(x)
|
| 386 |
+
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
| 387 |
+
|
| 388 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
| 389 |
+
q = q * self.scale
|
| 390 |
+
|
| 391 |
+
# add null key / value for classifier free guidance in prior net
|
| 392 |
+
|
| 393 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
|
| 394 |
+
k = torch.cat((nk, k), dim = -2)
|
| 395 |
+
v = torch.cat((nv, v), dim = -2)
|
| 396 |
+
|
| 397 |
+
# add text conditioning, if present
|
| 398 |
+
|
| 399 |
+
if exists(context):
|
| 400 |
+
assert exists(self.to_context)
|
| 401 |
+
ck, cv = self.to_context(context).chunk(2, dim = -1)
|
| 402 |
+
k = torch.cat((ck, k), dim = -2)
|
| 403 |
+
v = torch.cat((cv, v), dim = -2)
|
| 404 |
+
|
| 405 |
+
# cosine sim attention
|
| 406 |
+
|
| 407 |
+
if self.cosine_sim_attn:
|
| 408 |
+
q, k = map(l2norm, (q, k))
|
| 409 |
+
|
| 410 |
+
# calculate query / key similarities
|
| 411 |
+
|
| 412 |
+
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
|
| 413 |
+
|
| 414 |
+
# relative positional encoding (T5 style)
|
| 415 |
+
|
| 416 |
+
if exists(attn_bias):
|
| 417 |
+
null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
|
| 418 |
+
attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
|
| 419 |
+
sim = sim + attn_bias
|
| 420 |
+
|
| 421 |
+
# masking
|
| 422 |
+
|
| 423 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 424 |
+
|
| 425 |
+
if self.causal:
|
| 426 |
+
i, j = sim.shape[-2:]
|
| 427 |
+
causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
| 428 |
+
sim = sim.masked_fill(causal_mask, max_neg_value)
|
| 429 |
+
|
| 430 |
+
if exists(mask):
|
| 431 |
+
mask = F.pad(mask, (1, 0), value = True)
|
| 432 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
| 433 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 434 |
+
|
| 435 |
+
# attention
|
| 436 |
+
|
| 437 |
+
attn = sim.softmax(dim = -1)
|
| 438 |
+
|
| 439 |
+
# aggregate values
|
| 440 |
+
|
| 441 |
+
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
| 442 |
+
|
| 443 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 444 |
+
return self.to_out(out)
|
| 445 |
+
|
| 446 |
+
# pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension
|
| 447 |
+
|
| 448 |
+
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
|
| 449 |
+
kernel = cast_tuple(kernel, 2)
|
| 450 |
+
stride = cast_tuple(stride, 2)
|
| 451 |
+
padding = cast_tuple(padding, 2)
|
| 452 |
+
|
| 453 |
+
if len(kernel) == 2:
|
| 454 |
+
kernel = (1, *kernel)
|
| 455 |
+
|
| 456 |
+
if len(stride) == 2:
|
| 457 |
+
stride = (1, *stride)
|
| 458 |
+
|
| 459 |
+
if len(padding) == 2:
|
| 460 |
+
padding = (0, *padding)
|
| 461 |
+
|
| 462 |
+
return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
|
| 463 |
+
|
| 464 |
+
class Pad(nn.Module):
|
| 465 |
+
def __init__(self, padding, value = 0.):
|
| 466 |
+
super().__init__()
|
| 467 |
+
self.padding = padding
|
| 468 |
+
self.value = value
|
| 469 |
+
|
| 470 |
+
def forward(self, x):
|
| 471 |
+
return F.pad(x, self.padding, value = self.value)
|
| 472 |
+
|
| 473 |
+
# decoder
|
| 474 |
+
|
| 475 |
+
def Upsample(dim, dim_out = None):
|
| 476 |
+
dim_out = default(dim_out, dim)
|
| 477 |
+
|
| 478 |
+
return nn.Sequential(
|
| 479 |
+
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
| 480 |
+
Conv2d(dim, dim_out, 3, padding = 1)
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
class PixelShuffleUpsample(nn.Module):
|
| 484 |
+
def __init__(self, dim, dim_out = None):
|
| 485 |
+
super().__init__()
|
| 486 |
+
dim_out = default(dim_out, dim)
|
| 487 |
+
conv = Conv2d(dim, dim_out * 4, 1)
|
| 488 |
+
|
| 489 |
+
self.net = nn.Sequential(
|
| 490 |
+
conv,
|
| 491 |
+
nn.SiLU()
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
| 495 |
+
|
| 496 |
+
self.init_conv_(conv)
|
| 497 |
+
|
| 498 |
+
def init_conv_(self, conv):
|
| 499 |
+
o, i, f, h, w = conv.weight.shape
|
| 500 |
+
conv_weight = torch.empty(o // 4, i, f, h, w)
|
| 501 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 502 |
+
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
| 503 |
+
|
| 504 |
+
conv.weight.data.copy_(conv_weight)
|
| 505 |
+
nn.init.zeros_(conv.bias.data)
|
| 506 |
+
|
| 507 |
+
def forward(self, x):
|
| 508 |
+
out = self.net(x)
|
| 509 |
+
frames = x.shape[2]
|
| 510 |
+
out = rearrange(out, 'b c f h w -> (b f) c h w')
|
| 511 |
+
out = self.pixel_shuffle(out)
|
| 512 |
+
return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
|
| 513 |
+
|
| 514 |
+
def Downsample(dim, dim_out = None):
|
| 515 |
+
dim_out = default(dim_out, dim)
|
| 516 |
+
return Conv2d(dim, dim_out, 4, 2, 1)
|
| 517 |
+
|
| 518 |
+
class SinusoidalPosEmb(nn.Module):
|
| 519 |
+
def __init__(self, dim):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.dim = dim
|
| 522 |
+
|
| 523 |
+
def forward(self, x):
|
| 524 |
+
half_dim = self.dim // 2
|
| 525 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 526 |
+
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
| 527 |
+
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
| 528 |
+
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
| 529 |
+
|
| 530 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
| 531 |
+
def __init__(self, dim):
|
| 532 |
+
super().__init__()
|
| 533 |
+
assert (dim % 2) == 0
|
| 534 |
+
half_dim = dim // 2
|
| 535 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
| 536 |
+
|
| 537 |
+
def forward(self, x):
|
| 538 |
+
x = rearrange(x, 'b -> b 1')
|
| 539 |
+
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
|
| 540 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
|
| 541 |
+
fouriered = torch.cat((x, fouriered), dim = -1)
|
| 542 |
+
return fouriered
|
| 543 |
+
|
| 544 |
+
class Block(nn.Module):
|
| 545 |
+
def __init__(
|
| 546 |
+
self,
|
| 547 |
+
dim,
|
| 548 |
+
dim_out,
|
| 549 |
+
groups = 8,
|
| 550 |
+
norm = True
|
| 551 |
+
):
|
| 552 |
+
super().__init__()
|
| 553 |
+
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
|
| 554 |
+
self.activation = nn.SiLU()
|
| 555 |
+
self.project = Conv2d(dim, dim_out, 3, padding = 1)
|
| 556 |
+
|
| 557 |
+
def forward(self, x, scale_shift = None):
|
| 558 |
+
x = self.groupnorm(x)
|
| 559 |
+
|
| 560 |
+
if exists(scale_shift):
|
| 561 |
+
scale, shift = scale_shift
|
| 562 |
+
x = x * (scale + 1) + shift
|
| 563 |
+
|
| 564 |
+
x = self.activation(x)
|
| 565 |
+
return self.project(x)
|
| 566 |
+
|
| 567 |
+
class ResnetBlock(nn.Module):
|
| 568 |
+
def __init__(
|
| 569 |
+
self,
|
| 570 |
+
dim,
|
| 571 |
+
dim_out,
|
| 572 |
+
*,
|
| 573 |
+
cond_dim = None,
|
| 574 |
+
time_cond_dim = None,
|
| 575 |
+
groups = 8,
|
| 576 |
+
linear_attn = False,
|
| 577 |
+
use_gca = False,
|
| 578 |
+
squeeze_excite = False,
|
| 579 |
+
**attn_kwargs
|
| 580 |
+
):
|
| 581 |
+
super().__init__()
|
| 582 |
+
|
| 583 |
+
self.time_mlp = None
|
| 584 |
+
|
| 585 |
+
if exists(time_cond_dim):
|
| 586 |
+
self.time_mlp = nn.Sequential(
|
| 587 |
+
nn.SiLU(),
|
| 588 |
+
nn.Linear(time_cond_dim, dim_out * 2)
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
self.cross_attn = None
|
| 592 |
+
|
| 593 |
+
if exists(cond_dim):
|
| 594 |
+
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
|
| 595 |
+
|
| 596 |
+
self.cross_attn = EinopsToAndFrom(
|
| 597 |
+
'b c f h w',
|
| 598 |
+
'b (f h w) c',
|
| 599 |
+
attn_klass(
|
| 600 |
+
dim = dim_out,
|
| 601 |
+
context_dim = cond_dim,
|
| 602 |
+
**attn_kwargs
|
| 603 |
+
)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
self.block1 = Block(dim, dim_out, groups = groups)
|
| 607 |
+
self.block2 = Block(dim_out, dim_out, groups = groups)
|
| 608 |
+
|
| 609 |
+
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
|
| 610 |
+
|
| 611 |
+
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def forward(self, x, time_emb = None, cond = None):
|
| 615 |
+
|
| 616 |
+
scale_shift = None
|
| 617 |
+
if exists(self.time_mlp) and exists(time_emb):
|
| 618 |
+
time_emb = self.time_mlp(time_emb)
|
| 619 |
+
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
|
| 620 |
+
scale_shift = time_emb.chunk(2, dim = 1)
|
| 621 |
+
|
| 622 |
+
h = self.block1(x)
|
| 623 |
+
|
| 624 |
+
if exists(self.cross_attn):
|
| 625 |
+
assert exists(cond)
|
| 626 |
+
h = self.cross_attn(h, context = cond) + h
|
| 627 |
+
|
| 628 |
+
h = self.block2(h, scale_shift = scale_shift)
|
| 629 |
+
|
| 630 |
+
h = h * self.gca(h)
|
| 631 |
+
|
| 632 |
+
return h + self.res_conv(x)
|
| 633 |
+
|
| 634 |
+
class CrossAttention(nn.Module):
|
| 635 |
+
def __init__(
|
| 636 |
+
self,
|
| 637 |
+
dim,
|
| 638 |
+
*,
|
| 639 |
+
context_dim = None,
|
| 640 |
+
dim_head = 64,
|
| 641 |
+
heads = 8,
|
| 642 |
+
norm_context = False,
|
| 643 |
+
cosine_sim_attn = False
|
| 644 |
+
):
|
| 645 |
+
super().__init__()
|
| 646 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
|
| 647 |
+
self.cosine_sim_attn = cosine_sim_attn
|
| 648 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
| 649 |
+
|
| 650 |
+
self.heads = heads
|
| 651 |
+
inner_dim = dim_head * heads
|
| 652 |
+
|
| 653 |
+
context_dim = default(context_dim, dim)
|
| 654 |
+
|
| 655 |
+
self.norm = LayerNorm(dim)
|
| 656 |
+
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
|
| 657 |
+
|
| 658 |
+
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
| 659 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
| 660 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
| 661 |
+
|
| 662 |
+
self.to_out = nn.Sequential(
|
| 663 |
+
nn.Linear(inner_dim, dim, bias = False),
|
| 664 |
+
LayerNorm(dim)
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
def forward(self, x, context, mask = None):
|
| 668 |
+
b, n, device = *x.shape[:2], x.device
|
| 669 |
+
|
| 670 |
+
x = self.norm(x)
|
| 671 |
+
context = self.norm_context(context)
|
| 672 |
+
|
| 673 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
| 674 |
+
|
| 675 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
| 676 |
+
|
| 677 |
+
# add null key / value for classifier free guidance in prior net
|
| 678 |
+
|
| 679 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
| 680 |
+
|
| 681 |
+
k = torch.cat((nk, k), dim = -2)
|
| 682 |
+
v = torch.cat((nv, v), dim = -2)
|
| 683 |
+
|
| 684 |
+
q = q * self.scale
|
| 685 |
+
|
| 686 |
+
# cosine sim attention
|
| 687 |
+
|
| 688 |
+
if self.cosine_sim_attn:
|
| 689 |
+
q, k = map(l2norm, (q, k))
|
| 690 |
+
|
| 691 |
+
# similarities
|
| 692 |
+
|
| 693 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
|
| 694 |
+
|
| 695 |
+
# masking
|
| 696 |
+
|
| 697 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 698 |
+
|
| 699 |
+
if exists(mask):
|
| 700 |
+
mask = F.pad(mask, (1, 0), value = True)
|
| 701 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
| 702 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 703 |
+
|
| 704 |
+
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
| 705 |
+
|
| 706 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
| 707 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 708 |
+
return self.to_out(out)
|
| 709 |
+
|
| 710 |
+
class LinearCrossAttention(CrossAttention):
|
| 711 |
+
def forward(self, x, context, mask = None):
|
| 712 |
+
b, n, device = *x.shape[:2], x.device
|
| 713 |
+
|
| 714 |
+
x = self.norm(x)
|
| 715 |
+
context = self.norm_context(context)
|
| 716 |
+
|
| 717 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
| 718 |
+
|
| 719 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
|
| 720 |
+
|
| 721 |
+
# add null key / value for classifier free guidance in prior net
|
| 722 |
+
|
| 723 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b)
|
| 724 |
+
|
| 725 |
+
k = torch.cat((nk, k), dim = -2)
|
| 726 |
+
v = torch.cat((nv, v), dim = -2)
|
| 727 |
+
|
| 728 |
+
# masking
|
| 729 |
+
|
| 730 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 731 |
+
|
| 732 |
+
if exists(mask):
|
| 733 |
+
mask = F.pad(mask, (1, 0), value = True)
|
| 734 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 735 |
+
k = k.masked_fill(~mask, max_neg_value)
|
| 736 |
+
v = v.masked_fill(~mask, 0.)
|
| 737 |
+
|
| 738 |
+
# linear attention
|
| 739 |
+
|
| 740 |
+
q = q.softmax(dim = -1)
|
| 741 |
+
k = k.softmax(dim = -2)
|
| 742 |
+
|
| 743 |
+
q = q * self.scale
|
| 744 |
+
|
| 745 |
+
context = einsum('b n d, b n e -> b d e', k, v)
|
| 746 |
+
out = einsum('b n d, b d e -> b n e', q, context)
|
| 747 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
|
| 748 |
+
return self.to_out(out)
|
| 749 |
+
|
| 750 |
+
class LinearAttention(nn.Module):
|
| 751 |
+
def __init__(
|
| 752 |
+
self,
|
| 753 |
+
dim,
|
| 754 |
+
dim_head = 32,
|
| 755 |
+
heads = 8,
|
| 756 |
+
dropout = 0.05,
|
| 757 |
+
context_dim = None,
|
| 758 |
+
**kwargs
|
| 759 |
+
):
|
| 760 |
+
super().__init__()
|
| 761 |
+
self.scale = dim_head ** -0.5
|
| 762 |
+
self.heads = heads
|
| 763 |
+
inner_dim = dim_head * heads
|
| 764 |
+
self.norm = ChanLayerNorm(dim)
|
| 765 |
+
|
| 766 |
+
self.nonlin = nn.SiLU()
|
| 767 |
+
|
| 768 |
+
self.to_q = nn.Sequential(
|
| 769 |
+
nn.Dropout(dropout),
|
| 770 |
+
Conv2d(dim, inner_dim, 1, bias = False),
|
| 771 |
+
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
self.to_k = nn.Sequential(
|
| 775 |
+
nn.Dropout(dropout),
|
| 776 |
+
Conv2d(dim, inner_dim, 1, bias = False),
|
| 777 |
+
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
self.to_v = nn.Sequential(
|
| 781 |
+
nn.Dropout(dropout),
|
| 782 |
+
Conv2d(dim, inner_dim, 1, bias = False),
|
| 783 |
+
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
|
| 787 |
+
|
| 788 |
+
self.to_out = nn.Sequential(
|
| 789 |
+
Conv2d(inner_dim, dim, 1, bias = False),
|
| 790 |
+
ChanLayerNorm(dim)
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
def forward(self, fmap, context = None):
|
| 794 |
+
h, x, y = self.heads, *fmap.shape[-2:]
|
| 795 |
+
|
| 796 |
+
fmap = self.norm(fmap)
|
| 797 |
+
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
|
| 798 |
+
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
| 799 |
+
|
| 800 |
+
if exists(context):
|
| 801 |
+
assert exists(self.to_context)
|
| 802 |
+
ck, cv = self.to_context(context).chunk(2, dim = -1)
|
| 803 |
+
ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
|
| 804 |
+
k = torch.cat((k, ck), dim = -2)
|
| 805 |
+
v = torch.cat((v, cv), dim = -2)
|
| 806 |
+
|
| 807 |
+
q = q.softmax(dim = -1)
|
| 808 |
+
k = k.softmax(dim = -2)
|
| 809 |
+
|
| 810 |
+
q = q * self.scale
|
| 811 |
+
|
| 812 |
+
context = einsum('b n d, b n e -> b d e', k, v)
|
| 813 |
+
out = einsum('b n d, b d e -> b n e', q, context)
|
| 814 |
+
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
|
| 815 |
+
|
| 816 |
+
out = self.nonlin(out)
|
| 817 |
+
return self.to_out(out)
|
| 818 |
+
|
| 819 |
+
class GlobalContext(nn.Module):
|
| 820 |
+
""" basically a superior form of squeeze-excitation that is attention-esque """
|
| 821 |
+
|
| 822 |
+
def __init__(
|
| 823 |
+
self,
|
| 824 |
+
*,
|
| 825 |
+
dim_in,
|
| 826 |
+
dim_out
|
| 827 |
+
):
|
| 828 |
+
super().__init__()
|
| 829 |
+
self.to_k = Conv2d(dim_in, 1, 1)
|
| 830 |
+
hidden_dim = max(3, dim_out // 2)
|
| 831 |
+
|
| 832 |
+
self.net = nn.Sequential(
|
| 833 |
+
Conv2d(dim_in, hidden_dim, 1),
|
| 834 |
+
nn.SiLU(),
|
| 835 |
+
Conv2d(hidden_dim, dim_out, 1),
|
| 836 |
+
nn.Sigmoid()
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
def forward(self, x):
|
| 840 |
+
context = self.to_k(x)
|
| 841 |
+
x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
|
| 842 |
+
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
|
| 843 |
+
out = rearrange(out, '... -> ... 1 1')
|
| 844 |
+
return self.net(out)
|
| 845 |
+
|
| 846 |
+
def FeedForward(dim, mult = 2):
|
| 847 |
+
hidden_dim = int(dim * mult)
|
| 848 |
+
return nn.Sequential(
|
| 849 |
+
LayerNorm(dim),
|
| 850 |
+
nn.Linear(dim, hidden_dim, bias = False),
|
| 851 |
+
nn.GELU(),
|
| 852 |
+
LayerNorm(hidden_dim),
|
| 853 |
+
nn.Linear(hidden_dim, dim, bias = False)
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
|
| 857 |
+
hidden_dim = int(dim * mult)
|
| 858 |
+
return nn.Sequential(
|
| 859 |
+
ChanLayerNorm(dim),
|
| 860 |
+
Conv2d(dim, hidden_dim, 1, bias = False),
|
| 861 |
+
nn.GELU(),
|
| 862 |
+
ChanLayerNorm(hidden_dim),
|
| 863 |
+
Conv2d(hidden_dim, dim, 1, bias = False)
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
class TransformerBlock(nn.Module):
|
| 867 |
+
def __init__(
|
| 868 |
+
self,
|
| 869 |
+
dim,
|
| 870 |
+
*,
|
| 871 |
+
depth = 1,
|
| 872 |
+
heads = 8,
|
| 873 |
+
dim_head = 32,
|
| 874 |
+
ff_mult = 2,
|
| 875 |
+
context_dim = None,
|
| 876 |
+
cosine_sim_attn = False
|
| 877 |
+
):
|
| 878 |
+
super().__init__()
|
| 879 |
+
self.layers = nn.ModuleList([])
|
| 880 |
+
|
| 881 |
+
for _ in range(depth):
|
| 882 |
+
self.layers.append(nn.ModuleList([
|
| 883 |
+
EinopsToAndFrom('b c f h w', 'b (f h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
|
| 884 |
+
ChanFeedForward(dim = dim, mult = ff_mult)
|
| 885 |
+
]))
|
| 886 |
+
|
| 887 |
+
def forward(self, x, context = None):
|
| 888 |
+
for attn, ff in self.layers:
|
| 889 |
+
x = attn(x, context = context) + x
|
| 890 |
+
x = ff(x) + x
|
| 891 |
+
return x
|
| 892 |
+
|
| 893 |
+
class LinearAttentionTransformerBlock(nn.Module):
|
| 894 |
+
def __init__(
|
| 895 |
+
self,
|
| 896 |
+
dim,
|
| 897 |
+
*,
|
| 898 |
+
depth = 1,
|
| 899 |
+
heads = 8,
|
| 900 |
+
dim_head = 32,
|
| 901 |
+
ff_mult = 2,
|
| 902 |
+
context_dim = None,
|
| 903 |
+
**kwargs
|
| 904 |
+
):
|
| 905 |
+
super().__init__()
|
| 906 |
+
self.layers = nn.ModuleList([])
|
| 907 |
+
|
| 908 |
+
for _ in range(depth):
|
| 909 |
+
self.layers.append(nn.ModuleList([
|
| 910 |
+
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
|
| 911 |
+
ChanFeedForward(dim = dim, mult = ff_mult)
|
| 912 |
+
]))
|
| 913 |
+
|
| 914 |
+
def forward(self, x, context = None):
|
| 915 |
+
for attn, ff in self.layers:
|
| 916 |
+
x = attn(x, context = context) + x
|
| 917 |
+
x = ff(x) + x
|
| 918 |
+
return x
|
| 919 |
+
|
| 920 |
+
class CrossEmbedLayer(nn.Module):
|
| 921 |
+
def __init__(
|
| 922 |
+
self,
|
| 923 |
+
dim_in,
|
| 924 |
+
kernel_sizes,
|
| 925 |
+
dim_out = None,
|
| 926 |
+
stride = 2
|
| 927 |
+
):
|
| 928 |
+
super().__init__()
|
| 929 |
+
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
| 930 |
+
dim_out = default(dim_out, dim_in)
|
| 931 |
+
|
| 932 |
+
kernel_sizes = sorted(kernel_sizes)
|
| 933 |
+
num_scales = len(kernel_sizes)
|
| 934 |
+
|
| 935 |
+
# calculate the dimension at each scale
|
| 936 |
+
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
| 937 |
+
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
| 938 |
+
|
| 939 |
+
self.convs = nn.ModuleList([])
|
| 940 |
+
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
| 941 |
+
self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
| 942 |
+
|
| 943 |
+
def forward(self, x):
|
| 944 |
+
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
| 945 |
+
return torch.cat(fmaps, dim = 1)
|
| 946 |
+
|
| 947 |
+
class UpsampleCombiner(nn.Module):
|
| 948 |
+
def __init__(
|
| 949 |
+
self,
|
| 950 |
+
dim,
|
| 951 |
+
*,
|
| 952 |
+
enabled = False,
|
| 953 |
+
dim_ins = tuple(),
|
| 954 |
+
dim_outs = tuple()
|
| 955 |
+
):
|
| 956 |
+
super().__init__()
|
| 957 |
+
dim_outs = cast_tuple(dim_outs, len(dim_ins))
|
| 958 |
+
assert len(dim_ins) == len(dim_outs)
|
| 959 |
+
|
| 960 |
+
self.enabled = enabled
|
| 961 |
+
|
| 962 |
+
if not self.enabled:
|
| 963 |
+
self.dim_out = dim
|
| 964 |
+
return
|
| 965 |
+
|
| 966 |
+
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
|
| 967 |
+
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
|
| 968 |
+
|
| 969 |
+
def forward(self, x, fmaps = None):
|
| 970 |
+
target_size = x.shape[-1]
|
| 971 |
+
|
| 972 |
+
fmaps = default(fmaps, tuple())
|
| 973 |
+
|
| 974 |
+
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
|
| 975 |
+
return x
|
| 976 |
+
|
| 977 |
+
fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
|
| 978 |
+
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
|
| 979 |
+
return torch.cat((x, *outs), dim = 1)
|
| 980 |
+
|
| 981 |
+
class DynamicPositionBias(nn.Module):
|
| 982 |
+
def __init__(
|
| 983 |
+
self,
|
| 984 |
+
dim,
|
| 985 |
+
*,
|
| 986 |
+
heads,
|
| 987 |
+
depth
|
| 988 |
+
):
|
| 989 |
+
super().__init__()
|
| 990 |
+
self.mlp = nn.ModuleList([])
|
| 991 |
+
|
| 992 |
+
self.mlp.append(nn.Sequential(
|
| 993 |
+
nn.Linear(1, dim),
|
| 994 |
+
LayerNorm(dim),
|
| 995 |
+
nn.SiLU()
|
| 996 |
+
))
|
| 997 |
+
|
| 998 |
+
for _ in range(max(depth - 1, 0)):
|
| 999 |
+
self.mlp.append(nn.Sequential(
|
| 1000 |
+
nn.Linear(dim, dim),
|
| 1001 |
+
LayerNorm(dim),
|
| 1002 |
+
nn.SiLU()
|
| 1003 |
+
))
|
| 1004 |
+
|
| 1005 |
+
self.mlp.append(nn.Linear(dim, heads))
|
| 1006 |
+
|
| 1007 |
+
def forward(self, n, device, dtype):
|
| 1008 |
+
i = torch.arange(n, device = device)
|
| 1009 |
+
j = torch.arange(n, device = device)
|
| 1010 |
+
|
| 1011 |
+
indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
|
| 1012 |
+
indices += (n - 1)
|
| 1013 |
+
|
| 1014 |
+
pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
|
| 1015 |
+
pos = rearrange(pos, '... -> ... 1')
|
| 1016 |
+
|
| 1017 |
+
for layer in self.mlp:
|
| 1018 |
+
pos = layer(pos)
|
| 1019 |
+
|
| 1020 |
+
bias = pos[indices]
|
| 1021 |
+
bias = rearrange(bias, 'i j h -> h i j')
|
| 1022 |
+
return bias
|
| 1023 |
+
|
| 1024 |
+
class Unet3D(nn.Module):
|
| 1025 |
+
def __init__(
|
| 1026 |
+
self,
|
| 1027 |
+
*,
|
| 1028 |
+
dim,
|
| 1029 |
+
image_embed_dim = 1024,
|
| 1030 |
+
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
|
| 1031 |
+
num_resnet_blocks = 1,
|
| 1032 |
+
cond_dim = None,
|
| 1033 |
+
num_image_tokens = 4,
|
| 1034 |
+
num_time_tokens = 2,
|
| 1035 |
+
learned_sinu_pos_emb_dim = 16,
|
| 1036 |
+
out_dim = None,
|
| 1037 |
+
dim_mults=(1, 2, 4, 8),
|
| 1038 |
+
cond_images_channels = 0,
|
| 1039 |
+
channels = 3,
|
| 1040 |
+
channels_out = None,
|
| 1041 |
+
attn_dim_head = 64,
|
| 1042 |
+
attn_heads = 8,
|
| 1043 |
+
ff_mult = 2.,
|
| 1044 |
+
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
| 1045 |
+
layer_attns = False,
|
| 1046 |
+
layer_attns_depth = 1,
|
| 1047 |
+
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
|
| 1048 |
+
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
| 1049 |
+
time_rel_pos_bias_depth = 2,
|
| 1050 |
+
time_causal_attn = True,
|
| 1051 |
+
layer_cross_attns = True,
|
| 1052 |
+
use_linear_attn = False,
|
| 1053 |
+
use_linear_cross_attn = False,
|
| 1054 |
+
cond_on_text = True,
|
| 1055 |
+
max_text_len = 256,
|
| 1056 |
+
init_dim = None,
|
| 1057 |
+
resnet_groups = 8,
|
| 1058 |
+
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
|
| 1059 |
+
init_cross_embed = True,
|
| 1060 |
+
init_cross_embed_kernel_sizes = (3, 7, 15),
|
| 1061 |
+
cross_embed_downsample = False,
|
| 1062 |
+
cross_embed_downsample_kernel_sizes = (2, 4),
|
| 1063 |
+
attn_pool_text = True,
|
| 1064 |
+
attn_pool_num_latents = 32,
|
| 1065 |
+
dropout = 0.,
|
| 1066 |
+
memory_efficient = False,
|
| 1067 |
+
init_conv_to_final_conv_residual = False,
|
| 1068 |
+
use_global_context_attn = True,
|
| 1069 |
+
scale_skip_connection = True,
|
| 1070 |
+
final_resnet_block = True,
|
| 1071 |
+
final_conv_kernel_size = 3,
|
| 1072 |
+
cosine_sim_attn = False,
|
| 1073 |
+
self_cond = False,
|
| 1074 |
+
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
|
| 1075 |
+
pixel_shuffle_upsample = True # may address checkboard artifacts
|
| 1076 |
+
):
|
| 1077 |
+
super().__init__()
|
| 1078 |
+
|
| 1079 |
+
# guide researchers
|
| 1080 |
+
|
| 1081 |
+
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
|
| 1082 |
+
|
| 1083 |
+
if dim < 128:
|
| 1084 |
+
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
|
| 1085 |
+
|
| 1086 |
+
# save locals to take care of some hyperparameters for cascading DDPM
|
| 1087 |
+
|
| 1088 |
+
self._locals = locals()
|
| 1089 |
+
self._locals.pop('self', None)
|
| 1090 |
+
self._locals.pop('__class__', None)
|
| 1091 |
+
|
| 1092 |
+
self.self_cond = self_cond
|
| 1093 |
+
|
| 1094 |
+
# determine dimensions
|
| 1095 |
+
|
| 1096 |
+
self.channels = channels
|
| 1097 |
+
self.channels_out = default(channels_out, channels)
|
| 1098 |
+
|
| 1099 |
+
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
| 1100 |
+
# (2) in self conditioning, one appends the predict x0 (x_start)
|
| 1101 |
+
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
|
| 1102 |
+
init_dim = default(init_dim, dim)
|
| 1103 |
+
|
| 1104 |
+
# optional image conditioning
|
| 1105 |
+
|
| 1106 |
+
self.has_cond_image = cond_images_channels > 0
|
| 1107 |
+
self.cond_images_channels = cond_images_channels
|
| 1108 |
+
|
| 1109 |
+
init_channels += cond_images_channels
|
| 1110 |
+
|
| 1111 |
+
# initial convolution
|
| 1112 |
+
|
| 1113 |
+
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
| 1114 |
+
|
| 1115 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
| 1116 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
| 1117 |
+
|
| 1118 |
+
# time conditioning
|
| 1119 |
+
|
| 1120 |
+
cond_dim = default(cond_dim, dim)
|
| 1121 |
+
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
|
| 1122 |
+
|
| 1123 |
+
# embedding time for log(snr) noise from continuous version
|
| 1124 |
+
|
| 1125 |
+
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
|
| 1126 |
+
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
|
| 1127 |
+
|
| 1128 |
+
self.to_time_hiddens = nn.Sequential(
|
| 1129 |
+
sinu_pos_emb,
|
| 1130 |
+
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
|
| 1131 |
+
nn.SiLU()
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
self.to_time_cond = nn.Sequential(
|
| 1135 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
# project to time tokens as well as time hiddens
|
| 1139 |
+
|
| 1140 |
+
self.to_time_tokens = nn.Sequential(
|
| 1141 |
+
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
| 1142 |
+
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
# low res aug noise conditioning
|
| 1146 |
+
|
| 1147 |
+
self.lowres_cond = lowres_cond
|
| 1148 |
+
|
| 1149 |
+
if lowres_cond:
|
| 1150 |
+
self.to_lowres_time_hiddens = nn.Sequential(
|
| 1151 |
+
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
|
| 1152 |
+
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
|
| 1153 |
+
nn.SiLU()
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
self.to_lowres_time_cond = nn.Sequential(
|
| 1157 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
self.to_lowres_time_tokens = nn.Sequential(
|
| 1161 |
+
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
| 1162 |
+
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
# normalizations
|
| 1166 |
+
|
| 1167 |
+
self.norm_cond = nn.LayerNorm(cond_dim)
|
| 1168 |
+
|
| 1169 |
+
# text encoding conditioning (optional)
|
| 1170 |
+
|
| 1171 |
+
self.text_to_cond = None
|
| 1172 |
+
|
| 1173 |
+
if cond_on_text:
|
| 1174 |
+
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
|
| 1175 |
+
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
| 1176 |
+
|
| 1177 |
+
# finer control over whether to condition on text encodings
|
| 1178 |
+
|
| 1179 |
+
self.cond_on_text = cond_on_text
|
| 1180 |
+
|
| 1181 |
+
# attention pooling
|
| 1182 |
+
|
| 1183 |
+
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
|
| 1184 |
+
|
| 1185 |
+
# for classifier free guidance
|
| 1186 |
+
|
| 1187 |
+
self.max_text_len = max_text_len
|
| 1188 |
+
|
| 1189 |
+
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
| 1190 |
+
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
|
| 1191 |
+
|
| 1192 |
+
# for non-attention based text conditioning at all points in the network where time is also conditioned
|
| 1193 |
+
|
| 1194 |
+
self.to_text_non_attn_cond = None
|
| 1195 |
+
|
| 1196 |
+
if cond_on_text:
|
| 1197 |
+
self.to_text_non_attn_cond = nn.Sequential(
|
| 1198 |
+
nn.LayerNorm(cond_dim),
|
| 1199 |
+
nn.Linear(cond_dim, time_cond_dim),
|
| 1200 |
+
nn.SiLU(),
|
| 1201 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
# attention related params
|
| 1205 |
+
|
| 1206 |
+
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
|
| 1207 |
+
|
| 1208 |
+
num_layers = len(in_out)
|
| 1209 |
+
|
| 1210 |
+
# temporal attention - attention across video frames
|
| 1211 |
+
|
| 1212 |
+
temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
|
| 1213 |
+
temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))
|
| 1214 |
+
|
| 1215 |
+
temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn})))
|
| 1216 |
+
|
| 1217 |
+
# temporal attention relative positional encoding
|
| 1218 |
+
|
| 1219 |
+
self.time_rel_pos_bias = DynamicPositionBias(dim = dim * 2, heads = attn_heads, depth = time_rel_pos_bias_depth)
|
| 1220 |
+
|
| 1221 |
+
# resnet block klass
|
| 1222 |
+
|
| 1223 |
+
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
|
| 1224 |
+
resnet_groups = cast_tuple(resnet_groups, num_layers)
|
| 1225 |
+
|
| 1226 |
+
resnet_klass = partial(ResnetBlock, **attn_kwargs)
|
| 1227 |
+
|
| 1228 |
+
layer_attns = cast_tuple(layer_attns, num_layers)
|
| 1229 |
+
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
|
| 1230 |
+
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
|
| 1231 |
+
|
| 1232 |
+
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
|
| 1233 |
+
|
| 1234 |
+
# downsample klass
|
| 1235 |
+
|
| 1236 |
+
downsample_klass = Downsample
|
| 1237 |
+
|
| 1238 |
+
if cross_embed_downsample:
|
| 1239 |
+
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
| 1240 |
+
|
| 1241 |
+
# initial resnet block (for memory efficient unet)
|
| 1242 |
+
|
| 1243 |
+
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
|
| 1244 |
+
|
| 1245 |
+
self.init_temporal_peg = temporal_peg(init_dim)
|
| 1246 |
+
self.init_temporal_attn = temporal_attn(init_dim)
|
| 1247 |
+
|
| 1248 |
+
# scale for resnet skip connections
|
| 1249 |
+
|
| 1250 |
+
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
| 1251 |
+
|
| 1252 |
+
# layers
|
| 1253 |
+
|
| 1254 |
+
self.downs = nn.ModuleList([])
|
| 1255 |
+
self.ups = nn.ModuleList([])
|
| 1256 |
+
num_resolutions = len(in_out)
|
| 1257 |
+
|
| 1258 |
+
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns]
|
| 1259 |
+
reversed_layer_params = list(map(reversed, layer_params))
|
| 1260 |
+
|
| 1261 |
+
# downsampling layers
|
| 1262 |
+
|
| 1263 |
+
skip_connect_dims = [] # keep track of skip connection dimensions
|
| 1264 |
+
|
| 1265 |
+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(in_out, *layer_params)):
|
| 1266 |
+
is_last = ind >= (num_resolutions - 1)
|
| 1267 |
+
|
| 1268 |
+
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
|
| 1269 |
+
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
|
| 1270 |
+
|
| 1271 |
+
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
|
| 1272 |
+
|
| 1273 |
+
current_dim = dim_in
|
| 1274 |
+
|
| 1275 |
+
# whether to pre-downsample, from memory efficient unet
|
| 1276 |
+
|
| 1277 |
+
pre_downsample = None
|
| 1278 |
+
|
| 1279 |
+
if memory_efficient:
|
| 1280 |
+
pre_downsample = downsample_klass(dim_in, dim_out)
|
| 1281 |
+
current_dim = dim_out
|
| 1282 |
+
|
| 1283 |
+
skip_connect_dims.append(current_dim)
|
| 1284 |
+
|
| 1285 |
+
# whether to do post-downsample, for non-memory efficient unet
|
| 1286 |
+
|
| 1287 |
+
post_downsample = None
|
| 1288 |
+
if not memory_efficient:
|
| 1289 |
+
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1))
|
| 1290 |
+
|
| 1291 |
+
self.downs.append(nn.ModuleList([
|
| 1292 |
+
pre_downsample,
|
| 1293 |
+
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
|
| 1294 |
+
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
|
| 1295 |
+
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
|
| 1296 |
+
temporal_peg(current_dim),
|
| 1297 |
+
temporal_attn(current_dim),
|
| 1298 |
+
post_downsample
|
| 1299 |
+
]))
|
| 1300 |
+
|
| 1301 |
+
# middle layers
|
| 1302 |
+
|
| 1303 |
+
mid_dim = dims[-1]
|
| 1304 |
+
|
| 1305 |
+
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
| 1306 |
+
self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
| 1307 |
+
self.mid_temporal_peg = temporal_peg(mid_dim)
|
| 1308 |
+
self.mid_temporal_attn = temporal_attn(mid_dim)
|
| 1309 |
+
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
| 1310 |
+
|
| 1311 |
+
# upsample klass
|
| 1312 |
+
|
| 1313 |
+
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
| 1314 |
+
|
| 1315 |
+
# upsampling layers
|
| 1316 |
+
|
| 1317 |
+
upsample_fmap_dims = []
|
| 1318 |
+
|
| 1319 |
+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
|
| 1320 |
+
is_last = ind == (len(in_out) - 1)
|
| 1321 |
+
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
|
| 1322 |
+
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
|
| 1323 |
+
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
|
| 1324 |
+
|
| 1325 |
+
skip_connect_dim = skip_connect_dims.pop()
|
| 1326 |
+
|
| 1327 |
+
upsample_fmap_dims.append(dim_out)
|
| 1328 |
+
|
| 1329 |
+
self.ups.append(nn.ModuleList([
|
| 1330 |
+
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
|
| 1331 |
+
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
|
| 1332 |
+
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
|
| 1333 |
+
temporal_peg(dim_out),
|
| 1334 |
+
temporal_attn(dim_out),
|
| 1335 |
+
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
|
| 1336 |
+
]))
|
| 1337 |
+
|
| 1338 |
+
# whether to combine feature maps from all upsample blocks before final resnet block out
|
| 1339 |
+
|
| 1340 |
+
self.upsample_combiner = UpsampleCombiner(
|
| 1341 |
+
dim = dim,
|
| 1342 |
+
enabled = combine_upsample_fmaps,
|
| 1343 |
+
dim_ins = upsample_fmap_dims,
|
| 1344 |
+
dim_outs = dim
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
# whether to do a final residual from initial conv to the final resnet block out
|
| 1348 |
+
|
| 1349 |
+
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
|
| 1350 |
+
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
|
| 1351 |
+
|
| 1352 |
+
# final optional resnet block and convolution out
|
| 1353 |
+
|
| 1354 |
+
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
|
| 1355 |
+
|
| 1356 |
+
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
|
| 1357 |
+
final_conv_dim_in += (channels if lowres_cond else 0)
|
| 1358 |
+
|
| 1359 |
+
self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
| 1360 |
+
|
| 1361 |
+
zero_init_(self.final_conv)
|
| 1362 |
+
|
| 1363 |
+
# if the current settings for the unet are not correct
|
| 1364 |
+
# for cascading DDPM, then reinit the unet with the right settings
|
| 1365 |
+
def cast_model_parameters(
|
| 1366 |
+
self,
|
| 1367 |
+
*,
|
| 1368 |
+
lowres_cond,
|
| 1369 |
+
text_embed_dim,
|
| 1370 |
+
channels,
|
| 1371 |
+
channels_out,
|
| 1372 |
+
cond_on_text
|
| 1373 |
+
):
|
| 1374 |
+
if lowres_cond == self.lowres_cond and \
|
| 1375 |
+
channels == self.channels and \
|
| 1376 |
+
cond_on_text == self.cond_on_text and \
|
| 1377 |
+
text_embed_dim == self._locals['text_embed_dim'] and \
|
| 1378 |
+
channels_out == self.channels_out:
|
| 1379 |
+
return self
|
| 1380 |
+
|
| 1381 |
+
updated_kwargs = dict(
|
| 1382 |
+
lowres_cond = lowres_cond,
|
| 1383 |
+
text_embed_dim = text_embed_dim,
|
| 1384 |
+
channels = channels,
|
| 1385 |
+
channels_out = channels_out,
|
| 1386 |
+
cond_on_text = cond_on_text
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
return self.__class__(**{**self._locals, **updated_kwargs})
|
| 1390 |
+
|
| 1391 |
+
# methods for returning the full unet config as well as its parameter state
|
| 1392 |
+
|
| 1393 |
+
def to_config_and_state_dict(self):
|
| 1394 |
+
return self._locals, self.state_dict()
|
| 1395 |
+
|
| 1396 |
+
# class method for rehydrating the unet from its config and state dict
|
| 1397 |
+
|
| 1398 |
+
@classmethod
|
| 1399 |
+
def from_config_and_state_dict(klass, config, state_dict):
|
| 1400 |
+
unet = klass(**config)
|
| 1401 |
+
unet.load_state_dict(state_dict)
|
| 1402 |
+
return unet
|
| 1403 |
+
|
| 1404 |
+
# methods for persisting unet to disk
|
| 1405 |
+
|
| 1406 |
+
def persist_to_file(self, path):
|
| 1407 |
+
path = Path(path)
|
| 1408 |
+
path.parents[0].mkdir(exist_ok = True, parents = True)
|
| 1409 |
+
|
| 1410 |
+
config, state_dict = self.to_config_and_state_dict()
|
| 1411 |
+
pkg = dict(config = config, state_dict = state_dict)
|
| 1412 |
+
torch.save(pkg, str(path))
|
| 1413 |
+
|
| 1414 |
+
# class method for rehydrating the unet from file saved with `persist_to_file`
|
| 1415 |
+
|
| 1416 |
+
@classmethod
|
| 1417 |
+
def hydrate_from_file(klass, path):
|
| 1418 |
+
path = Path(path)
|
| 1419 |
+
assert path.exists()
|
| 1420 |
+
pkg = torch.load(str(path))
|
| 1421 |
+
|
| 1422 |
+
assert 'config' in pkg and 'state_dict' in pkg
|
| 1423 |
+
config, state_dict = pkg['config'], pkg['state_dict']
|
| 1424 |
+
|
| 1425 |
+
return Unet.from_config_and_state_dict(config, state_dict)
|
| 1426 |
+
|
| 1427 |
+
# forward with classifier free guidance
|
| 1428 |
+
|
| 1429 |
+
def forward_with_cond_scale(
|
| 1430 |
+
self,
|
| 1431 |
+
*args,
|
| 1432 |
+
cond_scale = 1.,
|
| 1433 |
+
**kwargs
|
| 1434 |
+
):
|
| 1435 |
+
logits = self.forward(*args, **kwargs)
|
| 1436 |
+
|
| 1437 |
+
if cond_scale == 1:
|
| 1438 |
+
return logits
|
| 1439 |
+
|
| 1440 |
+
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
| 1441 |
+
return null_logits + (logits - null_logits) * cond_scale
|
| 1442 |
+
|
| 1443 |
+
def forward(
|
| 1444 |
+
self,
|
| 1445 |
+
x,
|
| 1446 |
+
time,
|
| 1447 |
+
*,
|
| 1448 |
+
lowres_cond_img = None,
|
| 1449 |
+
lowres_noise_times = None,
|
| 1450 |
+
text_embeds = None,
|
| 1451 |
+
text_mask = None,
|
| 1452 |
+
cond_images = None,
|
| 1453 |
+
self_cond = None,
|
| 1454 |
+
cond_drop_prob = 0.
|
| 1455 |
+
):
|
| 1456 |
+
assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)'
|
| 1457 |
+
|
| 1458 |
+
batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype
|
| 1459 |
+
|
| 1460 |
+
# add self conditioning if needed
|
| 1461 |
+
|
| 1462 |
+
if self.self_cond:
|
| 1463 |
+
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
| 1464 |
+
x = torch.cat((x, self_cond), dim = 1)
|
| 1465 |
+
|
| 1466 |
+
# add low resolution conditioning, if present
|
| 1467 |
+
|
| 1468 |
+
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
| 1469 |
+
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
|
| 1470 |
+
|
| 1471 |
+
if exists(lowres_cond_img):
|
| 1472 |
+
x = torch.cat((x, lowres_cond_img), dim = 1)
|
| 1473 |
+
|
| 1474 |
+
# condition on input image
|
| 1475 |
+
|
| 1476 |
+
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
|
| 1477 |
+
|
| 1478 |
+
if exists(cond_images):
|
| 1479 |
+
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
|
| 1480 |
+
cond_images = resize_video_to(cond_images, x.shape[-1])
|
| 1481 |
+
x = torch.cat((cond_images, x), dim = 1)
|
| 1482 |
+
|
| 1483 |
+
# get time relative positions
|
| 1484 |
+
|
| 1485 |
+
time_attn_bias = self.time_rel_pos_bias(frames, device = device, dtype = dtype)
|
| 1486 |
+
|
| 1487 |
+
# initial convolution
|
| 1488 |
+
|
| 1489 |
+
x = self.init_conv(x)
|
| 1490 |
+
|
| 1491 |
+
x = self.init_temporal_peg(x)
|
| 1492 |
+
x = self.init_temporal_attn(x, attn_bias = time_attn_bias)
|
| 1493 |
+
|
| 1494 |
+
# init conv residual
|
| 1495 |
+
|
| 1496 |
+
if self.init_conv_to_final_conv_residual:
|
| 1497 |
+
init_conv_residual = x.clone()
|
| 1498 |
+
|
| 1499 |
+
# time conditioning
|
| 1500 |
+
|
| 1501 |
+
time_hiddens = self.to_time_hiddens(time)
|
| 1502 |
+
|
| 1503 |
+
# derive time tokens
|
| 1504 |
+
|
| 1505 |
+
time_tokens = self.to_time_tokens(time_hiddens)
|
| 1506 |
+
t = self.to_time_cond(time_hiddens)
|
| 1507 |
+
|
| 1508 |
+
# add lowres time conditioning to time hiddens
|
| 1509 |
+
# and add lowres time tokens along sequence dimension for attention
|
| 1510 |
+
|
| 1511 |
+
if self.lowres_cond:
|
| 1512 |
+
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
|
| 1513 |
+
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
|
| 1514 |
+
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
|
| 1515 |
+
|
| 1516 |
+
t = t + lowres_t
|
| 1517 |
+
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
|
| 1518 |
+
|
| 1519 |
+
# text conditioning
|
| 1520 |
+
|
| 1521 |
+
text_tokens = None
|
| 1522 |
+
|
| 1523 |
+
if exists(text_embeds) and self.cond_on_text:
|
| 1524 |
+
|
| 1525 |
+
# conditional dropout
|
| 1526 |
+
|
| 1527 |
+
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
|
| 1528 |
+
|
| 1529 |
+
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
|
| 1530 |
+
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
|
| 1531 |
+
|
| 1532 |
+
# calculate text embeds
|
| 1533 |
+
|
| 1534 |
+
text_tokens = self.text_to_cond(text_embeds)
|
| 1535 |
+
|
| 1536 |
+
text_tokens = text_tokens[:, :self.max_text_len]
|
| 1537 |
+
|
| 1538 |
+
if exists(text_mask):
|
| 1539 |
+
text_mask = text_mask[:, :self.max_text_len]
|
| 1540 |
+
|
| 1541 |
+
text_tokens_len = text_tokens.shape[1]
|
| 1542 |
+
remainder = self.max_text_len - text_tokens_len
|
| 1543 |
+
|
| 1544 |
+
if remainder > 0:
|
| 1545 |
+
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
| 1546 |
+
|
| 1547 |
+
if exists(text_mask):
|
| 1548 |
+
if remainder > 0:
|
| 1549 |
+
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
| 1550 |
+
|
| 1551 |
+
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
| 1552 |
+
text_keep_mask_embed = text_mask & text_keep_mask_embed
|
| 1553 |
+
|
| 1554 |
+
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
| 1555 |
+
|
| 1556 |
+
text_tokens = torch.where(
|
| 1557 |
+
text_keep_mask_embed,
|
| 1558 |
+
text_tokens,
|
| 1559 |
+
null_text_embed
|
| 1560 |
+
)
|
| 1561 |
+
|
| 1562 |
+
if exists(self.attn_pool):
|
| 1563 |
+
text_tokens = self.attn_pool(text_tokens)
|
| 1564 |
+
|
| 1565 |
+
# extra non-attention conditioning by projecting and then summing text embeddings to time
|
| 1566 |
+
# termed as text hiddens
|
| 1567 |
+
|
| 1568 |
+
mean_pooled_text_tokens = text_tokens.mean(dim = -2)
|
| 1569 |
+
|
| 1570 |
+
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
|
| 1571 |
+
|
| 1572 |
+
null_text_hidden = self.null_text_hidden.to(t.dtype)
|
| 1573 |
+
|
| 1574 |
+
text_hiddens = torch.where(
|
| 1575 |
+
text_keep_mask_hidden,
|
| 1576 |
+
text_hiddens,
|
| 1577 |
+
null_text_hidden
|
| 1578 |
+
)
|
| 1579 |
+
|
| 1580 |
+
t = t + text_hiddens
|
| 1581 |
+
|
| 1582 |
+
# main conditioning tokens (c)
|
| 1583 |
+
|
| 1584 |
+
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
|
| 1585 |
+
|
| 1586 |
+
# normalize conditioning tokens
|
| 1587 |
+
|
| 1588 |
+
c = self.norm_cond(c)
|
| 1589 |
+
|
| 1590 |
+
# initial resnet block (for memory efficient unet)
|
| 1591 |
+
|
| 1592 |
+
if exists(self.init_resnet_block):
|
| 1593 |
+
x = self.init_resnet_block(x, t)
|
| 1594 |
+
|
| 1595 |
+
# go through the layers of the unet, down and up
|
| 1596 |
+
|
| 1597 |
+
hiddens = []
|
| 1598 |
+
|
| 1599 |
+
for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, post_downsample in self.downs:
|
| 1600 |
+
if exists(pre_downsample):
|
| 1601 |
+
x = pre_downsample(x)
|
| 1602 |
+
|
| 1603 |
+
x = init_block(x, t, c)
|
| 1604 |
+
|
| 1605 |
+
for resnet_block in resnet_blocks:
|
| 1606 |
+
x = resnet_block(x, t)
|
| 1607 |
+
hiddens.append(x)
|
| 1608 |
+
|
| 1609 |
+
x = attn_block(x, c)
|
| 1610 |
+
x = temporal_peg(x)
|
| 1611 |
+
x = temporal_attn(x, attn_bias = time_attn_bias)
|
| 1612 |
+
|
| 1613 |
+
hiddens.append(x)
|
| 1614 |
+
|
| 1615 |
+
if exists(post_downsample):
|
| 1616 |
+
x = post_downsample(x)
|
| 1617 |
+
|
| 1618 |
+
x = self.mid_block1(x, t, c)
|
| 1619 |
+
|
| 1620 |
+
if exists(self.mid_attn):
|
| 1621 |
+
x = self.mid_attn(x)
|
| 1622 |
+
|
| 1623 |
+
x = self.mid_temporal_peg(x)
|
| 1624 |
+
x = self.mid_temporal_attn(x, attn_bias = time_attn_bias)
|
| 1625 |
+
|
| 1626 |
+
x = self.mid_block2(x, t, c)
|
| 1627 |
+
|
| 1628 |
+
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
| 1629 |
+
|
| 1630 |
+
up_hiddens = []
|
| 1631 |
+
|
| 1632 |
+
for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, upsample in self.ups:
|
| 1633 |
+
x = add_skip_connection(x)
|
| 1634 |
+
x = init_block(x, t, c)
|
| 1635 |
+
|
| 1636 |
+
for resnet_block in resnet_blocks:
|
| 1637 |
+
x = add_skip_connection(x)
|
| 1638 |
+
x = resnet_block(x, t)
|
| 1639 |
+
|
| 1640 |
+
x = attn_block(x, c)
|
| 1641 |
+
x = temporal_peg(x)
|
| 1642 |
+
x = temporal_attn(x, attn_bias = time_attn_bias)
|
| 1643 |
+
|
| 1644 |
+
up_hiddens.append(x.contiguous())
|
| 1645 |
+
x = upsample(x)
|
| 1646 |
+
|
| 1647 |
+
# whether to combine all feature maps from upsample blocks
|
| 1648 |
+
|
| 1649 |
+
x = self.upsample_combiner(x, up_hiddens)
|
| 1650 |
+
|
| 1651 |
+
# final top-most residual if needed
|
| 1652 |
+
|
| 1653 |
+
if self.init_conv_to_final_conv_residual:
|
| 1654 |
+
x = torch.cat((x, init_conv_residual), dim = 1)
|
| 1655 |
+
|
| 1656 |
+
if exists(self.final_res_block):
|
| 1657 |
+
x = self.final_res_block(x, t)
|
| 1658 |
+
|
| 1659 |
+
if exists(lowres_cond_img):
|
| 1660 |
+
x = torch.cat((x, lowres_cond_img), dim = 1)
|
| 1661 |
+
|
| 1662 |
+
return self.final_conv(x)
|
imagen_pytorch/joint_imagen.py
ADDED
|
@@ -0,0 +1,1942 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from contextlib import contextmanager, nullcontext
|
| 3 |
+
from functools import partial
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from random import random
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
|
| 8 |
+
import kornia.augmentation as K
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from einops import rearrange, reduce, repeat
|
| 13 |
+
from einops.layers.torch import Rearrange
|
| 14 |
+
from einops_exts import check_shape, rearrange_many
|
| 15 |
+
from einops_exts.torch import EinopsToAndFrom
|
| 16 |
+
from torch import nn
|
| 17 |
+
from torch.cuda.amp import autocast
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 19 |
+
from tqdm.auto import tqdm
|
| 20 |
+
from torch.special import expm1
|
| 21 |
+
|
| 22 |
+
from imagen_pytorch.imagen_pytorch import (
|
| 23 |
+
Attention, CrossEmbedLayer, Downsample, GaussianDiffusionContinuousTimes,
|
| 24 |
+
Identity, LearnedSinusoidalPosEmb, LinearAttentionTransformerBlock,
|
| 25 |
+
NullUnet, Parallel, PerceiverResampler, PixelShuffleUpsample,
|
| 26 |
+
Residual, ResnetBlock, TransformerBlock, Upsample, UpsampleCombiner,
|
| 27 |
+
cast_tuple, cast_uint8_images_to_float, default, eval_decorator,
|
| 28 |
+
exists, first, identity, is_float_dtype, maybe, module_device,
|
| 29 |
+
normalize_neg_one_to_one, pad_tuple_to_length, print_once, prob_mask_like,
|
| 30 |
+
resize_image_to, right_pad_dims_to, unnormalize_zero_to_one, zero_init_,
|
| 31 |
+
beta_linear_log_snr, alpha_cosine_log_snr, log, log_snr_to_alpha_sigma)
|
| 32 |
+
from imagen_pytorch.imagen_video.imagen_video import Unet3D, resize_video_to
|
| 33 |
+
from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim, t5_encode_text
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def log_1_min_a(a):
|
| 37 |
+
return torch.log(1 - a.exp() + 1e-40)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log_add_exp(a, b):
|
| 41 |
+
maximum = torch.max(a, b)
|
| 42 |
+
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def extract(a, t, x_shape):
|
| 46 |
+
b, *_ = t.shape
|
| 47 |
+
out = a.gather(-1, t)
|
| 48 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def log_categorical(log_x_start, log_prob):
|
| 52 |
+
return (log_x_start.exp() * log_prob).sum(dim=1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def index_to_log_onehot(x, num_classes):
|
| 56 |
+
assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
|
| 57 |
+
if len(x.size()) == 4 and x.size(1) == 1:
|
| 58 |
+
x = x.squeeze(1)
|
| 59 |
+
x_onehot = F.one_hot(x, num_classes)
|
| 60 |
+
permute_order = (0, -1) + tuple(range(1, len(x.size())))
|
| 61 |
+
x_onehot = x_onehot.permute(permute_order)
|
| 62 |
+
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
|
| 63 |
+
return log_x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def log_onehot_to_index(log_x):
|
| 67 |
+
return log_x.argmax(1, keepdims=True)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def sum_except_batch(x, num_dims=1):
|
| 71 |
+
'''
|
| 72 |
+
Sums all dimensions except the first.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
x: Tensor, shape (batch_size, ...)
|
| 76 |
+
num_dims: int, number of batch dims (default=1)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
x_sum: Tensor, shape (batch_size,)
|
| 80 |
+
'''
|
| 81 |
+
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@torch.jit.script
|
| 85 |
+
def alpha_cosine_p_log_snr(t, p: float = 0.8, s: float = 0.008):
|
| 86 |
+
# not sure if this accounts for beta being clipped to 0.999 in discrete version
|
| 87 |
+
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** (-2 * p)) - 1, eps=1e-5)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class MultinomialDiffusion(nn.Module):
|
| 91 |
+
def __init__(self, num_classes, *, noise_schedule, p=1.0, timesteps=1000):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
if noise_schedule == "linear":
|
| 95 |
+
self.log_snr = beta_linear_log_snr
|
| 96 |
+
elif noise_schedule == "cosine":
|
| 97 |
+
self.log_snr = alpha_cosine_log_snr
|
| 98 |
+
elif noise_schedule == "cosine_p":
|
| 99 |
+
self.log_snr = partial(alpha_cosine_p_log_snr, p=p)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f'invalid noise schedule {noise_schedule}')
|
| 102 |
+
|
| 103 |
+
cumprod_alpha = torch.tensor([log_snr_to_alpha_sigma(self.log_snr(
|
| 104 |
+
torch.tensor(t / timesteps)))[0] ** 2 for t in range(timesteps)])
|
| 105 |
+
alphas = cumprod_alpha / F.pad(cumprod_alpha, (1, 0), value=cumprod_alpha[0])[:-1]
|
| 106 |
+
self.register_buffer('log_alpha', torch.log(alphas))
|
| 107 |
+
self.register_buffer('log_1_min_alpha', log_1_min_a(self.log_alpha))
|
| 108 |
+
self.register_buffer('log_cumprod_alpha', torch.cumsum(self.log_alpha, axis=0))
|
| 109 |
+
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_a(self.log_cumprod_alpha))
|
| 110 |
+
self.num_classes = num_classes
|
| 111 |
+
self.num_timesteps = timesteps
|
| 112 |
+
|
| 113 |
+
def change_times_dtype(self, t):
|
| 114 |
+
if t.dtype != torch.int64:
|
| 115 |
+
if ((t < 1) * (t > 0)).any():
|
| 116 |
+
return torch.floor(t * self.num_timesteps).to(torch.int64)
|
| 117 |
+
else:
|
| 118 |
+
return t.to(torch.int64)
|
| 119 |
+
return t
|
| 120 |
+
|
| 121 |
+
def get_times(self, batch_size, noise_level, *, device):
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
return torch.full((batch_size,), noise_level, device=device, dtype=torch.float32)
|
| 124 |
+
|
| 125 |
+
def sample_random_times(self, batch_size, max_thres=0.999, *, device):
|
| 126 |
+
raise NotImplementedError
|
| 127 |
+
return torch.zeros((batch_size,), device=device).float().uniform_(0, max_thres)
|
| 128 |
+
|
| 129 |
+
def get_condition(self, times):
|
| 130 |
+
raise NotImplementedError
|
| 131 |
+
return maybe(self.log_snr)(times)
|
| 132 |
+
|
| 133 |
+
def get_sampling_timesteps(self, batch, *, device):
|
| 134 |
+
raise NotImplementedError
|
| 135 |
+
times = torch.linspace(1., 0., self.num_timesteps + 1, device=device)
|
| 136 |
+
times = repeat(times, 't -> b t', b=batch)
|
| 137 |
+
times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
|
| 138 |
+
times = times.unbind(dim=-1)
|
| 139 |
+
return times
|
| 140 |
+
|
| 141 |
+
def q_pred(self, log_x_start, t):
|
| 142 |
+
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
|
| 143 |
+
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
|
| 144 |
+
log_probs = log_add_exp(
|
| 145 |
+
log_x_start + log_cumprod_alpha_t,
|
| 146 |
+
log_1_min_cumprod_alpha - math.log(self.num_classes)
|
| 147 |
+
)
|
| 148 |
+
return log_probs
|
| 149 |
+
|
| 150 |
+
def log_sample_categorical(self, logits):
|
| 151 |
+
uniform = torch.rand_like(logits)
|
| 152 |
+
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
|
| 153 |
+
sample = (gumbel_noise + logits).argmax(dim=1)
|
| 154 |
+
log_sample = index_to_log_onehot(sample, self.num_classes)
|
| 155 |
+
return log_sample
|
| 156 |
+
|
| 157 |
+
def q_sample(self, log_x_start, t):
|
| 158 |
+
t = self.change_times_dtype(t) # caused by continuous timesteps.
|
| 159 |
+
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
|
| 160 |
+
log_sample = self.log_sample_categorical(log_EV_qxt_x0)
|
| 161 |
+
return log_sample
|
| 162 |
+
|
| 163 |
+
def q_pred_one_timestep(self, log_x_t, t):
|
| 164 |
+
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
|
| 165 |
+
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
|
| 166 |
+
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
|
| 167 |
+
log_probs = log_add_exp(
|
| 168 |
+
log_x_t + log_alpha_t,
|
| 169 |
+
log_1_min_alpha_t - math.log(self.num_classes)
|
| 170 |
+
)
|
| 171 |
+
return log_probs
|
| 172 |
+
|
| 173 |
+
def q_posterior(self, log_x_start, log_x_t, t):
|
| 174 |
+
t = self.change_times_dtype(t) # caused by continuous timesteps.
|
| 175 |
+
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
|
| 176 |
+
# where q(xt | xt-1, x0) = q(xt | xt-1).
|
| 177 |
+
|
| 178 |
+
# EV_log_qxt_x0 = self.q_pred(log_x_start, t)
|
| 179 |
+
|
| 180 |
+
# print('sum exp', EV_log_qxt_x0.exp().sum(1).mean())
|
| 181 |
+
# assert False
|
| 182 |
+
|
| 183 |
+
# log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1)
|
| 184 |
+
|
| 185 |
+
t_minus_1 = t - 1
|
| 186 |
+
# Remove negative values, will not be used anyway for final decoder
|
| 187 |
+
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
|
| 188 |
+
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
|
| 189 |
+
|
| 190 |
+
num_axes = (1,) * (len(log_x_start.size()) - 1)
|
| 191 |
+
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
|
| 192 |
+
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
|
| 193 |
+
|
| 194 |
+
# unnormed_logprobs = log_EV_qxtmin_x0 +
|
| 195 |
+
# log q_pred_one_timestep(x_t, t)
|
| 196 |
+
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
|
| 197 |
+
# Not very easy to see why this is true. But it is :)
|
| 198 |
+
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
|
| 199 |
+
|
| 200 |
+
log_EV_xtmin_given_xt_given_xstart = \
|
| 201 |
+
unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)
|
| 202 |
+
|
| 203 |
+
return log_EV_xtmin_given_xt_given_xstart
|
| 204 |
+
|
| 205 |
+
def q_sample_from_to(self, log_x_from, from_t, to_t):
|
| 206 |
+
shape, device, dtype = log_x_from.shape, log_x_from.device, log_x_from.dtype
|
| 207 |
+
batch = shape[0]
|
| 208 |
+
|
| 209 |
+
if isinstance(from_t, float):
|
| 210 |
+
from_t = torch.full((batch,), from_t, device=device, dtype=dtype)
|
| 211 |
+
|
| 212 |
+
if isinstance(to_t, float):
|
| 213 |
+
to_t = torch.full((batch,), to_t, device=device, dtype=dtype)
|
| 214 |
+
|
| 215 |
+
from_t = self.change_times_dtype(from_t) # caused by continuous timesteps.
|
| 216 |
+
to_t = self.change_times_dtype(to_t) # caused by continuous timesteps.
|
| 217 |
+
|
| 218 |
+
log_cumprod_alpha_to_t = extract(self.log_cumprod_alpha, to_t, log_x_from.shape)
|
| 219 |
+
log_cumprod_alpha_from_t = extract(self.log_cumprod_alpha, from_t, log_x_from.shape)
|
| 220 |
+
log_probs = log_add_exp(
|
| 221 |
+
log_x_from + log_cumprod_alpha_to_t - log_cumprod_alpha_from_t,
|
| 222 |
+
log_1_min_a(log_cumprod_alpha_to_t - log_cumprod_alpha_from_t) - math.log(self.num_classes)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
mask = (to_t == torch.zeros_like(to_t)).float()[:, None, None, None]
|
| 226 |
+
log_sample = index_to_log_onehot(log_probs.argmax(dim=1), self.num_classes) * mask \
|
| 227 |
+
+ self.log_sample_categorical(log_probs) * (1. - mask)
|
| 228 |
+
|
| 229 |
+
return log_sample
|
| 230 |
+
|
| 231 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 232 |
+
raise NotImplementedError
|
| 233 |
+
|
| 234 |
+
# calculate loss
|
| 235 |
+
|
| 236 |
+
def multinomial_kl(self, log_prob1, log_prob2):
|
| 237 |
+
return (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
|
| 238 |
+
|
| 239 |
+
def kl_prior(self, log_x_start):
|
| 240 |
+
b = log_x_start.size(0)
|
| 241 |
+
device = log_x_start.device
|
| 242 |
+
ones = torch.ones(b, device=device).long()
|
| 243 |
+
|
| 244 |
+
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
|
| 245 |
+
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))
|
| 246 |
+
|
| 247 |
+
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
|
| 248 |
+
return sum_except_batch(kl_prior)
|
| 249 |
+
|
| 250 |
+
def loss_fn(self, target_log_lbl, pred_lbl, t, log_lbl):
|
| 251 |
+
t = self.change_times_dtype(t)
|
| 252 |
+
pt = torch.ones_like(t).float() / self.num_timesteps
|
| 253 |
+
|
| 254 |
+
kl = self.multinomial_kl(target_log_lbl, pred_lbl)
|
| 255 |
+
kl = sum_except_batch(kl)
|
| 256 |
+
|
| 257 |
+
decoder_nll = -log_categorical(log_lbl, pred_lbl)
|
| 258 |
+
decoder_nll = sum_except_batch(decoder_nll)
|
| 259 |
+
mask = (t == torch.zeros_like(t)).float()
|
| 260 |
+
kl = mask * decoder_nll + (1. - mask) * kl
|
| 261 |
+
|
| 262 |
+
kl_prior = self.kl_prior(log_lbl)
|
| 263 |
+
vb_loss = kl / pt + kl_prior
|
| 264 |
+
|
| 265 |
+
loss = vb_loss / (math.log(2) * pred_lbl.shape[1:].numel())
|
| 266 |
+
return loss
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LabelEmbedding(nn.Module):
|
| 270 |
+
def __init__(self, num_classes, channels):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.emb_layer = nn.Embedding(num_classes, channels)
|
| 273 |
+
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
assert x.dim() == 4, f'x.shape should be (B, 1, H, W) but {x.shape}'
|
| 276 |
+
assert x.size(1) == 1, f'x.shape should be (B, 1, H, W) but {x.shape}'
|
| 277 |
+
x = self.emb_layer(x.long().squeeze(1))
|
| 278 |
+
x = x.permute(0, 3, 1, 2)
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class JointUnet(nn.Module):
|
| 283 |
+
def __init__(
|
| 284 |
+
self,
|
| 285 |
+
*,
|
| 286 |
+
dim,
|
| 287 |
+
num_classes,
|
| 288 |
+
image_embed_dim=1024,
|
| 289 |
+
text_embed_dim=get_encoded_dim(DEFAULT_T5_NAME),
|
| 290 |
+
num_resnet_blocks=1,
|
| 291 |
+
cond_dim=None,
|
| 292 |
+
num_image_tokens=4,
|
| 293 |
+
num_time_tokens=2,
|
| 294 |
+
learned_sinu_pos_emb_dim=16,
|
| 295 |
+
out_dim=None,
|
| 296 |
+
dim_mults=(1, 2, 4, 8),
|
| 297 |
+
cond_images_channels=0,
|
| 298 |
+
channels=3,
|
| 299 |
+
channels_lbl=3,
|
| 300 |
+
channels_out=None,
|
| 301 |
+
attn_dim_head=64,
|
| 302 |
+
attn_heads=8,
|
| 303 |
+
ff_mult=2.,
|
| 304 |
+
lowres_cond=False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
| 305 |
+
layer_attns=True,
|
| 306 |
+
layer_attns_depth=1,
|
| 307 |
+
# whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
|
| 308 |
+
layer_attns_add_text_cond=True,
|
| 309 |
+
# whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
| 310 |
+
attend_at_middle=True,
|
| 311 |
+
layer_cross_attns=True,
|
| 312 |
+
use_linear_attn=False,
|
| 313 |
+
use_linear_cross_attn=False,
|
| 314 |
+
cond_on_text=True,
|
| 315 |
+
max_text_len=256,
|
| 316 |
+
init_dim=None,
|
| 317 |
+
resnet_groups=8,
|
| 318 |
+
init_conv_kernel_size=7, # kernel size of initial conv, if not using cross embed
|
| 319 |
+
init_cross_embed=True,
|
| 320 |
+
init_cross_embed_kernel_sizes=(3, 7, 15),
|
| 321 |
+
cross_embed_downsample=False,
|
| 322 |
+
cross_embed_downsample_kernel_sizes=(2, 4),
|
| 323 |
+
attn_pool_text=True,
|
| 324 |
+
attn_pool_num_latents=32,
|
| 325 |
+
dropout=0.,
|
| 326 |
+
memory_efficient=False,
|
| 327 |
+
init_conv_to_final_conv_residual=False,
|
| 328 |
+
use_global_context_attn=True,
|
| 329 |
+
scale_skip_connection=True,
|
| 330 |
+
final_resnet_block=True,
|
| 331 |
+
final_conv_kernel_size=3,
|
| 332 |
+
cosine_sim_attn=False,
|
| 333 |
+
self_cond=False,
|
| 334 |
+
combine_upsample_fmaps=False, # combine feature maps from all upsample blocks, used in unet squared successfully
|
| 335 |
+
pixel_shuffle_upsample=True # may address checkboard artifacts
|
| 336 |
+
):
|
| 337 |
+
super().__init__()
|
| 338 |
+
|
| 339 |
+
# guide researchers
|
| 340 |
+
|
| 341 |
+
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
|
| 342 |
+
|
| 343 |
+
if dim < 128:
|
| 344 |
+
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
|
| 345 |
+
|
| 346 |
+
# save locals to take care of some hyperparameters for cascading DDPM
|
| 347 |
+
|
| 348 |
+
self._locals = locals()
|
| 349 |
+
self._locals.pop('self', None)
|
| 350 |
+
self._locals.pop('__class__', None)
|
| 351 |
+
|
| 352 |
+
# determine dimensions
|
| 353 |
+
|
| 354 |
+
self.channels = channels
|
| 355 |
+
self.channels_out = default(channels_out, channels)
|
| 356 |
+
|
| 357 |
+
# label embedding
|
| 358 |
+
|
| 359 |
+
self.num_classes = num_classes
|
| 360 |
+
self.init_emb_seg = LabelEmbedding(self.num_classes, channels_lbl)
|
| 361 |
+
self.init_emb_seg_lowres = LabelEmbedding(self.num_classes, channels_lbl) if lowres_cond else None
|
| 362 |
+
|
| 363 |
+
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
| 364 |
+
# (2) in self conditioning, one appends the predict x0 (x_start)
|
| 365 |
+
# (3) in joint diffusion, label condition appends on image.
|
| 366 |
+
init_channels = (channels + channels_lbl) * (1 + int(lowres_cond) + int(self_cond)) # Joint Imagen
|
| 367 |
+
init_dim = default(init_dim, dim)
|
| 368 |
+
|
| 369 |
+
self.self_cond = self_cond
|
| 370 |
+
if self_cond:
|
| 371 |
+
self.self_cond_lbl_emb = LabelEmbedding(self.num_classes, channels_lbl)
|
| 372 |
+
|
| 373 |
+
# optional image conditioning
|
| 374 |
+
|
| 375 |
+
self.has_cond_image = cond_images_channels > 0
|
| 376 |
+
self.cond_images_channels = cond_images_channels
|
| 377 |
+
|
| 378 |
+
init_channels += cond_images_channels
|
| 379 |
+
|
| 380 |
+
# initial convolution
|
| 381 |
+
|
| 382 |
+
self.init_conv = CrossEmbedLayer(init_channels, dim_out=init_dim, kernel_sizes=init_cross_embed_kernel_sizes, stride=1) if init_cross_embed else nn.Conv2d(
|
| 383 |
+
init_channels, init_dim, init_conv_kernel_size, padding=init_conv_kernel_size // 2)
|
| 384 |
+
|
| 385 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
| 386 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
| 387 |
+
|
| 388 |
+
# time conditioning
|
| 389 |
+
|
| 390 |
+
cond_dim = default(cond_dim, dim)
|
| 391 |
+
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
|
| 392 |
+
|
| 393 |
+
# embedding time for log(snr) noise from continuous version
|
| 394 |
+
|
| 395 |
+
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
|
| 396 |
+
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
|
| 397 |
+
|
| 398 |
+
self.to_time_hiddens = nn.Sequential(
|
| 399 |
+
sinu_pos_emb,
|
| 400 |
+
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
|
| 401 |
+
nn.SiLU()
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
self.to_time_cond = nn.Sequential(
|
| 405 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# project to time tokens as well as time hiddens
|
| 409 |
+
|
| 410 |
+
self.to_time_tokens = nn.Sequential(
|
| 411 |
+
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
| 412 |
+
Rearrange('b (r d) -> b r d', r=num_time_tokens)
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# low res aug noise conditioning
|
| 416 |
+
|
| 417 |
+
self.lowres_cond = lowres_cond
|
| 418 |
+
|
| 419 |
+
if lowres_cond:
|
| 420 |
+
self.to_lowres_time_hiddens = nn.Sequential(
|
| 421 |
+
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
|
| 422 |
+
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
|
| 423 |
+
nn.SiLU()
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
self.to_lowres_time_cond = nn.Sequential(
|
| 427 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.to_lowres_time_tokens = nn.Sequential(
|
| 431 |
+
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
| 432 |
+
Rearrange('b (r d) -> b r d', r=num_time_tokens)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# normalizations
|
| 436 |
+
|
| 437 |
+
self.norm_cond = nn.LayerNorm(cond_dim)
|
| 438 |
+
|
| 439 |
+
# text encoding conditioning (optional)
|
| 440 |
+
|
| 441 |
+
self.text_to_cond = None
|
| 442 |
+
|
| 443 |
+
if cond_on_text:
|
| 444 |
+
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
|
| 445 |
+
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
| 446 |
+
|
| 447 |
+
# finer control over whether to condition on text encodings
|
| 448 |
+
|
| 449 |
+
self.cond_on_text = cond_on_text
|
| 450 |
+
|
| 451 |
+
# attention pooling
|
| 452 |
+
|
| 453 |
+
self.attn_pool = PerceiverResampler(dim=cond_dim, depth=2, dim_head=attn_dim_head, heads=attn_heads,
|
| 454 |
+
num_latents=attn_pool_num_latents, cosine_sim_attn=cosine_sim_attn) if attn_pool_text else None
|
| 455 |
+
|
| 456 |
+
# for classifier free guidance
|
| 457 |
+
|
| 458 |
+
self.max_text_len = max_text_len
|
| 459 |
+
|
| 460 |
+
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
| 461 |
+
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
|
| 462 |
+
|
| 463 |
+
# for non-attention based text conditioning at all points in the network where time is also conditioned
|
| 464 |
+
|
| 465 |
+
self.to_text_non_attn_cond = None
|
| 466 |
+
|
| 467 |
+
if cond_on_text:
|
| 468 |
+
self.to_text_non_attn_cond = nn.Sequential(
|
| 469 |
+
nn.LayerNorm(cond_dim),
|
| 470 |
+
nn.Linear(cond_dim, time_cond_dim),
|
| 471 |
+
nn.SiLU(),
|
| 472 |
+
nn.Linear(time_cond_dim, time_cond_dim)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# attention related params
|
| 476 |
+
|
| 477 |
+
attn_kwargs = dict(heads=attn_heads, dim_head=attn_dim_head, cosine_sim_attn=cosine_sim_attn)
|
| 478 |
+
|
| 479 |
+
num_layers = len(in_out)
|
| 480 |
+
|
| 481 |
+
# resnet block klass
|
| 482 |
+
|
| 483 |
+
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
|
| 484 |
+
resnet_groups = cast_tuple(resnet_groups, num_layers)
|
| 485 |
+
|
| 486 |
+
resnet_klass = partial(ResnetBlock, **attn_kwargs)
|
| 487 |
+
|
| 488 |
+
layer_attns = cast_tuple(layer_attns, num_layers)
|
| 489 |
+
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
|
| 490 |
+
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
|
| 491 |
+
|
| 492 |
+
use_linear_attn = cast_tuple(use_linear_attn, num_layers)
|
| 493 |
+
use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
|
| 494 |
+
|
| 495 |
+
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
|
| 496 |
+
|
| 497 |
+
# downsample klass
|
| 498 |
+
|
| 499 |
+
downsample_klass = Downsample
|
| 500 |
+
|
| 501 |
+
if cross_embed_downsample:
|
| 502 |
+
downsample_klass = partial(CrossEmbedLayer, kernel_sizes=cross_embed_downsample_kernel_sizes)
|
| 503 |
+
|
| 504 |
+
# initial resnet block (for memory efficient unet)
|
| 505 |
+
|
| 506 |
+
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim=time_cond_dim,
|
| 507 |
+
groups=resnet_groups[0], use_gca=use_global_context_attn) if memory_efficient else None
|
| 508 |
+
|
| 509 |
+
# scale for resnet skip connections
|
| 510 |
+
|
| 511 |
+
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
| 512 |
+
|
| 513 |
+
# layers
|
| 514 |
+
|
| 515 |
+
self.downs = nn.ModuleList([])
|
| 516 |
+
self.ups = nn.ModuleList([])
|
| 517 |
+
num_resolutions = len(in_out)
|
| 518 |
+
|
| 519 |
+
layer_params = [num_resnet_blocks, resnet_groups, layer_attns,
|
| 520 |
+
layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
|
| 521 |
+
reversed_layer_params = list(map(reversed, layer_params))
|
| 522 |
+
|
| 523 |
+
# downsampling layers
|
| 524 |
+
|
| 525 |
+
skip_connect_dims = [] # keep track of skip connection dimensions
|
| 526 |
+
|
| 527 |
+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
|
| 528 |
+
is_last = ind >= (num_resolutions - 1)
|
| 529 |
+
|
| 530 |
+
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
|
| 531 |
+
|
| 532 |
+
if layer_attn:
|
| 533 |
+
transformer_block_klass = TransformerBlock
|
| 534 |
+
elif layer_use_linear_attn:
|
| 535 |
+
transformer_block_klass = LinearAttentionTransformerBlock
|
| 536 |
+
else:
|
| 537 |
+
transformer_block_klass = Identity
|
| 538 |
+
|
| 539 |
+
current_dim = dim_in
|
| 540 |
+
|
| 541 |
+
# whether to pre-downsample, from memory efficient unet
|
| 542 |
+
|
| 543 |
+
pre_downsample = None
|
| 544 |
+
|
| 545 |
+
if memory_efficient:
|
| 546 |
+
pre_downsample = downsample_klass(dim_in, dim_out)
|
| 547 |
+
current_dim = dim_out
|
| 548 |
+
|
| 549 |
+
skip_connect_dims.append(current_dim)
|
| 550 |
+
|
| 551 |
+
# whether to do post-downsample, for non-memory efficient unet
|
| 552 |
+
|
| 553 |
+
post_downsample = None
|
| 554 |
+
if not memory_efficient:
|
| 555 |
+
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(
|
| 556 |
+
nn.Conv2d(dim_in, dim_out, 3, padding=1), nn.Conv2d(dim_in, dim_out, 1))
|
| 557 |
+
|
| 558 |
+
self.downs.append(nn.ModuleList([
|
| 559 |
+
pre_downsample,
|
| 560 |
+
resnet_klass(current_dim, current_dim, cond_dim=layer_cond_dim,
|
| 561 |
+
linear_attn=layer_use_linear_cross_attn, time_cond_dim=time_cond_dim, groups=groups),
|
| 562 |
+
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim=time_cond_dim,
|
| 563 |
+
groups=groups, use_gca=use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
|
| 564 |
+
transformer_block_klass(dim=current_dim, depth=layer_attn_depth,
|
| 565 |
+
ff_mult=ff_mult, context_dim=cond_dim, **attn_kwargs),
|
| 566 |
+
post_downsample
|
| 567 |
+
]))
|
| 568 |
+
|
| 569 |
+
# middle layers
|
| 570 |
+
|
| 571 |
+
mid_dim = dims[-1]
|
| 572 |
+
|
| 573 |
+
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim=cond_dim,
|
| 574 |
+
time_cond_dim=time_cond_dim, groups=resnet_groups[-1])
|
| 575 |
+
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(
|
| 576 |
+
Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
| 577 |
+
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim=cond_dim,
|
| 578 |
+
time_cond_dim=time_cond_dim, groups=resnet_groups[-1])
|
| 579 |
+
|
| 580 |
+
# upsample klass
|
| 581 |
+
|
| 582 |
+
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
| 583 |
+
|
| 584 |
+
# upsampling layers
|
| 585 |
+
|
| 586 |
+
upsample_fmap_dims = []
|
| 587 |
+
|
| 588 |
+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
|
| 589 |
+
is_last = ind == (len(in_out) - 1)
|
| 590 |
+
|
| 591 |
+
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
|
| 592 |
+
|
| 593 |
+
if layer_attn:
|
| 594 |
+
transformer_block_klass = TransformerBlock
|
| 595 |
+
elif layer_use_linear_attn:
|
| 596 |
+
transformer_block_klass = LinearAttentionTransformerBlock
|
| 597 |
+
else:
|
| 598 |
+
transformer_block_klass = Identity
|
| 599 |
+
|
| 600 |
+
skip_connect_dim = skip_connect_dims.pop()
|
| 601 |
+
|
| 602 |
+
upsample_fmap_dims.append(dim_out)
|
| 603 |
+
|
| 604 |
+
self.ups.append(nn.ModuleList([
|
| 605 |
+
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim=layer_cond_dim,
|
| 606 |
+
linear_attn=layer_use_linear_cross_attn, time_cond_dim=time_cond_dim, groups=groups),
|
| 607 |
+
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim=time_cond_dim,
|
| 608 |
+
groups=groups, use_gca=use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
|
| 609 |
+
transformer_block_klass(dim=dim_out, depth=layer_attn_depth, ff_mult=ff_mult,
|
| 610 |
+
context_dim=cond_dim, **attn_kwargs),
|
| 611 |
+
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity(),
|
| 612 |
+
]))
|
| 613 |
+
|
| 614 |
+
# whether to combine feature maps from all upsample blocks before final resnet block out
|
| 615 |
+
|
| 616 |
+
self.upsample_combiner = UpsampleCombiner(
|
| 617 |
+
dim=dim,
|
| 618 |
+
enabled=combine_upsample_fmaps,
|
| 619 |
+
dim_ins=upsample_fmap_dims,
|
| 620 |
+
dim_outs=dim
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# whether to do a final residual from initial conv to the final resnet block out
|
| 624 |
+
|
| 625 |
+
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
|
| 626 |
+
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
|
| 627 |
+
|
| 628 |
+
# final optional resnet block and convolution out
|
| 629 |
+
|
| 630 |
+
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim=time_cond_dim,
|
| 631 |
+
groups=resnet_groups[0], use_gca=True) if final_resnet_block else None
|
| 632 |
+
|
| 633 |
+
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
|
| 634 |
+
final_conv_dim_in += (channels + channels_lbl) if lowres_cond else 0
|
| 635 |
+
|
| 636 |
+
self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out,
|
| 637 |
+
final_conv_kernel_size, padding=final_conv_kernel_size // 2)
|
| 638 |
+
self.final_conv_seg = nn.Conv2d(final_conv_dim_in, self.num_classes,
|
| 639 |
+
final_conv_kernel_size, padding=final_conv_kernel_size // 2)
|
| 640 |
+
|
| 641 |
+
zero_init_(self.final_conv)
|
| 642 |
+
zero_init_(self.final_conv_seg)
|
| 643 |
+
|
| 644 |
+
# if the current settings for the unet are not correct
|
| 645 |
+
# for cascading DDPM, then reinit the unet with the right settings
|
| 646 |
+
def cast_model_parameters(
|
| 647 |
+
self,
|
| 648 |
+
*,
|
| 649 |
+
lowres_cond,
|
| 650 |
+
text_embed_dim,
|
| 651 |
+
channels,
|
| 652 |
+
channels_out,
|
| 653 |
+
cond_on_text
|
| 654 |
+
):
|
| 655 |
+
if lowres_cond == self.lowres_cond and \
|
| 656 |
+
channels == self.channels and \
|
| 657 |
+
cond_on_text == self.cond_on_text and \
|
| 658 |
+
text_embed_dim == self._locals['text_embed_dim'] and \
|
| 659 |
+
channels_out == self.channels_out:
|
| 660 |
+
return self
|
| 661 |
+
|
| 662 |
+
updated_kwargs = dict(
|
| 663 |
+
lowres_cond=lowres_cond,
|
| 664 |
+
text_embed_dim=text_embed_dim,
|
| 665 |
+
channels=channels,
|
| 666 |
+
channels_out=channels_out,
|
| 667 |
+
cond_on_text=cond_on_text
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
return self.__class__(**{**self._locals, **updated_kwargs})
|
| 671 |
+
|
| 672 |
+
# methods for returning the full unet config as well as its parameter state
|
| 673 |
+
|
| 674 |
+
def to_config_and_state_dict(self):
|
| 675 |
+
return self._locals, self.state_dict()
|
| 676 |
+
|
| 677 |
+
# class method for rehydrating the unet from its config and state dict
|
| 678 |
+
|
| 679 |
+
@classmethod
|
| 680 |
+
def from_config_and_state_dict(klass, config, state_dict):
|
| 681 |
+
unet = klass(**config)
|
| 682 |
+
unet.load_state_dict(state_dict)
|
| 683 |
+
return unet
|
| 684 |
+
|
| 685 |
+
# methods for persisting unet to disk
|
| 686 |
+
|
| 687 |
+
def persist_to_file(self, path):
|
| 688 |
+
path = Path(path)
|
| 689 |
+
path.parents[0].mkdir(exist_ok=True, parents=True)
|
| 690 |
+
|
| 691 |
+
config, state_dict = self.to_config_and_state_dict()
|
| 692 |
+
pkg = dict(config=config, state_dict=state_dict)
|
| 693 |
+
torch.save(pkg, str(path))
|
| 694 |
+
|
| 695 |
+
# class method for rehydrating the unet from file saved with `persist_to_file`
|
| 696 |
+
|
| 697 |
+
@classmethod
|
| 698 |
+
def hydrate_from_file(klass, path):
|
| 699 |
+
path = Path(path)
|
| 700 |
+
assert path.exists()
|
| 701 |
+
pkg = torch.load(str(path))
|
| 702 |
+
|
| 703 |
+
assert 'config' in pkg and 'state_dict' in pkg
|
| 704 |
+
config, state_dict = pkg['config'], pkg['state_dict']
|
| 705 |
+
|
| 706 |
+
return JointUnet.from_config_and_state_dict(config, state_dict)
|
| 707 |
+
|
| 708 |
+
# forward with classifier free guidance
|
| 709 |
+
|
| 710 |
+
def forward_with_cond_scale(
|
| 711 |
+
self,
|
| 712 |
+
*args,
|
| 713 |
+
cond_scale=1.,
|
| 714 |
+
**kwargs
|
| 715 |
+
):
|
| 716 |
+
logits, logits_seg = self.forward(*args, **kwargs)
|
| 717 |
+
|
| 718 |
+
if cond_scale == 1:
|
| 719 |
+
return logits, logits_seg
|
| 720 |
+
|
| 721 |
+
null_logits, null_logits_seg = self.forward(*args, cond_drop_prob=1., **kwargs)
|
| 722 |
+
|
| 723 |
+
cond_logits = null_logits + (logits - null_logits) * cond_scale
|
| 724 |
+
|
| 725 |
+
# TODO: CFG of categorical is not clear.
|
| 726 |
+
cond_logits_seg = null_logits_seg + (logits_seg - null_logits_seg) * cond_scale
|
| 727 |
+
|
| 728 |
+
return cond_logits, cond_logits_seg
|
| 729 |
+
|
| 730 |
+
def forward(
|
| 731 |
+
self,
|
| 732 |
+
x,
|
| 733 |
+
lbl,
|
| 734 |
+
time,
|
| 735 |
+
*,
|
| 736 |
+
lowres_cond_img=None,
|
| 737 |
+
lowres_cond_lbl=None,
|
| 738 |
+
lowres_noise_times=None,
|
| 739 |
+
text_embeds=None,
|
| 740 |
+
text_mask=None,
|
| 741 |
+
cond_images=None,
|
| 742 |
+
self_cond=None,
|
| 743 |
+
self_cond_lbl=None,
|
| 744 |
+
cond_drop_prob=0.
|
| 745 |
+
):
|
| 746 |
+
batch_size, device = x.shape[0], x.device
|
| 747 |
+
|
| 748 |
+
# joint imagen
|
| 749 |
+
|
| 750 |
+
lbl = self.init_emb_seg(lbl.long())
|
| 751 |
+
x = torch.cat((x, lbl), dim=1)
|
| 752 |
+
|
| 753 |
+
# condition on self
|
| 754 |
+
|
| 755 |
+
if self.self_cond:
|
| 756 |
+
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
| 757 |
+
if self_cond_lbl is None:
|
| 758 |
+
self_cond_lbl = torch.zeros_like(lbl)
|
| 759 |
+
else:
|
| 760 |
+
self_cond_lbl = self.self_cond_lbl_emb(self_cond_lbl.long())
|
| 761 |
+
x = torch.cat((x, self_cond, self_cond_lbl), dim=1)
|
| 762 |
+
|
| 763 |
+
# add low resolution conditioning, if present
|
| 764 |
+
|
| 765 |
+
assert not (self.lowres_cond and not exists(lowres_cond_img)
|
| 766 |
+
), 'low resolution conditioning image must be present'
|
| 767 |
+
assert not (self.lowres_cond and not exists(lowres_noise_times)
|
| 768 |
+
), 'low resolution conditioning noise time must be present'
|
| 769 |
+
|
| 770 |
+
if exists(lowres_cond_img) and exists(lowres_cond_lbl):
|
| 771 |
+
lowres_cond_lbl = self.init_emb_seg_lowres(lowres_cond_lbl.long())
|
| 772 |
+
x = torch.cat((x, lowres_cond_img, lowres_cond_lbl), dim=1)
|
| 773 |
+
|
| 774 |
+
# condition on input image
|
| 775 |
+
|
| 776 |
+
assert not (self.has_cond_image ^ exists(cond_images)), \
|
| 777 |
+
'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
|
| 778 |
+
|
| 779 |
+
if exists(cond_images):
|
| 780 |
+
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
|
| 781 |
+
cond_images = resize_image_to(cond_images, x.shape[-1])
|
| 782 |
+
x = torch.cat((cond_images, x), dim=1)
|
| 783 |
+
|
| 784 |
+
# initial convolution
|
| 785 |
+
|
| 786 |
+
x = self.init_conv(x)
|
| 787 |
+
|
| 788 |
+
# init conv residual
|
| 789 |
+
|
| 790 |
+
if self.init_conv_to_final_conv_residual:
|
| 791 |
+
init_conv_residual = x.clone()
|
| 792 |
+
|
| 793 |
+
# time conditioning
|
| 794 |
+
|
| 795 |
+
time_hiddens = self.to_time_hiddens(time)
|
| 796 |
+
|
| 797 |
+
# derive time tokens
|
| 798 |
+
|
| 799 |
+
time_tokens = self.to_time_tokens(time_hiddens)
|
| 800 |
+
t = self.to_time_cond(time_hiddens)
|
| 801 |
+
|
| 802 |
+
# add lowres time conditioning to time hiddens
|
| 803 |
+
# and add lowres time tokens along sequence dimension for attention
|
| 804 |
+
|
| 805 |
+
if self.lowres_cond:
|
| 806 |
+
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
|
| 807 |
+
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
|
| 808 |
+
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
|
| 809 |
+
|
| 810 |
+
t = t + lowres_t
|
| 811 |
+
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim=-2)
|
| 812 |
+
|
| 813 |
+
# text conditioning
|
| 814 |
+
|
| 815 |
+
text_tokens = None
|
| 816 |
+
|
| 817 |
+
if exists(text_embeds) and self.cond_on_text:
|
| 818 |
+
|
| 819 |
+
# conditional dropout
|
| 820 |
+
|
| 821 |
+
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
|
| 822 |
+
|
| 823 |
+
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
|
| 824 |
+
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
|
| 825 |
+
|
| 826 |
+
# calculate text embeds
|
| 827 |
+
|
| 828 |
+
text_tokens = self.text_to_cond(text_embeds)
|
| 829 |
+
|
| 830 |
+
text_tokens = text_tokens[:, :self.max_text_len]
|
| 831 |
+
|
| 832 |
+
if exists(text_mask):
|
| 833 |
+
text_mask = text_mask[:, :self.max_text_len]
|
| 834 |
+
|
| 835 |
+
text_tokens_len = text_tokens.shape[1]
|
| 836 |
+
remainder = self.max_text_len - text_tokens_len
|
| 837 |
+
|
| 838 |
+
if remainder > 0:
|
| 839 |
+
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
| 840 |
+
|
| 841 |
+
if exists(text_mask):
|
| 842 |
+
if remainder > 0:
|
| 843 |
+
text_mask = F.pad(text_mask, (0, remainder), value=False)
|
| 844 |
+
|
| 845 |
+
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
| 846 |
+
text_keep_mask_embed = text_mask & text_keep_mask_embed
|
| 847 |
+
|
| 848 |
+
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
| 849 |
+
|
| 850 |
+
text_tokens = torch.where(
|
| 851 |
+
text_keep_mask_embed,
|
| 852 |
+
text_tokens,
|
| 853 |
+
null_text_embed
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
if exists(self.attn_pool):
|
| 857 |
+
text_tokens = self.attn_pool(text_tokens)
|
| 858 |
+
|
| 859 |
+
# extra non-attention conditioning by projecting and then summing text embeddings to time
|
| 860 |
+
# termed as text hiddens
|
| 861 |
+
|
| 862 |
+
mean_pooled_text_tokens = text_tokens.mean(dim=-2)
|
| 863 |
+
|
| 864 |
+
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
|
| 865 |
+
|
| 866 |
+
null_text_hidden = self.null_text_hidden.to(t.dtype)
|
| 867 |
+
|
| 868 |
+
text_hiddens = torch.where(
|
| 869 |
+
text_keep_mask_hidden,
|
| 870 |
+
text_hiddens,
|
| 871 |
+
null_text_hidden
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
t = t + text_hiddens
|
| 875 |
+
|
| 876 |
+
# main conditioning tokens (c)
|
| 877 |
+
|
| 878 |
+
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim=-2)
|
| 879 |
+
|
| 880 |
+
# normalize conditioning tokens
|
| 881 |
+
|
| 882 |
+
c = self.norm_cond(c)
|
| 883 |
+
|
| 884 |
+
# initial resnet block (for memory efficient unet)
|
| 885 |
+
|
| 886 |
+
if exists(self.init_resnet_block):
|
| 887 |
+
x = self.init_resnet_block(x, t)
|
| 888 |
+
|
| 889 |
+
# go through the layers of the unet, down and up
|
| 890 |
+
|
| 891 |
+
hiddens = []
|
| 892 |
+
|
| 893 |
+
for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
|
| 894 |
+
if exists(pre_downsample):
|
| 895 |
+
x = pre_downsample(x)
|
| 896 |
+
|
| 897 |
+
x = init_block(x, t, c)
|
| 898 |
+
|
| 899 |
+
for resnet_block in resnet_blocks:
|
| 900 |
+
x = resnet_block(x, t)
|
| 901 |
+
hiddens.append(x)
|
| 902 |
+
|
| 903 |
+
x = attn_block(x, c)
|
| 904 |
+
hiddens.append(x)
|
| 905 |
+
|
| 906 |
+
if exists(post_downsample):
|
| 907 |
+
x = post_downsample(x)
|
| 908 |
+
|
| 909 |
+
x = self.mid_block1(x, t, c)
|
| 910 |
+
|
| 911 |
+
if exists(self.mid_attn):
|
| 912 |
+
x = self.mid_attn(x)
|
| 913 |
+
|
| 914 |
+
x = self.mid_block2(x, t, c)
|
| 915 |
+
|
| 916 |
+
def add_skip_connection(x): return torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim=1)
|
| 917 |
+
|
| 918 |
+
up_hiddens = []
|
| 919 |
+
|
| 920 |
+
for init_block, resnet_blocks, attn_block, upsample in self.ups:
|
| 921 |
+
x = add_skip_connection(x)
|
| 922 |
+
x = init_block(x, t, c)
|
| 923 |
+
|
| 924 |
+
for resnet_block in resnet_blocks:
|
| 925 |
+
x = add_skip_connection(x)
|
| 926 |
+
x = resnet_block(x, t)
|
| 927 |
+
|
| 928 |
+
x = attn_block(x, c)
|
| 929 |
+
up_hiddens.append(x.contiguous())
|
| 930 |
+
|
| 931 |
+
x = upsample(x)
|
| 932 |
+
|
| 933 |
+
# whether to combine all feature maps from upsample blocks
|
| 934 |
+
|
| 935 |
+
x = self.upsample_combiner(x, up_hiddens)
|
| 936 |
+
|
| 937 |
+
# final top-most residual if needed
|
| 938 |
+
|
| 939 |
+
if self.init_conv_to_final_conv_residual:
|
| 940 |
+
x = torch.cat((x, init_conv_residual), dim=1)
|
| 941 |
+
|
| 942 |
+
if exists(self.final_res_block):
|
| 943 |
+
x = self.final_res_block(x, t)
|
| 944 |
+
|
| 945 |
+
if exists(lowres_cond_img) and exists(lowres_cond_lbl):
|
| 946 |
+
x = torch.cat((x, lowres_cond_img, lowres_cond_lbl), dim=1)
|
| 947 |
+
|
| 948 |
+
return self.final_conv(x), self.final_conv_seg(x)
|
| 949 |
+
|
| 950 |
+
# predefined unets, with configs lining up with hyperparameters in appendix of paper
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
class BaseJointUnet(JointUnet):
|
| 954 |
+
def __init__(self, *args, **kwargs):
|
| 955 |
+
default_kwargs = dict(
|
| 956 |
+
dim=128,
|
| 957 |
+
dim_mults=(1, 2, 4, 8),
|
| 958 |
+
num_resnet_blocks=(2, 4, 8, 8),
|
| 959 |
+
layer_attns=(False, False, False, True),
|
| 960 |
+
layer_cross_attns=(False, False, False, True),
|
| 961 |
+
attn_heads=8,
|
| 962 |
+
ff_mult=2.,
|
| 963 |
+
memory_efficient=True
|
| 964 |
+
)
|
| 965 |
+
super().__init__(*args, **{**default_kwargs, **kwargs})
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
class SRJointUnet(JointUnet):
|
| 969 |
+
def __init__(self, *args, **kwargs):
|
| 970 |
+
default_kwargs = dict(
|
| 971 |
+
dim=128,
|
| 972 |
+
dim_mults=(1, 2, 4, 8),
|
| 973 |
+
num_resnet_blocks=(2, 4, 8, 8),
|
| 974 |
+
layer_attns=False,
|
| 975 |
+
layer_cross_attns=(False, False, False, True),
|
| 976 |
+
attn_heads=8,
|
| 977 |
+
ff_mult=2.,
|
| 978 |
+
memory_efficient=True
|
| 979 |
+
)
|
| 980 |
+
super().__init__(*args, **{**default_kwargs, **kwargs})
|
| 981 |
+
|
| 982 |
+
# main imagen ddpm class, which is a cascading DDPM from Ho et al.
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class JointImagen(nn.Module):
|
| 986 |
+
def __init__(
|
| 987 |
+
self,
|
| 988 |
+
unets,
|
| 989 |
+
*,
|
| 990 |
+
image_sizes, # for cascading ddpm, image size at each stage
|
| 991 |
+
num_classes,
|
| 992 |
+
text_encoder_name=DEFAULT_T5_NAME,
|
| 993 |
+
text_embed_dim=None,
|
| 994 |
+
channels=3,
|
| 995 |
+
timesteps=1000,
|
| 996 |
+
sample_timesteps=100,
|
| 997 |
+
cond_drop_prob=0.1,
|
| 998 |
+
loss_type='l2',
|
| 999 |
+
noise_schedules='cosine',
|
| 1000 |
+
noise_schedules_lbl='cosine_p',
|
| 1001 |
+
cosine_p_lbl=1.0,
|
| 1002 |
+
pred_objectives='noise',
|
| 1003 |
+
random_crop_sizes=None,
|
| 1004 |
+
lowres_noise_schedule='linear',
|
| 1005 |
+
# in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
|
| 1006 |
+
lowres_sample_noise_level=0.2,
|
| 1007 |
+
# unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
|
| 1008 |
+
per_sample_random_aug_noise_level=False,
|
| 1009 |
+
lowres_max_thres=0.999,
|
| 1010 |
+
condition_on_text=True,
|
| 1011 |
+
# whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
| 1012 |
+
auto_normalize_img=True,
|
| 1013 |
+
# p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
|
| 1014 |
+
p2_loss_weight_gamma=0.5,
|
| 1015 |
+
p2_loss_weight_k=1,
|
| 1016 |
+
dynamic_thresholding=True,
|
| 1017 |
+
dynamic_thresholding_percentile=0.95, # unsure what this was based on perusal of paper
|
| 1018 |
+
only_train_unet_number=None,
|
| 1019 |
+
):
|
| 1020 |
+
super().__init__()
|
| 1021 |
+
|
| 1022 |
+
# joint
|
| 1023 |
+
|
| 1024 |
+
self.num_classes = num_classes
|
| 1025 |
+
|
| 1026 |
+
# loss
|
| 1027 |
+
|
| 1028 |
+
if loss_type == 'l1':
|
| 1029 |
+
loss_fn = F.l1_loss
|
| 1030 |
+
elif loss_type == 'l2':
|
| 1031 |
+
loss_fn = F.mse_loss
|
| 1032 |
+
elif loss_type == 'huber':
|
| 1033 |
+
loss_fn = F.smooth_l1_loss
|
| 1034 |
+
else:
|
| 1035 |
+
raise NotImplementedError()
|
| 1036 |
+
|
| 1037 |
+
self.loss_type = loss_type
|
| 1038 |
+
self.loss_fn = loss_fn
|
| 1039 |
+
|
| 1040 |
+
# conditioning hparams
|
| 1041 |
+
|
| 1042 |
+
self.condition_on_text = condition_on_text
|
| 1043 |
+
self.unconditional = not condition_on_text
|
| 1044 |
+
|
| 1045 |
+
# channels
|
| 1046 |
+
|
| 1047 |
+
self.channels = channels
|
| 1048 |
+
|
| 1049 |
+
# automatically take care of ensuring that first unet is unconditional
|
| 1050 |
+
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
| 1051 |
+
|
| 1052 |
+
unets = cast_tuple(unets)
|
| 1053 |
+
num_unets = len(unets)
|
| 1054 |
+
|
| 1055 |
+
# determine noise schedules per unet
|
| 1056 |
+
|
| 1057 |
+
timesteps = cast_tuple(timesteps, num_unets)
|
| 1058 |
+
sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
| 1059 |
+
|
| 1060 |
+
# make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets
|
| 1061 |
+
|
| 1062 |
+
noise_schedules = cast_tuple(noise_schedules)
|
| 1063 |
+
noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
|
| 1064 |
+
noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
|
| 1065 |
+
noise_schedules_lbl = cast_tuple(noise_schedules_lbl)
|
| 1066 |
+
noise_schedules_lbl = pad_tuple_to_length(noise_schedules_lbl, 2, 'cosine_p')
|
| 1067 |
+
noise_schedules_lbl = pad_tuple_to_length(noise_schedules_lbl, num_unets, 'linear')
|
| 1068 |
+
|
| 1069 |
+
# construct noise schedulers
|
| 1070 |
+
|
| 1071 |
+
noise_scheduler_klass = GaussianDiffusionContinuousTimes
|
| 1072 |
+
noise_scheduler_lbl_klass = MultinomialDiffusion
|
| 1073 |
+
self.noise_schedulers = nn.ModuleList([])
|
| 1074 |
+
self.noise_schedulers_lbl = nn.ModuleList([])
|
| 1075 |
+
|
| 1076 |
+
for timestep, noise_schedule, noise_schedule_lbl in zip(timesteps, noise_schedules, noise_schedules_lbl):
|
| 1077 |
+
noise_scheduler = noise_scheduler_klass(noise_schedule=noise_schedule, timesteps=timestep)
|
| 1078 |
+
self.noise_schedulers.append(noise_scheduler)
|
| 1079 |
+
noise_scheduler_lbl = noise_scheduler_lbl_klass(
|
| 1080 |
+
num_classes, noise_schedule=noise_schedule_lbl, timesteps=timestep, p=cosine_p_lbl)
|
| 1081 |
+
self.noise_schedulers_lbl.append(noise_scheduler_lbl)
|
| 1082 |
+
|
| 1083 |
+
self.noise_schedulers_sample = nn.ModuleList([])
|
| 1084 |
+
self.noise_schedulers_lbl_sample = nn.ModuleList([])
|
| 1085 |
+
|
| 1086 |
+
for sample_timestep, noise_schedule, noise_schedule_lbl in zip(sample_timesteps, noise_schedules, noise_schedules_lbl):
|
| 1087 |
+
noise_scheduler_sample = noise_scheduler_klass(noise_schedule=noise_schedule, timesteps=sample_timestep)
|
| 1088 |
+
self.noise_schedulers_sample.append(noise_scheduler_sample)
|
| 1089 |
+
noise_scheduler_lbl_sample = noise_scheduler_lbl_klass(
|
| 1090 |
+
num_classes, noise_schedule=noise_schedule_lbl, timesteps=sample_timestep, p=cosine_p_lbl)
|
| 1091 |
+
self.noise_schedulers_lbl_sample.append(noise_scheduler_lbl_sample)
|
| 1092 |
+
|
| 1093 |
+
# randomly cropping for upsampler training
|
| 1094 |
+
|
| 1095 |
+
self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
|
| 1096 |
+
assert all(map(lambda x: x is None or (isinstance(x, (tuple, list)) and len(x) == 2), self.random_crop_sizes))
|
| 1097 |
+
assert not exists(first(self.random_crop_sizes)), \
|
| 1098 |
+
'you should not need to randomly crop image during training for base unet, only for upsamplers '\
|
| 1099 |
+
'- so pass in `random_crop_sizes = (None, 128, 256)` as example'
|
| 1100 |
+
|
| 1101 |
+
# lowres augmentation noise schedule
|
| 1102 |
+
|
| 1103 |
+
self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule=lowres_noise_schedule)
|
| 1104 |
+
self.lowres_noise_schedule_lbl = MultinomialDiffusion(
|
| 1105 |
+
num_classes, noise_schedule=lowres_noise_schedule, p=cosine_p_lbl)
|
| 1106 |
+
|
| 1107 |
+
# ddpm objectives - predicting noise by default
|
| 1108 |
+
|
| 1109 |
+
self.pred_objectives = cast_tuple(pred_objectives, num_unets)
|
| 1110 |
+
|
| 1111 |
+
# get text encoder
|
| 1112 |
+
|
| 1113 |
+
self.text_encoder_name = text_encoder_name
|
| 1114 |
+
self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
|
| 1115 |
+
|
| 1116 |
+
self.encode_text = partial(t5_encode_text, name=text_encoder_name)
|
| 1117 |
+
|
| 1118 |
+
# construct unets
|
| 1119 |
+
|
| 1120 |
+
self.unets = nn.ModuleList([])
|
| 1121 |
+
|
| 1122 |
+
self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
|
| 1123 |
+
self.only_train_unet_number = only_train_unet_number
|
| 1124 |
+
|
| 1125 |
+
for ind, one_unet in enumerate(unets):
|
| 1126 |
+
assert isinstance(one_unet, (JointUnet, Unet3D, NullUnet))
|
| 1127 |
+
is_first = ind == 0
|
| 1128 |
+
|
| 1129 |
+
one_unet = one_unet.cast_model_parameters(
|
| 1130 |
+
lowres_cond=not is_first,
|
| 1131 |
+
cond_on_text=self.condition_on_text,
|
| 1132 |
+
text_embed_dim=self.text_embed_dim if self.condition_on_text else None,
|
| 1133 |
+
channels=self.channels,
|
| 1134 |
+
channels_out=self.channels
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
self.unets.append(one_unet)
|
| 1138 |
+
|
| 1139 |
+
# unet image sizes
|
| 1140 |
+
|
| 1141 |
+
self.image_sizes = cast_tuple(image_sizes)
|
| 1142 |
+
assert all(map(lambda x: isinstance(x, (tuple, list)) and len(x) == 2, self.image_sizes))
|
| 1143 |
+
|
| 1144 |
+
assert num_unets == len(self.image_sizes), \
|
| 1145 |
+
f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {self.image_sizes}'
|
| 1146 |
+
|
| 1147 |
+
self.sample_channels = cast_tuple(self.channels, num_unets)
|
| 1148 |
+
|
| 1149 |
+
# determine whether we are training on images or video
|
| 1150 |
+
|
| 1151 |
+
is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
|
| 1152 |
+
self.is_video = is_video
|
| 1153 |
+
|
| 1154 |
+
self.right_pad_dims_to_datatype = partial(rearrange, pattern=(
|
| 1155 |
+
'b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
|
| 1156 |
+
self.resize_to = resize_video_to if is_video else resize_image_to
|
| 1157 |
+
|
| 1158 |
+
# cascading ddpm related stuff
|
| 1159 |
+
|
| 1160 |
+
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
| 1161 |
+
assert lowres_conditions == (False, *((True,) * (num_unets - 1))), \
|
| 1162 |
+
'the first unet must be unconditioned (by low resolution image), ' \
|
| 1163 |
+
'and the rest of the unets must have `lowres_cond` set to True'
|
| 1164 |
+
|
| 1165 |
+
self.lowres_sample_noise_level = lowres_sample_noise_level
|
| 1166 |
+
self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
|
| 1167 |
+
self.lowres_max_thres = lowres_max_thres
|
| 1168 |
+
|
| 1169 |
+
# classifier free guidance
|
| 1170 |
+
|
| 1171 |
+
self.cond_drop_prob = cond_drop_prob
|
| 1172 |
+
self.can_classifier_guidance = cond_drop_prob > 0.
|
| 1173 |
+
|
| 1174 |
+
# normalize and unnormalize image functions
|
| 1175 |
+
|
| 1176 |
+
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
| 1177 |
+
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
| 1178 |
+
self.input_image_range = (0. if auto_normalize_img else -1., 1.)
|
| 1179 |
+
|
| 1180 |
+
# dynamic thresholding
|
| 1181 |
+
|
| 1182 |
+
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
|
| 1183 |
+
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
|
| 1184 |
+
|
| 1185 |
+
# p2 loss weight
|
| 1186 |
+
|
| 1187 |
+
self.p2_loss_weight_k = p2_loss_weight_k
|
| 1188 |
+
self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
|
| 1189 |
+
|
| 1190 |
+
assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), \
|
| 1191 |
+
'in paper, they noticed any gamma greater than 2 is harmful'
|
| 1192 |
+
|
| 1193 |
+
# one temp parameter for keeping track of device
|
| 1194 |
+
|
| 1195 |
+
self.register_buffer('_temp', torch.tensor([0.]), persistent=False)
|
| 1196 |
+
|
| 1197 |
+
# default to device of unets passed in
|
| 1198 |
+
|
| 1199 |
+
self.to(next(self.unets.parameters()).device)
|
| 1200 |
+
|
| 1201 |
+
def force_unconditional_(self):
|
| 1202 |
+
self.condition_on_text = False
|
| 1203 |
+
self.unconditional = True
|
| 1204 |
+
|
| 1205 |
+
for unet in self.unets:
|
| 1206 |
+
unet.cond_on_text = False
|
| 1207 |
+
|
| 1208 |
+
@property
|
| 1209 |
+
def device(self):
|
| 1210 |
+
return self._temp.device
|
| 1211 |
+
|
| 1212 |
+
def get_unet(self, unet_number):
|
| 1213 |
+
assert 0 < unet_number <= len(self.unets)
|
| 1214 |
+
index = unet_number - 1
|
| 1215 |
+
|
| 1216 |
+
if isinstance(self.unets, nn.ModuleList):
|
| 1217 |
+
unets_list = [unet for unet in self.unets]
|
| 1218 |
+
delattr(self, 'unets')
|
| 1219 |
+
self.unets = unets_list
|
| 1220 |
+
|
| 1221 |
+
if index != self.unet_being_trained_index:
|
| 1222 |
+
for unet_index, unet in enumerate(self.unets):
|
| 1223 |
+
unet.to(self.device if unet_index == index else 'cpu')
|
| 1224 |
+
|
| 1225 |
+
self.unet_being_trained_index = index
|
| 1226 |
+
return self.unets[index]
|
| 1227 |
+
|
| 1228 |
+
def reset_unets_all_one_device(self, device=None):
|
| 1229 |
+
device = default(device, self.device)
|
| 1230 |
+
self.unets = nn.ModuleList([*self.unets])
|
| 1231 |
+
self.unets.to(device)
|
| 1232 |
+
|
| 1233 |
+
self.unet_being_trained_index = -1
|
| 1234 |
+
|
| 1235 |
+
@contextmanager
|
| 1236 |
+
def one_unet_in_gpu(self, unet_number=None, unet=None):
|
| 1237 |
+
assert exists(unet_number) ^ exists(unet)
|
| 1238 |
+
|
| 1239 |
+
if exists(unet_number):
|
| 1240 |
+
unet = self.unets[unet_number - 1]
|
| 1241 |
+
|
| 1242 |
+
devices = [module_device(unet) for unet in self.unets]
|
| 1243 |
+
self.unets.cpu()
|
| 1244 |
+
unet.to(self.device)
|
| 1245 |
+
|
| 1246 |
+
yield
|
| 1247 |
+
|
| 1248 |
+
for unet, device in zip(self.unets, devices):
|
| 1249 |
+
unet.to(device)
|
| 1250 |
+
|
| 1251 |
+
# overriding state dict functions
|
| 1252 |
+
|
| 1253 |
+
def state_dict(self, *args, **kwargs):
|
| 1254 |
+
self.reset_unets_all_one_device()
|
| 1255 |
+
return super().state_dict(*args, **kwargs)
|
| 1256 |
+
|
| 1257 |
+
def load_state_dict(self, *args, **kwargs):
|
| 1258 |
+
self.reset_unets_all_one_device()
|
| 1259 |
+
return super().load_state_dict(*args, **kwargs)
|
| 1260 |
+
|
| 1261 |
+
# gaussian diffusion methods
|
| 1262 |
+
|
| 1263 |
+
def p_mean_variance(
|
| 1264 |
+
self,
|
| 1265 |
+
unet: JointUnet,
|
| 1266 |
+
x,
|
| 1267 |
+
log_lbl,
|
| 1268 |
+
t,
|
| 1269 |
+
*,
|
| 1270 |
+
noise_scheduler: GaussianDiffusionContinuousTimes,
|
| 1271 |
+
noise_scheduler_lbl: MultinomialDiffusion,
|
| 1272 |
+
text_embeds=None,
|
| 1273 |
+
text_mask=None,
|
| 1274 |
+
cond_images=None,
|
| 1275 |
+
lowres_cond_img=None,
|
| 1276 |
+
lowres_cond_lbl=None,
|
| 1277 |
+
self_cond=None,
|
| 1278 |
+
self_cond_lbl=None,
|
| 1279 |
+
lowres_noise_times=None,
|
| 1280 |
+
cond_scale=1.,
|
| 1281 |
+
model_output=None,
|
| 1282 |
+
t_next=None,
|
| 1283 |
+
pred_objective='noise',
|
| 1284 |
+
dynamic_threshold=True,
|
| 1285 |
+
):
|
| 1286 |
+
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
| 1287 |
+
lbl = log_onehot_to_index(log_lbl)
|
| 1288 |
+
pred, pred_lbl = default(model_output, lambda: unet.forward_with_cond_scale(
|
| 1289 |
+
x, lbl, noise_scheduler.get_condition(t),
|
| 1290 |
+
text_embeds=text_embeds, text_mask=text_mask,
|
| 1291 |
+
cond_images=cond_images, cond_scale=cond_scale,
|
| 1292 |
+
lowres_cond_img=lowres_cond_img, lowres_cond_lbl=lowres_cond_lbl,
|
| 1293 |
+
self_cond=self_cond, self_cond_lbl=self_cond_lbl,
|
| 1294 |
+
lowres_noise_times=self.lowres_noise_schedule.get_condition(lowres_noise_times)))
|
| 1295 |
+
pred_lbl = F.log_softmax(pred_lbl, dim=1)
|
| 1296 |
+
pred_lbl = noise_scheduler_lbl.q_posterior(pred_lbl, log_lbl, t)
|
| 1297 |
+
|
| 1298 |
+
if pred_objective == 'noise':
|
| 1299 |
+
x_start = noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
|
| 1300 |
+
# lbl_start = noise_scheduler_lbl.predict_start_from_noise(log_lbl, t=t, noise=pred_lbl) # TODO ???
|
| 1301 |
+
elif pred_objective == 'x_start':
|
| 1302 |
+
x_start = pred
|
| 1303 |
+
# lbl_start = pred_lbl
|
| 1304 |
+
else:
|
| 1305 |
+
raise ValueError(f'unknown objective {pred_objective}')
|
| 1306 |
+
lbl_start = None
|
| 1307 |
+
|
| 1308 |
+
if dynamic_threshold:
|
| 1309 |
+
# following pseudocode in appendix
|
| 1310 |
+
# s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
|
| 1311 |
+
s = torch.quantile(
|
| 1312 |
+
rearrange(x_start, 'b ... -> b (...)').abs(),
|
| 1313 |
+
self.dynamic_thresholding_percentile,
|
| 1314 |
+
dim=-1
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
s.clamp_(min=1.)
|
| 1318 |
+
s = right_pad_dims_to(x_start, s)
|
| 1319 |
+
x_start = x_start.clamp(-s, s) / s
|
| 1320 |
+
else:
|
| 1321 |
+
x_start.clamp_(-1., 1.)
|
| 1322 |
+
|
| 1323 |
+
mean_and_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next)
|
| 1324 |
+
log_lbl = noise_scheduler_lbl.log_sample_categorical(pred_lbl)
|
| 1325 |
+
return mean_and_variance, log_lbl, x_start, lbl_start
|
| 1326 |
+
|
| 1327 |
+
@torch.no_grad()
|
| 1328 |
+
def p_sample(
|
| 1329 |
+
self,
|
| 1330 |
+
unet,
|
| 1331 |
+
x,
|
| 1332 |
+
log_lbl,
|
| 1333 |
+
t,
|
| 1334 |
+
*,
|
| 1335 |
+
noise_scheduler,
|
| 1336 |
+
noise_scheduler_lbl,
|
| 1337 |
+
t_next=None,
|
| 1338 |
+
text_embeds=None,
|
| 1339 |
+
text_mask=None,
|
| 1340 |
+
cond_images=None,
|
| 1341 |
+
cond_scale=1.,
|
| 1342 |
+
self_cond=None,
|
| 1343 |
+
self_cond_lbl=None,
|
| 1344 |
+
lowres_cond_img=None,
|
| 1345 |
+
lowres_cond_lbl=None,
|
| 1346 |
+
lowres_noise_times=None,
|
| 1347 |
+
pred_objective='noise',
|
| 1348 |
+
dynamic_threshold=True,
|
| 1349 |
+
):
|
| 1350 |
+
b, *_, device = *x.shape, x.device
|
| 1351 |
+
(model_mean, _, model_log_variance), pred_lbl, x_start, lbl_start = self.p_mean_variance(
|
| 1352 |
+
unet, x=x, log_lbl=log_lbl, t=t, t_next=t_next,
|
| 1353 |
+
noise_scheduler=noise_scheduler, noise_scheduler_lbl=noise_scheduler_lbl,
|
| 1354 |
+
text_embeds=text_embeds, text_mask=text_mask,
|
| 1355 |
+
cond_images=cond_images, cond_scale=cond_scale,
|
| 1356 |
+
lowres_cond_img=lowres_cond_img, lowres_cond_lbl=lowres_cond_lbl,
|
| 1357 |
+
self_cond=self_cond, self_cond_lbl=self_cond_lbl,
|
| 1358 |
+
lowres_noise_times=lowres_noise_times,
|
| 1359 |
+
pred_objective=pred_objective, dynamic_threshold=dynamic_threshold)
|
| 1360 |
+
noise = torch.randn_like(x)
|
| 1361 |
+
# no noise when t == 0
|
| 1362 |
+
is_last_sampling_timestep = (t_next == 0) if isinstance(
|
| 1363 |
+
noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
|
| 1364 |
+
nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 1365 |
+
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 1366 |
+
return pred, pred_lbl, x_start, lbl_start
|
| 1367 |
+
|
| 1368 |
+
@torch.no_grad()
|
| 1369 |
+
def p_sample_loop(
|
| 1370 |
+
self,
|
| 1371 |
+
unet,
|
| 1372 |
+
shape,
|
| 1373 |
+
*,
|
| 1374 |
+
noise_scheduler: GaussianDiffusionContinuousTimes,
|
| 1375 |
+
noise_scheduler_lbl: MultinomialDiffusion,
|
| 1376 |
+
lowres_cond_img=None,
|
| 1377 |
+
lowres_cond_lbl=None,
|
| 1378 |
+
lowres_noise_times=None,
|
| 1379 |
+
text_embeds=None,
|
| 1380 |
+
text_mask=None,
|
| 1381 |
+
cond_images=None,
|
| 1382 |
+
inpaint_images=None,
|
| 1383 |
+
inpaint_labels=None,
|
| 1384 |
+
inpaint_masks=None,
|
| 1385 |
+
inpaint_resample_times=5,
|
| 1386 |
+
init_images=None,
|
| 1387 |
+
init_labels=None,
|
| 1388 |
+
skip_steps=None,
|
| 1389 |
+
cond_scale=1,
|
| 1390 |
+
pred_objective='noise',
|
| 1391 |
+
dynamic_threshold=True,
|
| 1392 |
+
use_tqdm=True
|
| 1393 |
+
):
|
| 1394 |
+
assert init_labels is None, 'not implemented yet'
|
| 1395 |
+
device = self.device
|
| 1396 |
+
|
| 1397 |
+
batch, _, h, w = shape
|
| 1398 |
+
img = torch.randn(shape, device=device)
|
| 1399 |
+
uniform_logits = torch.zeros((batch, self.num_classes) + (h, w), device=device)
|
| 1400 |
+
log_lbl = noise_scheduler_lbl.log_sample_categorical(uniform_logits)
|
| 1401 |
+
|
| 1402 |
+
# for initialization with an image or video
|
| 1403 |
+
|
| 1404 |
+
if exists(init_images):
|
| 1405 |
+
img += init_images
|
| 1406 |
+
# TODO init_labels
|
| 1407 |
+
|
| 1408 |
+
# keep track of x0, for self conditioning
|
| 1409 |
+
|
| 1410 |
+
x_start = None
|
| 1411 |
+
lbl_start = None
|
| 1412 |
+
|
| 1413 |
+
# prepare inpainting
|
| 1414 |
+
|
| 1415 |
+
has_inpainting = exists(inpaint_images) and exists(inpaint_labels) and exists(inpaint_masks)
|
| 1416 |
+
resample_times = inpaint_resample_times if has_inpainting else 1
|
| 1417 |
+
|
| 1418 |
+
if has_inpainting:
|
| 1419 |
+
assert inpaint_masks.shape[1] == 2, \
|
| 1420 |
+
f'inpaint mask is a tuple of (mask_image, mask_label) but now:\n{inpaint_labels}'
|
| 1421 |
+
inpaint_images = self.normalize_img(inpaint_images)
|
| 1422 |
+
inpaint_images = self.resize_to(inpaint_images, shape[-2:])
|
| 1423 |
+
|
| 1424 |
+
log_inpaint_labels = index_to_log_onehot(inpaint_labels.long(), self.num_classes)
|
| 1425 |
+
log_inpaint_labels = self.resize_to(log_inpaint_labels, shape[-2:])
|
| 1426 |
+
|
| 1427 |
+
inpaint_masks_image = self.resize_to(inpaint_masks[:, [0]], shape[-2:]).bool()
|
| 1428 |
+
inpaint_masks_label = self.resize_to(inpaint_masks[:, [1]], shape[-2:]).bool()
|
| 1429 |
+
|
| 1430 |
+
# time
|
| 1431 |
+
|
| 1432 |
+
timesteps = noise_scheduler.get_sampling_timesteps(batch, device=device)
|
| 1433 |
+
timesteps = [t * (t < 1.) + (1 - 1e-7) * (t >= 1.) for t in timesteps]
|
| 1434 |
+
|
| 1435 |
+
# whether to skip any steps
|
| 1436 |
+
|
| 1437 |
+
skip_steps = default(skip_steps, 0)
|
| 1438 |
+
timesteps = timesteps[skip_steps:]
|
| 1439 |
+
|
| 1440 |
+
for times, times_next in tqdm(timesteps, desc='sampling loop time step', total=len(timesteps), disable=not use_tqdm):
|
| 1441 |
+
is_last_timestep = times_next == 0
|
| 1442 |
+
|
| 1443 |
+
for r in reversed(range(resample_times)):
|
| 1444 |
+
is_last_resample_step = r == 0
|
| 1445 |
+
|
| 1446 |
+
if has_inpainting:
|
| 1447 |
+
noised_inpaint_images, _ = noise_scheduler.q_sample(inpaint_images, t=times)
|
| 1448 |
+
img = img * ~inpaint_masks_image + noised_inpaint_images * inpaint_masks_image
|
| 1449 |
+
log_noised_inpaint_labels = noise_scheduler_lbl.q_sample(log_inpaint_labels, t=times)
|
| 1450 |
+
log_lbl = log_lbl * ~inpaint_masks_label + log_noised_inpaint_labels * inpaint_masks_label
|
| 1451 |
+
|
| 1452 |
+
self_cond = x_start if unet.self_cond else None
|
| 1453 |
+
self_cond_lbl = lbl_start if unet.self_cond else None
|
| 1454 |
+
|
| 1455 |
+
img, log_lbl, x_start, lbl_start = self.p_sample(
|
| 1456 |
+
unet,
|
| 1457 |
+
img,
|
| 1458 |
+
log_lbl,
|
| 1459 |
+
times,
|
| 1460 |
+
t_next=times_next,
|
| 1461 |
+
text_embeds=text_embeds,
|
| 1462 |
+
text_mask=text_mask,
|
| 1463 |
+
cond_images=cond_images,
|
| 1464 |
+
cond_scale=cond_scale,
|
| 1465 |
+
self_cond=self_cond,
|
| 1466 |
+
self_cond_lbl=self_cond_lbl,
|
| 1467 |
+
lowres_cond_img=lowres_cond_img,
|
| 1468 |
+
lowres_cond_lbl=lowres_cond_lbl,
|
| 1469 |
+
lowres_noise_times=lowres_noise_times,
|
| 1470 |
+
noise_scheduler=noise_scheduler,
|
| 1471 |
+
noise_scheduler_lbl=noise_scheduler_lbl,
|
| 1472 |
+
pred_objective=pred_objective,
|
| 1473 |
+
dynamic_threshold=dynamic_threshold,
|
| 1474 |
+
)
|
| 1475 |
+
|
| 1476 |
+
if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
|
| 1477 |
+
renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
|
| 1478 |
+
img = torch.where(
|
| 1479 |
+
self.right_pad_dims_to_datatype(is_last_timestep),
|
| 1480 |
+
img,
|
| 1481 |
+
renoised_img
|
| 1482 |
+
)
|
| 1483 |
+
renoised_log_lbl = noise_scheduler_lbl.q_sample_from_to(log_lbl, times_next, times)
|
| 1484 |
+
log_lbl = torch.where(
|
| 1485 |
+
self.right_pad_dims_to_datatype(is_last_timestep),
|
| 1486 |
+
log_lbl,
|
| 1487 |
+
renoised_log_lbl
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
img.clamp_(-1., 1.)
|
| 1491 |
+
|
| 1492 |
+
# final inpainting
|
| 1493 |
+
|
| 1494 |
+
if has_inpainting:
|
| 1495 |
+
img = img * ~inpaint_masks_image + inpaint_images * inpaint_masks_image
|
| 1496 |
+
log_lbl = log_lbl * ~inpaint_masks_label + log_inpaint_labels * inpaint_masks_label
|
| 1497 |
+
|
| 1498 |
+
unnormalize_img = self.unnormalize_img(img)
|
| 1499 |
+
lbl = log_onehot_to_index(log_lbl)
|
| 1500 |
+
return unnormalize_img, lbl
|
| 1501 |
+
|
| 1502 |
+
@torch.no_grad()
|
| 1503 |
+
@eval_decorator
|
| 1504 |
+
def sample(
|
| 1505 |
+
self,
|
| 1506 |
+
texts: List[str] = None,
|
| 1507 |
+
text_masks=None,
|
| 1508 |
+
text_embeds=None,
|
| 1509 |
+
video_frames=None,
|
| 1510 |
+
cond_images=None,
|
| 1511 |
+
inpaint_images=None,
|
| 1512 |
+
inpaint_labels=None,
|
| 1513 |
+
inpaint_masks=None,
|
| 1514 |
+
inpaint_resample_times=5,
|
| 1515 |
+
init_images=None,
|
| 1516 |
+
init_labels=None,
|
| 1517 |
+
skip_steps=None,
|
| 1518 |
+
batch_size=1,
|
| 1519 |
+
cond_scale=1.,
|
| 1520 |
+
lowres_sample_noise_level=None,
|
| 1521 |
+
start_at_unet_number=1,
|
| 1522 |
+
start_image_or_video=None,
|
| 1523 |
+
start_label_or_video=None,
|
| 1524 |
+
stop_at_unet_number=None,
|
| 1525 |
+
return_all_unet_outputs=False,
|
| 1526 |
+
return_pil_images=False,
|
| 1527 |
+
device=None,
|
| 1528 |
+
use_tqdm=True
|
| 1529 |
+
):
|
| 1530 |
+
device = default(device, self.device)
|
| 1531 |
+
self.reset_unets_all_one_device(device=device)
|
| 1532 |
+
|
| 1533 |
+
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
|
| 1534 |
+
|
| 1535 |
+
if exists(texts) and not exists(text_embeds) and not self.unconditional:
|
| 1536 |
+
assert all([*map(len, texts)]), 'text cannot be empty'
|
| 1537 |
+
|
| 1538 |
+
with autocast(enabled=False):
|
| 1539 |
+
text_embeds, text_masks = self.encode_text(texts, return_attn_mask=True)
|
| 1540 |
+
|
| 1541 |
+
text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
|
| 1542 |
+
|
| 1543 |
+
if not self.unconditional:
|
| 1544 |
+
assert exists(text_embeds), \
|
| 1545 |
+
'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
|
| 1546 |
+
|
| 1547 |
+
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim=-1))
|
| 1548 |
+
batch_size = text_embeds.shape[0]
|
| 1549 |
+
|
| 1550 |
+
if exists(inpaint_images) and exists(inpaint_labels):
|
| 1551 |
+
if self.unconditional:
|
| 1552 |
+
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
|
| 1553 |
+
batch_size = inpaint_images.shape[0]
|
| 1554 |
+
|
| 1555 |
+
assert inpaint_images.shape[0] == batch_size, \
|
| 1556 |
+
'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
|
| 1557 |
+
assert inpaint_labels.shape[0] == batch_size, \
|
| 1558 |
+
'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
|
| 1559 |
+
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), \
|
| 1560 |
+
'number of inpainting images must be equal to the number of text to be conditioned on'
|
| 1561 |
+
assert not (self.condition_on_text and inpaint_labels.shape[0] != text_embeds.shape[0]), \
|
| 1562 |
+
'number of inpainting images must be equal to the number of text to be conditioned on'
|
| 1563 |
+
|
| 1564 |
+
assert not (self.condition_on_text and not exists(text_embeds)), \
|
| 1565 |
+
'text or text encodings must be passed into imagen if specified'
|
| 1566 |
+
assert not (not self.condition_on_text and exists(text_embeds)), \
|
| 1567 |
+
'imagen specified not to be conditioned on text, yet it is presented'
|
| 1568 |
+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), \
|
| 1569 |
+
f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
|
| 1570 |
+
|
| 1571 |
+
assert (not (exists(inpaint_images) or exists(inpaint_labels) or exists(inpaint_masks))) \
|
| 1572 |
+
or (exists(inpaint_images) and exists(inpaint_labels) and exists(inpaint_masks)), \
|
| 1573 |
+
'inpaint images, labels and masks must be both passed in to do inpainting'
|
| 1574 |
+
|
| 1575 |
+
outputs = []
|
| 1576 |
+
|
| 1577 |
+
is_cuda = next(self.parameters()).is_cuda
|
| 1578 |
+
device = next(self.parameters()).device
|
| 1579 |
+
|
| 1580 |
+
lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
|
| 1581 |
+
|
| 1582 |
+
num_unets = len(self.unets)
|
| 1583 |
+
|
| 1584 |
+
# condition scaling
|
| 1585 |
+
|
| 1586 |
+
cond_scale = cast_tuple(cond_scale, num_unets)
|
| 1587 |
+
|
| 1588 |
+
# add frame dimension for video
|
| 1589 |
+
|
| 1590 |
+
assert not (self.is_video and not exists(video_frames)
|
| 1591 |
+
), 'video_frames must be passed in on sample time if training on video'
|
| 1592 |
+
|
| 1593 |
+
frame_dims = (video_frames,) if self.is_video else tuple()
|
| 1594 |
+
|
| 1595 |
+
# for initial image and skipping steps
|
| 1596 |
+
|
| 1597 |
+
init_images = cast_tuple(init_images, num_unets)
|
| 1598 |
+
init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
|
| 1599 |
+
init_labels = cast_tuple(init_labels, num_unets)
|
| 1600 |
+
|
| 1601 |
+
skip_steps = cast_tuple(skip_steps, num_unets)
|
| 1602 |
+
|
| 1603 |
+
# handle starting at a unet greater than 1, for training only-upscaler training
|
| 1604 |
+
|
| 1605 |
+
if start_at_unet_number > 1:
|
| 1606 |
+
assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
|
| 1607 |
+
assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
|
| 1608 |
+
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
|
| 1609 |
+
assert exists(start_label_or_video), 'starting image or video must be supplied if only doing upscaling'
|
| 1610 |
+
|
| 1611 |
+
prev_image_size = self.image_sizes[start_at_unet_number - 2]
|
| 1612 |
+
img = self.resize_to(start_image_or_video, prev_image_size)
|
| 1613 |
+
lbl = self.resize_to(start_label_or_video, prev_image_size)
|
| 1614 |
+
|
| 1615 |
+
# go through each unet in cascade
|
| 1616 |
+
|
| 1617 |
+
for unet_number, unet, channel, image_size, noise_scheduler, noise_scheduler_lbl, pred_objective, \
|
| 1618 |
+
dynamic_threshold, unet_cond_scale, unet_init_images, unet_init_labels, unet_skip_steps \
|
| 1619 |
+
in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes,
|
| 1620 |
+
self.noise_schedulers_sample, self.noise_schedulers_lbl_sample, self.pred_objectives,
|
| 1621 |
+
self.dynamic_thresholding, cond_scale, init_images, init_labels, skip_steps),
|
| 1622 |
+
disable=not use_tqdm):
|
| 1623 |
+
|
| 1624 |
+
if unet_number < start_at_unet_number:
|
| 1625 |
+
continue
|
| 1626 |
+
|
| 1627 |
+
assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'
|
| 1628 |
+
|
| 1629 |
+
context = self.one_unet_in_gpu(unet=unet) if is_cuda else nullcontext()
|
| 1630 |
+
|
| 1631 |
+
with context:
|
| 1632 |
+
lowres_cond_img = lowres_cond_lbl = lowres_noise_times = None
|
| 1633 |
+
shape = (batch_size, channel, *frame_dims, *image_size)
|
| 1634 |
+
|
| 1635 |
+
if unet.lowres_cond:
|
| 1636 |
+
lowres_noise_times = self.lowres_noise_schedule.get_times(
|
| 1637 |
+
batch_size, lowres_sample_noise_level, device=device)
|
| 1638 |
+
|
| 1639 |
+
lowres_cond_img = self.resize_to(img, image_size)
|
| 1640 |
+
lowres_cond_lbl = self.resize_to(lbl.float(), image_size)
|
| 1641 |
+
|
| 1642 |
+
lowres_cond_img = self.normalize_img(lowres_cond_img)
|
| 1643 |
+
lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(
|
| 1644 |
+
x_start=lowres_cond_img, t=lowres_noise_times, noise=torch.randn_like(lowres_cond_img))
|
| 1645 |
+
lowres_cond_log_lbl = index_to_log_onehot(lowres_cond_lbl.long(), self.num_classes)
|
| 1646 |
+
lowres_cond_log_lbl_noisy = self.lowres_noise_schedule_lbl.q_sample(
|
| 1647 |
+
lowres_cond_log_lbl, t=lowres_noise_times)
|
| 1648 |
+
lowres_cond_lbl_noisy = log_onehot_to_index(lowres_cond_log_lbl_noisy)
|
| 1649 |
+
lowres_cond_lbl = lowres_cond_lbl_noisy # change just naming
|
| 1650 |
+
|
| 1651 |
+
if exists(unet_init_images) and exists(unet_init_labels):
|
| 1652 |
+
unet_init_images = self.resize_to(unet_init_images, image_size)
|
| 1653 |
+
unet_init_labels = self.resize_to(unet_init_labels, image_size)
|
| 1654 |
+
|
| 1655 |
+
shape = (batch_size, self.channels, *frame_dims, *image_size)
|
| 1656 |
+
|
| 1657 |
+
img, lbl = self.p_sample_loop(
|
| 1658 |
+
unet,
|
| 1659 |
+
shape,
|
| 1660 |
+
text_embeds=text_embeds,
|
| 1661 |
+
text_mask=text_masks,
|
| 1662 |
+
cond_images=cond_images,
|
| 1663 |
+
inpaint_images=inpaint_images,
|
| 1664 |
+
inpaint_labels=inpaint_labels,
|
| 1665 |
+
inpaint_masks=inpaint_masks,
|
| 1666 |
+
inpaint_resample_times=inpaint_resample_times,
|
| 1667 |
+
init_images=unet_init_images,
|
| 1668 |
+
init_labels=unet_init_labels,
|
| 1669 |
+
skip_steps=unet_skip_steps,
|
| 1670 |
+
cond_scale=unet_cond_scale,
|
| 1671 |
+
lowres_cond_img=lowres_cond_img,
|
| 1672 |
+
lowres_cond_lbl=lowres_cond_lbl,
|
| 1673 |
+
lowres_noise_times=lowres_noise_times,
|
| 1674 |
+
noise_scheduler=noise_scheduler,
|
| 1675 |
+
noise_scheduler_lbl=noise_scheduler_lbl,
|
| 1676 |
+
pred_objective=pred_objective,
|
| 1677 |
+
dynamic_threshold=dynamic_threshold,
|
| 1678 |
+
use_tqdm=use_tqdm
|
| 1679 |
+
)
|
| 1680 |
+
|
| 1681 |
+
outputs.append((img.cpu(), lbl.cpu()))
|
| 1682 |
+
|
| 1683 |
+
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
| 1684 |
+
break
|
| 1685 |
+
|
| 1686 |
+
# either return last unet output or all unet outputs
|
| 1687 |
+
output_index = -1 if not return_all_unet_outputs else slice(None)
|
| 1688 |
+
|
| 1689 |
+
if not return_pil_images:
|
| 1690 |
+
return outputs[output_index]
|
| 1691 |
+
|
| 1692 |
+
if not return_all_unet_outputs:
|
| 1693 |
+
outputs = outputs[-1:]
|
| 1694 |
+
|
| 1695 |
+
assert not self.is_video, 'converting sampled video tensor to video file is not supported yet'
|
| 1696 |
+
|
| 1697 |
+
# TODO lbl pil_images
|
| 1698 |
+
pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim=0))), outputs))
|
| 1699 |
+
|
| 1700 |
+
# now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
|
| 1701 |
+
return pil_images[output_index]
|
| 1702 |
+
|
| 1703 |
+
def p_losses(
|
| 1704 |
+
self,
|
| 1705 |
+
unet: Union[JointUnet, Unet3D, NullUnet, DistributedDataParallel],
|
| 1706 |
+
x_start,
|
| 1707 |
+
lbl_start,
|
| 1708 |
+
times,
|
| 1709 |
+
*,
|
| 1710 |
+
noise_scheduler: GaussianDiffusionContinuousTimes,
|
| 1711 |
+
noise_scheduler_lbl: MultinomialDiffusion,
|
| 1712 |
+
lowres_cond_img=None,
|
| 1713 |
+
lowres_cond_lbl=None,
|
| 1714 |
+
lowres_aug_times=None,
|
| 1715 |
+
text_embeds=None,
|
| 1716 |
+
text_mask=None,
|
| 1717 |
+
cond_images=None,
|
| 1718 |
+
noise=None,
|
| 1719 |
+
noise_lbl=None,
|
| 1720 |
+
times_next=None,
|
| 1721 |
+
pred_objective='noise',
|
| 1722 |
+
p2_loss_weight_gamma=0.,
|
| 1723 |
+
random_crop_size=None
|
| 1724 |
+
):
|
| 1725 |
+
is_video = x_start.ndim == 5
|
| 1726 |
+
|
| 1727 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 1728 |
+
# noise_lbl = default(noise_lbl, lambda: torch.randn_like(x_start)) # TODO
|
| 1729 |
+
|
| 1730 |
+
# normalize to [-1, 1]
|
| 1731 |
+
|
| 1732 |
+
x_start = self.normalize_img(x_start)
|
| 1733 |
+
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
| 1734 |
+
|
| 1735 |
+
# random cropping during training
|
| 1736 |
+
# for upsamplers
|
| 1737 |
+
|
| 1738 |
+
if exists(random_crop_size):
|
| 1739 |
+
if is_video:
|
| 1740 |
+
frames = x_start.shape[2]
|
| 1741 |
+
x_start, lowres_cond_img, noise = rearrange_many(
|
| 1742 |
+
(x_start, lowres_cond_img, noise), 'b c f h w -> (b f) c h w')
|
| 1743 |
+
|
| 1744 |
+
aug = K.RandomCrop(random_crop_size, p=1.)
|
| 1745 |
+
|
| 1746 |
+
# make sure low res conditioner and image both get augmented the same way
|
| 1747 |
+
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
|
| 1748 |
+
x_start = aug(x_start)
|
| 1749 |
+
lbl_start = aug(lbl_start, params=aug._params)
|
| 1750 |
+
lowres_cond_img = aug(lowres_cond_img, params=aug._params)
|
| 1751 |
+
lowres_cond_lbl = aug(lowres_cond_lbl, params=aug._params)
|
| 1752 |
+
noise = aug(noise, params=aug._params)
|
| 1753 |
+
|
| 1754 |
+
if is_video:
|
| 1755 |
+
x_start, lowres_cond_img, noise = rearrange_many(
|
| 1756 |
+
(x_start, lowres_cond_img, noise), '(b f) c h w -> b c f h w', f=frames)
|
| 1757 |
+
|
| 1758 |
+
# get x_t
|
| 1759 |
+
|
| 1760 |
+
x_noisy, log_snr = noise_scheduler.q_sample(x_start=x_start, t=times, noise=noise)
|
| 1761 |
+
log_lbl_start = index_to_log_onehot(lbl_start.long(), self.num_classes)
|
| 1762 |
+
log_lbl_noisy = noise_scheduler_lbl.q_sample(log_lbl_start, t=times)
|
| 1763 |
+
lbl_noisy = log_onehot_to_index(log_lbl_noisy)
|
| 1764 |
+
|
| 1765 |
+
# also noise the lowres conditioning image
|
| 1766 |
+
# at sample time, they then fix the noise level of 0.1 - 0.3
|
| 1767 |
+
|
| 1768 |
+
lowres_cond_img_noisy = None
|
| 1769 |
+
lowres_cond_lbl_noisy = None
|
| 1770 |
+
if exists(lowres_cond_img) and exists(lowres_cond_lbl):
|
| 1771 |
+
lowres_aug_times = default(lowres_aug_times, times)
|
| 1772 |
+
lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(
|
| 1773 |
+
x_start=lowres_cond_img, t=lowres_aug_times, noise=torch.randn_like(lowres_cond_img))
|
| 1774 |
+
lowres_cond_log_lbl = index_to_log_onehot(lowres_cond_lbl.long(), self.num_classes)
|
| 1775 |
+
lowres_cond_log_lbl_noisy = self.lowres_noise_schedule_lbl.q_sample(
|
| 1776 |
+
lowres_cond_log_lbl, t=lowres_aug_times)
|
| 1777 |
+
lowres_cond_lbl_noisy = log_onehot_to_index(lowres_cond_log_lbl_noisy)
|
| 1778 |
+
|
| 1779 |
+
# time condition
|
| 1780 |
+
|
| 1781 |
+
noise_cond = noise_scheduler.get_condition(times)
|
| 1782 |
+
|
| 1783 |
+
# unet kwargs
|
| 1784 |
+
|
| 1785 |
+
unet_kwargs = dict(
|
| 1786 |
+
text_embeds=text_embeds,
|
| 1787 |
+
text_mask=text_mask,
|
| 1788 |
+
cond_images=cond_images,
|
| 1789 |
+
lowres_noise_times=self.lowres_noise_schedule.get_condition(lowres_aug_times),
|
| 1790 |
+
lowres_cond_img=lowres_cond_img_noisy,
|
| 1791 |
+
lowres_cond_lbl=lowres_cond_lbl_noisy,
|
| 1792 |
+
cond_drop_prob=self.cond_drop_prob,
|
| 1793 |
+
)
|
| 1794 |
+
|
| 1795 |
+
# self condition if needed
|
| 1796 |
+
|
| 1797 |
+
# Because 'unet' can be an instance of DistributedDataParallel coming from the
|
| 1798 |
+
# ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
|
| 1799 |
+
# access the member 'module' of the wrapped unet instance.
|
| 1800 |
+
self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
|
| 1801 |
+
|
| 1802 |
+
if self_cond and random() < 0.5:
|
| 1803 |
+
with torch.no_grad():
|
| 1804 |
+
pred, pred_lbl = unet.forward(
|
| 1805 |
+
x_noisy,
|
| 1806 |
+
lbl_noisy,
|
| 1807 |
+
noise_cond,
|
| 1808 |
+
**unet_kwargs
|
| 1809 |
+
).detach()
|
| 1810 |
+
pred_lbl = F.log_softmax(pred_lbl, dim=1)
|
| 1811 |
+
pred_lbl = noise_scheduler_lbl.q_posterior(pred_lbl, log_lbl_noisy, times)
|
| 1812 |
+
|
| 1813 |
+
x_start = noise_scheduler.predict_start_from_noise(
|
| 1814 |
+
x_noisy, t=times, noise=pred) if pred_objective == 'noise' else pred
|
| 1815 |
+
# lbl_start = noise_scheduler_lbl.predict_start_from_noise(
|
| 1816 |
+
# lbl_noisy, t=times, noise=pred_lbl) if pred_objective == 'noise' else pred_lbl # TODO ???
|
| 1817 |
+
lbl_start = None
|
| 1818 |
+
|
| 1819 |
+
unet_kwargs = {**unet_kwargs, 'self_cond': x_start, 'self_cond_lbl': lbl_start}
|
| 1820 |
+
|
| 1821 |
+
# get prediction
|
| 1822 |
+
|
| 1823 |
+
pred, pred_lbl = unet.forward(
|
| 1824 |
+
x_noisy,
|
| 1825 |
+
lbl_noisy,
|
| 1826 |
+
noise_cond,
|
| 1827 |
+
**unet_kwargs
|
| 1828 |
+
)
|
| 1829 |
+
pred_lbl = F.log_softmax(pred_lbl, dim=1)
|
| 1830 |
+
pred_lbl_post = noise_scheduler_lbl.q_posterior(pred_lbl, log_lbl_noisy, times)
|
| 1831 |
+
|
| 1832 |
+
# prediction objective
|
| 1833 |
+
|
| 1834 |
+
if pred_objective == 'noise':
|
| 1835 |
+
target = noise
|
| 1836 |
+
elif pred_objective == 'x_start':
|
| 1837 |
+
target = x_start
|
| 1838 |
+
else:
|
| 1839 |
+
raise ValueError(f'unknown objective {pred_objective}')
|
| 1840 |
+
target_log_lbl = noise_scheduler_lbl.q_posterior(log_lbl_start, log_lbl_noisy, times)
|
| 1841 |
+
|
| 1842 |
+
# losses
|
| 1843 |
+
|
| 1844 |
+
losses = self.loss_fn(pred, target, reduction='none')
|
| 1845 |
+
losses = reduce(losses, 'b ... -> b', 'mean')
|
| 1846 |
+
losses_lbl = noise_scheduler_lbl.loss_fn(target_log_lbl, pred_lbl_post, times, log_lbl_start)
|
| 1847 |
+
|
| 1848 |
+
# p2 loss reweighting
|
| 1849 |
+
|
| 1850 |
+
if p2_loss_weight_gamma > 0:
|
| 1851 |
+
loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
|
| 1852 |
+
losses = losses * loss_weight
|
| 1853 |
+
losses_lbl = losses_lbl * loss_weight
|
| 1854 |
+
|
| 1855 |
+
return losses.mean(), losses_lbl.mean()
|
| 1856 |
+
|
| 1857 |
+
def forward(
|
| 1858 |
+
self,
|
| 1859 |
+
images,
|
| 1860 |
+
labels,
|
| 1861 |
+
unet: Union[JointUnet, Unet3D, NullUnet, DistributedDataParallel] = None,
|
| 1862 |
+
texts: List[str] = None,
|
| 1863 |
+
text_embeds=None,
|
| 1864 |
+
text_masks=None,
|
| 1865 |
+
unet_number=None,
|
| 1866 |
+
cond_images=None
|
| 1867 |
+
):
|
| 1868 |
+
# assert images.shape[-1] == images.shape[-2], \
|
| 1869 |
+
# f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
|
| 1870 |
+
assert not (len(self.unets) > 1 and not exists(unet_number)), \
|
| 1871 |
+
f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
| 1872 |
+
unet_number = default(unet_number, 1)
|
| 1873 |
+
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, \
|
| 1874 |
+
'you can only train on unet #{self.only_train_unet_number}'
|
| 1875 |
+
|
| 1876 |
+
images = cast_uint8_images_to_float(images)
|
| 1877 |
+
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
|
| 1878 |
+
|
| 1879 |
+
assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
|
| 1880 |
+
|
| 1881 |
+
unet_index = unet_number - 1
|
| 1882 |
+
|
| 1883 |
+
unet = default(unet, lambda: self.get_unet(unet_number))
|
| 1884 |
+
|
| 1885 |
+
assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
|
| 1886 |
+
|
| 1887 |
+
noise_scheduler = self.noise_schedulers[unet_index]
|
| 1888 |
+
noise_scheduler_lbl = self.noise_schedulers_lbl[unet_index]
|
| 1889 |
+
p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
|
| 1890 |
+
pred_objective = self.pred_objectives[unet_index]
|
| 1891 |
+
target_image_size = self.image_sizes[unet_index]
|
| 1892 |
+
random_crop_size = self.random_crop_sizes[unet_index]
|
| 1893 |
+
prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
|
| 1894 |
+
|
| 1895 |
+
b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5
|
| 1896 |
+
|
| 1897 |
+
check_shape(images, 'b c ...', c=self.channels)
|
| 1898 |
+
assert h >= target_image_size[0] and w >= target_image_size[1]
|
| 1899 |
+
|
| 1900 |
+
frames = images.shape[2] if is_video else None
|
| 1901 |
+
|
| 1902 |
+
times = noise_scheduler.sample_random_times(b, device=device)
|
| 1903 |
+
|
| 1904 |
+
if exists(texts) and not exists(text_embeds) and not self.unconditional:
|
| 1905 |
+
assert all([*map(len, texts)]), 'text cannot be empty'
|
| 1906 |
+
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
|
| 1907 |
+
|
| 1908 |
+
with autocast(enabled=False):
|
| 1909 |
+
text_embeds, text_masks = self.encode_text(texts, return_attn_mask=True)
|
| 1910 |
+
|
| 1911 |
+
text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
|
| 1912 |
+
|
| 1913 |
+
if not self.unconditional:
|
| 1914 |
+
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim=-1))
|
| 1915 |
+
|
| 1916 |
+
assert not (self.condition_on_text and not exists(text_embeds)
|
| 1917 |
+
), 'text or text encodings must be passed into decoder if specified'
|
| 1918 |
+
assert not (not self.condition_on_text and exists(text_embeds)
|
| 1919 |
+
), 'decoder specified not to be conditioned on text, yet it is presented'
|
| 1920 |
+
|
| 1921 |
+
assert not (exists(
|
| 1922 |
+
text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
|
| 1923 |
+
|
| 1924 |
+
lowres_cond_img = lowres_cond_lbl = lowres_aug_times = None
|
| 1925 |
+
if exists(prev_image_size):
|
| 1926 |
+
lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range=self.input_image_range)
|
| 1927 |
+
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range=self.input_image_range)
|
| 1928 |
+
lowres_cond_lbl = self.resize_to(labels, prev_image_size, clamp_range=None)
|
| 1929 |
+
lowres_cond_lbl = self.resize_to(lowres_cond_lbl, target_image_size, clamp_range=None)
|
| 1930 |
+
|
| 1931 |
+
if self.per_sample_random_aug_noise_level:
|
| 1932 |
+
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(
|
| 1933 |
+
b, self.lowres_max_thres, device=device)
|
| 1934 |
+
else:
|
| 1935 |
+
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(
|
| 1936 |
+
1, self.lowres_max_thres, device=device)
|
| 1937 |
+
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b=b)
|
| 1938 |
+
|
| 1939 |
+
images = self.resize_to(images, target_image_size)
|
| 1940 |
+
labels = self.resize_to(labels, target_image_size)
|
| 1941 |
+
|
| 1942 |
+
return self.p_losses(unet, images, labels, times, text_embeds=text_embeds, text_mask=text_masks, cond_images=cond_images, noise_scheduler=noise_scheduler, noise_scheduler_lbl=noise_scheduler_lbl, lowres_cond_img=lowres_cond_img, lowres_cond_lbl=lowres_cond_lbl, lowres_aug_times=lowres_aug_times, pred_objective=pred_objective, p2_loss_weight_gamma=p2_loss_weight_gamma, random_crop_size=random_crop_size)
|
imagen_pytorch/t5.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import transformers
|
| 3 |
+
from typing import List
|
| 4 |
+
from transformers import T5Tokenizer, T5EncoderModel, T5Config
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
transformers.logging.set_verbosity_error()
|
| 8 |
+
|
| 9 |
+
def exists(val):
|
| 10 |
+
return val is not None
|
| 11 |
+
|
| 12 |
+
def default(val, d):
|
| 13 |
+
if exists(val):
|
| 14 |
+
return val
|
| 15 |
+
return d() if callable(d) else d
|
| 16 |
+
|
| 17 |
+
# config
|
| 18 |
+
|
| 19 |
+
MAX_LENGTH = 256
|
| 20 |
+
|
| 21 |
+
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
|
| 22 |
+
|
| 23 |
+
T5_CONFIGS = {}
|
| 24 |
+
|
| 25 |
+
# singleton globals
|
| 26 |
+
|
| 27 |
+
def get_tokenizer(name):
|
| 28 |
+
tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH, cache_dir="checkpoints")
|
| 29 |
+
return tokenizer
|
| 30 |
+
|
| 31 |
+
def get_model(name):
|
| 32 |
+
model = T5EncoderModel.from_pretrained(name, cache_dir="checkpoints")
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
def get_model_and_tokenizer(name):
|
| 36 |
+
global T5_CONFIGS
|
| 37 |
+
|
| 38 |
+
if name not in T5_CONFIGS:
|
| 39 |
+
T5_CONFIGS[name] = dict()
|
| 40 |
+
if "model" not in T5_CONFIGS[name]:
|
| 41 |
+
T5_CONFIGS[name]["model"] = get_model(name)
|
| 42 |
+
if "tokenizer" not in T5_CONFIGS[name]:
|
| 43 |
+
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
|
| 44 |
+
|
| 45 |
+
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
|
| 46 |
+
|
| 47 |
+
def get_encoded_dim(name):
|
| 48 |
+
if name not in T5_CONFIGS:
|
| 49 |
+
# avoids loading the model if we only want to get the dim
|
| 50 |
+
config = T5Config.from_pretrained(name, cache_dir="checkpoints")
|
| 51 |
+
T5_CONFIGS[name] = dict(config=config)
|
| 52 |
+
elif "config" in T5_CONFIGS[name]:
|
| 53 |
+
config = T5_CONFIGS[name]["config"]
|
| 54 |
+
elif "model" in T5_CONFIGS[name]:
|
| 55 |
+
config = T5_CONFIGS[name]["model"].config
|
| 56 |
+
else:
|
| 57 |
+
assert False
|
| 58 |
+
return config.d_model
|
| 59 |
+
|
| 60 |
+
# encoding text
|
| 61 |
+
|
| 62 |
+
def t5_tokenize(
|
| 63 |
+
texts: List[str],
|
| 64 |
+
name = DEFAULT_T5_NAME
|
| 65 |
+
):
|
| 66 |
+
t5, tokenizer = get_model_and_tokenizer(name)
|
| 67 |
+
|
| 68 |
+
if torch.cuda.is_available():
|
| 69 |
+
t5 = t5.cuda()
|
| 70 |
+
|
| 71 |
+
device = next(t5.parameters()).device
|
| 72 |
+
|
| 73 |
+
encoded = tokenizer.batch_encode_plus(
|
| 74 |
+
texts,
|
| 75 |
+
return_tensors = "pt",
|
| 76 |
+
padding = 'longest',
|
| 77 |
+
max_length = MAX_LENGTH,
|
| 78 |
+
truncation = True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
input_ids = encoded.input_ids.to(device)
|
| 82 |
+
attn_mask = encoded.attention_mask.to(device)
|
| 83 |
+
return input_ids, attn_mask
|
| 84 |
+
|
| 85 |
+
def t5_encode_tokenized_text(
|
| 86 |
+
token_ids,
|
| 87 |
+
attn_mask = None,
|
| 88 |
+
pad_id = None,
|
| 89 |
+
name = DEFAULT_T5_NAME
|
| 90 |
+
):
|
| 91 |
+
assert exists(attn_mask) or exists(pad_id)
|
| 92 |
+
t5, _ = get_model_and_tokenizer(name)
|
| 93 |
+
|
| 94 |
+
attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
|
| 95 |
+
|
| 96 |
+
t5.eval()
|
| 97 |
+
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
output = t5(input_ids = token_ids, attention_mask = attn_mask)
|
| 100 |
+
encoded_text = output.last_hidden_state.detach()
|
| 101 |
+
|
| 102 |
+
attn_mask = attn_mask.bool()
|
| 103 |
+
|
| 104 |
+
encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0.
|
| 105 |
+
return encoded_text
|
| 106 |
+
|
| 107 |
+
def t5_encode_text(
|
| 108 |
+
texts: List[str],
|
| 109 |
+
name = DEFAULT_T5_NAME,
|
| 110 |
+
return_attn_mask = False
|
| 111 |
+
):
|
| 112 |
+
token_ids, attn_mask = t5_tokenize(texts, name = name)
|
| 113 |
+
encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
|
| 114 |
+
|
| 115 |
+
if return_attn_mask:
|
| 116 |
+
attn_mask = attn_mask.bool()
|
| 117 |
+
return encoded_text, attn_mask
|
| 118 |
+
|
| 119 |
+
return encoded_text
|
imagen_pytorch/trainer.py
ADDED
|
@@ -0,0 +1,1782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import copy
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from math import ceil
|
| 7 |
+
from contextlib import contextmanager, nullcontext
|
| 8 |
+
from functools import partial, wraps
|
| 9 |
+
from collections.abc import Iterable
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.utils.data import random_split, DataLoader
|
| 15 |
+
from torch.optim import Adam
|
| 16 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
|
| 17 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 18 |
+
|
| 19 |
+
import pytorch_warmup as warmup
|
| 20 |
+
|
| 21 |
+
from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
|
| 22 |
+
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
|
| 23 |
+
from imagen_pytorch.joint_imagen import JointImagen
|
| 24 |
+
from imagen_pytorch.data import cycle
|
| 25 |
+
|
| 26 |
+
from imagen_pytorch.version import __version__
|
| 27 |
+
from packaging import version
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
from ema_pytorch import EMA
|
| 32 |
+
|
| 33 |
+
from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
|
| 34 |
+
|
| 35 |
+
from fsspec.core import url_to_fs
|
| 36 |
+
from fsspec.implementations.local import LocalFileSystem
|
| 37 |
+
|
| 38 |
+
# helper functions
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def exists(val):
|
| 42 |
+
return val is not None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def default(val, d):
|
| 46 |
+
if exists(val):
|
| 47 |
+
return val
|
| 48 |
+
return d() if callable(d) else d
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def cast_tuple(val, length=1):
|
| 52 |
+
if isinstance(val, list):
|
| 53 |
+
val = tuple(val)
|
| 54 |
+
|
| 55 |
+
return val if isinstance(val, tuple) else ((val,) * length)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def find_first(fn, arr):
|
| 59 |
+
for ind, el in enumerate(arr):
|
| 60 |
+
if fn(el):
|
| 61 |
+
return ind
|
| 62 |
+
return -1
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def pick_and_pop(keys, d):
|
| 66 |
+
values = list(map(lambda key: d.pop(key), keys))
|
| 67 |
+
return dict(zip(keys, values))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def group_dict_by_key(cond, d):
|
| 71 |
+
return_val = [dict(), dict()]
|
| 72 |
+
for key in d.keys():
|
| 73 |
+
match = bool(cond(key))
|
| 74 |
+
ind = int(not match)
|
| 75 |
+
return_val[ind][key] = d[key]
|
| 76 |
+
return (*return_val,)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def string_begins_with(prefix, str):
|
| 80 |
+
return str.startswith(prefix)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def group_by_key_prefix(prefix, d):
|
| 84 |
+
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def groupby_prefix_and_trim(prefix, d):
|
| 88 |
+
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
| 89 |
+
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
| 90 |
+
return kwargs_without_prefix, kwargs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def num_to_groups(num, divisor):
|
| 94 |
+
groups = num // divisor
|
| 95 |
+
remainder = num % divisor
|
| 96 |
+
arr = [divisor] * groups
|
| 97 |
+
if remainder > 0:
|
| 98 |
+
arr.append(remainder)
|
| 99 |
+
return arr
|
| 100 |
+
|
| 101 |
+
# url to fs, bucket, path - for checkpointing to cloud
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def url_to_bucket(url):
|
| 105 |
+
if '://' not in url:
|
| 106 |
+
return url
|
| 107 |
+
|
| 108 |
+
prefix, suffix = url.split('://')
|
| 109 |
+
|
| 110 |
+
if prefix in {'gs', 's3'}:
|
| 111 |
+
return suffix.split('/')[0]
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
|
| 114 |
+
|
| 115 |
+
# decorators
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def eval_decorator(fn):
|
| 119 |
+
def inner(model, *args, **kwargs):
|
| 120 |
+
was_training = model.training
|
| 121 |
+
model.eval()
|
| 122 |
+
out = fn(model, *args, **kwargs)
|
| 123 |
+
model.train(was_training)
|
| 124 |
+
return out
|
| 125 |
+
return inner
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def cast_torch_tensor(fn, cast_fp16=False):
|
| 129 |
+
@wraps(fn)
|
| 130 |
+
def inner(model, *args, **kwargs):
|
| 131 |
+
device = kwargs.pop('_device', model.device)
|
| 132 |
+
cast_device = kwargs.pop('_cast_device', True)
|
| 133 |
+
|
| 134 |
+
should_cast_fp16 = cast_fp16 and model.cast_half_at_training
|
| 135 |
+
|
| 136 |
+
kwargs_keys = kwargs.keys()
|
| 137 |
+
all_args = (*args, *kwargs.values())
|
| 138 |
+
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
| 139 |
+
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
| 140 |
+
|
| 141 |
+
if cast_device:
|
| 142 |
+
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
| 143 |
+
|
| 144 |
+
if should_cast_fp16:
|
| 145 |
+
all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(
|
| 146 |
+
t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
|
| 147 |
+
|
| 148 |
+
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
| 149 |
+
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
| 150 |
+
|
| 151 |
+
out = fn(model, *args, **kwargs)
|
| 152 |
+
return out
|
| 153 |
+
return inner
|
| 154 |
+
|
| 155 |
+
# gradient accumulation functions
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def split_iterable(it, split_size):
|
| 159 |
+
accum = []
|
| 160 |
+
for ind in range(ceil(len(it) / split_size)):
|
| 161 |
+
start_index = ind * split_size
|
| 162 |
+
accum.append(it[start_index: (start_index + split_size)])
|
| 163 |
+
return accum
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def split(t, split_size=None):
|
| 167 |
+
if not exists(split_size):
|
| 168 |
+
return t
|
| 169 |
+
|
| 170 |
+
if isinstance(t, torch.Tensor):
|
| 171 |
+
return t.split(split_size, dim=0)
|
| 172 |
+
|
| 173 |
+
if isinstance(t, Iterable):
|
| 174 |
+
return split_iterable(t, split_size)
|
| 175 |
+
|
| 176 |
+
return TypeError
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def find_first(cond, arr):
|
| 180 |
+
for el in arr:
|
| 181 |
+
if cond(el):
|
| 182 |
+
return el
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def split_args_and_kwargs(*args, split_size=None, **kwargs):
|
| 187 |
+
all_args = (*args, *kwargs.values())
|
| 188 |
+
len_all_args = len(all_args)
|
| 189 |
+
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
| 190 |
+
assert exists(first_tensor)
|
| 191 |
+
|
| 192 |
+
batch_size = len(first_tensor)
|
| 193 |
+
split_size = default(split_size, batch_size)
|
| 194 |
+
num_chunks = ceil(batch_size / split_size)
|
| 195 |
+
|
| 196 |
+
dict_len = len(kwargs)
|
| 197 |
+
dict_keys = kwargs.keys()
|
| 198 |
+
split_kwargs_index = len_all_args - dict_len
|
| 199 |
+
|
| 200 |
+
split_all_args = [split(arg, split_size=split_size) if exists(arg) and isinstance(
|
| 201 |
+
arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
|
| 202 |
+
chunk_sizes = tuple(map(len, split_all_args[0]))
|
| 203 |
+
|
| 204 |
+
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
| 205 |
+
chunked_args, chunked_kwargs_values = chunked_all_args[:
|
| 206 |
+
split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
| 207 |
+
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
| 208 |
+
chunk_size_frac = chunk_size / batch_size
|
| 209 |
+
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
| 210 |
+
|
| 211 |
+
# imagen trainer
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def imagen_sample_in_chunks(fn):
|
| 215 |
+
@wraps(fn)
|
| 216 |
+
def inner(self, *args, max_batch_size=None, **kwargs):
|
| 217 |
+
if not exists(max_batch_size):
|
| 218 |
+
return fn(self, *args, **kwargs)
|
| 219 |
+
|
| 220 |
+
if self.imagen.unconditional:
|
| 221 |
+
batch_size = kwargs.get('batch_size')
|
| 222 |
+
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
| 223 |
+
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
| 224 |
+
else:
|
| 225 |
+
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs)
|
| 226 |
+
in split_args_and_kwargs(*args, split_size=max_batch_size, **kwargs)]
|
| 227 |
+
|
| 228 |
+
if isinstance(outputs[0], torch.Tensor):
|
| 229 |
+
return torch.cat(outputs, dim=0)
|
| 230 |
+
|
| 231 |
+
return list(map(lambda t: torch.cat(t, dim=0), list(zip(*outputs))))
|
| 232 |
+
|
| 233 |
+
return inner
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def restore_parts(state_dict_target, state_dict_from):
|
| 237 |
+
for name, param in state_dict_from.items():
|
| 238 |
+
|
| 239 |
+
if name not in state_dict_target:
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
if param.size() == state_dict_target[name].size():
|
| 243 |
+
state_dict_target[name].copy_(param)
|
| 244 |
+
else:
|
| 245 |
+
print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
|
| 246 |
+
|
| 247 |
+
return state_dict_target
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_unet_from_trainer(trainer, checkpoint_path, src_unet_idx, tgt_unet_idx, only_model=True):
|
| 251 |
+
assert only_model == True # TODO optimizer, scheduler ...
|
| 252 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
| 253 |
+
state_dict = OrderedDict()
|
| 254 |
+
for key, val in ckpt['model'].items():
|
| 255 |
+
if key.startswith(f'unets.{src_unet_idx}'):
|
| 256 |
+
state_dict[key[8:]] = val
|
| 257 |
+
trainer.imagen.unets[tgt_unet_idx].load_state_dict(state_dict)
|
| 258 |
+
state_dict = OrderedDict()
|
| 259 |
+
for key, val in ckpt['ema'].items():
|
| 260 |
+
if key.startswith(f'{src_unet_idx}.'):
|
| 261 |
+
state_dict[key[2:]] = val
|
| 262 |
+
trainer.ema_unets[tgt_unet_idx].load_state_dict(state_dict)
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class ImagenTrainer(nn.Module):
|
| 267 |
+
locked = False
|
| 268 |
+
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
imagen=None,
|
| 272 |
+
imagen_checkpoint_path=None,
|
| 273 |
+
use_ema=True,
|
| 274 |
+
lr=1e-4,
|
| 275 |
+
eps=1e-8,
|
| 276 |
+
beta1=0.9,
|
| 277 |
+
beta2=0.99,
|
| 278 |
+
max_grad_norm=None,
|
| 279 |
+
group_wd_params=True,
|
| 280 |
+
warmup_steps=None,
|
| 281 |
+
cosine_decay_max_steps=None,
|
| 282 |
+
only_train_unet_number=None,
|
| 283 |
+
fp16=False,
|
| 284 |
+
precision=None,
|
| 285 |
+
split_batches=True,
|
| 286 |
+
dl_tuple_output_keywords_names=('images', 'texts'),
|
| 287 |
+
verbose=True,
|
| 288 |
+
split_valid_fraction=0.025,
|
| 289 |
+
split_valid_from_train=False,
|
| 290 |
+
split_random_seed=42,
|
| 291 |
+
checkpoint_path=None,
|
| 292 |
+
checkpoint_every=None,
|
| 293 |
+
checkpoint_fs=None,
|
| 294 |
+
fs_kwargs: dict = None,
|
| 295 |
+
max_checkpoints_keep=20,
|
| 296 |
+
**kwargs
|
| 297 |
+
):
|
| 298 |
+
super().__init__()
|
| 299 |
+
assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
|
| 300 |
+
assert exists(imagen) ^ exists(
|
| 301 |
+
imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
|
| 302 |
+
|
| 303 |
+
# determine filesystem, using fsspec, for saving to local filesystem or cloud
|
| 304 |
+
|
| 305 |
+
self.fs = checkpoint_fs
|
| 306 |
+
|
| 307 |
+
if not exists(self.fs):
|
| 308 |
+
fs_kwargs = default(fs_kwargs, {})
|
| 309 |
+
self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
|
| 310 |
+
|
| 311 |
+
assert isinstance(imagen, (Imagen, ElucidatedImagen))
|
| 312 |
+
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
| 313 |
+
|
| 314 |
+
# elucidated or not
|
| 315 |
+
|
| 316 |
+
self.is_elucidated = isinstance(imagen, ElucidatedImagen)
|
| 317 |
+
|
| 318 |
+
# create accelerator instance
|
| 319 |
+
|
| 320 |
+
accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)
|
| 321 |
+
|
| 322 |
+
assert not (fp16 and exists(precision)
|
| 323 |
+
), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
|
| 324 |
+
accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')
|
| 325 |
+
|
| 326 |
+
self.accelerator = Accelerator(**{
|
| 327 |
+
'split_batches': split_batches,
|
| 328 |
+
'mixed_precision': accelerator_mixed_precision,
|
| 329 |
+
'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerate_kwargs})
|
| 330 |
+
|
| 331 |
+
ImagenTrainer.locked = self.is_distributed
|
| 332 |
+
|
| 333 |
+
# cast data to fp16 at training time if needed
|
| 334 |
+
|
| 335 |
+
self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
|
| 336 |
+
|
| 337 |
+
# grad scaler must be managed outside of accelerator
|
| 338 |
+
|
| 339 |
+
grad_scaler_enabled = fp16
|
| 340 |
+
|
| 341 |
+
# imagen, unets and ema unets
|
| 342 |
+
|
| 343 |
+
self.imagen = imagen
|
| 344 |
+
self.num_unets = len(self.imagen.unets)
|
| 345 |
+
|
| 346 |
+
self.use_ema = use_ema and self.is_main
|
| 347 |
+
self.ema_unets = nn.ModuleList([])
|
| 348 |
+
|
| 349 |
+
# keep track of what unet is being trained on
|
| 350 |
+
# only going to allow 1 unet training at a time
|
| 351 |
+
|
| 352 |
+
self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on
|
| 353 |
+
|
| 354 |
+
# data related functions
|
| 355 |
+
|
| 356 |
+
self.train_dl_iter = None
|
| 357 |
+
self.train_dl = None
|
| 358 |
+
|
| 359 |
+
self.valid_dl_iter = None
|
| 360 |
+
self.valid_dl = None
|
| 361 |
+
|
| 362 |
+
self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
|
| 363 |
+
|
| 364 |
+
# auto splitting validation from training, if dataset is passed in
|
| 365 |
+
|
| 366 |
+
self.split_valid_from_train = split_valid_from_train
|
| 367 |
+
|
| 368 |
+
assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
|
| 369 |
+
self.split_valid_fraction = split_valid_fraction
|
| 370 |
+
self.split_random_seed = split_random_seed
|
| 371 |
+
|
| 372 |
+
# be able to finely customize learning rate, weight decay
|
| 373 |
+
# per unet
|
| 374 |
+
|
| 375 |
+
lr, eps, warmup_steps, cosine_decay_max_steps = map(
|
| 376 |
+
partial(cast_tuple, length=self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))
|
| 377 |
+
|
| 378 |
+
for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
|
| 379 |
+
optimizer = Adam(
|
| 380 |
+
unet.parameters(),
|
| 381 |
+
lr=unet_lr,
|
| 382 |
+
eps=unet_eps,
|
| 383 |
+
betas=(beta1, beta2),
|
| 384 |
+
**kwargs
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
if self.use_ema:
|
| 388 |
+
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
| 389 |
+
|
| 390 |
+
scaler = GradScaler(enabled=grad_scaler_enabled)
|
| 391 |
+
|
| 392 |
+
scheduler = warmup_scheduler = None
|
| 393 |
+
|
| 394 |
+
if exists(unet_cosine_decay_max_steps):
|
| 395 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=unet_cosine_decay_max_steps)
|
| 396 |
+
|
| 397 |
+
if exists(unet_warmup_steps):
|
| 398 |
+
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=unet_warmup_steps)
|
| 399 |
+
|
| 400 |
+
if not exists(scheduler):
|
| 401 |
+
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
|
| 402 |
+
|
| 403 |
+
# set on object
|
| 404 |
+
|
| 405 |
+
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
| 406 |
+
setattr(self, f'scaler{ind}', scaler)
|
| 407 |
+
setattr(self, f'scheduler{ind}', scheduler)
|
| 408 |
+
setattr(self, f'warmup{ind}', warmup_scheduler)
|
| 409 |
+
|
| 410 |
+
# gradient clipping if needed
|
| 411 |
+
|
| 412 |
+
self.max_grad_norm = max_grad_norm
|
| 413 |
+
|
| 414 |
+
# step tracker and misc
|
| 415 |
+
|
| 416 |
+
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
| 417 |
+
|
| 418 |
+
self.verbose = verbose
|
| 419 |
+
|
| 420 |
+
# automatic set devices based on what accelerator decided
|
| 421 |
+
|
| 422 |
+
self.imagen.to(self.device)
|
| 423 |
+
self.to(self.device)
|
| 424 |
+
|
| 425 |
+
# checkpointing
|
| 426 |
+
|
| 427 |
+
assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
|
| 428 |
+
self.checkpoint_path = checkpoint_path
|
| 429 |
+
self.checkpoint_every = checkpoint_every
|
| 430 |
+
self.max_checkpoints_keep = max_checkpoints_keep
|
| 431 |
+
|
| 432 |
+
self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
|
| 433 |
+
|
| 434 |
+
if exists(checkpoint_path) and self.can_checkpoint:
|
| 435 |
+
bucket = url_to_bucket(checkpoint_path)
|
| 436 |
+
|
| 437 |
+
if not self.fs.exists(bucket):
|
| 438 |
+
self.fs.mkdir(bucket)
|
| 439 |
+
|
| 440 |
+
self.load_from_checkpoint_folder()
|
| 441 |
+
|
| 442 |
+
# only allowing training for unet
|
| 443 |
+
|
| 444 |
+
self.only_train_unet_number = only_train_unet_number
|
| 445 |
+
self.validate_and_set_unet_being_trained(only_train_unet_number)
|
| 446 |
+
|
| 447 |
+
# computed values
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def device(self):
|
| 451 |
+
return self.accelerator.device
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def is_distributed(self):
|
| 455 |
+
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def is_main(self):
|
| 459 |
+
return self.accelerator.is_main_process
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def is_local_main(self):
|
| 463 |
+
return self.accelerator.is_local_main_process
|
| 464 |
+
|
| 465 |
+
@property
|
| 466 |
+
def unwrapped_unet(self):
|
| 467 |
+
return self.accelerator.unwrap_model(self.unet_being_trained)
|
| 468 |
+
|
| 469 |
+
# optimizer helper functions
|
| 470 |
+
|
| 471 |
+
def get_lr(self, unet_number):
|
| 472 |
+
self.validate_unet_number(unet_number)
|
| 473 |
+
unet_index = unet_number - 1
|
| 474 |
+
|
| 475 |
+
optim = getattr(self, f'optim{unet_index}')
|
| 476 |
+
|
| 477 |
+
return optim.param_groups[0]['lr']
|
| 478 |
+
|
| 479 |
+
# function for allowing only one unet from being trained at a time
|
| 480 |
+
|
| 481 |
+
def validate_and_set_unet_being_trained(self, unet_number=None):
|
| 482 |
+
if exists(unet_number):
|
| 483 |
+
self.validate_unet_number(unet_number)
|
| 484 |
+
|
| 485 |
+
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
|
| 486 |
+
|
| 487 |
+
self.only_train_unet_number = unet_number
|
| 488 |
+
self.imagen.only_train_unet_number = unet_number
|
| 489 |
+
|
| 490 |
+
if not exists(unet_number):
|
| 491 |
+
return
|
| 492 |
+
|
| 493 |
+
self.wrap_unet(unet_number)
|
| 494 |
+
|
| 495 |
+
def wrap_unet(self, unet_number):
|
| 496 |
+
if hasattr(self, 'one_unet_wrapped'):
|
| 497 |
+
return
|
| 498 |
+
|
| 499 |
+
unet = self.imagen.get_unet(unet_number)
|
| 500 |
+
self.unet_being_trained = self.accelerator.prepare(unet)
|
| 501 |
+
unet_index = unet_number - 1
|
| 502 |
+
|
| 503 |
+
optimizer = getattr(self, f'optim{unet_index}')
|
| 504 |
+
scheduler = getattr(self, f'scheduler{unet_index}')
|
| 505 |
+
|
| 506 |
+
optimizer = self.accelerator.prepare(optimizer)
|
| 507 |
+
|
| 508 |
+
if exists(scheduler):
|
| 509 |
+
scheduler = self.accelerator.prepare(scheduler)
|
| 510 |
+
|
| 511 |
+
setattr(self, f'optim{unet_index}', optimizer)
|
| 512 |
+
setattr(self, f'scheduler{unet_index}', scheduler)
|
| 513 |
+
|
| 514 |
+
self.one_unet_wrapped = True
|
| 515 |
+
|
| 516 |
+
# hacking accelerator due to not having separate gradscaler per optimizer
|
| 517 |
+
|
| 518 |
+
def set_accelerator_scaler(self, unet_number):
|
| 519 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 520 |
+
scaler = getattr(self, f'scaler{unet_number - 1}')
|
| 521 |
+
|
| 522 |
+
self.accelerator.scaler = scaler
|
| 523 |
+
for optimizer in self.accelerator._optimizers:
|
| 524 |
+
optimizer.scaler = scaler
|
| 525 |
+
|
| 526 |
+
# helper print
|
| 527 |
+
|
| 528 |
+
def print(self, msg):
|
| 529 |
+
if not self.is_main:
|
| 530 |
+
return
|
| 531 |
+
|
| 532 |
+
if not self.verbose:
|
| 533 |
+
return
|
| 534 |
+
|
| 535 |
+
return self.accelerator.print(msg)
|
| 536 |
+
|
| 537 |
+
# validating the unet number
|
| 538 |
+
|
| 539 |
+
def validate_unet_number(self, unet_number=None):
|
| 540 |
+
if self.num_unets == 1:
|
| 541 |
+
unet_number = default(unet_number, 1)
|
| 542 |
+
|
| 543 |
+
assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
|
| 544 |
+
return unet_number
|
| 545 |
+
|
| 546 |
+
# number of training steps taken
|
| 547 |
+
|
| 548 |
+
def num_steps_taken(self, unet_number=None):
|
| 549 |
+
if self.num_unets == 1:
|
| 550 |
+
unet_number = default(unet_number, 1)
|
| 551 |
+
|
| 552 |
+
return self.steps[unet_number - 1].item()
|
| 553 |
+
|
| 554 |
+
def print_untrained_unets(self):
|
| 555 |
+
print_final_error = False
|
| 556 |
+
|
| 557 |
+
for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
|
| 558 |
+
if steps > 0 or isinstance(unet, NullUnet):
|
| 559 |
+
continue
|
| 560 |
+
|
| 561 |
+
self.print(f'unet {ind + 1} has not been trained')
|
| 562 |
+
print_final_error = True
|
| 563 |
+
|
| 564 |
+
if print_final_error:
|
| 565 |
+
self.print(
|
| 566 |
+
'when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
|
| 567 |
+
|
| 568 |
+
# data related functions
|
| 569 |
+
|
| 570 |
+
def add_train_dataloader(self, dl=None):
|
| 571 |
+
if not exists(dl):
|
| 572 |
+
return
|
| 573 |
+
|
| 574 |
+
assert not exists(self.train_dl), 'training dataloader was already added'
|
| 575 |
+
self.train_dl = self.accelerator.prepare(dl)
|
| 576 |
+
|
| 577 |
+
def add_valid_dataloader(self, dl):
|
| 578 |
+
if not exists(dl):
|
| 579 |
+
return
|
| 580 |
+
|
| 581 |
+
assert not exists(self.valid_dl), 'validation dataloader was already added'
|
| 582 |
+
self.valid_dl = self.accelerator.prepare(dl)
|
| 583 |
+
|
| 584 |
+
def add_train_dataset(self, ds=None, *, batch_size, **dl_kwargs):
|
| 585 |
+
if not exists(ds):
|
| 586 |
+
return
|
| 587 |
+
|
| 588 |
+
assert not exists(self.train_dl), 'training dataloader was already added'
|
| 589 |
+
|
| 590 |
+
valid_ds = None
|
| 591 |
+
if self.split_valid_from_train:
|
| 592 |
+
train_size = int((1 - self.split_valid_fraction) * len(ds))
|
| 593 |
+
valid_size = len(ds) - train_size
|
| 594 |
+
|
| 595 |
+
ds, valid_ds = random_split(ds, [train_size, valid_size],
|
| 596 |
+
generator=torch.Generator().manual_seed(self.split_random_seed))
|
| 597 |
+
self.print(f'training with dataset of {len(ds)} samples '
|
| 598 |
+
f'and validating with randomly splitted {len(valid_ds)} samples')
|
| 599 |
+
|
| 600 |
+
dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
|
| 601 |
+
self.train_dl = self.accelerator.prepare(dl)
|
| 602 |
+
|
| 603 |
+
if not self.split_valid_from_train:
|
| 604 |
+
return
|
| 605 |
+
|
| 606 |
+
self.add_valid_dataset(valid_ds, batch_size=batch_size, **dl_kwargs)
|
| 607 |
+
|
| 608 |
+
def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
|
| 609 |
+
if not exists(ds):
|
| 610 |
+
return
|
| 611 |
+
|
| 612 |
+
assert not exists(self.valid_dl), 'validation dataloader was already added'
|
| 613 |
+
|
| 614 |
+
dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
|
| 615 |
+
self.valid_dl = self.accelerator.prepare(dl)
|
| 616 |
+
|
| 617 |
+
def create_train_iter(self):
|
| 618 |
+
assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
|
| 619 |
+
|
| 620 |
+
if exists(self.train_dl_iter):
|
| 621 |
+
return
|
| 622 |
+
|
| 623 |
+
self.train_dl_iter = cycle(self.train_dl)
|
| 624 |
+
|
| 625 |
+
def create_valid_iter(self):
|
| 626 |
+
assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
|
| 627 |
+
|
| 628 |
+
if exists(self.valid_dl_iter):
|
| 629 |
+
return
|
| 630 |
+
|
| 631 |
+
self.valid_dl_iter = cycle(self.valid_dl)
|
| 632 |
+
|
| 633 |
+
def train_step(self, unet_number=None, **kwargs):
|
| 634 |
+
self.create_train_iter()
|
| 635 |
+
loss = self.step_with_dl_iter(self.train_dl_iter, unet_number=unet_number, **kwargs)
|
| 636 |
+
self.update(unet_number=unet_number)
|
| 637 |
+
return loss
|
| 638 |
+
|
| 639 |
+
@torch.no_grad()
|
| 640 |
+
@eval_decorator
|
| 641 |
+
def valid_step(self, **kwargs):
|
| 642 |
+
self.create_valid_iter()
|
| 643 |
+
|
| 644 |
+
context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
|
| 645 |
+
|
| 646 |
+
with context():
|
| 647 |
+
loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
|
| 648 |
+
return loss
|
| 649 |
+
|
| 650 |
+
def step_with_dl_iter(self, dl_iter, **kwargs):
|
| 651 |
+
dl_tuple_output = cast_tuple(next(dl_iter))
|
| 652 |
+
model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
|
| 653 |
+
loss = self.forward(**{**kwargs, **model_input})
|
| 654 |
+
return loss
|
| 655 |
+
|
| 656 |
+
# checkpointing functions
|
| 657 |
+
|
| 658 |
+
@property
|
| 659 |
+
def all_checkpoints_sorted(self):
|
| 660 |
+
glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
|
| 661 |
+
checkpoints = self.fs.glob(glob_pattern)
|
| 662 |
+
sorted_checkpoints = sorted(checkpoints, key=lambda x: int(str(x).split('.')[-2]), reverse=True)
|
| 663 |
+
return sorted_checkpoints
|
| 664 |
+
|
| 665 |
+
def load_from_checkpoint_folder(self, last_total_steps=-1):
|
| 666 |
+
if last_total_steps != -1:
|
| 667 |
+
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
|
| 668 |
+
self.load(filepath)
|
| 669 |
+
return
|
| 670 |
+
|
| 671 |
+
sorted_checkpoints = self.all_checkpoints_sorted
|
| 672 |
+
|
| 673 |
+
if len(sorted_checkpoints) == 0:
|
| 674 |
+
self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
|
| 675 |
+
return
|
| 676 |
+
|
| 677 |
+
last_checkpoint = sorted_checkpoints[0]
|
| 678 |
+
self.load(last_checkpoint)
|
| 679 |
+
|
| 680 |
+
def save_to_checkpoint_folder(self):
|
| 681 |
+
self.accelerator.wait_for_everyone()
|
| 682 |
+
|
| 683 |
+
if not self.can_checkpoint:
|
| 684 |
+
return
|
| 685 |
+
|
| 686 |
+
total_steps = int(self.steps.sum().item())
|
| 687 |
+
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
|
| 688 |
+
|
| 689 |
+
self.save(filepath)
|
| 690 |
+
|
| 691 |
+
if self.max_checkpoints_keep <= 0:
|
| 692 |
+
return
|
| 693 |
+
|
| 694 |
+
sorted_checkpoints = self.all_checkpoints_sorted
|
| 695 |
+
checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
|
| 696 |
+
|
| 697 |
+
for checkpoint in checkpoints_to_discard:
|
| 698 |
+
self.fs.rm(checkpoint)
|
| 699 |
+
|
| 700 |
+
# saving and loading functions
|
| 701 |
+
|
| 702 |
+
def save(
|
| 703 |
+
self,
|
| 704 |
+
path,
|
| 705 |
+
overwrite=True,
|
| 706 |
+
without_optim_and_sched=False,
|
| 707 |
+
**kwargs
|
| 708 |
+
):
|
| 709 |
+
# self.accelerator.wait_for_everyone()
|
| 710 |
+
|
| 711 |
+
if not self.can_checkpoint:
|
| 712 |
+
return
|
| 713 |
+
|
| 714 |
+
fs = self.fs
|
| 715 |
+
|
| 716 |
+
assert not (fs.exists(path) and not overwrite)
|
| 717 |
+
|
| 718 |
+
self.reset_ema_unets_all_one_device()
|
| 719 |
+
|
| 720 |
+
save_obj = dict(
|
| 721 |
+
model=self.imagen.state_dict(),
|
| 722 |
+
version=__version__,
|
| 723 |
+
steps=self.steps.cpu(),
|
| 724 |
+
**kwargs
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
|
| 728 |
+
|
| 729 |
+
for ind in save_optim_and_sched_iter:
|
| 730 |
+
scaler_key = f'scaler{ind}'
|
| 731 |
+
optimizer_key = f'optim{ind}'
|
| 732 |
+
scheduler_key = f'scheduler{ind}'
|
| 733 |
+
warmup_scheduler_key = f'warmup{ind}'
|
| 734 |
+
|
| 735 |
+
scaler = getattr(self, scaler_key)
|
| 736 |
+
optimizer = getattr(self, optimizer_key)
|
| 737 |
+
scheduler = getattr(self, scheduler_key)
|
| 738 |
+
warmup_scheduler = getattr(self, warmup_scheduler_key)
|
| 739 |
+
|
| 740 |
+
if exists(scheduler):
|
| 741 |
+
save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
|
| 742 |
+
|
| 743 |
+
if exists(warmup_scheduler):
|
| 744 |
+
save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
|
| 745 |
+
|
| 746 |
+
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
| 747 |
+
|
| 748 |
+
if self.use_ema:
|
| 749 |
+
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
| 750 |
+
|
| 751 |
+
# determine if imagen config is available
|
| 752 |
+
|
| 753 |
+
if hasattr(self.imagen, '_config'):
|
| 754 |
+
self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')
|
| 755 |
+
|
| 756 |
+
save_obj = {
|
| 757 |
+
**save_obj,
|
| 758 |
+
'imagen_type': 'elucidated' if self.is_elucidated else 'original',
|
| 759 |
+
'imagen_params': self.imagen._config
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
# save to path
|
| 763 |
+
|
| 764 |
+
with fs.open(path, 'wb') as f:
|
| 765 |
+
torch.save(save_obj, f)
|
| 766 |
+
|
| 767 |
+
self.print(f'checkpoint saved to {path}')
|
| 768 |
+
|
| 769 |
+
def load(self, path, only_model=False, strict=True, noop_if_not_exist=False):
|
| 770 |
+
fs = self.fs
|
| 771 |
+
|
| 772 |
+
if noop_if_not_exist and not fs.exists(path):
|
| 773 |
+
self.print(f'trainer checkpoint not found at {str(path)}')
|
| 774 |
+
return
|
| 775 |
+
|
| 776 |
+
assert fs.exists(path), f'{path} does not exist'
|
| 777 |
+
|
| 778 |
+
self.reset_ema_unets_all_one_device()
|
| 779 |
+
|
| 780 |
+
# to avoid extra GPU memory usage in main process when using Accelerate
|
| 781 |
+
|
| 782 |
+
with fs.open(path) as f:
|
| 783 |
+
loaded_obj = torch.load(f, map_location='cpu')
|
| 784 |
+
|
| 785 |
+
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
| 786 |
+
self.print(
|
| 787 |
+
f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
|
| 788 |
+
|
| 789 |
+
try:
|
| 790 |
+
self.imagen.load_state_dict(loaded_obj['model'], strict=strict)
|
| 791 |
+
except RuntimeError:
|
| 792 |
+
print("Failed loading state dict. Trying partial load")
|
| 793 |
+
self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
|
| 794 |
+
loaded_obj['model']))
|
| 795 |
+
|
| 796 |
+
if only_model:
|
| 797 |
+
return loaded_obj
|
| 798 |
+
|
| 799 |
+
self.steps.copy_(loaded_obj['steps'])
|
| 800 |
+
|
| 801 |
+
for ind in range(0, self.num_unets):
|
| 802 |
+
scaler_key = f'scaler{ind}'
|
| 803 |
+
optimizer_key = f'optim{ind}'
|
| 804 |
+
scheduler_key = f'scheduler{ind}'
|
| 805 |
+
warmup_scheduler_key = f'warmup{ind}'
|
| 806 |
+
|
| 807 |
+
scaler = getattr(self, scaler_key)
|
| 808 |
+
optimizer = getattr(self, optimizer_key)
|
| 809 |
+
scheduler = getattr(self, scheduler_key)
|
| 810 |
+
warmup_scheduler = getattr(self, warmup_scheduler_key)
|
| 811 |
+
|
| 812 |
+
if exists(scheduler) and scheduler_key in loaded_obj:
|
| 813 |
+
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
| 814 |
+
|
| 815 |
+
if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
|
| 816 |
+
warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
|
| 817 |
+
|
| 818 |
+
if exists(optimizer):
|
| 819 |
+
try:
|
| 820 |
+
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
| 821 |
+
scaler.load_state_dict(loaded_obj[scaler_key])
|
| 822 |
+
except:
|
| 823 |
+
self.print(
|
| 824 |
+
'could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')
|
| 825 |
+
|
| 826 |
+
if self.use_ema:
|
| 827 |
+
assert 'ema' in loaded_obj
|
| 828 |
+
try:
|
| 829 |
+
self.ema_unets.load_state_dict(loaded_obj['ema'], strict=strict)
|
| 830 |
+
except RuntimeError:
|
| 831 |
+
print("Failed loading state dict. Trying partial load")
|
| 832 |
+
self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
|
| 833 |
+
loaded_obj['ema']))
|
| 834 |
+
|
| 835 |
+
self.print(f'checkpoint loaded from {path}')
|
| 836 |
+
return loaded_obj
|
| 837 |
+
|
| 838 |
+
# managing ema unets and their devices
|
| 839 |
+
|
| 840 |
+
@property
|
| 841 |
+
def unets(self):
|
| 842 |
+
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
| 843 |
+
|
| 844 |
+
def get_ema_unet(self, unet_number=None):
|
| 845 |
+
if not self.use_ema:
|
| 846 |
+
return
|
| 847 |
+
|
| 848 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 849 |
+
index = unet_number - 1
|
| 850 |
+
|
| 851 |
+
if isinstance(self.unets, nn.ModuleList):
|
| 852 |
+
unets_list = [unet for unet in self.ema_unets]
|
| 853 |
+
delattr(self, 'ema_unets')
|
| 854 |
+
self.ema_unets = unets_list
|
| 855 |
+
|
| 856 |
+
if index != self.ema_unet_being_trained_index:
|
| 857 |
+
for unet_index, unet in enumerate(self.ema_unets):
|
| 858 |
+
unet.to(self.device if unet_index == index else 'cpu')
|
| 859 |
+
|
| 860 |
+
self.ema_unet_being_trained_index = index
|
| 861 |
+
return self.ema_unets[index]
|
| 862 |
+
|
| 863 |
+
def reset_ema_unets_all_one_device(self, device=None):
|
| 864 |
+
if not self.use_ema:
|
| 865 |
+
return
|
| 866 |
+
|
| 867 |
+
device = default(device, self.device)
|
| 868 |
+
self.ema_unets = nn.ModuleList([*self.ema_unets])
|
| 869 |
+
self.ema_unets.to(device)
|
| 870 |
+
|
| 871 |
+
self.ema_unet_being_trained_index = -1
|
| 872 |
+
|
| 873 |
+
@torch.no_grad()
|
| 874 |
+
@contextmanager
|
| 875 |
+
def use_ema_unets(self):
|
| 876 |
+
if not self.use_ema:
|
| 877 |
+
output = yield
|
| 878 |
+
return output
|
| 879 |
+
|
| 880 |
+
self.reset_ema_unets_all_one_device()
|
| 881 |
+
self.imagen.reset_unets_all_one_device()
|
| 882 |
+
|
| 883 |
+
self.unets.eval()
|
| 884 |
+
|
| 885 |
+
trainable_unets = self.imagen.unets
|
| 886 |
+
self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling
|
| 887 |
+
|
| 888 |
+
output = yield
|
| 889 |
+
|
| 890 |
+
self.imagen.unets = trainable_unets # restore original training unets
|
| 891 |
+
|
| 892 |
+
# cast the ema_model unets back to original device
|
| 893 |
+
for ema in self.ema_unets:
|
| 894 |
+
ema.restore_ema_model_device()
|
| 895 |
+
|
| 896 |
+
return output
|
| 897 |
+
|
| 898 |
+
def print_unet_devices(self):
|
| 899 |
+
self.print('unet devices:')
|
| 900 |
+
for i, unet in enumerate(self.imagen.unets):
|
| 901 |
+
device = next(unet.parameters()).device
|
| 902 |
+
self.print(f'\tunet {i}: {device}')
|
| 903 |
+
|
| 904 |
+
if not self.use_ema:
|
| 905 |
+
return
|
| 906 |
+
|
| 907 |
+
self.print('\nema unet devices:')
|
| 908 |
+
for i, ema_unet in enumerate(self.ema_unets):
|
| 909 |
+
device = next(ema_unet.parameters()).device
|
| 910 |
+
self.print(f'\tema unet {i}: {device}')
|
| 911 |
+
|
| 912 |
+
# overriding state dict functions
|
| 913 |
+
|
| 914 |
+
def state_dict(self, *args, **kwargs):
|
| 915 |
+
self.reset_ema_unets_all_one_device()
|
| 916 |
+
return super().state_dict(*args, **kwargs)
|
| 917 |
+
|
| 918 |
+
def load_state_dict(self, *args, **kwargs):
|
| 919 |
+
self.reset_ema_unets_all_one_device()
|
| 920 |
+
return super().load_state_dict(*args, **kwargs)
|
| 921 |
+
|
| 922 |
+
# encoding text functions
|
| 923 |
+
|
| 924 |
+
def encode_text(self, text, **kwargs):
|
| 925 |
+
return self.imagen.encode_text(text, **kwargs)
|
| 926 |
+
|
| 927 |
+
# forwarding functions and gradient step updates
|
| 928 |
+
|
| 929 |
+
def update(self, unet_number=None):
|
| 930 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 931 |
+
self.validate_and_set_unet_being_trained(unet_number)
|
| 932 |
+
self.set_accelerator_scaler(unet_number)
|
| 933 |
+
|
| 934 |
+
index = unet_number - 1
|
| 935 |
+
unet = self.unet_being_trained
|
| 936 |
+
|
| 937 |
+
optimizer = getattr(self, f'optim{index}')
|
| 938 |
+
scaler = getattr(self, f'scaler{index}')
|
| 939 |
+
scheduler = getattr(self, f'scheduler{index}')
|
| 940 |
+
warmup_scheduler = getattr(self, f'warmup{index}')
|
| 941 |
+
|
| 942 |
+
# set the grad scaler on the accelerator, since we are managing one per u-net
|
| 943 |
+
|
| 944 |
+
if exists(self.max_grad_norm):
|
| 945 |
+
self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
| 946 |
+
|
| 947 |
+
optimizer.step()
|
| 948 |
+
optimizer.zero_grad()
|
| 949 |
+
|
| 950 |
+
if self.use_ema:
|
| 951 |
+
ema_unet = self.get_ema_unet(unet_number)
|
| 952 |
+
ema_unet.update()
|
| 953 |
+
|
| 954 |
+
# scheduler, if needed
|
| 955 |
+
|
| 956 |
+
maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
|
| 957 |
+
|
| 958 |
+
with maybe_warmup_context:
|
| 959 |
+
if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
|
| 960 |
+
scheduler.step()
|
| 961 |
+
|
| 962 |
+
self.steps += F.one_hot(torch.tensor(unet_number - 1, device=self.steps.device), num_classes=len(self.steps))
|
| 963 |
+
|
| 964 |
+
if not exists(self.checkpoint_path):
|
| 965 |
+
return
|
| 966 |
+
|
| 967 |
+
total_steps = int(self.steps.sum().item())
|
| 968 |
+
|
| 969 |
+
if total_steps % self.checkpoint_every:
|
| 970 |
+
return
|
| 971 |
+
|
| 972 |
+
self.save_to_checkpoint_folder()
|
| 973 |
+
|
| 974 |
+
@torch.no_grad()
|
| 975 |
+
@cast_torch_tensor
|
| 976 |
+
@imagen_sample_in_chunks
|
| 977 |
+
def sample(self, *args, **kwargs):
|
| 978 |
+
context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
|
| 979 |
+
|
| 980 |
+
self.print_untrained_unets()
|
| 981 |
+
|
| 982 |
+
if not self.is_main:
|
| 983 |
+
kwargs['use_tqdm'] = False
|
| 984 |
+
|
| 985 |
+
with context():
|
| 986 |
+
output = self.imagen.sample(*args, device=self.device, **kwargs)
|
| 987 |
+
|
| 988 |
+
return output
|
| 989 |
+
|
| 990 |
+
@partial(cast_torch_tensor, cast_fp16=True)
|
| 991 |
+
def forward(
|
| 992 |
+
self,
|
| 993 |
+
*args,
|
| 994 |
+
unet_number=None,
|
| 995 |
+
max_batch_size=None,
|
| 996 |
+
**kwargs
|
| 997 |
+
):
|
| 998 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 999 |
+
self.validate_and_set_unet_being_trained(unet_number)
|
| 1000 |
+
self.set_accelerator_scaler(unet_number)
|
| 1001 |
+
|
| 1002 |
+
assert not exists(
|
| 1003 |
+
self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
|
| 1004 |
+
|
| 1005 |
+
total_loss = 0.
|
| 1006 |
+
|
| 1007 |
+
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size=max_batch_size, **kwargs):
|
| 1008 |
+
with self.accelerator.autocast():
|
| 1009 |
+
loss = self.imagen(*chunked_args, unet=self.unet_being_trained,
|
| 1010 |
+
unet_number=unet_number, **chunked_kwargs)
|
| 1011 |
+
loss = loss * chunk_size_frac
|
| 1012 |
+
|
| 1013 |
+
total_loss += loss.item()
|
| 1014 |
+
|
| 1015 |
+
if self.training:
|
| 1016 |
+
self.accelerator.backward(loss)
|
| 1017 |
+
|
| 1018 |
+
return total_loss
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
class JointImagenTrainer(nn.Module):
|
| 1022 |
+
locked = False
|
| 1023 |
+
|
| 1024 |
+
def __init__(
|
| 1025 |
+
self,
|
| 1026 |
+
imagen=None,
|
| 1027 |
+
imagen_checkpoint_path=None,
|
| 1028 |
+
use_ema=True,
|
| 1029 |
+
lr=1e-4,
|
| 1030 |
+
eps=1e-8,
|
| 1031 |
+
beta1=0.9,
|
| 1032 |
+
beta2=0.99,
|
| 1033 |
+
max_grad_norm=None,
|
| 1034 |
+
group_wd_params=True,
|
| 1035 |
+
warmup_steps=None,
|
| 1036 |
+
cosine_decay_max_steps=None,
|
| 1037 |
+
only_train_unet_number=None,
|
| 1038 |
+
fp16=False,
|
| 1039 |
+
precision=None,
|
| 1040 |
+
split_batches=True,
|
| 1041 |
+
dl_tuple_output_keywords_names=('images', 'labels', 'texts'),
|
| 1042 |
+
verbose=True,
|
| 1043 |
+
split_valid_fraction=0.025,
|
| 1044 |
+
split_valid_from_train=False,
|
| 1045 |
+
split_random_seed=42,
|
| 1046 |
+
checkpoint_path=None,
|
| 1047 |
+
checkpoint_every=None,
|
| 1048 |
+
checkpoint_fs=None,
|
| 1049 |
+
fs_kwargs: dict = None,
|
| 1050 |
+
max_checkpoints_keep=20,
|
| 1051 |
+
lambdas=(1., 1.), # lambdas for image / label losses
|
| 1052 |
+
**kwargs
|
| 1053 |
+
):
|
| 1054 |
+
super().__init__()
|
| 1055 |
+
assert not JointImagenTrainer.locked, 'JointImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
|
| 1056 |
+
assert exists(imagen) ^ exists(
|
| 1057 |
+
imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
|
| 1058 |
+
|
| 1059 |
+
# save lambdas for backward
|
| 1060 |
+
|
| 1061 |
+
self.lambdas = lambdas
|
| 1062 |
+
|
| 1063 |
+
# determine filesystem, using fsspec, for saving to local filesystem or cloud
|
| 1064 |
+
|
| 1065 |
+
self.fs = checkpoint_fs
|
| 1066 |
+
|
| 1067 |
+
if not exists(self.fs):
|
| 1068 |
+
fs_kwargs = default(fs_kwargs, {})
|
| 1069 |
+
self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
|
| 1070 |
+
|
| 1071 |
+
assert isinstance(imagen, (JointImagen, )) # ElucidatedImagen is not implemented yet
|
| 1072 |
+
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
| 1073 |
+
|
| 1074 |
+
# elucidated or not
|
| 1075 |
+
|
| 1076 |
+
self.is_elucidated = isinstance(imagen, ElucidatedImagen)
|
| 1077 |
+
|
| 1078 |
+
# create accelerator instance
|
| 1079 |
+
|
| 1080 |
+
accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)
|
| 1081 |
+
|
| 1082 |
+
assert not (fp16 and exists(precision)
|
| 1083 |
+
), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
|
| 1084 |
+
accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')
|
| 1085 |
+
|
| 1086 |
+
self.accelerator = Accelerator(**{
|
| 1087 |
+
'split_batches': split_batches,
|
| 1088 |
+
'mixed_precision': accelerator_mixed_precision,
|
| 1089 |
+
'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerate_kwargs})
|
| 1090 |
+
|
| 1091 |
+
JointImagenTrainer.locked = self.is_distributed
|
| 1092 |
+
|
| 1093 |
+
# cast data to fp16 at training time if needed
|
| 1094 |
+
|
| 1095 |
+
self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
|
| 1096 |
+
|
| 1097 |
+
# grad scaler must be managed outside of accelerator
|
| 1098 |
+
|
| 1099 |
+
grad_scaler_enabled = fp16
|
| 1100 |
+
|
| 1101 |
+
# imagen, unets and ema unets
|
| 1102 |
+
|
| 1103 |
+
self.imagen = imagen
|
| 1104 |
+
self.num_unets = len(self.imagen.unets)
|
| 1105 |
+
|
| 1106 |
+
self.use_ema = use_ema and self.is_main
|
| 1107 |
+
self.ema_unets = nn.ModuleList([])
|
| 1108 |
+
|
| 1109 |
+
# keep track of what unet is being trained on
|
| 1110 |
+
# only going to allow 1 unet training at a time
|
| 1111 |
+
|
| 1112 |
+
self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on
|
| 1113 |
+
|
| 1114 |
+
# data related functions
|
| 1115 |
+
|
| 1116 |
+
self.train_dl_iter = None
|
| 1117 |
+
self.train_dl = None
|
| 1118 |
+
|
| 1119 |
+
self.valid_dl_iter = None
|
| 1120 |
+
self.valid_dl = None
|
| 1121 |
+
|
| 1122 |
+
self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
|
| 1123 |
+
|
| 1124 |
+
# auto splitting validation from training, if dataset is passed in
|
| 1125 |
+
|
| 1126 |
+
self.split_valid_from_train = split_valid_from_train
|
| 1127 |
+
|
| 1128 |
+
assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
|
| 1129 |
+
self.split_valid_fraction = split_valid_fraction
|
| 1130 |
+
self.split_random_seed = split_random_seed
|
| 1131 |
+
|
| 1132 |
+
# be able to finely customize learning rate, weight decay
|
| 1133 |
+
# per unet
|
| 1134 |
+
|
| 1135 |
+
lr, eps, warmup_steps, cosine_decay_max_steps = map(
|
| 1136 |
+
partial(cast_tuple, length=self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))
|
| 1137 |
+
|
| 1138 |
+
for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
|
| 1139 |
+
optimizer = Adam(
|
| 1140 |
+
unet.parameters(),
|
| 1141 |
+
lr=unet_lr,
|
| 1142 |
+
eps=unet_eps,
|
| 1143 |
+
betas=(beta1, beta2),
|
| 1144 |
+
**kwargs
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
if self.use_ema:
|
| 1148 |
+
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
| 1149 |
+
|
| 1150 |
+
scaler = GradScaler(enabled=grad_scaler_enabled)
|
| 1151 |
+
|
| 1152 |
+
scheduler = warmup_scheduler = None
|
| 1153 |
+
|
| 1154 |
+
if exists(unet_cosine_decay_max_steps):
|
| 1155 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=unet_cosine_decay_max_steps)
|
| 1156 |
+
|
| 1157 |
+
if exists(unet_warmup_steps):
|
| 1158 |
+
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=unet_warmup_steps)
|
| 1159 |
+
|
| 1160 |
+
if not exists(scheduler):
|
| 1161 |
+
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
|
| 1162 |
+
|
| 1163 |
+
# set on object
|
| 1164 |
+
|
| 1165 |
+
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
| 1166 |
+
setattr(self, f'scaler{ind}', scaler)
|
| 1167 |
+
setattr(self, f'scheduler{ind}', scheduler)
|
| 1168 |
+
setattr(self, f'warmup{ind}', warmup_scheduler)
|
| 1169 |
+
|
| 1170 |
+
# gradient clipping if needed
|
| 1171 |
+
|
| 1172 |
+
self.max_grad_norm = max_grad_norm
|
| 1173 |
+
|
| 1174 |
+
# step tracker and misc
|
| 1175 |
+
|
| 1176 |
+
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
| 1177 |
+
|
| 1178 |
+
self.verbose = verbose
|
| 1179 |
+
|
| 1180 |
+
# automatic set devices based on what accelerator decided
|
| 1181 |
+
|
| 1182 |
+
self.imagen.to(self.device)
|
| 1183 |
+
self.to(self.device)
|
| 1184 |
+
|
| 1185 |
+
# checkpointing
|
| 1186 |
+
|
| 1187 |
+
assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
|
| 1188 |
+
self.checkpoint_path = checkpoint_path
|
| 1189 |
+
self.checkpoint_every = checkpoint_every
|
| 1190 |
+
self.max_checkpoints_keep = max_checkpoints_keep
|
| 1191 |
+
|
| 1192 |
+
self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
|
| 1193 |
+
|
| 1194 |
+
if exists(checkpoint_path) and self.can_checkpoint:
|
| 1195 |
+
bucket = url_to_bucket(checkpoint_path)
|
| 1196 |
+
|
| 1197 |
+
if not self.fs.exists(bucket):
|
| 1198 |
+
self.fs.mkdir(bucket)
|
| 1199 |
+
|
| 1200 |
+
self.load_from_checkpoint_folder()
|
| 1201 |
+
|
| 1202 |
+
# only allowing training for unet
|
| 1203 |
+
|
| 1204 |
+
self.only_train_unet_number = only_train_unet_number
|
| 1205 |
+
self.validate_and_set_unet_being_trained(only_train_unet_number)
|
| 1206 |
+
|
| 1207 |
+
# computed values
|
| 1208 |
+
|
| 1209 |
+
@property
|
| 1210 |
+
def device(self):
|
| 1211 |
+
return self.accelerator.device
|
| 1212 |
+
|
| 1213 |
+
@property
|
| 1214 |
+
def is_distributed(self):
|
| 1215 |
+
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
|
| 1216 |
+
|
| 1217 |
+
@property
|
| 1218 |
+
def is_main(self):
|
| 1219 |
+
return self.accelerator.is_main_process
|
| 1220 |
+
|
| 1221 |
+
@property
|
| 1222 |
+
def is_local_main(self):
|
| 1223 |
+
return self.accelerator.is_local_main_process
|
| 1224 |
+
|
| 1225 |
+
@property
|
| 1226 |
+
def unwrapped_unet(self):
|
| 1227 |
+
return self.accelerator.unwrap_model(self.unet_being_trained)
|
| 1228 |
+
|
| 1229 |
+
# optimizer helper functions
|
| 1230 |
+
|
| 1231 |
+
def get_lr(self, unet_number):
|
| 1232 |
+
self.validate_unet_number(unet_number)
|
| 1233 |
+
unet_index = unet_number - 1
|
| 1234 |
+
|
| 1235 |
+
optim = getattr(self, f'optim{unet_index}')
|
| 1236 |
+
|
| 1237 |
+
return optim.param_groups[0]['lr']
|
| 1238 |
+
|
| 1239 |
+
# function for allowing only one unet from being trained at a time
|
| 1240 |
+
|
| 1241 |
+
def validate_and_set_unet_being_trained(self, unet_number=None):
|
| 1242 |
+
if exists(unet_number):
|
| 1243 |
+
self.validate_unet_number(unet_number)
|
| 1244 |
+
|
| 1245 |
+
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
|
| 1246 |
+
|
| 1247 |
+
self.only_train_unet_number = unet_number
|
| 1248 |
+
self.imagen.only_train_unet_number = unet_number
|
| 1249 |
+
|
| 1250 |
+
if not exists(unet_number):
|
| 1251 |
+
return
|
| 1252 |
+
|
| 1253 |
+
self.wrap_unet(unet_number)
|
| 1254 |
+
|
| 1255 |
+
def wrap_unet(self, unet_number):
|
| 1256 |
+
if hasattr(self, 'one_unet_wrapped'):
|
| 1257 |
+
return
|
| 1258 |
+
|
| 1259 |
+
unet = self.imagen.get_unet(unet_number)
|
| 1260 |
+
self.unet_being_trained = self.accelerator.prepare(unet)
|
| 1261 |
+
unet_index = unet_number - 1
|
| 1262 |
+
|
| 1263 |
+
optimizer = getattr(self, f'optim{unet_index}')
|
| 1264 |
+
scheduler = getattr(self, f'scheduler{unet_index}')
|
| 1265 |
+
|
| 1266 |
+
optimizer = self.accelerator.prepare(optimizer)
|
| 1267 |
+
|
| 1268 |
+
if exists(scheduler):
|
| 1269 |
+
scheduler = self.accelerator.prepare(scheduler)
|
| 1270 |
+
|
| 1271 |
+
setattr(self, f'optim{unet_index}', optimizer)
|
| 1272 |
+
setattr(self, f'scheduler{unet_index}', scheduler)
|
| 1273 |
+
|
| 1274 |
+
self.one_unet_wrapped = True
|
| 1275 |
+
|
| 1276 |
+
# hacking accelerator due to not having separate gradscaler per optimizer
|
| 1277 |
+
|
| 1278 |
+
def set_accelerator_scaler(self, unet_number):
|
| 1279 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 1280 |
+
scaler = getattr(self, f'scaler{unet_number - 1}')
|
| 1281 |
+
|
| 1282 |
+
self.accelerator.scaler = scaler
|
| 1283 |
+
for optimizer in self.accelerator._optimizers:
|
| 1284 |
+
optimizer.scaler = scaler
|
| 1285 |
+
|
| 1286 |
+
# helper print
|
| 1287 |
+
|
| 1288 |
+
def print(self, msg):
|
| 1289 |
+
if not self.is_main:
|
| 1290 |
+
return
|
| 1291 |
+
|
| 1292 |
+
if not self.verbose:
|
| 1293 |
+
return
|
| 1294 |
+
|
| 1295 |
+
return self.accelerator.print(msg)
|
| 1296 |
+
|
| 1297 |
+
# validating the unet number
|
| 1298 |
+
|
| 1299 |
+
def validate_unet_number(self, unet_number=None):
|
| 1300 |
+
if self.num_unets == 1:
|
| 1301 |
+
unet_number = default(unet_number, 1)
|
| 1302 |
+
|
| 1303 |
+
assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
|
| 1304 |
+
return unet_number
|
| 1305 |
+
|
| 1306 |
+
# number of training steps taken
|
| 1307 |
+
|
| 1308 |
+
def num_steps_taken(self, unet_number=None):
|
| 1309 |
+
if self.num_unets == 1:
|
| 1310 |
+
unet_number = default(unet_number, 1)
|
| 1311 |
+
|
| 1312 |
+
return self.steps[unet_number - 1].item()
|
| 1313 |
+
|
| 1314 |
+
def print_untrained_unets(self):
|
| 1315 |
+
print_final_error = False
|
| 1316 |
+
|
| 1317 |
+
for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
|
| 1318 |
+
if steps > 0 or isinstance(unet, NullUnet):
|
| 1319 |
+
continue
|
| 1320 |
+
|
| 1321 |
+
self.print(f'unet {ind + 1} has not been trained')
|
| 1322 |
+
print_final_error = True
|
| 1323 |
+
|
| 1324 |
+
if print_final_error:
|
| 1325 |
+
self.print(
|
| 1326 |
+
'when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
|
| 1327 |
+
|
| 1328 |
+
# data related functions
|
| 1329 |
+
|
| 1330 |
+
def add_train_dataloader(self, dl=None):
|
| 1331 |
+
if not exists(dl):
|
| 1332 |
+
return
|
| 1333 |
+
|
| 1334 |
+
assert not exists(self.train_dl), 'training dataloader was already added'
|
| 1335 |
+
self.train_dl = self.accelerator.prepare(dl)
|
| 1336 |
+
|
| 1337 |
+
def add_valid_dataloader(self, dl):
|
| 1338 |
+
if not exists(dl):
|
| 1339 |
+
return
|
| 1340 |
+
|
| 1341 |
+
assert not exists(self.valid_dl), 'validation dataloader was already added'
|
| 1342 |
+
self.valid_dl = self.accelerator.prepare(dl)
|
| 1343 |
+
|
| 1344 |
+
def add_train_dataset(self, ds=None, *, batch_size, **dl_kwargs):
|
| 1345 |
+
if not exists(ds):
|
| 1346 |
+
return
|
| 1347 |
+
|
| 1348 |
+
assert not exists(self.train_dl), 'training dataloader was already added'
|
| 1349 |
+
|
| 1350 |
+
valid_ds = None
|
| 1351 |
+
if self.split_valid_from_train:
|
| 1352 |
+
train_size = int((1 - self.split_valid_fraction) * len(ds))
|
| 1353 |
+
valid_size = len(ds) - train_size
|
| 1354 |
+
|
| 1355 |
+
ds, valid_ds = random_split(ds, [train_size, valid_size],
|
| 1356 |
+
generator=torch.Generator().manual_seed(self.split_random_seed))
|
| 1357 |
+
self.print(
|
| 1358 |
+
f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')
|
| 1359 |
+
|
| 1360 |
+
dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
|
| 1361 |
+
self.train_dl = self.accelerator.prepare(dl)
|
| 1362 |
+
|
| 1363 |
+
if not self.split_valid_from_train:
|
| 1364 |
+
return
|
| 1365 |
+
|
| 1366 |
+
self.add_valid_dataset(valid_ds, batch_size=batch_size, **dl_kwargs)
|
| 1367 |
+
|
| 1368 |
+
def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
|
| 1369 |
+
if not exists(ds):
|
| 1370 |
+
return
|
| 1371 |
+
|
| 1372 |
+
assert not exists(self.valid_dl), 'validation dataloader was already added'
|
| 1373 |
+
|
| 1374 |
+
dl = DataLoader(ds, batch_size=batch_size, **dl_kwargs)
|
| 1375 |
+
self.valid_dl = self.accelerator.prepare(dl)
|
| 1376 |
+
|
| 1377 |
+
def create_train_iter(self):
|
| 1378 |
+
assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
|
| 1379 |
+
|
| 1380 |
+
if exists(self.train_dl_iter):
|
| 1381 |
+
return
|
| 1382 |
+
|
| 1383 |
+
self.train_dl_iter = cycle(self.train_dl)
|
| 1384 |
+
|
| 1385 |
+
def create_valid_iter(self):
|
| 1386 |
+
assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
|
| 1387 |
+
|
| 1388 |
+
if exists(self.valid_dl_iter):
|
| 1389 |
+
return
|
| 1390 |
+
|
| 1391 |
+
self.valid_dl_iter = cycle(self.valid_dl)
|
| 1392 |
+
|
| 1393 |
+
def train_step(self, unet_number=None, **kwargs):
|
| 1394 |
+
self.create_train_iter()
|
| 1395 |
+
loss = self.step_with_dl_iter(self.train_dl_iter, unet_number=unet_number, **kwargs)
|
| 1396 |
+
self.update(unet_number=unet_number)
|
| 1397 |
+
return loss
|
| 1398 |
+
|
| 1399 |
+
@torch.no_grad()
|
| 1400 |
+
@eval_decorator
|
| 1401 |
+
def valid_step(self, **kwargs):
|
| 1402 |
+
self.create_valid_iter()
|
| 1403 |
+
|
| 1404 |
+
context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
|
| 1405 |
+
|
| 1406 |
+
with context():
|
| 1407 |
+
loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
|
| 1408 |
+
return loss
|
| 1409 |
+
|
| 1410 |
+
def step_with_dl_iter(self, dl_iter, **kwargs):
|
| 1411 |
+
dl_tuple_output = cast_tuple(next(dl_iter))
|
| 1412 |
+
model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
|
| 1413 |
+
loss = self.forward(**{**kwargs, **model_input})
|
| 1414 |
+
return loss
|
| 1415 |
+
|
| 1416 |
+
# checkpointing functions
|
| 1417 |
+
|
| 1418 |
+
@property
|
| 1419 |
+
def all_checkpoints_sorted(self):
|
| 1420 |
+
glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
|
| 1421 |
+
checkpoints = self.fs.glob(glob_pattern)
|
| 1422 |
+
sorted_checkpoints = sorted(checkpoints, key=lambda x: int(str(x).split('.')[-2]), reverse=True)
|
| 1423 |
+
return sorted_checkpoints
|
| 1424 |
+
|
| 1425 |
+
def load_from_checkpoint_folder(self, last_total_steps=-1):
|
| 1426 |
+
if last_total_steps != -1:
|
| 1427 |
+
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
|
| 1428 |
+
self.load(filepath)
|
| 1429 |
+
return
|
| 1430 |
+
|
| 1431 |
+
sorted_checkpoints = self.all_checkpoints_sorted
|
| 1432 |
+
|
| 1433 |
+
if len(sorted_checkpoints) == 0:
|
| 1434 |
+
self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
|
| 1435 |
+
return
|
| 1436 |
+
|
| 1437 |
+
last_checkpoint = sorted_checkpoints[0]
|
| 1438 |
+
self.load(last_checkpoint)
|
| 1439 |
+
|
| 1440 |
+
def save_to_checkpoint_folder(self):
|
| 1441 |
+
self.accelerator.wait_for_everyone()
|
| 1442 |
+
|
| 1443 |
+
if not self.can_checkpoint:
|
| 1444 |
+
return
|
| 1445 |
+
|
| 1446 |
+
total_steps = int(self.steps.sum().item())
|
| 1447 |
+
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
|
| 1448 |
+
|
| 1449 |
+
self.save(filepath)
|
| 1450 |
+
|
| 1451 |
+
if self.max_checkpoints_keep <= 0:
|
| 1452 |
+
return
|
| 1453 |
+
|
| 1454 |
+
sorted_checkpoints = self.all_checkpoints_sorted
|
| 1455 |
+
checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
|
| 1456 |
+
|
| 1457 |
+
for checkpoint in checkpoints_to_discard:
|
| 1458 |
+
self.fs.rm(checkpoint)
|
| 1459 |
+
|
| 1460 |
+
# saving and loading functions
|
| 1461 |
+
|
| 1462 |
+
def save(
|
| 1463 |
+
self,
|
| 1464 |
+
path,
|
| 1465 |
+
overwrite=True,
|
| 1466 |
+
without_optim_and_sched=False,
|
| 1467 |
+
**kwargs
|
| 1468 |
+
):
|
| 1469 |
+
# self.accelerator.wait_for_everyone()
|
| 1470 |
+
|
| 1471 |
+
if not self.can_checkpoint:
|
| 1472 |
+
return
|
| 1473 |
+
|
| 1474 |
+
fs = self.fs
|
| 1475 |
+
|
| 1476 |
+
assert not (fs.exists(path) and not overwrite)
|
| 1477 |
+
|
| 1478 |
+
self.reset_ema_unets_all_one_device()
|
| 1479 |
+
|
| 1480 |
+
save_obj = dict(
|
| 1481 |
+
model=self.imagen.state_dict(),
|
| 1482 |
+
version=__version__,
|
| 1483 |
+
steps=self.steps.cpu(),
|
| 1484 |
+
**kwargs
|
| 1485 |
+
)
|
| 1486 |
+
|
| 1487 |
+
save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
|
| 1488 |
+
|
| 1489 |
+
for ind in save_optim_and_sched_iter:
|
| 1490 |
+
scaler_key = f'scaler{ind}'
|
| 1491 |
+
optimizer_key = f'optim{ind}'
|
| 1492 |
+
scheduler_key = f'scheduler{ind}'
|
| 1493 |
+
warmup_scheduler_key = f'warmup{ind}'
|
| 1494 |
+
|
| 1495 |
+
scaler = getattr(self, scaler_key)
|
| 1496 |
+
optimizer = getattr(self, optimizer_key)
|
| 1497 |
+
scheduler = getattr(self, scheduler_key)
|
| 1498 |
+
warmup_scheduler = getattr(self, warmup_scheduler_key)
|
| 1499 |
+
|
| 1500 |
+
if exists(scheduler):
|
| 1501 |
+
save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
|
| 1502 |
+
|
| 1503 |
+
if exists(warmup_scheduler):
|
| 1504 |
+
save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
|
| 1505 |
+
|
| 1506 |
+
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
| 1507 |
+
|
| 1508 |
+
if self.use_ema:
|
| 1509 |
+
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
| 1510 |
+
|
| 1511 |
+
# determine if imagen config is available
|
| 1512 |
+
|
| 1513 |
+
if hasattr(self.imagen, '_config'):
|
| 1514 |
+
self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')
|
| 1515 |
+
|
| 1516 |
+
save_obj = {
|
| 1517 |
+
**save_obj,
|
| 1518 |
+
'imagen_type': 'elucidated' if self.is_elucidated else 'original',
|
| 1519 |
+
'imagen_params': self.imagen._config
|
| 1520 |
+
}
|
| 1521 |
+
|
| 1522 |
+
# save to path
|
| 1523 |
+
|
| 1524 |
+
with fs.open(path, 'wb') as f:
|
| 1525 |
+
torch.save(save_obj, f)
|
| 1526 |
+
|
| 1527 |
+
self.print(f'checkpoint saved to {path}')
|
| 1528 |
+
|
| 1529 |
+
def load(self, path, only_model=False, strict=True, noop_if_not_exist=False):
|
| 1530 |
+
fs = self.fs
|
| 1531 |
+
|
| 1532 |
+
if noop_if_not_exist and not fs.exists(path):
|
| 1533 |
+
self.print(f'trainer checkpoint not found at {str(path)}')
|
| 1534 |
+
return
|
| 1535 |
+
|
| 1536 |
+
assert fs.exists(path), f'{path} does not exist'
|
| 1537 |
+
|
| 1538 |
+
self.reset_ema_unets_all_one_device()
|
| 1539 |
+
|
| 1540 |
+
# to avoid extra GPU memory usage in main process when using Accelerate
|
| 1541 |
+
|
| 1542 |
+
with fs.open(path) as f:
|
| 1543 |
+
loaded_obj = torch.load(f, map_location='cpu')
|
| 1544 |
+
|
| 1545 |
+
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
| 1546 |
+
self.print(f'loading saved imagen at version {loaded_obj["version"]}, '
|
| 1547 |
+
f'but current package version is {__version__}')
|
| 1548 |
+
|
| 1549 |
+
try:
|
| 1550 |
+
self.imagen.load_state_dict(loaded_obj['model'], strict=strict)
|
| 1551 |
+
except RuntimeError:
|
| 1552 |
+
print("Failed loading state dict. Trying partial load")
|
| 1553 |
+
self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
|
| 1554 |
+
loaded_obj['model']))
|
| 1555 |
+
|
| 1556 |
+
if only_model:
|
| 1557 |
+
return loaded_obj
|
| 1558 |
+
|
| 1559 |
+
self.steps.copy_(loaded_obj['steps'])
|
| 1560 |
+
|
| 1561 |
+
for ind in range(0, self.num_unets):
|
| 1562 |
+
scaler_key = f'scaler{ind}'
|
| 1563 |
+
optimizer_key = f'optim{ind}'
|
| 1564 |
+
scheduler_key = f'scheduler{ind}'
|
| 1565 |
+
warmup_scheduler_key = f'warmup{ind}'
|
| 1566 |
+
|
| 1567 |
+
scaler = getattr(self, scaler_key)
|
| 1568 |
+
optimizer = getattr(self, optimizer_key)
|
| 1569 |
+
scheduler = getattr(self, scheduler_key)
|
| 1570 |
+
warmup_scheduler = getattr(self, warmup_scheduler_key)
|
| 1571 |
+
|
| 1572 |
+
if exists(scheduler) and scheduler_key in loaded_obj:
|
| 1573 |
+
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
| 1574 |
+
|
| 1575 |
+
if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
|
| 1576 |
+
warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
|
| 1577 |
+
|
| 1578 |
+
if exists(optimizer):
|
| 1579 |
+
try:
|
| 1580 |
+
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
| 1581 |
+
scaler.load_state_dict(loaded_obj[scaler_key])
|
| 1582 |
+
except:
|
| 1583 |
+
self.print('could not load optimizer and scaler, '
|
| 1584 |
+
'possibly because you have turned on mixed precision training since the last run. '
|
| 1585 |
+
'resuming with new optimizer and scalers')
|
| 1586 |
+
|
| 1587 |
+
if self.use_ema:
|
| 1588 |
+
assert 'ema' in loaded_obj
|
| 1589 |
+
try:
|
| 1590 |
+
self.ema_unets.load_state_dict(loaded_obj['ema'], strict=strict)
|
| 1591 |
+
except RuntimeError:
|
| 1592 |
+
print("Failed loading state dict. Trying partial load")
|
| 1593 |
+
self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
|
| 1594 |
+
loaded_obj['ema']))
|
| 1595 |
+
|
| 1596 |
+
self.print(f'checkpoint loaded from {path}')
|
| 1597 |
+
return loaded_obj
|
| 1598 |
+
|
| 1599 |
+
# managing ema unets and their devices
|
| 1600 |
+
|
| 1601 |
+
@property
|
| 1602 |
+
def unets(self):
|
| 1603 |
+
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
| 1604 |
+
|
| 1605 |
+
def get_ema_unet(self, unet_number=None):
|
| 1606 |
+
if not self.use_ema:
|
| 1607 |
+
return
|
| 1608 |
+
|
| 1609 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 1610 |
+
index = unet_number - 1
|
| 1611 |
+
|
| 1612 |
+
if isinstance(self.unets, nn.ModuleList):
|
| 1613 |
+
unets_list = [unet for unet in self.ema_unets]
|
| 1614 |
+
delattr(self, 'ema_unets')
|
| 1615 |
+
self.ema_unets = unets_list
|
| 1616 |
+
|
| 1617 |
+
if index != self.ema_unet_being_trained_index:
|
| 1618 |
+
for unet_index, unet in enumerate(self.ema_unets):
|
| 1619 |
+
unet.to(self.device if unet_index == index else 'cpu')
|
| 1620 |
+
|
| 1621 |
+
self.ema_unet_being_trained_index = index
|
| 1622 |
+
return self.ema_unets[index]
|
| 1623 |
+
|
| 1624 |
+
def reset_ema_unets_all_one_device(self, device=None):
|
| 1625 |
+
if not self.use_ema:
|
| 1626 |
+
return
|
| 1627 |
+
|
| 1628 |
+
device = default(device, self.device)
|
| 1629 |
+
self.ema_unets = nn.ModuleList([*self.ema_unets])
|
| 1630 |
+
self.ema_unets.to(device)
|
| 1631 |
+
|
| 1632 |
+
self.ema_unet_being_trained_index = -1
|
| 1633 |
+
|
| 1634 |
+
@torch.no_grad()
|
| 1635 |
+
@contextmanager
|
| 1636 |
+
def use_ema_unets(self):
|
| 1637 |
+
if not self.use_ema:
|
| 1638 |
+
output = yield
|
| 1639 |
+
return output
|
| 1640 |
+
|
| 1641 |
+
self.reset_ema_unets_all_one_device()
|
| 1642 |
+
self.imagen.reset_unets_all_one_device()
|
| 1643 |
+
|
| 1644 |
+
self.unets.eval()
|
| 1645 |
+
|
| 1646 |
+
trainable_unets = self.imagen.unets
|
| 1647 |
+
self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling
|
| 1648 |
+
|
| 1649 |
+
output = yield
|
| 1650 |
+
|
| 1651 |
+
self.imagen.unets = trainable_unets # restore original training unets
|
| 1652 |
+
|
| 1653 |
+
# cast the ema_model unets back to original device
|
| 1654 |
+
for ema in self.ema_unets:
|
| 1655 |
+
ema.restore_ema_model_device()
|
| 1656 |
+
|
| 1657 |
+
return output
|
| 1658 |
+
|
| 1659 |
+
def print_unet_devices(self):
|
| 1660 |
+
self.print('unet devices:')
|
| 1661 |
+
for i, unet in enumerate(self.imagen.unets):
|
| 1662 |
+
device = next(unet.parameters()).device
|
| 1663 |
+
self.print(f'\tunet {i}: {device}')
|
| 1664 |
+
|
| 1665 |
+
if not self.use_ema:
|
| 1666 |
+
return
|
| 1667 |
+
|
| 1668 |
+
self.print('\nema unet devices:')
|
| 1669 |
+
for i, ema_unet in enumerate(self.ema_unets):
|
| 1670 |
+
device = next(ema_unet.parameters()).device
|
| 1671 |
+
self.print(f'\tema unet {i}: {device}')
|
| 1672 |
+
|
| 1673 |
+
# overriding state dict functions
|
| 1674 |
+
|
| 1675 |
+
def state_dict(self, *args, **kwargs):
|
| 1676 |
+
self.reset_ema_unets_all_one_device()
|
| 1677 |
+
return super().state_dict(*args, **kwargs)
|
| 1678 |
+
|
| 1679 |
+
def load_state_dict(self, *args, **kwargs):
|
| 1680 |
+
self.reset_ema_unets_all_one_device()
|
| 1681 |
+
return super().load_state_dict(*args, **kwargs)
|
| 1682 |
+
|
| 1683 |
+
# encoding text functions
|
| 1684 |
+
|
| 1685 |
+
def encode_text(self, text, **kwargs):
|
| 1686 |
+
return self.imagen.encode_text(text, **kwargs)
|
| 1687 |
+
|
| 1688 |
+
# forwarding functions and gradient step updates
|
| 1689 |
+
|
| 1690 |
+
def update(self, unet_number=None):
|
| 1691 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 1692 |
+
self.validate_and_set_unet_being_trained(unet_number)
|
| 1693 |
+
self.set_accelerator_scaler(unet_number)
|
| 1694 |
+
|
| 1695 |
+
index = unet_number - 1
|
| 1696 |
+
unet = self.unet_being_trained
|
| 1697 |
+
|
| 1698 |
+
optimizer = getattr(self, f'optim{index}')
|
| 1699 |
+
scaler = getattr(self, f'scaler{index}')
|
| 1700 |
+
scheduler = getattr(self, f'scheduler{index}')
|
| 1701 |
+
warmup_scheduler = getattr(self, f'warmup{index}')
|
| 1702 |
+
|
| 1703 |
+
# set the grad scaler on the accelerator, since we are managing one per u-net
|
| 1704 |
+
|
| 1705 |
+
if exists(self.max_grad_norm):
|
| 1706 |
+
self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
| 1707 |
+
|
| 1708 |
+
optimizer.step()
|
| 1709 |
+
optimizer.zero_grad()
|
| 1710 |
+
|
| 1711 |
+
if self.use_ema:
|
| 1712 |
+
ema_unet = self.get_ema_unet(unet_number)
|
| 1713 |
+
ema_unet.update()
|
| 1714 |
+
|
| 1715 |
+
# scheduler, if needed
|
| 1716 |
+
|
| 1717 |
+
maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
|
| 1718 |
+
|
| 1719 |
+
with maybe_warmup_context:
|
| 1720 |
+
if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
|
| 1721 |
+
scheduler.step()
|
| 1722 |
+
|
| 1723 |
+
self.steps += F.one_hot(torch.tensor(unet_number - 1, device=self.steps.device), num_classes=len(self.steps))
|
| 1724 |
+
|
| 1725 |
+
if not exists(self.checkpoint_path):
|
| 1726 |
+
return
|
| 1727 |
+
|
| 1728 |
+
total_steps = int(self.steps.sum().item())
|
| 1729 |
+
|
| 1730 |
+
if total_steps % self.checkpoint_every:
|
| 1731 |
+
return
|
| 1732 |
+
|
| 1733 |
+
self.save_to_checkpoint_folder()
|
| 1734 |
+
|
| 1735 |
+
@torch.no_grad()
|
| 1736 |
+
@cast_torch_tensor
|
| 1737 |
+
@imagen_sample_in_chunks
|
| 1738 |
+
def sample(self, *args, **kwargs):
|
| 1739 |
+
context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
|
| 1740 |
+
|
| 1741 |
+
self.print_untrained_unets()
|
| 1742 |
+
|
| 1743 |
+
if not self.is_main:
|
| 1744 |
+
kwargs['use_tqdm'] = False
|
| 1745 |
+
|
| 1746 |
+
with context():
|
| 1747 |
+
output = self.imagen.sample(*args, device=self.device, **kwargs)
|
| 1748 |
+
|
| 1749 |
+
return output
|
| 1750 |
+
|
| 1751 |
+
@partial(cast_torch_tensor, cast_fp16=True)
|
| 1752 |
+
def forward(
|
| 1753 |
+
self,
|
| 1754 |
+
*args,
|
| 1755 |
+
unet_number=None,
|
| 1756 |
+
max_batch_size=None,
|
| 1757 |
+
**kwargs
|
| 1758 |
+
):
|
| 1759 |
+
unet_number = self.validate_unet_number(unet_number)
|
| 1760 |
+
self.validate_and_set_unet_being_trained(unet_number)
|
| 1761 |
+
self.set_accelerator_scaler(unet_number)
|
| 1762 |
+
|
| 1763 |
+
assert not exists(
|
| 1764 |
+
self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
|
| 1765 |
+
|
| 1766 |
+
total_loss = 0.
|
| 1767 |
+
total_loss_seg = 0.
|
| 1768 |
+
|
| 1769 |
+
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size=max_batch_size, **kwargs):
|
| 1770 |
+
with self.accelerator.autocast():
|
| 1771 |
+
loss, loss_seg = self.imagen(*chunked_args, unet=self.unet_being_trained,
|
| 1772 |
+
unet_number=unet_number, **chunked_kwargs)
|
| 1773 |
+
loss = loss * chunk_size_frac
|
| 1774 |
+
loss_seg = loss_seg * chunk_size_frac
|
| 1775 |
+
|
| 1776 |
+
total_loss += loss.item()
|
| 1777 |
+
total_loss_seg += loss_seg.item()
|
| 1778 |
+
|
| 1779 |
+
if self.training:
|
| 1780 |
+
self.accelerator.backward(loss * self.lambdas[0] + loss_seg * self.lambdas[1])
|
| 1781 |
+
|
| 1782 |
+
return total_loss, total_loss_seg
|
imagen_pytorch/utils.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from functools import reduce
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
|
| 7 |
+
from ema_pytorch import EMA
|
| 8 |
+
|
| 9 |
+
def exists(val):
|
| 10 |
+
return val is not None
|
| 11 |
+
|
| 12 |
+
def safeget(dictionary, keys, default = None):
|
| 13 |
+
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
| 14 |
+
|
| 15 |
+
def load_imagen_from_checkpoint(
|
| 16 |
+
checkpoint_path,
|
| 17 |
+
load_weights = True,
|
| 18 |
+
load_ema_if_available = False
|
| 19 |
+
):
|
| 20 |
+
model_path = Path(checkpoint_path)
|
| 21 |
+
full_model_path = str(model_path.resolve())
|
| 22 |
+
assert model_path.exists(), f'checkpoint not found at {full_model_path}'
|
| 23 |
+
loaded = torch.load(str(model_path), map_location='cpu')
|
| 24 |
+
|
| 25 |
+
imagen_params = safeget(loaded, 'imagen_params')
|
| 26 |
+
imagen_type = safeget(loaded, 'imagen_type')
|
| 27 |
+
|
| 28 |
+
if imagen_type == 'original':
|
| 29 |
+
imagen_klass = ImagenConfig
|
| 30 |
+
elif imagen_type == 'elucidated':
|
| 31 |
+
imagen_klass = ElucidatedImagenConfig
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
|
| 34 |
+
|
| 35 |
+
assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
|
| 36 |
+
|
| 37 |
+
imagen = imagen_klass(**imagen_params).create()
|
| 38 |
+
|
| 39 |
+
if not load_weights:
|
| 40 |
+
return imagen
|
| 41 |
+
|
| 42 |
+
has_ema = 'ema' in loaded
|
| 43 |
+
should_load_ema = has_ema and load_ema_if_available
|
| 44 |
+
|
| 45 |
+
imagen.load_state_dict(loaded['model'])
|
| 46 |
+
|
| 47 |
+
if not should_load_ema:
|
| 48 |
+
print('loading non-EMA version of unets')
|
| 49 |
+
return imagen
|
| 50 |
+
|
| 51 |
+
ema_unets = nn.ModuleList([])
|
| 52 |
+
for unet in imagen.unets:
|
| 53 |
+
ema_unets.append(EMA(unet))
|
| 54 |
+
|
| 55 |
+
ema_unets.load_state_dict(loaded['ema'])
|
| 56 |
+
|
| 57 |
+
for unet, ema_unet in zip(imagen.unets, ema_unets):
|
| 58 |
+
unet.load_state_dict(ema_unet.ema_model.state_dict())
|
| 59 |
+
|
| 60 |
+
print('loaded EMA version of unets')
|
| 61 |
+
return imagen
|
imagen_pytorch/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '1.11.14'
|
pyproject.toml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.autopep8]
|
| 2 |
+
max_line_length = 120
|
| 3 |
+
ignore = ["E402"]
|
repaint/LICENSES/LICENSE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
you may not use this file except in compliance with the License.
|
| 4 |
+
You may obtain a copy of the License at
|
| 5 |
+
|
| 6 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
|
| 8 |
+
The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
See the License for the specific language governing permissions and
|
| 13 |
+
limitations under the License.
|
repaint/LICENSES/LICENSE_guided_diffusion
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 OpenAI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
repaint/LICENSES/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# License and Acknowledgement
|
| 2 |
+
|
| 3 |
+
A big thanks to following contributes that open sourced their code and therefore helped us a lot in developing RePaint!
|
| 4 |
+
|
| 5 |
+
This repository was forked from:
|
| 6 |
+
https://github.com/openai/guided-diffusion
|
| 7 |
+
|
| 8 |
+
It contains code from:
|
| 9 |
+
https://github.com/hojonathanho/diffusion
|
| 10 |
+
|
| 11 |
+
If we missed a contribution, please contact us.
|
repaint/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RePaint
|
| 2 |
+
**Inpainting using Denoising Diffusion Probabilistic Models**
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
CVPR 2022 [[Paper]](https://bit.ly/3b1ABEb)
|
| 6 |
+
|
| 7 |
+
[](#)
|
| 8 |
+
|
| 9 |
+
## Setup
|
| 10 |
+
|
| 11 |
+
### 1. Code
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
git clone https://github.com/andreas128/RePaint.git
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### 2. Environment
|
| 18 |
+
```bash
|
| 19 |
+
pip install numpy torch blobfile tqdm pyYaml pillow # e.g. torch 1.7.1+cu110.
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### 3. Download models and data
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
pip install --upgrade gdown && bash ./download.sh
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
That downloads the models for ImageNet, CelebA-HQ, and Places2, as well as the face example and example masks.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
### 4. Run example
|
| 32 |
+
```bash
|
| 33 |
+
python test.py --conf_path confs/face_example.yml
|
| 34 |
+
```
|
| 35 |
+
Find the output in `./log/face_example/inpainted`
|
| 36 |
+
|
| 37 |
+
*Note: After refactoring the code, we did not reevaluate all experiments.*
|
| 38 |
+
|
| 39 |
+
<br>
|
| 40 |
+
|
| 41 |
+
# RePaint fills a missing image part using diffusion models
|
| 42 |
+
|
| 43 |
+
<table border="0" cellspacing="0" cellpadding="0">
|
| 44 |
+
<tr>
|
| 45 |
+
<td><img alt="RePaint Inpainting using Denoising Diffusion Probabilistic Models Demo 1" src="https://user-images.githubusercontent.com/11280511/150766080-9f3d7bc9-99f2-472e-9e5d-b6ed456340d1.gif"></td>
|
| 46 |
+
<td><img alt="RePaint Inpainting using Denoising Diffusion Probabilistic Models Demo 2" src="https://user-images.githubusercontent.com/11280511/150766125-adf5a3cb-17f2-432c-a8f6-ce0b97122819.gif"></td>
|
| 47 |
+
</tr>
|
| 48 |
+
</table>
|
| 49 |
+
|
| 50 |
+
**What are the blue parts?** <br>
|
| 51 |
+
Those parts are missing and therefore have to be filled by RePaint. <br> RePaint generates the missing parts inspired by the known parts.
|
| 52 |
+
|
| 53 |
+
**How does it work?** <br>
|
| 54 |
+
RePaint starts from pure noise. Then the image is denoised step-by-step. <br> It uses the known part to fill the unknown part in each step.
|
| 55 |
+
|
| 56 |
+
**Why does the noise level fluctuate during generation?** <br>
|
| 57 |
+
Our noise schedule improves the harmony between the generated and <br> the known part [[4.2 Resampling]](https://bit.ly/3b1ABEb).
|
| 58 |
+
|
| 59 |
+
<br>
|
| 60 |
+
|
| 61 |
+
## Details on data
|
| 62 |
+
|
| 63 |
+
**Which datasets and masks have a ready-to-use config file?**
|
| 64 |
+
|
| 65 |
+
We provide config files for ImageNet (inet256), CelebA-HQ (c256) and Places2 (p256) for the masks "thin", "thick", "every second line", "super-resolution", "expand" and "half" in [`./confs`](https://github.com/andreas128/RePaint/tree/main/confs). You can use them as shown in the example above.
|
| 66 |
+
|
| 67 |
+
**How to prepare the test data?**
|
| 68 |
+
|
| 69 |
+
We use [LaMa](https://github.com/saic-mdal/lama) for validation and testing. Follow their instructions and add the images as specified in the config files. When you download the data using `download.sh`, you can see examples of masks we used.
|
| 70 |
+
|
| 71 |
+
**How to apply it to other images?**
|
| 72 |
+
|
| 73 |
+
Copy the config file for the dataset that matches your data best (for faces aligned like CelebA-HQ `_c256`, for diverse images `_inet256`). Then set the [`gt_path`](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L70) and [`mask_path`](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L71) to where your input is. The masks have the value 255 for known regions and 0 for unknown areas (the ones that get generated).
|
| 74 |
+
|
| 75 |
+
**How to apply it for other datasets?**
|
| 76 |
+
|
| 77 |
+
If you work with other data than faces, places or general images, train a model using the [guided-diffusion](https://github.com/openai/guided-diffusion) repository. Note that RePaint is an inference scheme. We do not train or finetune the diffusion model but condition pre-trained models.
|
| 78 |
+
|
| 79 |
+
## Adapt the code
|
| 80 |
+
|
| 81 |
+
**How to design a new schedule?**
|
| 82 |
+
|
| 83 |
+
Fill in your own parameters in this [line](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/guided_diffusion/scheduler.py#L180) to visualize the schedule using `python guided_diffusion/scheduler.py`. Then copy a config file, set your parameters in these [lines](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L61-L65) and run the inference using `python test.py --conf_path confs/my_schedule.yml`.
|
| 84 |
+
|
| 85 |
+
**How to speed up the inference?**
|
| 86 |
+
|
| 87 |
+
The following settings are in the [schedule_jump_params](https://github.com/andreas128/RePaint/blob/0fea066b52346c331cdf1bf7aed616c8c8896714/confs/face_example.yml#L61) key in the config files. You can visualize them as described above.
|
| 88 |
+
|
| 89 |
+
- Reduce `t_T`, the total number of steps (without resampling). The lower it is, the more noise gets removed per step.
|
| 90 |
+
- Reduce `jump_n_sample` to resample fewer times.
|
| 91 |
+
- Apply resampling not from the beginning but only after a specific time by setting `start_resampling`.
|
| 92 |
+
|
| 93 |
+
## Code overview
|
| 94 |
+
|
| 95 |
+
- **Schedule:** The list of diffusion times t which will be traversed are obtained in this [line](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L503). e.g. times = [249, 248, 249, 248, 247, 248, 247, 248, 247, 246, ...]
|
| 96 |
+
- **Denoise:** Reverse diffusion steps from x<sub>t</sub> (more noise) to a x<sub>t-1</sub> (less noisy) are done below this [line](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L515).
|
| 97 |
+
- **Predict:** The model is called [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L237) and obtains x<sub>t</sub> and the time t to predict a tensor with 6 channels containing information about the mean and variance of x<sub>t-1</sub>. Then the value range of the variance is adjusted [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L252). The mean of x<sub>t-1</sub> is obtained by the weighted sum of the estimated [x<sub>0</sub>](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L270) and x<sub>t</sub> [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L189). The obtained mean and variance is used [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L402) to sample x<sub>t-1</sub>. (This is the original reverse step from [guided-diffusion](https://github.com/openai/guided-diffusion.git). )
|
| 98 |
+
- **Condition:** The known part of the input image needs to have the same amount of noise as the part that the diffusion model generates to join them. The required amount of noise is calculated [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L368) and added to the known part [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L371). The generated and sampled parts get joined using a maks [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L373).
|
| 99 |
+
- **Undo:** The forward diffusion steps from x<sub>t-1</sub> to x<sub>t</sub> is done after this [line](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L536). The noise gets added to x<sub>t-1</sub> [here](https://github.com/andreas128/RePaint/blob/76cb5b49d3f28715980f6e809c6859b148be9867/guided_diffusion/gaussian_diffusion.py#L176).
|
| 100 |
+
|
| 101 |
+
## Issues
|
| 102 |
+
|
| 103 |
+
**Do you have further questions?**
|
| 104 |
+
|
| 105 |
+
Please open an [issue](https://github.com/andreas128/RePaint/issues), and we will try to help you.
|
| 106 |
+
|
| 107 |
+
**Did you find a mistake?**
|
| 108 |
+
|
| 109 |
+
Please create a pull request. For examply by clicking the pencil button on the top right on the github page.
|
| 110 |
+
|
| 111 |
+
<br>
|
| 112 |
+
|
| 113 |
+
# RePaint on diverse content and shapes of missing regions
|
| 114 |
+
|
| 115 |
+
The blue region is unknown and filled by RePaint:
|
| 116 |
+
|
| 117 |
+

|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
**Note: RePaint creates many meaningful fillings.** <br>
|
| 121 |
+
1) **Face:** Expressions and features like an earring or a mole. <br>
|
| 122 |
+
2) **Computer:** The computer screen shows different images, text, and even a logo. <br>
|
| 123 |
+
3) **Greens:** RePaint makes sense of the tiny known part and incorporates it in a beetle, spaghetti, and plants. <br>
|
| 124 |
+
4) **Garden:** From simple filling like a curtain to complex filling like a human. <br>
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
<br>
|
| 128 |
+
|
| 129 |
+
# Extreme Case 1: Generate every second line
|
| 130 |
+

|
| 131 |
+
|
| 132 |
+
- Every Second line of the input image is unknown.
|
| 133 |
+
- Most inpainting methods fail on such masks.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
<br>
|
| 137 |
+
|
| 138 |
+
# Extreme Case 2: Upscale an image
|
| 139 |
+

|
| 140 |
+
|
| 141 |
+
- The inpainting only knows pixels with a stridden access of 2.
|
| 142 |
+
- A ratio of 3/4 of the image has to be filled.
|
| 143 |
+
- This is equivalent to Super-Resolution with the Nearest Neighbor kernel.
|
| 144 |
+
|
| 145 |
+
<br>
|
| 146 |
+
|
| 147 |
+
# RePaint conditions the diffusion model on the known part
|
| 148 |
+
|
| 149 |
+
- RePaint uses unconditionally trained Denoising Diffusion Probabilistic Models.
|
| 150 |
+
- We condition during inference on the given image content.
|
| 151 |
+
|
| 152 |
+

|
| 153 |
+
|
| 154 |
+
**Intuition of one conditioned denoising step:**
|
| 155 |
+
1) **Sample the known part:** Add gaussian noise to the known regions of the image. <br> We obtain a noisy image that follows the denoising process exactly.
|
| 156 |
+
2) **Denoise one step:** Denoise the previous image for one step. This generates <br> content for the unknown region conditioned on the known region.
|
| 157 |
+
3) **Join:** Merge the images from both steps.
|
| 158 |
+
|
| 159 |
+
Details are in Algorithm 1 on Page 5. [[Paper]](https://bit.ly/3b1ABEb)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
<br>
|
| 163 |
+
|
| 164 |
+
# How to harmonize the generated with the known part?
|
| 165 |
+
|
| 166 |
+
- **Fail:** When using only the algorithm above, the filling is not well harmonized with the known part (n=1).
|
| 167 |
+
- **Fix:** When applying the [[4.2 Resampling]](https://bit.ly/3b1ABEb) technique, the images are better harmonized (n>1).
|
| 168 |
+
|
| 169 |
+
<img width="1577" alt="Diffusion Model Resampling" src="https://user-images.githubusercontent.com/11280511/150822917-737c00b0-b6bb-439d-a5bf-e73238d30990.png">
|
| 170 |
+
|
| 171 |
+
<br>
|
| 172 |
+
|
| 173 |
+
# RePaint Fails
|
| 174 |
+
- The ImageNet model is biased towards inpainting dogs.
|
| 175 |
+
- This is due to the high ratio of dog images in ImageNet.
|
| 176 |
+
|
| 177 |
+
<img width="1653" alt="RePaint Fails" src="https://user-images.githubusercontent.com/11280511/150853163-b965f59c-5ad4-485b-816e-4391e77b5199.png">
|
| 178 |
+
|
| 179 |
+
<br>
|
| 180 |
+
|
| 181 |
+
# User Study State-of-the-Art Comparison
|
| 182 |
+
|
| 183 |
+
- Outperforms autoregression-based and GAN-based SOTA methods, <br> with 95% significance for all masks except for two inconclusive cases.
|
| 184 |
+
- The user study was done for six different masks on three datasets.
|
| 185 |
+
- RePaint outperformed SOTA methods in 42 of 44 cases. [[Paper]](https://bit.ly/3b1ABEb)
|
| 186 |
+
|
| 187 |
+
<br>
|
| 188 |
+
|
| 189 |
+
# Explore the Visual Examples
|
| 190 |
+
- Datasets: CelebA-HQ, ImageNet, Places2
|
| 191 |
+
- Masks: Random strokes, half image, huge, sparse
|
| 192 |
+
- Explore more examples like this in the [[Appendix]](https://bit.ly/3b1ABEb).
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
<img width="1556" alt="Denosing Diffusion Inpainting Examples" src="https://user-images.githubusercontent.com/11280511/150864677-0eb482ae-c114-4b0b-b1e0-9be9574da307.png">
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
<br>
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Acknowledgement
|
| 202 |
+
|
| 203 |
+
This work was supported by the ETH Zürich Fund (OK), a Huawei Technologies Oy (Finland) project, and an Nvidia GPU grant.
|
| 204 |
+
|
| 205 |
+
This repository is based on [guided-diffuion](https://github.com/openai/guided-diffusion.git) from OpenAI.
|
repaint/conf_mgt/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from conf_mgt.conf_base import Default_Conf
|
repaint/conf_mgt/conf_base.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
from functools import lru_cache
|
| 18 |
+
import os
|
| 19 |
+
import torch
|
| 20 |
+
from utils import imwrite
|
| 21 |
+
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
from os.path import isfile, expanduser
|
| 24 |
+
|
| 25 |
+
def to_file_ext(img_names, ext):
|
| 26 |
+
img_names_out = []
|
| 27 |
+
for img_name in img_names:
|
| 28 |
+
splits = img_name.split('.')
|
| 29 |
+
if not len(splits) == 2:
|
| 30 |
+
raise RuntimeError("File name needs exactly one '.':", img_name)
|
| 31 |
+
img_names_out.append(splits[0] + '.' + ext)
|
| 32 |
+
|
| 33 |
+
return img_names_out
|
| 34 |
+
|
| 35 |
+
def write_images(imgs, img_names, dir_path):
|
| 36 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
for image_name, image in zip(img_names, imgs):
|
| 39 |
+
out_path = os.path.join(dir_path, image_name)
|
| 40 |
+
imwrite(img=image, path=out_path)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class NoneDict(defaultdict):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super().__init__(self.return_None)
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def return_None():
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def __getattr__(self, attr):
|
| 53 |
+
return self.get(attr)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Default_Conf(NoneDict):
|
| 57 |
+
def __init__(self):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def get_dataloader(self, dset='train', dsName=None, batch_size=None, return_dataset=False):
|
| 61 |
+
|
| 62 |
+
if batch_size is None:
|
| 63 |
+
batch_size = self.batch_size
|
| 64 |
+
|
| 65 |
+
candidates = self['data'][dset]
|
| 66 |
+
ds_conf = candidates[dsName].copy()
|
| 67 |
+
|
| 68 |
+
if ds_conf.get('mask_loader', False):
|
| 69 |
+
from guided_diffusion.image_datasets import load_data_inpa
|
| 70 |
+
return load_data_inpa(**ds_conf, conf=self)
|
| 71 |
+
else:
|
| 72 |
+
raise NotImplementedError()
|
| 73 |
+
|
| 74 |
+
def get_debug_variance_path(self):
|
| 75 |
+
return os.path.expanduser(os.path.join(self.get_default_eval_conf()['paths']['root'], 'debug/debug_variance'))
|
| 76 |
+
|
| 77 |
+
@ staticmethod
|
| 78 |
+
def device():
|
| 79 |
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 80 |
+
|
| 81 |
+
def eval_imswrite(self, srs=None, img_names=None, dset=None, name=None, ext='png', lrs=None, gts=None, gt_keep_masks=None, verify_same=True):
|
| 82 |
+
img_names = to_file_ext(img_names, ext)
|
| 83 |
+
|
| 84 |
+
if dset is None:
|
| 85 |
+
dset = self.get_default_eval_name()
|
| 86 |
+
|
| 87 |
+
max_len = self['data'][dset][name].get('max_len')
|
| 88 |
+
|
| 89 |
+
if srs is not None:
|
| 90 |
+
sr_dir_path = expanduser(self['data'][dset][name]['paths']['srs'])
|
| 91 |
+
write_images(srs, img_names, sr_dir_path)
|
| 92 |
+
|
| 93 |
+
if gt_keep_masks is not None:
|
| 94 |
+
mask_dir_path = expanduser(
|
| 95 |
+
self['data'][dset][name]['paths']['gt_keep_masks'])
|
| 96 |
+
write_images(gt_keep_masks, img_names, mask_dir_path)
|
| 97 |
+
|
| 98 |
+
gts_path = self['data'][dset][name]['paths'].get('gts')
|
| 99 |
+
if gts is not None and gts_path:
|
| 100 |
+
gt_dir_path = expanduser(gts_path)
|
| 101 |
+
write_images(gts, img_names, gt_dir_path)
|
| 102 |
+
|
| 103 |
+
if lrs is not None:
|
| 104 |
+
lrs_dir_path = expanduser(
|
| 105 |
+
self['data'][dset][name]['paths']['lrs'])
|
| 106 |
+
write_images(lrs, img_names, lrs_dir_path)
|
| 107 |
+
|
| 108 |
+
def get_default_eval_name(self):
|
| 109 |
+
candidates = self['data']['eval'].keys()
|
| 110 |
+
if len(candidates) != 1:
|
| 111 |
+
raise RuntimeError(
|
| 112 |
+
f"Need exactly one candidate for {self.name}: {candidates}")
|
| 113 |
+
return list(candidates)[0]
|
| 114 |
+
|
| 115 |
+
def pget(self, name, default=None):
|
| 116 |
+
if '.' in name:
|
| 117 |
+
names = name.split('.')
|
| 118 |
+
else:
|
| 119 |
+
names = [name]
|
| 120 |
+
|
| 121 |
+
sub_dict = self
|
| 122 |
+
for name in names:
|
| 123 |
+
sub_dict = sub_dict.get(name, default)
|
| 124 |
+
|
| 125 |
+
if sub_dict == None:
|
| 126 |
+
return default
|
| 127 |
+
|
| 128 |
+
return sub_dict
|
repaint/confs/face_example.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
latex_name: RePaint
|
| 53 |
+
method_name: Repaint
|
| 54 |
+
image_size: 256
|
| 55 |
+
model_path: ./data/pretrained/celeba256_250000.pt
|
| 56 |
+
name: face_example
|
| 57 |
+
inpa_inj_sched_prev: true
|
| 58 |
+
n_jobs: 1
|
| 59 |
+
print_estimated_vars: true
|
| 60 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 61 |
+
schedule_jump_params:
|
| 62 |
+
t_T: 250
|
| 63 |
+
n_sample: 5 # for GCDP, 1 for image2layout and 5 for layout2image
|
| 64 |
+
jump_length: 10
|
| 65 |
+
jump_n_sample: 10
|
| 66 |
+
data:
|
| 67 |
+
eval:
|
| 68 |
+
paper_face_mask:
|
| 69 |
+
mask_loader: true
|
| 70 |
+
gt_path: ./data/datasets/gts/gcdp
|
| 71 |
+
mask_path: ./data/datasets/gt_keep_masks/gcdp
|
| 72 |
+
image_size: 256
|
| 73 |
+
class_cond: false
|
| 74 |
+
deterministic: true
|
| 75 |
+
random_crop: false
|
| 76 |
+
random_flip: false
|
| 77 |
+
return_dict: true
|
| 78 |
+
drop_last: false
|
| 79 |
+
batch_size: 1
|
| 80 |
+
return_dataloader: true
|
| 81 |
+
offset: 0
|
| 82 |
+
max_len: 8
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/face_example/inpainted
|
| 85 |
+
lrs: ./log/face_example/gt_masked
|
| 86 |
+
gts: ./log/face_example/gt
|
| 87 |
+
gt_keep_masks: ./log/face_example/gt_keep_mask
|
repaint/confs/test_c256_ev2li.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/celeba256_250000.pt
|
| 54 |
+
name: test_c256_ev2li
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_c256_ev2li_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/c256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/ev2li
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: fix_ev2li_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_c256_ev2li/inpainted
|
| 84 |
+
lrs: ./log/test_c256_ev2li/gt_masked
|
| 85 |
+
gts: ./log/test_c256_ev2li/gt
|
| 86 |
+
gt_keep_masks: ./log/test_c256_ev2li/gt_keep_mask
|
repaint/confs/test_c256_ex64.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
|
| 54 |
+
name: test_c256_ex64
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_c256_ex64_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/c256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/ex64
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: fix_ex64_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_c256_ex64/inpainted
|
| 84 |
+
lrs: ./log/test_c256_ex64/gt_masked
|
| 85 |
+
gts: ./log/test_c256_ex64/gt
|
| 86 |
+
gt_keep_masks: ./log/test_c256_ex64/gt_keep_mask
|
repaint/confs/test_c256_genhalf.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
|
| 54 |
+
name: test_c256_genhalf
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_c256_genhalf_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/c256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/genhalf
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: fix_genhalf_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_c256_genhalf/inpainted
|
| 84 |
+
lrs: ./log/test_c256_genhalf/gt_masked
|
| 85 |
+
gts: ./log/test_c256_genhalf/gt
|
| 86 |
+
gt_keep_masks: ./log/test_c256_genhalf/gt_keep_mask
|
repaint/confs/test_c256_nn2.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
|
| 54 |
+
name: test_c256_nn2
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_c256_nn2_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/c256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/nn2
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: fix_nn2_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_c256_nn2/inpainted
|
| 84 |
+
lrs: ./log/test_c256_nn2/gt_masked
|
| 85 |
+
gts: ./log/test_c256_nn2/gt
|
| 86 |
+
gt_keep_masks: ./log/test_c256_nn2/gt_keep_mask
|
repaint/confs/test_c256_thick.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
|
| 54 |
+
name: test_c256_thick
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_c256_thick_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/c256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/thick
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_thick_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_c256_thick/inpainted
|
| 84 |
+
lrs: ./log/test_c256_thick/gt_masked
|
| 85 |
+
gts: ./log/test_c256_thick/gt
|
| 86 |
+
gt_keep_masks: ./log/test_c256_thick/gt_keep_mask
|
repaint/confs/test_c256_thin.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: /cluster/work/cvl/gudiff/guided-diffusion/models/celeba256_diffsteps1000_4gpus/ema_0.9999_250000.pt
|
| 54 |
+
name: test_c256_thin
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_c256_thin_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/c256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/thin
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_thin_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_c256_thin/inpainted
|
| 84 |
+
lrs: ./log/test_c256_thin/gt_masked
|
| 85 |
+
gts: ./log/test_c256_thin/gt
|
| 86 |
+
gt_keep_masks: ./log/test_c256_thin/gt_keep_mask
|
repaint/confs/test_inet256_ev2li.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: true
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: true
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 1.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
classifier_path: ./data/pretrained/256x256_classifier.pt
|
| 54 |
+
model_path: ./data/pretrained/256x256_diffusion.pt
|
| 55 |
+
name: test_inet256_ev2li
|
| 56 |
+
inpa_inj_sched_prev: true
|
| 57 |
+
n_jobs: 25
|
| 58 |
+
print_estimated_vars: true
|
| 59 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 60 |
+
schedule_jump_params:
|
| 61 |
+
t_T: 250
|
| 62 |
+
n_sample: 1
|
| 63 |
+
jump_length: 10
|
| 64 |
+
jump_n_sample: 10
|
| 65 |
+
data:
|
| 66 |
+
eval:
|
| 67 |
+
lama_inet256_ev2li_n100_test:
|
| 68 |
+
mask_loader: true
|
| 69 |
+
gt_path: ./data/datasets/gts/inet256
|
| 70 |
+
mask_path: ./data/datasets/gt_keep_masks/ev2li
|
| 71 |
+
image_size: 256
|
| 72 |
+
class_cond: false
|
| 73 |
+
deterministic: true
|
| 74 |
+
random_crop: false
|
| 75 |
+
random_flip: false
|
| 76 |
+
return_dict: true
|
| 77 |
+
drop_last: false
|
| 78 |
+
batch_size: 4
|
| 79 |
+
return_dataloader: true
|
| 80 |
+
ds_conf:
|
| 81 |
+
name: random_ev2li_256
|
| 82 |
+
max_len: 100
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/test_inet256_ev2li/inpainted
|
| 85 |
+
lrs: ./log/test_inet256_ev2li/gt_masked
|
| 86 |
+
gts: ./log/test_inet256_ev2li/gt
|
| 87 |
+
gt_keep_masks: ./log/test_inet256_ev2li/gt_keep_mask
|
repaint/confs/test_inet256_ex64.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: true
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: true
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 1.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
classifier_path: ./data/pretrained/256x256_classifier.pt
|
| 54 |
+
model_path: ./data/pretrained/256x256_diffusion.pt
|
| 55 |
+
name: test_inet256_ex64
|
| 56 |
+
inpa_inj_sched_prev: true
|
| 57 |
+
n_jobs: 25
|
| 58 |
+
print_estimated_vars: true
|
| 59 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 60 |
+
schedule_jump_params:
|
| 61 |
+
t_T: 250
|
| 62 |
+
n_sample: 1
|
| 63 |
+
jump_length: 10
|
| 64 |
+
jump_n_sample: 10
|
| 65 |
+
data:
|
| 66 |
+
eval:
|
| 67 |
+
lama_inet256_ex64_n100_test:
|
| 68 |
+
mask_loader: true
|
| 69 |
+
gt_path: ./data/datasets/gts/inet256
|
| 70 |
+
mask_path: ./data/datasets/gt_keep_masks/ex64
|
| 71 |
+
image_size: 256
|
| 72 |
+
class_cond: false
|
| 73 |
+
deterministic: true
|
| 74 |
+
random_crop: false
|
| 75 |
+
random_flip: false
|
| 76 |
+
return_dict: true
|
| 77 |
+
drop_last: false
|
| 78 |
+
batch_size: 4
|
| 79 |
+
return_dataloader: true
|
| 80 |
+
ds_conf:
|
| 81 |
+
name: random_ex64_256
|
| 82 |
+
max_len: 100
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/test_inet256_ex64/inpainted
|
| 85 |
+
lrs: ./log/test_inet256_ex64/gt_masked
|
| 86 |
+
gts: ./log/test_inet256_ex64/gt
|
| 87 |
+
gt_keep_masks: ./log/test_inet256_ex64/gt_keep_mask
|
repaint/confs/test_inet256_genhalf.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: true
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: true
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 1.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
classifier_path: ./data/pretrained/256x256_classifier.pt
|
| 54 |
+
model_path: ./data/pretrained/256x256_diffusion.pt
|
| 55 |
+
name: test_inet256_genhalf
|
| 56 |
+
inpa_inj_sched_prev: true
|
| 57 |
+
n_jobs: 25
|
| 58 |
+
print_estimated_vars: true
|
| 59 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 60 |
+
schedule_jump_params:
|
| 61 |
+
t_T: 250
|
| 62 |
+
n_sample: 1
|
| 63 |
+
jump_length: 10
|
| 64 |
+
jump_n_sample: 10
|
| 65 |
+
data:
|
| 66 |
+
eval:
|
| 67 |
+
lama_inet256_genhalf_n100_test:
|
| 68 |
+
mask_loader: true
|
| 69 |
+
gt_path: ./data/datasets/gts/inet256
|
| 70 |
+
mask_path: ./data/datasets/gt_keep_masks/genhalf
|
| 71 |
+
image_size: 256
|
| 72 |
+
class_cond: false
|
| 73 |
+
deterministic: true
|
| 74 |
+
random_crop: false
|
| 75 |
+
random_flip: false
|
| 76 |
+
return_dict: true
|
| 77 |
+
drop_last: false
|
| 78 |
+
batch_size: 4
|
| 79 |
+
return_dataloader: true
|
| 80 |
+
ds_conf:
|
| 81 |
+
name: random_genhalf_256
|
| 82 |
+
max_len: 100
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/test_inet256_genhalf/inpainted
|
| 85 |
+
lrs: ./log/test_inet256_genhalf/gt_masked
|
| 86 |
+
gts: ./log/test_inet256_genhalf/gt
|
| 87 |
+
gt_keep_masks: ./log/test_inet256_genhalf/gt_keep_mask
|
repaint/confs/test_inet256_nn2.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: true
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: true
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 1.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
classifier_path: ./data/pretrained/256x256_classifier.pt
|
| 54 |
+
model_path: ./data/pretrained/256x256_diffusion.pt
|
| 55 |
+
name: test_inet256_nn2
|
| 56 |
+
inpa_inj_sched_prev: true
|
| 57 |
+
n_jobs: 25
|
| 58 |
+
print_estimated_vars: true
|
| 59 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 60 |
+
schedule_jump_params:
|
| 61 |
+
t_T: 250
|
| 62 |
+
n_sample: 1
|
| 63 |
+
jump_length: 10
|
| 64 |
+
jump_n_sample: 10
|
| 65 |
+
data:
|
| 66 |
+
eval:
|
| 67 |
+
lama_inet256_nn2_n100_test:
|
| 68 |
+
mask_loader: true
|
| 69 |
+
gt_path: ./data/datasets/gts/inet256
|
| 70 |
+
mask_path: ./data/datasets/gt_keep_masks/nn2
|
| 71 |
+
image_size: 256
|
| 72 |
+
class_cond: false
|
| 73 |
+
deterministic: true
|
| 74 |
+
random_crop: false
|
| 75 |
+
random_flip: false
|
| 76 |
+
return_dict: true
|
| 77 |
+
drop_last: false
|
| 78 |
+
batch_size: 4
|
| 79 |
+
return_dataloader: true
|
| 80 |
+
ds_conf:
|
| 81 |
+
name: random_nn2_256
|
| 82 |
+
max_len: 100
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/test_inet256_nn2/inpainted
|
| 85 |
+
lrs: ./log/test_inet256_nn2/gt_masked
|
| 86 |
+
gts: ./log/test_inet256_nn2/gt
|
| 87 |
+
gt_keep_masks: ./log/test_inet256_nn2/gt_keep_mask
|
repaint/confs/test_inet256_thick.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: true
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: true
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 1.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
classifier_path: ./data/pretrained/256x256_classifier.pt
|
| 54 |
+
model_path: ./data/pretrained/256x256_diffusion.pt
|
| 55 |
+
name: test_inet256_thick
|
| 56 |
+
inpa_inj_sched_prev: true
|
| 57 |
+
n_jobs: 25
|
| 58 |
+
print_estimated_vars: true
|
| 59 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 60 |
+
schedule_jump_params:
|
| 61 |
+
t_T: 250
|
| 62 |
+
n_sample: 1
|
| 63 |
+
jump_length: 10
|
| 64 |
+
jump_n_sample: 10
|
| 65 |
+
data:
|
| 66 |
+
eval:
|
| 67 |
+
lama_inet256_thick_n100_test:
|
| 68 |
+
mask_loader: true
|
| 69 |
+
gt_path: ./data/datasets/gts/inet256
|
| 70 |
+
mask_path: ./data/datasets/gt_keep_masks/thick
|
| 71 |
+
image_size: 256
|
| 72 |
+
class_cond: false
|
| 73 |
+
deterministic: true
|
| 74 |
+
random_crop: false
|
| 75 |
+
random_flip: false
|
| 76 |
+
return_dict: true
|
| 77 |
+
drop_last: false
|
| 78 |
+
batch_size: 4
|
| 79 |
+
return_dataloader: true
|
| 80 |
+
ds_conf:
|
| 81 |
+
name: random_thick_256
|
| 82 |
+
max_len: 100
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/test_inet256_thick/inpainted
|
| 85 |
+
lrs: ./log/test_inet256_thick/gt_masked
|
| 86 |
+
gts: ./log/test_inet256_thick/gt
|
| 87 |
+
gt_keep_masks: ./log/test_inet256_thick/gt_keep_mask
|
repaint/confs/test_inet256_thin.yml
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: true
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: true
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 1.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
classifier_path: ./data/pretrained/256x256_classifier.pt
|
| 54 |
+
model_path: ./data/pretrained/256x256_diffusion.pt
|
| 55 |
+
name: test_inet256_thin
|
| 56 |
+
inpa_inj_sched_prev: true
|
| 57 |
+
n_jobs: 25
|
| 58 |
+
print_estimated_vars: true
|
| 59 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 60 |
+
schedule_jump_params:
|
| 61 |
+
t_T: 250
|
| 62 |
+
n_sample: 1
|
| 63 |
+
jump_length: 10
|
| 64 |
+
jump_n_sample: 10
|
| 65 |
+
data:
|
| 66 |
+
eval:
|
| 67 |
+
lama_inet256_thin_n100_test:
|
| 68 |
+
mask_loader: true
|
| 69 |
+
gt_path: ./data/datasets/gts/inet256
|
| 70 |
+
mask_path: ./data/datasets/gt_keep_masks/thin
|
| 71 |
+
image_size: 256
|
| 72 |
+
class_cond: false
|
| 73 |
+
deterministic: true
|
| 74 |
+
random_crop: false
|
| 75 |
+
random_flip: false
|
| 76 |
+
return_dict: true
|
| 77 |
+
drop_last: false
|
| 78 |
+
batch_size: 4
|
| 79 |
+
return_dataloader: true
|
| 80 |
+
ds_conf:
|
| 81 |
+
name: random_thin_256
|
| 82 |
+
max_len: 100
|
| 83 |
+
paths:
|
| 84 |
+
srs: ./log/test_inet256_thin/inpainted
|
| 85 |
+
lrs: ./log/test_inet256_thin/gt_masked
|
| 86 |
+
gts: ./log/test_inet256_thin/gt
|
| 87 |
+
gt_keep_masks: ./log/test_inet256_thin/gt_keep_mask
|
repaint/confs/test_p256_ev2li.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/places256_300000.pt
|
| 54 |
+
name: test_p256_ev2li
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_p256_ev2li_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/p256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/ev2li
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_ev2li_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_p256_ev2li/inpainted
|
| 84 |
+
lrs: ./log/test_p256_ev2li/gt_masked
|
| 85 |
+
gts: ./log/test_p256_ev2li/gt
|
| 86 |
+
gt_keep_masks: ./log/test_p256_ev2li/gt_keep_mask
|
repaint/confs/test_p256_ex64.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/places256_300000.pt
|
| 54 |
+
name: test_p256_ex64
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_p256_ex64_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/p256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/ex64
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_ex64_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_p256_ex64/inpainted
|
| 84 |
+
lrs: ./log/test_p256_ex64/gt_masked
|
| 85 |
+
gts: ./log/test_p256_ex64/gt
|
| 86 |
+
gt_keep_masks: ./log/test_p256_ex64/gt_keep_mask
|
repaint/confs/test_p256_genhalf.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/places256_300000.pt
|
| 54 |
+
name: test_p256_genhalf
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_p256_genhalf_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/p256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/genhalf
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_genhalf_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_p256_genhalf/inpainted
|
| 84 |
+
lrs: ./log/test_p256_genhalf/gt_masked
|
| 85 |
+
gts: ./log/test_p256_genhalf/gt
|
| 86 |
+
gt_keep_masks: ./log/test_p256_genhalf/gt_keep_mask
|
repaint/confs/test_p256_nn2.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/places256_300000.pt
|
| 54 |
+
name: test_p256_nn2
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_p256_nn2_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/p256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/nn2
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_nn2_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_p256_nn2/inpainted
|
| 84 |
+
lrs: ./log/test_p256_nn2/gt_masked
|
| 85 |
+
gts: ./log/test_p256_nn2/gt
|
| 86 |
+
gt_keep_masks: ./log/test_p256_nn2/gt_keep_mask
|
repaint/confs/test_p256_thick.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/places256_300000.pt
|
| 54 |
+
name: test_p256_thick
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_p256_thick_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/p256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/thick
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_thick_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_p256_thick/inpainted
|
| 84 |
+
lrs: ./log/test_p256_thick/gt_masked
|
| 85 |
+
gts: ./log/test_p256_thick/gt
|
| 86 |
+
gt_keep_masks: ./log/test_p256_thick/gt_keep_mask
|
repaint/confs/test_p256_thin.yml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
attention_resolutions: 32,16,8
|
| 18 |
+
class_cond: false
|
| 19 |
+
diffusion_steps: 1000
|
| 20 |
+
learn_sigma: true
|
| 21 |
+
noise_schedule: linear
|
| 22 |
+
num_channels: 256
|
| 23 |
+
num_head_channels: 64
|
| 24 |
+
num_heads: 4
|
| 25 |
+
num_res_blocks: 2
|
| 26 |
+
resblock_updown: true
|
| 27 |
+
use_fp16: false
|
| 28 |
+
use_scale_shift_norm: true
|
| 29 |
+
classifier_scale: 4.0
|
| 30 |
+
lr_kernel_n_std: 2
|
| 31 |
+
num_samples: 100
|
| 32 |
+
show_progress: true
|
| 33 |
+
timestep_respacing: '250'
|
| 34 |
+
use_kl: false
|
| 35 |
+
predict_xstart: false
|
| 36 |
+
rescale_timesteps: false
|
| 37 |
+
rescale_learned_sigmas: false
|
| 38 |
+
classifier_use_fp16: false
|
| 39 |
+
classifier_width: 128
|
| 40 |
+
classifier_depth: 2
|
| 41 |
+
classifier_attention_resolutions: 32,16,8
|
| 42 |
+
classifier_use_scale_shift_norm: true
|
| 43 |
+
classifier_resblock_updown: true
|
| 44 |
+
classifier_pool: attention
|
| 45 |
+
num_heads_upsample: -1
|
| 46 |
+
channel_mult: ''
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
use_checkpoint: false
|
| 49 |
+
use_new_attention_order: false
|
| 50 |
+
clip_denoised: true
|
| 51 |
+
use_ddim: false
|
| 52 |
+
image_size: 256
|
| 53 |
+
model_path: ./data/pretrained/places256_300000.pt
|
| 54 |
+
name: test_p256_thin
|
| 55 |
+
inpa_inj_sched_prev: true
|
| 56 |
+
n_jobs: 25
|
| 57 |
+
print_estimated_vars: true
|
| 58 |
+
inpa_inj_sched_prev_cumnoise: false
|
| 59 |
+
schedule_jump_params:
|
| 60 |
+
t_T: 250
|
| 61 |
+
n_sample: 1
|
| 62 |
+
jump_length: 10
|
| 63 |
+
jump_n_sample: 10
|
| 64 |
+
data:
|
| 65 |
+
eval:
|
| 66 |
+
lama_p256_thin_n100_test:
|
| 67 |
+
mask_loader: true
|
| 68 |
+
gt_path: ./data/datasets/gts/p256
|
| 69 |
+
mask_path: ./data/datasets/gt_keep_masks/thin
|
| 70 |
+
image_size: 256
|
| 71 |
+
class_cond: false
|
| 72 |
+
deterministic: true
|
| 73 |
+
random_crop: false
|
| 74 |
+
random_flip: false
|
| 75 |
+
return_dict: true
|
| 76 |
+
drop_last: false
|
| 77 |
+
batch_size: 4
|
| 78 |
+
return_dataloader: true
|
| 79 |
+
ds_conf:
|
| 80 |
+
name: random_thin_256
|
| 81 |
+
max_len: 100
|
| 82 |
+
paths:
|
| 83 |
+
srs: ./log/test_p256_thin/inpainted
|
| 84 |
+
lrs: ./log/test_p256_thin/gt_masked
|
| 85 |
+
gts: ./log/test_p256_thin/gt
|
| 86 |
+
gt_keep_masks: ./log/test_p256_thin/gt_keep_mask
|
repaint/download.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
(
|
| 4 |
+
mkdir -p data/pretrained
|
| 5 |
+
cd data/pretrained
|
| 6 |
+
|
| 7 |
+
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_classifier.pt # Trained by OpenAI
|
| 8 |
+
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt # Trained by OpenAI
|
| 9 |
+
|
| 10 |
+
gdown https://drive.google.com/uc?id=1norNWWGYP3EZ_o05DmoW1ryKuKMmhlCX
|
| 11 |
+
gdown https://drive.google.com/uc?id=1QEl-btGbzQz6IwkXiFGd49uQNTUtTHsk
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
# data
|
| 15 |
+
(
|
| 16 |
+
gdown https://drive.google.com/uc?id=1Q_dxuyI41AAmSv9ti3780BwaJQqwvwMv
|
| 17 |
+
unzip data.zip
|
| 18 |
+
rm data.zip
|
| 19 |
+
)
|
repaint/guided_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Based on "Improved Denoising Diffusion Probabilistic Models".
|
| 19 |
+
"""
|
repaint/guided_diffusion/dist_util.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 7 |
+
#
|
| 8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Helpers for distributed training.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import io
|
| 22 |
+
|
| 23 |
+
import blobfile as bf
|
| 24 |
+
import torch as th
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dev(device):
|
| 28 |
+
"""
|
| 29 |
+
Get the device to use for torch.distributed.
|
| 30 |
+
"""
|
| 31 |
+
if device is None:
|
| 32 |
+
if th.cuda.is_available():
|
| 33 |
+
return th.device(f"cuda")
|
| 34 |
+
return th.device("cpu")
|
| 35 |
+
return th.device(device)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_state_dict(path, backend=None, **kwargs):
|
| 39 |
+
with bf.BlobFile(path, "rb") as f:
|
| 40 |
+
data = f.read()
|
| 41 |
+
return th.load(io.BytesIO(data), **kwargs)
|
| 42 |
+
|
| 43 |
+
|