devanshsrivastav commited on
Commit
6f02465
·
verified ·
1 Parent(s): 0a37f67

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. cache/models--openai--clip-vit-large-patch14/blobs/a2bf730a0c7debf160f7a6b50b3aaf3703e7e88ac73de7a314903141db026dcb +3 -0
  2. cache/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/model.safetensors +3 -0
  3. stable_diffusion/assets/a-painting-of-a-fire.png +3 -0
  4. stable_diffusion/assets/a-photograph-of-a-fire.png +3 -0
  5. stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png +3 -0
  6. stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png +3 -0
  7. stable_diffusion/assets/a-watercolor-painting-of-a-fire.png +3 -0
  8. stable_diffusion/assets/birdhouse.png +3 -0
  9. stable_diffusion/assets/fire.png +3 -0
  10. stable_diffusion/assets/inpainting.png +3 -0
  11. stable_diffusion/assets/rdm-preview.jpg +3 -0
  12. stable_diffusion/assets/reconstruction1.png +3 -0
  13. stable_diffusion/assets/rick.jpeg +3 -0
  14. stable_diffusion/assets/stable-samples/img2img/mountains-1.png +3 -0
  15. stable_diffusion/assets/stable-samples/img2img/mountains-2.png +3 -0
  16. stable_diffusion/assets/stable-samples/img2img/mountains-3.png +3 -0
  17. stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg +3 -0
  18. stable_diffusion/assets/stable-samples/txt2img/000002025.png +3 -0
  19. stable_diffusion/assets/stable-samples/txt2img/000002035.png +3 -0
  20. stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png +3 -0
  21. stable_diffusion/assets/txt2img-convsample.png +3 -0
  22. stable_diffusion/ldm/data/base.py +41 -0
  23. stable_diffusion/ldm/data/imagenet.py +394 -0
  24. stable_diffusion/ldm/data/lsun.py +92 -0
  25. stable_diffusion/ldm/extras.py +77 -0
  26. stable_diffusion/ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  27. stable_diffusion/ldm/models/autoencoder.py +196 -0
  28. stable_diffusion/ldm/models/diffusion/__init__.py +0 -0
  29. stable_diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  30. stable_diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  31. stable_diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  32. stable_diffusion/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc +0 -0
  33. stable_diffusion/ldm/models/diffusion/classifier.py +269 -0
  34. stable_diffusion/ldm/models/diffusion/ddim.py +337 -0
  35. stable_diffusion/ldm/models/diffusion/ddpm.py +1946 -0
  36. stable_diffusion/ldm/models/diffusion/plms.py +258 -0
  37. stable_diffusion/ldm/models/diffusion/sampling_util.py +23 -0
  38. stable_diffusion/ldm/modules/image_degradation/utils/test.png +3 -0
  39. stable_diffusion/ldm/thirdp/psp/__pycache__/helpers.cpython-38.pyc +0 -0
  40. stable_diffusion/ldm/thirdp/psp/__pycache__/id_loss.cpython-38.pyc +0 -0
  41. stable_diffusion/ldm/thirdp/psp/__pycache__/model_irse.cpython-38.pyc +0 -0
  42. stable_diffusion/ldm/thirdp/psp/helpers.py +121 -0
  43. stable_diffusion/ldm/thirdp/psp/id_loss.py +25 -0
  44. stable_diffusion/ldm/thirdp/psp/model_irse.py +87 -0
  45. weights/Abstractionism.pth +3 -0
  46. weights/Artist_Sketch.pth +3 -0
  47. weights/Blossom_Season.pth +3 -0
  48. weights/Bricks.pth +3 -0
  49. weights/Cats.pth +3 -0
  50. weights/Color_Fantasy.pth +3 -0
cache/models--openai--clip-vit-large-patch14/blobs/a2bf730a0c7debf160f7a6b50b3aaf3703e7e88ac73de7a314903141db026dcb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2bf730a0c7debf160f7a6b50b3aaf3703e7e88ac73de7a314903141db026dcb
3
+ size 1710540580
cache/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2bf730a0c7debf160f7a6b50b3aaf3703e7e88ac73de7a314903141db026dcb
3
+ size 1710540580
stable_diffusion/assets/a-painting-of-a-fire.png ADDED

Git LFS Details

  • SHA256: c75723177e8435a568baefd9c910b812a830b12b5d994d23187b97c8e872fdee
  • Pointer size: 131 Bytes
  • Size of remote file: 667 kB
stable_diffusion/assets/a-photograph-of-a-fire.png ADDED

Git LFS Details

  • SHA256: 5bab36cd3f762cf7913b585a15309d72e5cf884da17219551400fd4abf03fd6a
  • Pointer size: 131 Bytes
  • Size of remote file: 611 kB
stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png ADDED

Git LFS Details

  • SHA256: 34edb946f340bf1c0ff2549e527b79d17221653bb4853f499c2096867021c9f8
  • Pointer size: 131 Bytes
  • Size of remote file: 624 kB
stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png ADDED

Git LFS Details

  • SHA256: 39e2467c508c87c83bf4665e0b35d65ce597d058153e302f35b7fe65a849a839
  • Pointer size: 131 Bytes
  • Size of remote file: 561 kB
stable_diffusion/assets/a-watercolor-painting-of-a-fire.png ADDED

Git LFS Details

  • SHA256: f6e4edf4ddfe7f9eaa74b61f923e804335cd10dfe6550ee72b908f90618de1ad
  • Pointer size: 131 Bytes
  • Size of remote file: 722 kB
stable_diffusion/assets/birdhouse.png ADDED

Git LFS Details

  • SHA256: 05b741598653c2964a7cc1610d121db3f94a069c02ab8135589b0aa77855f053
  • Pointer size: 131 Bytes
  • Size of remote file: 775 kB
stable_diffusion/assets/fire.png ADDED

Git LFS Details

  • SHA256: f16fac5dc01921f3838d2ec9959d64941cab29ab408909b2c3585c1bdc85c752
  • Pointer size: 131 Bytes
  • Size of remote file: 626 kB
stable_diffusion/assets/inpainting.png ADDED

Git LFS Details

  • SHA256: 2eecc2b72319ac95de5a0f1524b6a1c087e73d3c97183b642f7c78c972ea1205
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB
stable_diffusion/assets/rdm-preview.jpg ADDED

Git LFS Details

  • SHA256: c29d3eb98a39f91559a6831e970573a6acc02b499382d28d7ab1f7bb48b2c680
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
stable_diffusion/assets/reconstruction1.png ADDED

Git LFS Details

  • SHA256: 9c76a8d687565726afe12b0e244ca2ef4194c4ec0bb35b7a1d5d6a5c74302152
  • Pointer size: 131 Bytes
  • Size of remote file: 807 kB
stable_diffusion/assets/rick.jpeg ADDED

Git LFS Details

  • SHA256: 6eaa54f68ca3eaf97ab64292c3cd86353049cf1973254d553fcaf85cf84015e0
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
stable_diffusion/assets/stable-samples/img2img/mountains-1.png ADDED

Git LFS Details

  • SHA256: db02e9545f40d083f5202c4ce7ab948273a5399e5032cc0f7ed037b7e3fcfa93
  • Pointer size: 131 Bytes
  • Size of remote file: 625 kB
stable_diffusion/assets/stable-samples/img2img/mountains-2.png ADDED

Git LFS Details

  • SHA256: ffcfe5095f1bd7e2cc244f8ea55a03b45610942699d6b52aefbdce57054d1642
  • Pointer size: 131 Bytes
  • Size of remote file: 658 kB
stable_diffusion/assets/stable-samples/img2img/mountains-3.png ADDED

Git LFS Details

  • SHA256: 2b08c5eb8790bb06369c56635b57115e97b032af657b53a571e16c4438c7261a
  • Pointer size: 131 Bytes
  • Size of remote file: 656 kB
stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg ADDED

Git LFS Details

  • SHA256: fcce9c9c090fd6ea797ed2bfb7dd21b38241a435170461f59b7834fb3af6bf3a
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
stable_diffusion/assets/stable-samples/txt2img/000002025.png ADDED

Git LFS Details

  • SHA256: 427e10286097e0334a7721d9d61aed74a7386fcca97a3d49d8a41fbe93ecbb72
  • Pointer size: 131 Bytes
  • Size of remote file: 968 kB
stable_diffusion/assets/stable-samples/txt2img/000002035.png ADDED

Git LFS Details

  • SHA256: 8a72b7473d8c95252df70e84e9ed214e6c61ceacf07864783d9d08a2abd9bc6a
  • Pointer size: 131 Bytes
  • Size of remote file: 996 kB
stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png ADDED

Git LFS Details

  • SHA256: d9f67469d4fa49964c13e2a226bf50363d3f4525e340031932c7f24a29477217
  • Pointer size: 131 Bytes
  • Size of remote file: 678 kB
stable_diffusion/assets/txt2img-convsample.png ADDED

Git LFS Details

  • SHA256: cf914aa2e2f4c2ea3bdf0b2a210d2dc018f5bfabeb93703194bff2b842bb5afb
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB
stable_diffusion/ldm/data/base.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+
4
+
5
+ class Txt2ImgIterableBaseDataset(IterableDataset):
6
+ '''
7
+ Define an interface to make the IterableDatasets for text2img data chainable
8
+ '''
9
+ def __init__(self, num_records=0, valid_ids=None, size=256):
10
+ super().__init__()
11
+ self.num_records = num_records
12
+ self.valid_ids = valid_ids
13
+ self.sample_ids = valid_ids
14
+ self.size = size
15
+
16
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
+
18
+ def __len__(self):
19
+ return self.num_records
20
+
21
+ @abstractmethod
22
+ def __iter__(self):
23
+ pass
24
+
25
+ import os
26
+ import numpy as np
27
+
28
+
29
+ class PRNGMixin(object):
30
+ """
31
+ Adds a prng property which is a numpy RandomState which gets
32
+ reinitialized whenever the pid changes to avoid synchronized sampling
33
+ behavior when used in conjunction with multiprocessing.
34
+ """
35
+ @property
36
+ def prng(self):
37
+ currentpid = os.getpid()
38
+ if getattr(self, "_initpid", None) != currentpid:
39
+ self._initpid = currentpid
40
+ self._prng = np.random.RandomState()
41
+ return self._prng
stable_diffusion/ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
stable_diffusion/ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
stable_diffusion/ldm/extras.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ import sys
5
+ sys.path.append()
6
+ from ldm.util import instantiate_from_config
7
+ import logging
8
+ from contextlib import contextmanager
9
+
10
+ from contextlib import contextmanager
11
+ import logging
12
+
13
+ @contextmanager
14
+ def all_logging_disabled(highest_level=logging.CRITICAL):
15
+ """
16
+ A context manager that will prevent any logging messages
17
+ triggered during the body from being processed.
18
+ :param highest_level: the maximum logging level in use.
19
+ This would only need to be changed if a custom level greater than CRITICAL
20
+ is defined.
21
+ https://gist.github.com/simon-weber/7853144
22
+ """
23
+ # two kind-of hacks here:
24
+ # * can't get the highest logging level in effect => delegate to the user
25
+ # * can't get the current module-level override => use an undocumented
26
+ # (but non-private!) interface
27
+
28
+ previous_level = logging.root.manager.disable
29
+
30
+ logging.disable(highest_level)
31
+
32
+ try:
33
+ yield
34
+ finally:
35
+ logging.disable(previous_level)
36
+
37
+ def load_training_dir(train_dir, device, epoch="last"):
38
+ """Load a checkpoint and config from training directory"""
39
+ train_dir = Path(train_dir)
40
+ ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
41
+ assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
42
+ config = list(train_dir.rglob(f"*-project.yaml"))
43
+ assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
44
+ if len(config) > 1:
45
+ print(f"found {len(config)} matching config files")
46
+ config = sorted(config)[-1]
47
+ print(f"selecting {config}")
48
+ else:
49
+ config = config[0]
50
+
51
+
52
+ config = OmegaConf.load(config)
53
+ return load_model_from_config(config, ckpt[0], device)
54
+
55
+ def load_model_from_config(config, ckpt, device="cpu", verbose=False):
56
+ """Loads a model from config and a ckpt
57
+ if config is a path will use omegaconf to load
58
+ """
59
+ if isinstance(config, (str, Path)):
60
+ config = OmegaConf.load(config)
61
+
62
+ with all_logging_disabled():
63
+ print(f"Loading model from {ckpt}")
64
+ pl_sd = torch.load(ckpt, map_location="cpu")
65
+ global_step = pl_sd["global_step"]
66
+ sd = pl_sd["state_dict"]
67
+ model = instantiate_from_config(config.model)
68
+ m, u = model.load_state_dict(sd, strict=False)
69
+ if len(m) > 0 and verbose:
70
+ print("missing keys:")
71
+ print(m)
72
+ if len(u) > 0 and verbose:
73
+ print("unexpected keys:")
74
+ model.to(device)
75
+ model.eval()
76
+ model.cond_stage_model.device = device
77
+ return model
stable_diffusion/ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (6.2 kB). View file
 
stable_diffusion/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+
5
+ import sys
6
+ sys.path.append('.')
7
+
8
+ from stable_diffusion.ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from stable_diffusion.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+ from stable_diffusion.ldm.util import instantiate_from_config
11
+
12
+
13
+ class AutoencoderKL(pl.LightningModule):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig, # torch.nn.Identity
17
+ embed_dim, # embed_dim = 4
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None, # This is None
22
+ monitor=None, # val/rec_loss
23
+ ):
24
+ super().__init__()
25
+ self.image_key = image_key # 'image'
26
+
27
+ # The encoder and decoder are reverse in the VQVAE
28
+ # The encoder encodes the image to a latent space, and then transfer it to a Gaussian Distribution
29
+ self.encoder = Encoder(**ddconfig)
30
+ # Note, the output of the encoder is NOT directly fed into the decoder. The output channel size of the encoder is 2 * z_channel, as identified by the ddconfig['double_z']. This is becuase the output of the encoder is used to construct a Gaussian Distribution
31
+ # The decoder decodes the latent space to an image
32
+ self.decoder = Decoder(**ddconfig)
33
+
34
+ # torch.nn.Identity
35
+ self.loss = instantiate_from_config(lossconfig) # Identity function
36
+
37
+ # double_z = True.
38
+ assert ddconfig["double_z"]
39
+
40
+ # z_channels = 4
41
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
42
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
43
+
44
+ # embed_dim = 4
45
+ self.embed_dim = embed_dim
46
+
47
+ # colorize_nlabels is None
48
+ if colorize_nlabels is not None:
49
+ assert type(colorize_nlabels)==int
50
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
51
+
52
+ # monitor = val/rec_loss
53
+ if monitor is not None:
54
+ self.monitor = monitor
55
+
56
+ # ckpt_path = None, the checkpoint loading of stable diffusion is conducted outside
57
+ if ckpt_path is not None:
58
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
59
+
60
+ def init_from_ckpt(self, path, ignore_keys=list()):
61
+ sd = torch.load(path, map_location="cpu")["state_dict"]
62
+ keys = list(sd.keys())
63
+ for k in keys:
64
+ for ik in ignore_keys:
65
+ if k.startswith(ik):
66
+ print("Deleting key {} from state_dict.".format(k))
67
+ del sd[k]
68
+ self.load_state_dict(sd, strict=False)
69
+ print(f"Restored from {path}")
70
+
71
+ def encode(self, x):
72
+ # x: [bs, 3, 256, 256], h: [bs, 8, 32, 32]
73
+ h = self.encoder(x)
74
+ # serves as the mean and variance of the Gaussian distribution (halve the last dim)
75
+ # moments: [bs, 8, 32, 32]
76
+ moments = self.quant_conv(h)
77
+ posterior = DiagonalGaussianDistribution(moments)
78
+ return posterior
79
+
80
+ def decode(self, z):
81
+ # z: [bs, 4, 32, 32]
82
+ z = self.post_quant_conv(z)
83
+ # z: [bs, 4, 32, 32]
84
+ dec = self.decoder(z)
85
+ # dec: [bs, 3, 256, 256]
86
+ return dec
87
+
88
+ def forward(self, input, sample_posterior=True):
89
+ posterior = self.encode(input)
90
+ if sample_posterior:
91
+ z = posterior.sample() # a normal sampling
92
+ else:
93
+ z = posterior.mode() # returns the mean
94
+ dec = self.decode(z)
95
+ return dec, posterior
96
+
97
+ def get_input(self, batch, k):
98
+ x = batch[k]
99
+ if len(x.shape) == 3:
100
+ x = x[..., None]
101
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
102
+ return x
103
+
104
+ def training_step(self, batch, batch_idx, optimizer_idx): # in Stable Diffusion we use pretrained VAE and freeze it.
105
+ inputs = self.get_input(batch, self.image_key)
106
+ reconstructions, posterior = self(inputs)
107
+
108
+ if optimizer_idx == 0:
109
+ # train encoder+decoder+logvar
110
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
111
+ last_layer=self.get_last_layer(), split="train")
112
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
113
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
114
+ return aeloss
115
+
116
+ if optimizer_idx == 1:
117
+ # train the discriminator
118
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
119
+ last_layer=self.get_last_layer(), split="train")
120
+
121
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
122
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
123
+ return discloss
124
+
125
+ def validation_step(self, batch, batch_idx):
126
+ inputs = self.get_input(batch, self.image_key)
127
+ reconstructions, posterior = self(inputs)
128
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
129
+ last_layer=self.get_last_layer(), split="val")
130
+
131
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
132
+ last_layer=self.get_last_layer(), split="val")
133
+
134
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
135
+ self.log_dict(log_dict_ae)
136
+ self.log_dict(log_dict_disc)
137
+ return self.log_dict
138
+
139
+ def configure_optimizers(self):
140
+ lr = self.learning_rate
141
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
142
+ list(self.decoder.parameters())+
143
+ list(self.quant_conv.parameters())+
144
+ list(self.post_quant_conv.parameters()),
145
+ lr=lr, betas=(0.5, 0.9))
146
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
147
+ lr=lr, betas=(0.5, 0.9))
148
+ return [opt_ae, opt_disc], []
149
+
150
+ def get_last_layer(self):
151
+ return self.decoder.conv_out.weight
152
+
153
+ @torch.no_grad()
154
+ def log_images(self, batch, only_inputs=False, **kwargs):
155
+ log = dict()
156
+ x = self.get_input(batch, self.image_key)
157
+ x = x.to(self.device)
158
+ if not only_inputs:
159
+ xrec, posterior = self(x)
160
+ if x.shape[1] > 3:
161
+ # colorize with random projection
162
+ assert xrec.shape[1] > 3
163
+ x = self.to_rgb(x)
164
+ xrec = self.to_rgb(xrec)
165
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
166
+ log["reconstructions"] = xrec
167
+ log["inputs"] = x
168
+ return log
169
+
170
+ def to_rgb(self, x):
171
+ assert self.image_key == "segmentation"
172
+ if not hasattr(self, "colorize"):
173
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
174
+ x = F.conv2d(x, weight=self.colorize)
175
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
176
+ return x
177
+
178
+
179
+ class IdentityFirstStage(torch.nn.Module):
180
+ def __init__(self, *args, vq_interface=False, **kwargs):
181
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
182
+ super().__init__()
183
+
184
+ def encode(self, x, *args, **kwargs):
185
+ return x
186
+
187
+ def decode(self, x, *args, **kwargs):
188
+ return x
189
+
190
+ def quantize(self, x, *args, **kwargs):
191
+ if self.vq_interface:
192
+ return x, None, [None, None, None]
193
+ return x
194
+
195
+ def forward(self, x, *args, **kwargs):
196
+ return x
stable_diffusion/ldm/models/diffusion/__init__.py ADDED
File without changes
stable_diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (207 Bytes). View file
 
stable_diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (9.42 kB). View file
 
stable_diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc ADDED
Binary file (56.7 kB). View file
 
stable_diffusion/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc ADDED
Binary file (1.13 kB). View file
 
stable_diffusion/ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+ import sys
13
+ sys.path.append('.')
14
+
15
+ from stable_diffusion.ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
16
+ from stable_diffusion.ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
17
+
18
+ __models__ = {
19
+ 'class_label': EncoderUNetModel,
20
+ 'segmentation': UNetModel
21
+ }
22
+
23
+
24
+ def disabled_train(self, mode=True):
25
+ """Overwrite model.train with this function to make sure train/eval mode
26
+ does not change anymore."""
27
+ return self
28
+
29
+
30
+ class NoisyLatentImageClassifier(pl.LightningModule):
31
+
32
+ def __init__(self,
33
+ diffusion_path,
34
+ num_classes,
35
+ ckpt_path=None,
36
+ pool='attention',
37
+ label_key=None,
38
+ diffusion_ckpt_path=None,
39
+ scheduler_config=None,
40
+ weight_decay=1.e-2,
41
+ log_steps=10,
42
+ monitor='val/loss',
43
+ *args,
44
+ **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self.num_classes = num_classes
47
+ # get latest config of diffusion model
48
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
49
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
50
+ self.diffusion_config.params.ckpt = diffusion_ckpt_path
51
+ self.load_diffusion()
52
+
53
+ self.monitor = monitor
54
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
55
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
56
+ self.log_steps = log_steps
57
+
58
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
59
+ else self.diffusion_model.cond_stage_key
60
+
61
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
62
+
63
+ if self.label_key not in __models__:
64
+ raise NotImplementedError()
65
+
66
+ self.load_classifier(ckpt_path, pool)
67
+
68
+ self.scheduler_config = scheduler_config
69
+ self.use_scheduler = self.scheduler_config is not None
70
+ self.weight_decay = weight_decay
71
+
72
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
73
+ sd = torch.load(path, map_location="cpu")
74
+ if "state_dict" in list(sd.keys()):
75
+ sd = sd["state_dict"]
76
+ keys = list(sd.keys())
77
+ for k in keys:
78
+ for ik in ignore_keys:
79
+ if k.startswith(ik):
80
+ print("Deleting key {} from state_dict.".format(k))
81
+ del sd[k]
82
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
83
+ sd, strict=False)
84
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
85
+ if len(missing) > 0:
86
+ print(f"Missing Keys: {missing}")
87
+ if len(unexpected) > 0:
88
+ print(f"Unexpected Keys: {unexpected}")
89
+
90
+ def load_diffusion(self):
91
+ model = instantiate_from_config(self.diffusion_config)
92
+ self.diffusion_model = model.eval()
93
+ self.diffusion_model.train = disabled_train
94
+ for param in self.diffusion_model.parameters():
95
+ param.requires_grad = False
96
+
97
+ def load_classifier(self, ckpt_path, pool):
98
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
99
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
100
+ model_config.out_channels = self.num_classes
101
+ if self.label_key == 'class_label':
102
+ model_config.pool = pool
103
+
104
+ self.model = __models__[self.label_key](**model_config)
105
+ if ckpt_path is not None:
106
+ print('#####################################################################')
107
+ print(f'load from ckpt "{ckpt_path}"')
108
+ print('#####################################################################')
109
+ self.init_from_ckpt(ckpt_path)
110
+
111
+ @torch.no_grad()
112
+ def get_x_noisy(self, x, t, noise=None):
113
+ noise = default(noise, lambda: torch.randn_like(x))
114
+ continuous_sqrt_alpha_cumprod = None
115
+ if self.diffusion_model.use_continuous_noise:
116
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
117
+ # todo: make sure t+1 is correct here
118
+
119
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
120
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
121
+
122
+ def forward(self, x_noisy, t, *args, **kwargs):
123
+ return self.model(x_noisy, t)
124
+
125
+ @torch.no_grad()
126
+ def get_input(self, batch, k):
127
+ x = batch[k]
128
+ if len(x.shape) == 3:
129
+ x = x[..., None]
130
+ x = rearrange(x, 'b h w c -> b c h w')
131
+ x = x.to(memory_format=torch.contiguous_format).float()
132
+ return x
133
+
134
+ @torch.no_grad()
135
+ def get_conditioning(self, batch, k=None):
136
+ if k is None:
137
+ k = self.label_key
138
+ assert k is not None, 'Needs to provide label key'
139
+
140
+ targets = batch[k].to(self.device)
141
+
142
+ if self.label_key == 'segmentation':
143
+ targets = rearrange(targets, 'b h w c -> b c h w')
144
+ for down in range(self.numd):
145
+ h, w = targets.shape[-2:]
146
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
147
+
148
+ # targets = rearrange(targets,'b c h w -> b h w c')
149
+
150
+ return targets
151
+
152
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
153
+ _, top_ks = torch.topk(logits, k, dim=1)
154
+ if reduction == "mean":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
156
+ elif reduction == "none":
157
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
158
+
159
+ def on_train_epoch_start(self):
160
+ # save some memory
161
+ self.diffusion_model.model.to('cpu')
162
+
163
+ @torch.no_grad()
164
+ def write_logs(self, loss, logits, targets):
165
+ log_prefix = 'train' if self.training else 'val'
166
+ log = {}
167
+ log[f"{log_prefix}/loss"] = loss.mean()
168
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
169
+ logits, targets, k=1, reduction="mean"
170
+ )
171
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
172
+ logits, targets, k=5, reduction="mean"
173
+ )
174
+
175
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
176
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
177
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
178
+ lr = self.optimizers().param_groups[0]['lr']
179
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
180
+
181
+ def shared_step(self, batch, t=None):
182
+ x, *_ = self.diffusion_model.get_input(batch, key=self.diffusion_model.first_stage_key)
183
+ targets = self.get_conditioning(batch)
184
+ if targets.dim() == 4:
185
+ targets = targets.argmax(dim=1)
186
+ if t is None:
187
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
188
+ else:
189
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
190
+ x_noisy = self.get_x_noisy(x, t)
191
+ logits = self(x_noisy, t)
192
+
193
+ loss = F.cross_entropy(logits, targets, reduction='none')
194
+
195
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
196
+
197
+ loss = loss.mean()
198
+ return loss, logits, x_noisy, targets
199
+
200
+ def training_step(self, batch, batch_idx):
201
+ loss, *_ = self.shared_step(batch)
202
+ return loss
203
+
204
+ def reset_noise_accs(self):
205
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
206
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
207
+
208
+ def on_validation_start(self):
209
+ self.reset_noise_accs()
210
+
211
+ @torch.no_grad()
212
+ def validation_step(self, batch, batch_idx):
213
+ loss, *_ = self.shared_step(batch)
214
+
215
+ for t in self.noisy_acc:
216
+ _, logits, _, targets = self.shared_step(batch, t)
217
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
218
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
219
+
220
+ return loss
221
+
222
+ def configure_optimizers(self):
223
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
224
+
225
+ if self.use_scheduler:
226
+ scheduler = instantiate_from_config(self.scheduler_config)
227
+
228
+ print("Setting up LambdaLR scheduler...")
229
+ scheduler = [
230
+ {
231
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
232
+ 'interval': 'step',
233
+ 'frequency': 1
234
+ }]
235
+ return [optimizer], scheduler
236
+
237
+ return optimizer
238
+
239
+ @torch.no_grad()
240
+ def log_images(self, batch, N=8, *args, **kwargs):
241
+ log = dict()
242
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
243
+ log['inputs'] = x
244
+
245
+ y = self.get_conditioning(batch)
246
+
247
+ if self.label_key == 'class_label':
248
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
249
+ log['labels'] = y
250
+
251
+ if ismap(y):
252
+ log['labels'] = self.diffusion_model.to_rgb(y)
253
+
254
+ for step in range(self.log_steps):
255
+ current_time = step * self.log_time_interval
256
+
257
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
258
+
259
+ log[f'inputs@t{current_time}'] = x_noisy
260
+
261
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
262
+ pred = rearrange(pred, 'b h w c -> b c h w')
263
+
264
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
265
+
266
+ for key in log:
267
+ log[key] = log[key][:N]
268
+
269
+ return log
stable_diffusion/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import sys
7
+ from einops import rearrange
8
+
9
+ sys.path.append('.')
10
+
11
+ from stable_diffusion.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
12
+ extract_into_tensor
13
+ from stable_diffusion.ldm.models.diffusion.sampling_util import norm_thresholding
14
+
15
+
16
+ class DDIMSampler(object):
17
+ def __init__(self, model, schedule="linear", **kwargs):
18
+ super().__init__()
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def to(self, device):
24
+ """Same as to in torch module
25
+ Don't really underestand why this isn't a module in the first place"""
26
+ for k, v in self.__dict__.items():
27
+ if isinstance(v, torch.Tensor):
28
+ new_v = getattr(self, k).to(device)
29
+ setattr(self, k, new_v)
30
+
31
+ def register_buffer(self, name, attr):
32
+ if type(attr) == torch.Tensor:
33
+ if attr.device != torch.device("cuda"):
34
+ attr = attr.to(torch.device("cuda"))
35
+ setattr(self, name, attr)
36
+
37
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
38
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
39
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
40
+ alphas_cumprod = self.model.alphas_cumprod
41
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer('betas', to_torch(self.model.betas))
45
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
46
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
47
+
48
+ # calculations for diffusion q(x_t | x_{t-1}) and others
49
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
50
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
51
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
52
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
53
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
54
+
55
+ # ddim sampling parameters
56
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
57
+ ddim_timesteps=self.ddim_timesteps,
58
+ eta=ddim_eta,verbose=verbose)
59
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
60
+ self.register_buffer('ddim_alphas', ddim_alphas)
61
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
62
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
63
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
64
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
65
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
66
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
67
+
68
+ # @torch.no_grad()
69
+ def sample(self,
70
+ S,
71
+ batch_size,
72
+ shape,
73
+ conditioning=None,
74
+ callback=None,
75
+ normals_sequence=None,
76
+ img_callback=None,
77
+ quantize_x0=False,
78
+ eta=0.,
79
+ mask=None,
80
+ x0=None,
81
+ temperature=1.,
82
+ noise_dropout=0.,
83
+ score_corrector=None,
84
+ corrector_kwargs=None,
85
+ verbose=True,
86
+ x_T=None,
87
+ t_start=-1,
88
+ log_every_t=100,
89
+ unconditional_guidance_scale=1.,
90
+ unconditional_conditioning=None,
91
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
92
+ dynamic_threshold=None,
93
+ till_T=None,
94
+ verbose_iter=False,
95
+ **kwargs
96
+ ):
97
+ if conditioning is not None:
98
+ if isinstance(conditioning, dict):
99
+ ctmp = conditioning[list(conditioning.keys())[0]]
100
+ while isinstance(ctmp, list): ctmp = ctmp[0]
101
+ cbs = ctmp.shape[0]
102
+ if cbs != batch_size:
103
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
104
+ # else:
105
+ # if conditioning.shape[0] != batch_size:
106
+ # print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
107
+
108
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
109
+ # sampling
110
+ C, H, W = shape
111
+ size = (batch_size, C, H, W)
112
+ if verbose_iter:
113
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
114
+ samples, intermediates = self.ddim_sampling(conditioning, size,
115
+ callback=callback,
116
+ img_callback=img_callback,
117
+ quantize_denoised=quantize_x0,
118
+ mask=mask, x0=x0,
119
+ ddim_use_original_steps=False,
120
+ noise_dropout=noise_dropout,
121
+ temperature=temperature,
122
+ score_corrector=score_corrector,
123
+ corrector_kwargs=corrector_kwargs,
124
+ x_T=x_T,
125
+ log_every_t=log_every_t,
126
+ unconditional_guidance_scale=unconditional_guidance_scale,
127
+ unconditional_conditioning=unconditional_conditioning,
128
+ dynamic_threshold=dynamic_threshold,
129
+ till_T=till_T,
130
+ verbose_iter=verbose_iter,
131
+ t_start=t_start
132
+ )
133
+ return samples, intermediates
134
+
135
+ # @torch.no_grad()
136
+ def ddim_sampling(self, cond, shape,
137
+ x_T=None, ddim_use_original_steps=False,
138
+ callback=None, timesteps=None, quantize_denoised=False,
139
+ mask=None, x0=None, img_callback=None, log_every_t=100,
140
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
141
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
142
+ t_start=-1, till_T=None, verbose_iter=True):
143
+ device = self.model.betas.device
144
+ b = shape[0]
145
+ if x_T is None:
146
+ img = torch.randn(shape, device=device)
147
+ else:
148
+ img = x_T
149
+
150
+ if timesteps is None:
151
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
152
+ elif timesteps is not None and not ddim_use_original_steps:
153
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
154
+ timesteps = self.ddim_timesteps[:subset_end]
155
+
156
+ timesteps = timesteps[:t_start]
157
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
158
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
159
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
160
+ if verbose_iter:
161
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
162
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
163
+ else:
164
+ iterator = time_range
165
+ if till_T is not None:
166
+ till = till_T
167
+ else:
168
+ till = 0
169
+ for i, step in enumerate(iterator):
170
+ index = total_steps - i - 1
171
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
172
+
173
+ if mask is not None:
174
+ assert x0 is not None
175
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
176
+ img = img_orig * mask + (1. - mask) * img
177
+
178
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
179
+ quantize_denoised=quantize_denoised, temperature=temperature,
180
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
181
+ corrector_kwargs=corrector_kwargs,
182
+ unconditional_guidance_scale=unconditional_guidance_scale,
183
+ unconditional_conditioning=unconditional_conditioning,
184
+ dynamic_threshold=dynamic_threshold)
185
+ img, pred_x0 = outs
186
+ if callback:
187
+ img = callback(i, img, pred_x0)
188
+ if img_callback: img_callback(pred_x0, i)
189
+
190
+ if index % log_every_t == 0 or index == total_steps - 1:
191
+ intermediates['x_inter'].append(img)
192
+ intermediates['pred_x0'].append(pred_x0)
193
+
194
+ return img, intermediates
195
+
196
+ # @torch.no_grad()
197
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
200
+ dynamic_threshold=None):
201
+ b, *_, device = *x.shape, x.device
202
+
203
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
204
+ e_t = self.model.apply_model(x, t, c)
205
+ else:
206
+ x_in = torch.cat([x] * 2)
207
+ t_in = torch.cat([t] * 2)
208
+ if isinstance(c, dict):
209
+ assert isinstance(unconditional_conditioning, dict)
210
+ # print(f'C: {c}')
211
+ c_in = dict()
212
+ for k in c:
213
+ if isinstance(c[k], list):
214
+ c_in[k] = [torch.cat([
215
+ unconditional_conditioning[k][i],
216
+ c[k][i]]) for i in range(len(c[k]))]
217
+ else:
218
+ c_in[k] = torch.cat([
219
+ unconditional_conditioning[k],
220
+ c[k]])
221
+ else:
222
+ c_in = torch.cat([unconditional_conditioning, c])
223
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
224
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
225
+
226
+ if score_corrector is not None:
227
+ assert self.model.parameterization == "eps"
228
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
229
+
230
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
231
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
232
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
233
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
234
+ # select parameters corresponding to the currently considered timestep
235
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
236
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
237
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
238
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
239
+
240
+ # current prediction for x_0
241
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
242
+ if quantize_denoised:
243
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
244
+
245
+ if dynamic_threshold is not None:
246
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
247
+
248
+ # direction pointing to x_t
249
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
250
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
251
+ if noise_dropout > 0.:
252
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
253
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
254
+ return x_prev, pred_x0
255
+
256
+ @torch.no_grad()
257
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
258
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None):
259
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
260
+
261
+ assert t_enc <= num_reference_steps
262
+ num_steps = t_enc
263
+
264
+ if use_original_steps:
265
+ alphas_next = self.alphas_cumprod[:num_steps]
266
+ alphas = self.alphas_cumprod_prev[:num_steps]
267
+ else:
268
+ alphas_next = self.ddim_alphas[:num_steps]
269
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
270
+
271
+ x_next = x0
272
+ intermediates = []
273
+ inter_steps = []
274
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
275
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
276
+ if unconditional_guidance_scale == 1.:
277
+ noise_pred = self.model.apply_model(x_next, t, c)
278
+ else:
279
+ assert unconditional_conditioning is not None
280
+ e_t_uncond, noise_pred = torch.chunk(
281
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
282
+ torch.cat((unconditional_conditioning, c))), 2)
283
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
284
+
285
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
286
+ weighted_noise_pred = alphas_next[i].sqrt() * (
287
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
288
+ x_next = xt_weighted + weighted_noise_pred
289
+ if return_intermediates and i % (
290
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
291
+ intermediates.append(x_next)
292
+ inter_steps.append(i)
293
+ elif return_intermediates and i >= num_steps - 2:
294
+ intermediates.append(x_next)
295
+ inter_steps.append(i)
296
+
297
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
298
+ if return_intermediates:
299
+ out.update({'intermediates': intermediates})
300
+ return x_next, out
301
+
302
+ @torch.no_grad()
303
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
304
+ # fast, but does not allow for exact reconstruction
305
+ # t serves as an index to gather the correct alphas
306
+ if use_original_steps:
307
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
308
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
309
+ else:
310
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
311
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
312
+
313
+ if noise is None:
314
+ noise = torch.randn_like(x0)
315
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
316
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
317
+
318
+ @torch.no_grad()
319
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
320
+ use_original_steps=False):
321
+
322
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
323
+ timesteps = timesteps[:t_start]
324
+
325
+ time_range = np.flip(timesteps)
326
+ total_steps = timesteps.shape[0]
327
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
328
+
329
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
330
+ x_dec = x_latent
331
+ for i, step in enumerate(iterator):
332
+ index = total_steps - i - 1
333
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
334
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
335
+ unconditional_guidance_scale=unconditional_guidance_scale,
336
+ unconditional_conditioning=unconditional_conditioning)
337
+ return x_dec
stable_diffusion/ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager, nullcontext
16
+ from functools import partial
17
+ import itertools
18
+ from tqdm import tqdm
19
+ from torchvision.utils import make_grid
20
+ from pytorch_lightning.utilities.distributed import rank_zero_only
21
+ from omegaconf import ListConfig
22
+ import sys
23
+ sys.path.append('.')
24
+
25
+ from stable_diffusion.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
26
+ from stable_diffusion.ldm.modules.ema import LitEma
27
+ from stable_diffusion.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
28
+ from stable_diffusion.ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
29
+ from stable_diffusion.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
30
+ from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler
31
+ from stable_diffusion.ldm.modules.attention import CrossAttention
32
+
33
+
34
+ __conditioning_keys__ = {'concat': 'c_concat',
35
+ 'crossattn': 'c_crossattn',
36
+ 'adm': 'y'}
37
+
38
+
39
+ def disabled_train(self, mode=True):
40
+ """Overwrite model.train with this function to make sure train/eval mode
41
+ does not change anymore."""
42
+ return self
43
+
44
+
45
+ def uniform_on_device(r1, r2, shape, device):
46
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
47
+
48
+
49
+ class DDPM(pl.LightningModule):
50
+ # classic DDPM with Gaussian diffusion, in image space
51
+ def __init__(self,
52
+ unet_config,
53
+ timesteps=1000,
54
+ beta_schedule="linear",
55
+ loss_type="l2",
56
+ ckpt_path=None,
57
+ ignore_keys=[],
58
+ load_only_unet=False,
59
+ monitor="val/loss",
60
+ use_ema=True,
61
+ first_stage_key="image",
62
+ image_size=256,
63
+ channels=3,
64
+ log_every_t=100,
65
+ clip_denoised=True,
66
+ linear_start=1e-4,
67
+ linear_end=2e-2,
68
+ cosine_s=8e-3,
69
+ given_betas=None,
70
+ original_elbo_weight=0.,
71
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
72
+ l_simple_weight=1.,
73
+ conditioning_key=None,
74
+ parameterization="eps", # all assuming fixed variance schedules
75
+ scheduler_config=None,
76
+ use_positional_encodings=False,
77
+ learn_logvar=False,
78
+ logvar_init=0.,
79
+ make_it_fit=False,
80
+ ucg_training=None,
81
+ load_ema=False,
82
+ ):
83
+ super().__init__()
84
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
85
+ self.parameterization = parameterization
86
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
87
+ self.cond_stage_model = None
88
+ self.clip_denoised = clip_denoised
89
+ self.log_every_t = log_every_t
90
+ self.first_stage_key = first_stage_key
91
+ self.image_size = image_size # try conv?
92
+ self.channels = channels
93
+ self.use_positional_encodings = use_positional_encodings
94
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
95
+ count_params(self.model, verbose=True)
96
+ self.use_ema = use_ema
97
+ if self.use_ema:
98
+ self.model_ema = LitEma(self.model)
99
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
100
+
101
+ self.use_scheduler = scheduler_config is not None
102
+ if self.use_scheduler:
103
+ self.scheduler_config = scheduler_config
104
+
105
+ self.v_posterior = v_posterior
106
+ self.original_elbo_weight = original_elbo_weight
107
+ self.l_simple_weight = l_simple_weight
108
+
109
+ if monitor is not None:
110
+ self.monitor = monitor
111
+ self.make_it_fit = make_it_fit
112
+ if ckpt_path is not None:
113
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
114
+
115
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
116
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
117
+
118
+ self.loss_type = loss_type
119
+
120
+ self.learn_logvar = learn_logvar
121
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
122
+ if self.learn_logvar:
123
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
124
+
125
+ self.ucg_training = ucg_training or dict()
126
+ if self.ucg_training:
127
+ self.ucg_prng = np.random.RandomState()
128
+
129
+
130
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
131
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
132
+ if exists(given_betas):
133
+ betas = given_betas
134
+ else:
135
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
136
+ cosine_s=cosine_s)
137
+ alphas = 1. - betas
138
+ alphas_cumprod = np.cumprod(alphas, axis=0)
139
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
140
+
141
+ timesteps, = betas.shape
142
+ self.num_timesteps = int(timesteps)
143
+ self.linear_start = linear_start
144
+ self.linear_end = linear_end
145
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
146
+
147
+ to_torch = partial(torch.tensor, dtype=torch.float32)
148
+
149
+ self.register_buffer('betas', to_torch(betas))
150
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
151
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
152
+
153
+ # calculations for diffusion q(x_t | x_{t-1}) and others
154
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
155
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
156
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
157
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
158
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
159
+
160
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
161
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
162
+ 1. - alphas_cumprod) + self.v_posterior * betas
163
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
164
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
165
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
166
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
167
+ self.register_buffer('posterior_mean_coef1', to_torch(
168
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
169
+ self.register_buffer('posterior_mean_coef2', to_torch(
170
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
171
+
172
+ if self.parameterization == "eps":
173
+ lvlb_weights = self.betas ** 2 / (
174
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
175
+ elif self.parameterization == "x0":
176
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
177
+ else:
178
+ raise NotImplementedError("mu not supported")
179
+ # TODO how to choose this term
180
+ lvlb_weights[0] = lvlb_weights[1]
181
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
182
+ assert not torch.isnan(self.lvlb_weights).all()
183
+
184
+ @contextmanager
185
+ def ema_scope(self, context=None):
186
+ if self.use_ema:
187
+ self.model_ema.store(self.model.parameters())
188
+ self.model_ema.copy_to(self.model)
189
+ if context is not None:
190
+ print(f"{context}: Switched to EMA weights")
191
+ try:
192
+ yield None
193
+ finally:
194
+ if self.use_ema:
195
+ self.model_ema.restore(self.model.parameters())
196
+ if context is not None:
197
+ print(f"{context}: Restored training weights")
198
+
199
+ @torch.no_grad()
200
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
201
+ sd = torch.load(path, map_location="cpu")
202
+ if "state_dict" in list(sd.keys()):
203
+ sd = sd["state_dict"]
204
+ keys = list(sd.keys())
205
+ for k in keys:
206
+ for ik in ignore_keys:
207
+ if k.startswith(ik):
208
+ print("Deleting key {} from state_dict.".format(k))
209
+ del sd[k]
210
+ if self.make_it_fit:
211
+ n_params = len([name for name, _ in
212
+ itertools.chain(self.named_parameters(),
213
+ self.named_buffers())])
214
+ for name, param in tqdm(
215
+ itertools.chain(self.named_parameters(),
216
+ self.named_buffers()),
217
+ desc="Fitting old weights to new weights",
218
+ total=n_params
219
+ ):
220
+ if not name in sd:
221
+ continue
222
+ old_shape = sd[name].shape
223
+ new_shape = param.shape
224
+ assert len(old_shape)==len(new_shape)
225
+ if len(new_shape) > 2:
226
+ # we only modify first two axes
227
+ assert new_shape[2:] == old_shape[2:]
228
+ # assumes first axis corresponds to output dim
229
+ if not new_shape == old_shape:
230
+ new_param = param.clone()
231
+ old_param = sd[name]
232
+ if len(new_shape) == 1:
233
+ for i in range(new_param.shape[0]):
234
+ new_param[i] = old_param[i % old_shape[0]]
235
+ elif len(new_shape) >= 2:
236
+ for i in range(new_param.shape[0]):
237
+ for j in range(new_param.shape[1]):
238
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
239
+
240
+ n_used_old = torch.ones(old_shape[1])
241
+ for j in range(new_param.shape[1]):
242
+ n_used_old[j % old_shape[1]] += 1
243
+ n_used_new = torch.zeros(new_shape[1])
244
+ for j in range(new_param.shape[1]):
245
+ n_used_new[j] = n_used_old[j % old_shape[1]]
246
+
247
+ n_used_new = n_used_new[None, :]
248
+ while len(n_used_new.shape) < len(new_shape):
249
+ n_used_new = n_used_new.unsqueeze(-1)
250
+ new_param /= n_used_new
251
+
252
+ sd[name] = new_param
253
+
254
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
255
+ sd, strict=False)
256
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
257
+ if len(missing) > 0:
258
+ print(f"Missing Keys: {missing}")
259
+ if len(unexpected) > 0:
260
+ print(f"Unexpected Keys: {unexpected}")
261
+
262
+ def q_mean_variance(self, x_start, t):
263
+ """
264
+ Get the distribution q(x_t | x_0).
265
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
266
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
267
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
268
+ """
269
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
270
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
271
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
272
+ return mean, variance, log_variance
273
+
274
+ def predict_start_from_noise(self, x_t, t, noise):
275
+ return (
276
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
277
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
278
+ )
279
+
280
+ def q_posterior(self, x_start, x_t, t):
281
+ posterior_mean = (
282
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
283
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
284
+ )
285
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
286
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
287
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
288
+
289
+ def p_mean_variance(self, x, t, clip_denoised: bool):
290
+ model_out = self.model(x, t)
291
+ if self.parameterization == "eps":
292
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
293
+ elif self.parameterization == "x0":
294
+ x_recon = model_out
295
+ if clip_denoised:
296
+ x_recon.clamp_(-1., 1.)
297
+
298
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
299
+ return model_mean, posterior_variance, posterior_log_variance
300
+
301
+ @torch.no_grad()
302
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
303
+ b, *_, device = *x.shape, x.device
304
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
305
+ noise = noise_like(x.shape, device, repeat_noise)
306
+ # no noise when t == 0
307
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
308
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
309
+
310
+ @torch.no_grad()
311
+ def p_sample_loop(self, shape, return_intermediates=False):
312
+ device = self.betas.device
313
+ b = shape[0]
314
+ img = torch.randn(shape, device=device)
315
+ intermediates = [img]
316
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
317
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
318
+ clip_denoised=self.clip_denoised)
319
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
320
+ intermediates.append(img)
321
+ if return_intermediates:
322
+ return img, intermediates
323
+ return img
324
+
325
+ @torch.no_grad()
326
+ def sample(self, batch_size=16, return_intermediates=False):
327
+ image_size = self.image_size
328
+ channels = self.channels
329
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
330
+ return_intermediates=return_intermediates)
331
+
332
+ def q_sample(self, x_start, t, noise=None):
333
+ noise = default(noise, lambda: torch.randn_like(x_start))
334
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
335
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
336
+
337
+ def get_loss(self, pred, target, mean=True):
338
+ if self.loss_type == 'l1':
339
+ loss = (target - pred).abs()
340
+ if mean:
341
+ loss = loss.mean()
342
+ elif self.loss_type == 'l2':
343
+ if mean:
344
+ loss = torch.nn.functional.mse_loss(target, pred)
345
+ else:
346
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
347
+ else:
348
+ raise NotImplementedError("unknown loss type '{loss_type}'")
349
+
350
+ return loss
351
+
352
+ def p_losses(self, x_start, t, noise=None):
353
+ noise = default(noise, lambda: torch.randn_like(x_start))
354
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
355
+ model_out = self.model(x_noisy, t)
356
+
357
+ loss_dict = {}
358
+ if self.parameterization == "eps":
359
+ target = noise
360
+ elif self.parameterization == "x0":
361
+ target = x_start
362
+ else:
363
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
364
+
365
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
366
+
367
+ log_prefix = 'train' if self.training else 'val'
368
+
369
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
370
+ loss_simple = loss.mean() * self.l_simple_weight
371
+
372
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
373
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
374
+
375
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
376
+
377
+ loss_dict.update({f'{log_prefix}/loss': loss})
378
+
379
+ return loss, loss_dict
380
+
381
+ def forward(self, x, *args, **kwargs):
382
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
383
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
384
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
385
+ return self.p_losses(x, t, *args, **kwargs)
386
+
387
+ def get_input(self, batch, k):
388
+ x = batch[k]
389
+ # if len(x.shape) == 3:
390
+ # x = x[..., None]
391
+ # x = rearrange(x, 'b h w c -> b c h w')
392
+ # x = x.to(memory_format=torch.contiguous_format).float()
393
+ return x
394
+
395
+ def shared_step(self, batch):
396
+ x = self.get_input(batch, self.first_stage_key)
397
+ loss, loss_dict = self(x)
398
+ return loss, loss_dict
399
+
400
+ def training_step(self, batch, batch_idx):
401
+ for k in self.ucg_training:
402
+ p = self.ucg_training[k]["p"]
403
+ val = self.ucg_training[k]["val"]
404
+ if val is None:
405
+ val = ""
406
+ for i in range(len(batch[k])):
407
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
408
+ batch[k][i] = val
409
+ loss, loss_dict = self.shared_step(batch)
410
+
411
+ self.log_dict(loss_dict, prog_bar=True,
412
+ logger=True, on_step=True, on_epoch=True)
413
+
414
+ self.log("global_step", self.global_step,
415
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
416
+
417
+ if self.use_scheduler:
418
+ lr = self.optimizers().param_groups[0]['lr']
419
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
420
+
421
+ return loss
422
+
423
+ @torch.no_grad()
424
+ def validation_step(self, batch, batch_idx):
425
+ _, loss_dict_no_ema = self.shared_step(batch)
426
+ with self.ema_scope():
427
+ _, loss_dict_ema = self.shared_step(batch)
428
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
429
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
430
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
431
+
432
+ def on_train_batch_end(self, *args, **kwargs):
433
+ if self.use_ema:
434
+ self.model_ema(self.model)
435
+
436
+ def _get_rows_from_list(self, samples):
437
+ n_imgs_per_row = len(samples)
438
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
439
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
440
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
441
+ return denoise_grid
442
+
443
+ @torch.no_grad()
444
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
445
+ log = dict()
446
+ x = self.get_input(batch, self.first_stage_key)
447
+ N = min(x.shape[0], N)
448
+ n_row = min(x.shape[0], n_row)
449
+ x = x.to(self.device)[:N]
450
+ log["inputs"] = x
451
+
452
+ # get diffusion row
453
+ diffusion_row = list()
454
+ x_start = x[:n_row]
455
+
456
+ for t in range(self.num_timesteps):
457
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
458
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
459
+ t = t.to(self.device).long()
460
+ noise = torch.randn_like(x_start)
461
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
462
+ diffusion_row.append(x_noisy)
463
+
464
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
465
+
466
+ if sample:
467
+ # get denoise row
468
+ with self.ema_scope("Plotting"):
469
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
470
+
471
+ log["samples"] = samples
472
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
473
+
474
+ if return_keys:
475
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
476
+ return log
477
+ else:
478
+ return {key: log[key] for key in return_keys}
479
+ return log
480
+
481
+ def configure_optimizers(self):
482
+ lr = self.learning_rate
483
+ params = list(self.model.parameters())
484
+ if self.learn_logvar:
485
+ params = params + [self.logvar]
486
+ opt = torch.optim.AdamW(params, lr=lr)
487
+ return opt
488
+
489
+
490
+ class LatentDiffusion(DDPM):
491
+ """main class"""
492
+ def __init__(self,
493
+ first_stage_config,
494
+ cond_stage_config,
495
+ num_timesteps_cond=None,
496
+ cond_stage_key="image",
497
+ cond_stage_trainable=False,
498
+ concat_mode=True,
499
+ cond_stage_forward=None,
500
+ conditioning_key=None,
501
+ scale_factor=1.0,
502
+ scale_by_std=False,
503
+ unet_trainable=True,
504
+ *args, **kwargs):
505
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
506
+ self.scale_by_std = scale_by_std
507
+ assert self.num_timesteps_cond <= kwargs['timesteps']
508
+ # for backwards compatibility after implementation of DiffusionWrapper
509
+ if conditioning_key is None:
510
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
511
+ if cond_stage_config == '__is_unconditional__':
512
+ conditioning_key = None
513
+ ckpt_path = kwargs.pop("ckpt_path", None)
514
+ ignore_keys = kwargs.pop("ignore_keys", [])
515
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
516
+ self.concat_mode = concat_mode
517
+ self.cond_stage_trainable = cond_stage_trainable
518
+ self.unet_trainable = unet_trainable
519
+ self.cond_stage_key = cond_stage_key
520
+ try:
521
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
522
+ except:
523
+ self.num_downs = 0
524
+ if not scale_by_std:
525
+ self.scale_factor = scale_factor
526
+ else:
527
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
528
+ self.instantiate_first_stage(first_stage_config)
529
+ self.instantiate_cond_stage(cond_stage_config)
530
+ self.cond_stage_forward = cond_stage_forward
531
+ self.clip_denoised = False
532
+ self.bbox_tokenizer = None
533
+
534
+ self.restarted_from_ckpt = False
535
+ if ckpt_path is not None:
536
+ self.init_from_ckpt(ckpt_path, ignore_keys)
537
+ self.restarted_from_ckpt = True
538
+
539
+ def make_cond_schedule(self, ):
540
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
541
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
542
+ self.cond_ids[:self.num_timesteps_cond] = ids
543
+
544
+ @rank_zero_only
545
+ @torch.no_grad()
546
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
547
+ # only for very first batch
548
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
549
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
550
+ # set rescale weight to 1./std of encodings
551
+ print("### USING STD-RESCALING ###")
552
+ x = super().get_input(batch, self.first_stage_key)
553
+ x = x.to(self.device)
554
+ encoder_posterior = self.encode_first_stage(x)
555
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
556
+ del self.scale_factor
557
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
558
+ print(f"setting self.scale_factor to {self.scale_factor}")
559
+ print("### USING STD-RESCALING ###")
560
+
561
+ def register_schedule(self,
562
+ given_betas=None, beta_schedule="linear", timesteps=1000,
563
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
564
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
565
+
566
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
567
+ if self.shorten_cond_schedule:
568
+ self.make_cond_schedule()
569
+
570
+ def instantiate_first_stage(self, config):
571
+ model = instantiate_from_config(config)
572
+ self.first_stage_model = model.eval()
573
+ self.first_stage_model.train = disabled_train
574
+ for param in self.first_stage_model.parameters():
575
+ param.requires_grad = False
576
+
577
+ def instantiate_cond_stage(self, config):
578
+ if not self.cond_stage_trainable:
579
+ if config == "__is_first_stage__":
580
+ print("Using first stage also as cond stage.")
581
+ self.cond_stage_model = self.first_stage_model
582
+ elif config == "__is_unconditional__":
583
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
584
+ self.cond_stage_model = None
585
+ # self.be_unconditional = True
586
+ else:
587
+ model = instantiate_from_config(config)
588
+ self.cond_stage_model = model.eval()
589
+ # self.cond_stage_model.train = disabled_train
590
+ for param in self.cond_stage_model.parameters():
591
+ param.requires_grad = False
592
+ else:
593
+ assert config != '__is_first_stage__'
594
+ assert config != '__is_unconditional__'
595
+ model = instantiate_from_config(config)
596
+ self.cond_stage_model = model
597
+
598
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
599
+ denoise_row = []
600
+ for zd in tqdm(samples, desc=desc):
601
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
602
+ force_not_quantize=force_no_decoder_quantization))
603
+ n_imgs_per_row = len(denoise_row)
604
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
605
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
606
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
607
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
608
+ return denoise_grid
609
+
610
+ def get_first_stage_encoding(self, encoder_posterior):
611
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
612
+ z = encoder_posterior.sample()
613
+ elif isinstance(encoder_posterior, torch.Tensor):
614
+ z = encoder_posterior
615
+ else:
616
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
617
+ return self.scale_factor * z
618
+
619
+ def get_learned_conditioning(self, c):
620
+ if self.cond_stage_forward is None:
621
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
622
+ c = self.cond_stage_model.encode(c)
623
+ if isinstance(c, DiagonalGaussianDistribution):
624
+ c = c.mode()
625
+ else:
626
+ c = self.cond_stage_model(c)
627
+ else:
628
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
629
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
630
+ return c
631
+
632
+ def meshgrid(self, h, w):
633
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
634
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
635
+
636
+ arr = torch.cat([y, x], dim=-1)
637
+ return arr
638
+
639
+ def delta_border(self, h, w):
640
+ """
641
+ :param h: height
642
+ :param w: width
643
+ :return: normalized distance to image border,
644
+ wtith min distance = 0 at border and max dist = 0.5 at image center
645
+ """
646
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
647
+ arr = self.meshgrid(h, w) / lower_right_corner
648
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
649
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
650
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
651
+ return edge_dist
652
+
653
+ def get_weighting(self, h, w, Ly, Lx, device):
654
+ weighting = self.delta_border(h, w)
655
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
656
+ self.split_input_params["clip_max_weight"], )
657
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
658
+
659
+ if self.split_input_params["tie_braker"]:
660
+ L_weighting = self.delta_border(Ly, Lx)
661
+ L_weighting = torch.clip(L_weighting,
662
+ self.split_input_params["clip_min_tie_weight"],
663
+ self.split_input_params["clip_max_tie_weight"])
664
+
665
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
666
+ weighting = weighting * L_weighting
667
+ return weighting
668
+
669
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
670
+ """
671
+ :param x: img of size (bs, c, h, w)
672
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
673
+ """
674
+ bs, nc, h, w = x.shape
675
+
676
+ # number of crops in image
677
+ Ly = (h - kernel_size[0]) // stride[0] + 1
678
+ Lx = (w - kernel_size[1]) // stride[1] + 1
679
+
680
+ if uf == 1 and df == 1:
681
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
682
+ unfold = torch.nn.Unfold(**fold_params)
683
+
684
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
685
+
686
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
687
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
688
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
689
+
690
+ elif uf > 1 and df == 1:
691
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
692
+ unfold = torch.nn.Unfold(**fold_params)
693
+
694
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
695
+ dilation=1, padding=0,
696
+ stride=(stride[0] * uf, stride[1] * uf))
697
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
698
+
699
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
700
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
701
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
702
+
703
+ elif df > 1 and uf == 1:
704
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
705
+ unfold = torch.nn.Unfold(**fold_params)
706
+
707
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
708
+ dilation=1, padding=0,
709
+ stride=(stride[0] // df, stride[1] // df))
710
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
711
+
712
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
713
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
714
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
715
+
716
+ else:
717
+ raise NotImplementedError
718
+
719
+ return fold, unfold, normalization, weighting
720
+
721
+ @torch.no_grad()
722
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
723
+ cond_key=None, return_original_cond=False, bs=None, uncond=0.05, return_x=False):
724
+ x = super().get_input(batch, k)
725
+ if bs is not None:
726
+ x = x[:bs]
727
+ x = x.to(self.device)
728
+ encoder_posterior = self.encode_first_stage(x)
729
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
730
+
731
+ # if self.model.conditioning_key is not None:
732
+ # if cond_key is None:
733
+ # cond_key = self.cond_stage_key
734
+ # if cond_key != self.first_stage_key:
735
+ # if cond_key in ['caption', 'coordinates_bbox']:
736
+ # if cond_key in ['caption', 'coordinates_bbox', "txt"]:
737
+ # xc = batch[cond_key]
738
+ # elif cond_key == 'class_label':
739
+ # xc = batch
740
+ # else:
741
+ # xc = super().get_input(batch, cond_key).to(self.device)
742
+ # else:
743
+ # xc = x
744
+ # if not self.cond_stage_trainable or force_c_encode:
745
+ # if isinstance(xc, dict) or isinstance(xc, list):
746
+ # # import pudb; pudb.set_trace()
747
+ # c = self.get_learned_conditioning(xc)
748
+ # else:
749
+ # c = self.get_learned_conditioning(xc.to(self.device))
750
+ # else:
751
+ # c = xc
752
+ # if bs is not None:
753
+ # c = c[:bs]
754
+ # if self.use_positional_encodings:
755
+ # pos_x, pos_y = self.compute_latent_shifts(batch)
756
+ # ckey = __conditioning_keys__[self.model.conditioning_key]
757
+ # c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
758
+ # else:
759
+ # c = None
760
+ # xc = None
761
+ # if self.use_positional_encodings:
762
+ # pos_x, pos_y = self.compute_latent_shifts(batch)
763
+ # c = {'pos_x': pos_x, 'pos_y': pos_y}
764
+ # out = [z, c]
765
+
766
+ if cond_key is None:
767
+ cond_key = self.cond_stage_key
768
+ xc = super().get_input(batch, cond_key)
769
+ if bs is not None: # bs is None
770
+ xc["c_crossattn"] = xc["c_crossattn"][:bs]
771
+
772
+ # To support classifier-free guidance, randomly drop out only text conditioning 5%,
773
+ # only image conditioning 5%, and both 5%.
774
+ random = torch.rand(x.size(0), device=x.device)
775
+ prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
776
+
777
+ # LatentDiffusion.get_learned_conditioning
778
+ # null_prompt shape [1, 77, 768], 77 is the maximum token number, 768 is the hidden_state_dim
779
+ null_prompt = self.get_learned_conditioning([""])
780
+ cond = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
781
+
782
+ out = [z, cond]
783
+ if return_first_stage_outputs:
784
+ xrec = self.decode_first_stage(z)
785
+ out.extend([x, xrec])
786
+ if return_x:
787
+ out.extend([x])
788
+ if return_original_cond:
789
+ out.append(xc)
790
+ return out
791
+
792
+ @torch.no_grad()
793
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
794
+ if predict_cids:
795
+ if z.dim() == 4:
796
+ z = torch.argmax(z.exp(), dim=1).long()
797
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
798
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
799
+
800
+ z = 1. / self.scale_factor * z
801
+
802
+ if hasattr(self, "split_input_params"):
803
+ if self.split_input_params["patch_distributed_vq"]:
804
+ ks = self.split_input_params["ks"] # eg. (128, 128)
805
+ stride = self.split_input_params["stride"] # eg. (64, 64)
806
+ uf = self.split_input_params["vqf"]
807
+ bs, nc, h, w = z.shape
808
+ if ks[0] > h or ks[1] > w:
809
+ ks = (min(ks[0], h), min(ks[1], w))
810
+ print("reducing Kernel")
811
+
812
+ if stride[0] > h or stride[1] > w:
813
+ stride = (min(stride[0], h), min(stride[1], w))
814
+ print("reducing stride")
815
+
816
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
817
+
818
+ z = unfold(z) # (bn, nc * prod(**ks), L)
819
+ # 1. Reshape to img shape
820
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
821
+
822
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
823
+ for i in range(z.shape[-1])]
824
+
825
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
826
+ o = o * weighting
827
+ # Reverse 1. reshape to img shape
828
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
829
+ # stitch crops together
830
+ decoded = fold(o)
831
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
832
+ return decoded
833
+ else:
834
+ return self.first_stage_model.decode(z)
835
+
836
+ else:
837
+ return self.first_stage_model.decode(z)
838
+
839
+ @torch.no_grad()
840
+ def encode_first_stage(self, x):
841
+ if hasattr(self, "split_input_params"):
842
+ if self.split_input_params["patch_distributed_vq"]:
843
+ ks = self.split_input_params["ks"] # eg. (128, 128)
844
+ stride = self.split_input_params["stride"] # eg. (64, 64)
845
+ df = self.split_input_params["vqf"]
846
+ self.split_input_params['original_image_size'] = x.shape[-2:]
847
+ bs, nc, h, w = x.shape
848
+ if ks[0] > h or ks[1] > w:
849
+ ks = (min(ks[0], h), min(ks[1], w))
850
+ print("reducing Kernel")
851
+
852
+ if stride[0] > h or stride[1] > w:
853
+ stride = (min(stride[0], h), min(stride[1], w))
854
+ print("reducing stride")
855
+
856
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
857
+ z = unfold(x) # (bn, nc * prod(**ks), L)
858
+ # Reshape to img shape
859
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
860
+
861
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
862
+ for i in range(z.shape[-1])]
863
+
864
+ o = torch.stack(output_list, axis=-1)
865
+ o = o * weighting
866
+
867
+ # Reverse reshape to img shape
868
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
869
+ # stitch crops together
870
+ decoded = fold(o)
871
+ decoded = decoded / normalization
872
+ return decoded
873
+
874
+ else:
875
+ return self.first_stage_model.encode(x)
876
+ else:
877
+ return self.first_stage_model.encode(x)
878
+
879
+ def shared_step(self, batch, **kwargs):
880
+ x, c = self.get_input(batch, self.first_stage_key)
881
+ loss = self(x, c)
882
+ return loss
883
+
884
+ def forward(self, x, c, *args, **kwargs):
885
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
886
+ if self.model.conditioning_key is not None:
887
+ assert c is not None
888
+ if self.cond_stage_trainable:
889
+ c = self.get_learned_conditioning(c)
890
+ if self.shorten_cond_schedule: # TODO: drop this option
891
+ tc = self.cond_ids[t].to(self.device)
892
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
893
+ return self.p_losses(x, c, t, *args, **kwargs)
894
+
895
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
896
+ def rescale_bbox(bbox):
897
+ x0 = torch.clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
898
+ y0 = torch.clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
899
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
900
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
901
+ return x0, y0, w, h
902
+
903
+ return [rescale_bbox(b) for b in bboxes]
904
+
905
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
906
+
907
+ if isinstance(cond, dict):
908
+ # hybrid case, cond is exptected to be a dict
909
+ pass
910
+ else:
911
+ if not isinstance(cond, list):
912
+ cond = [cond]
913
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
914
+ cond = {key: cond}
915
+
916
+ if hasattr(self, "split_input_params"):
917
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
918
+ assert not return_ids
919
+ ks = self.split_input_params["ks"] # eg. (128, 128)
920
+ stride = self.split_input_params["stride"] # eg. (64, 64)
921
+
922
+ h, w = x_noisy.shape[-2:]
923
+
924
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
925
+
926
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
927
+ # Reshape to img shape
928
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
929
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
930
+
931
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
932
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
933
+ c_key = next(iter(cond.keys())) # get key
934
+ c = next(iter(cond.values())) # get value
935
+ assert (len(c) == 1) # todo extend to list with more than one elem
936
+ c = c[0] # get element
937
+
938
+ c = unfold(c)
939
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
940
+
941
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
942
+
943
+ elif self.cond_stage_key == 'coordinates_bbox':
944
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
945
+
946
+ # assuming padding of unfold is always 0 and its dilation is always 1
947
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
948
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
949
+ # as we are operating on latents, we need the factor from the original image size to the
950
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
951
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
952
+ rescale_latent = 2 ** (num_downs)
953
+
954
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
955
+ # need to rescale the tl patch coordinates to be in between (0,1)
956
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
957
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
958
+ for patch_nr in range(z.shape[-1])]
959
+
960
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
961
+ patch_limits = [(x_tl, y_tl,
962
+ rescale_latent * ks[0] / full_img_w,
963
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
964
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
965
+
966
+ # tokenize crop coordinates for the bounding boxes of the respective patches
967
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
968
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
969
+ print(patch_limits_tknzd[0].shape)
970
+ # cut tknzd crop position from conditioning
971
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
972
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
973
+ print(cut_cond.shape)
974
+
975
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
976
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
977
+ print(adapted_cond.shape)
978
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
979
+ print(adapted_cond.shape)
980
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
981
+ print(adapted_cond.shape)
982
+
983
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
984
+
985
+ else:
986
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
987
+
988
+ # apply model by loop over crops
989
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
990
+ assert not isinstance(output_list[0],
991
+ tuple) # todo cant deal with multiple model outputs check this never happens
992
+
993
+ o = torch.stack(output_list, axis=-1)
994
+ o = o * weighting
995
+ # Reverse reshape to img shape
996
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
997
+ # stitch crops together
998
+ x_recon = fold(o) / normalization
999
+
1000
+ else:
1001
+ x_recon = self.model(x_noisy, t, **cond)
1002
+
1003
+ if isinstance(x_recon, tuple) and not return_ids:
1004
+ return x_recon[0]
1005
+ else:
1006
+ return x_recon
1007
+
1008
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1009
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1010
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1011
+
1012
+ def _prior_bpd(self, x_start):
1013
+ """
1014
+ Get the prior KL term for the variational lower-bound, measured in
1015
+ bits-per-dim.
1016
+ This term can't be optimized, as it only depends on the encoder.
1017
+ :param x_start: the [N x C x ...] tensor of inputs.
1018
+ :return: a batch of [N] KL values (in bits), one per batch element.
1019
+ """
1020
+ batch_size = x_start.shape[0]
1021
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1022
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1023
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1024
+ return mean_flat(kl_prior) / np.log(2.0)
1025
+
1026
+ def p_losses(self, x_start, cond, t, noise=None):
1027
+ noise = default(noise, lambda: torch.randn_like(x_start))
1028
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1029
+ model_output = self.apply_model(x_noisy, t, cond)
1030
+
1031
+ loss_dict = {}
1032
+ prefix = 'train' if self.training else 'val'
1033
+
1034
+ if self.parameterization == "x0":
1035
+ target = x_start
1036
+ elif self.parameterization == "eps":
1037
+ target = noise
1038
+ else:
1039
+ raise NotImplementedError()
1040
+
1041
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1042
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1043
+
1044
+ # logvar_t = self.logvar[t].to(self.device)
1045
+ device = x_start.device
1046
+ t = t.to(device).long()
1047
+ logvar_t = self.logvar.to(device)[t]
1048
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1049
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1050
+ if self.learn_logvar:
1051
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1052
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1053
+
1054
+ loss = self.l_simple_weight * loss.mean()
1055
+
1056
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1057
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1058
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1059
+ loss += (self.original_elbo_weight * loss_vlb)
1060
+ loss_dict.update({f'{prefix}/loss': loss})
1061
+
1062
+ return loss, loss_dict
1063
+
1064
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1065
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1066
+ t_in = t
1067
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1068
+
1069
+ if score_corrector is not None:
1070
+ assert self.parameterization == "eps"
1071
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1072
+
1073
+ if return_codebook_ids:
1074
+ model_out, logits = model_out
1075
+
1076
+ if self.parameterization == "eps":
1077
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1078
+ elif self.parameterization == "x0":
1079
+ x_recon = model_out
1080
+ else:
1081
+ raise NotImplementedError()
1082
+
1083
+ if clip_denoised:
1084
+ x_recon.clamp_(-1., 1.)
1085
+ if quantize_denoised:
1086
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1087
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1088
+ if return_codebook_ids:
1089
+ return model_mean, posterior_variance, posterior_log_variance, logits
1090
+ elif return_x0:
1091
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1092
+ else:
1093
+ return model_mean, posterior_variance, posterior_log_variance
1094
+
1095
+ @torch.no_grad()
1096
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1097
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1098
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1099
+ b, *_, device = *x.shape, x.device
1100
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1101
+ return_codebook_ids=return_codebook_ids,
1102
+ quantize_denoised=quantize_denoised,
1103
+ return_x0=return_x0,
1104
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1105
+ if return_codebook_ids:
1106
+ raise DeprecationWarning("Support dropped.")
1107
+ model_mean, _, model_log_variance, logits = outputs
1108
+ elif return_x0:
1109
+ model_mean, _, model_log_variance, x0 = outputs
1110
+ else:
1111
+ model_mean, _, model_log_variance = outputs
1112
+
1113
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1114
+ if noise_dropout > 0.:
1115
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1116
+ # no noise when t == 0
1117
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1118
+
1119
+ if return_codebook_ids:
1120
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1121
+ if return_x0:
1122
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1123
+ else:
1124
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1125
+
1126
+ @torch.no_grad()
1127
+ def p_sample_edit(self, x, c, t, clip_denoised=False, repeat_noise=False,
1128
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1129
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1130
+ b, *_, device = *x.shape, x.device
1131
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1132
+ return_codebook_ids=return_codebook_ids,
1133
+ quantize_denoised=quantize_denoised,
1134
+ return_x0=return_x0,
1135
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1136
+ if return_codebook_ids:
1137
+ raise DeprecationWarning("Support dropped.")
1138
+ model_mean, _, model_log_variance, logits = outputs
1139
+ elif return_x0:
1140
+ model_mean, _, model_log_variance, x0 = outputs
1141
+ else:
1142
+ model_mean, _, model_log_variance = outputs
1143
+
1144
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1145
+ if noise_dropout > 0.:
1146
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1147
+ # no noise when t == 0
1148
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1149
+
1150
+ if return_codebook_ids:
1151
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1152
+ if return_x0:
1153
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1154
+ else:
1155
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, noise
1156
+
1157
+ @torch.no_grad()
1158
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1159
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1160
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1161
+ log_every_t=None):
1162
+ if not log_every_t:
1163
+ log_every_t = self.log_every_t
1164
+ timesteps = self.num_timesteps
1165
+ if batch_size is not None:
1166
+ b = batch_size if batch_size is not None else shape[0]
1167
+ shape = [batch_size] + list(shape)
1168
+ else:
1169
+ b = batch_size = shape[0]
1170
+ if x_T is None:
1171
+ img = torch.randn(shape, device=self.device)
1172
+ else:
1173
+ img = x_T
1174
+ intermediates = []
1175
+ if cond is not None:
1176
+ if isinstance(cond, dict):
1177
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1178
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1179
+ else:
1180
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1181
+
1182
+ if start_T is not None:
1183
+ timesteps = min(timesteps, start_T)
1184
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1185
+ total=timesteps) if verbose else reversed(
1186
+ range(0, timesteps))
1187
+ if type(temperature) == float:
1188
+ temperature = [temperature] * timesteps
1189
+
1190
+ for i in iterator:
1191
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1192
+ if self.shorten_cond_schedule:
1193
+ assert self.model.conditioning_key != 'hybrid'
1194
+ tc = self.cond_ids[ts].to(cond.device)
1195
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1196
+
1197
+ img, x0_partial = self.p_sample(img, cond, ts,
1198
+ clip_denoised=self.clip_denoised,
1199
+ quantize_denoised=quantize_denoised, return_x0=True,
1200
+ temperature=temperature[i], noise_dropout=noise_dropout,
1201
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1202
+ if mask is not None:
1203
+ assert x0 is not None
1204
+ img_orig = self.q_sample(x0, ts)
1205
+ img = img_orig * mask + (1. - mask) * img
1206
+
1207
+ if i % log_every_t == 0 or i == timesteps - 1:
1208
+ intermediates.append(x0_partial)
1209
+ if callback: callback(i)
1210
+ if img_callback: img_callback(img, i)
1211
+ return img, intermediates
1212
+
1213
+ @torch.no_grad()
1214
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1215
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1216
+ mask=None, x0=None, img_callback=None, start_T=None,
1217
+ log_every_t=None, till_T=None):
1218
+
1219
+ if not log_every_t:
1220
+ log_every_t = self.log_every_t
1221
+ device = self.betas.device
1222
+ b = shape[0]
1223
+ if x_T is None:
1224
+ img = torch.randn(shape, device=device)
1225
+ else:
1226
+ img = x_T
1227
+
1228
+ intermediates = [img]
1229
+ if timesteps is None:
1230
+ timesteps = self.num_timesteps
1231
+
1232
+ if start_T is not None:
1233
+ timesteps = min(timesteps, start_T)
1234
+
1235
+ if till_T is not None:
1236
+ till = till_T
1237
+ else:
1238
+ till = 0
1239
+ iterator = tqdm(reversed(range(till, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1240
+ range(till, timesteps))
1241
+
1242
+ if mask is not None:
1243
+ assert x0 is not None
1244
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1245
+
1246
+ for i in iterator:
1247
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1248
+ if self.shorten_cond_schedule:
1249
+ assert self.model.conditioning_key != 'hybrid'
1250
+ tc = self.cond_ids[ts].to(cond.device)
1251
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1252
+
1253
+ img = self.p_sample(img, cond, ts,
1254
+ clip_denoised=self.clip_denoised,
1255
+ quantize_denoised=quantize_denoised)
1256
+ if mask is not None:
1257
+ img_orig = self.q_sample(x0, ts)
1258
+ img = img_orig * mask + (1. - mask) * img
1259
+
1260
+ if i % log_every_t == 0 or i == timesteps - 1:
1261
+ intermediates.append(img)
1262
+ if callback: callback(i)
1263
+ if img_callback: img_callback(img, i)
1264
+
1265
+ if return_intermediates:
1266
+ return img, intermediates
1267
+ return img
1268
+
1269
+ @torch.no_grad()
1270
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1271
+ verbose=True, timesteps=None, quantize_denoised=False,
1272
+ mask=None, x0=None, till_T=None, shape=None,**kwargs):
1273
+ if shape is None:
1274
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1275
+ if cond is not None:
1276
+ if isinstance(cond, dict):
1277
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1278
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1279
+ else:
1280
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1281
+ return self.p_sample_loop(cond,
1282
+ shape,
1283
+ return_intermediates=return_intermediates, x_T=x_T,
1284
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1285
+ mask=mask, x0=x0, till_T=till_T)
1286
+
1287
+ @torch.no_grad()
1288
+ def sample_log(self, cond, batch_size,ddim, ddim_steps, **kwargs):
1289
+
1290
+ if ddim:
1291
+ ddim_sampler = DDIMSampler(self)
1292
+ shape = (self.channels, self.image_size, self.image_size)
1293
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1294
+ shape,cond,verbose=False,**kwargs)
1295
+
1296
+ else:
1297
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1298
+ return_intermediates=True,**kwargs)
1299
+
1300
+ return samples, intermediates
1301
+
1302
+ @torch.no_grad()
1303
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1304
+ if null_label is not None:
1305
+ xc = null_label
1306
+ if isinstance(xc, ListConfig):
1307
+ xc = list(xc)
1308
+ if isinstance(xc, dict) or isinstance(xc, list):
1309
+ c = self.get_learned_conditioning(xc)
1310
+ else:
1311
+ if hasattr(xc, "to"):
1312
+ xc = xc.to(self.device)
1313
+ c = self.get_learned_conditioning(xc)
1314
+ else:
1315
+ # todo: get null label from cond_stage_model
1316
+ raise NotImplementedError()
1317
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1318
+ return c
1319
+
1320
+ @torch.no_grad()
1321
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1322
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1323
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1324
+ use_ema_scope=True,
1325
+ **kwargs):
1326
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1327
+
1328
+ use_ddim = ddim_steps is not None
1329
+
1330
+ log = dict()
1331
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1332
+ return_first_stage_outputs=True,
1333
+ force_c_encode=True,
1334
+ return_original_cond=True,
1335
+ bs=N)
1336
+ N = min(x.shape[0], N)
1337
+ n_row = min(x.shape[0], n_row)
1338
+ log["inputs"] = x
1339
+ log["reconstruction"] = xrec
1340
+ if self.model.conditioning_key is not None:
1341
+ if hasattr(self.cond_stage_model, "decode"):
1342
+ xc = self.cond_stage_model.decode(c)
1343
+ log["conditioning"] = xc
1344
+ elif self.cond_stage_key in ["caption", "txt"]:
1345
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1346
+ log["conditioning"] = xc
1347
+ elif self.cond_stage_key == 'class_label':
1348
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
1349
+ log['conditioning'] = xc
1350
+ elif isimage(xc):
1351
+ log["conditioning"] = xc
1352
+ if ismap(xc):
1353
+ log["original_conditioning"] = self.to_rgb(xc)
1354
+
1355
+ if plot_diffusion_rows:
1356
+ # get diffusion row
1357
+ diffusion_row = list()
1358
+ z_start = z[:n_row]
1359
+ for t in range(self.num_timesteps):
1360
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1361
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1362
+ t = t.to(self.device).long()
1363
+ noise = torch.randn_like(z_start)
1364
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1365
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1366
+
1367
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1368
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1369
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1370
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1371
+ log["diffusion_row"] = diffusion_grid
1372
+
1373
+ if sample:
1374
+ # get denoise row
1375
+ with ema_scope("Sampling"):
1376
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1377
+ ddim_steps=ddim_steps,eta=ddim_eta)
1378
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1379
+ x_samples = self.decode_first_stage(samples)
1380
+ log["samples"] = x_samples
1381
+ if plot_denoise_rows:
1382
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1383
+ log["denoise_row"] = denoise_grid
1384
+
1385
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1386
+ self.first_stage_model, IdentityFirstStage):
1387
+ # also display when quantizing x0 while sampling
1388
+ with ema_scope("Plotting Quantized Denoised"):
1389
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1390
+ ddim_steps=ddim_steps,eta=ddim_eta,
1391
+ quantize_denoised=True)
1392
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1393
+ # quantize_denoised=True)
1394
+ x_samples = self.decode_first_stage(samples.to(self.device))
1395
+ log["samples_x0_quantized"] = x_samples
1396
+
1397
+ if unconditional_guidance_scale > 1.0:
1398
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1399
+ # uc = torch.zeros_like(c)
1400
+ with ema_scope("Sampling with classifier-free guidance"):
1401
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1402
+ ddim_steps=ddim_steps, eta=ddim_eta,
1403
+ unconditional_guidance_scale=unconditional_guidance_scale,
1404
+ unconditional_conditioning=uc,
1405
+ )
1406
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1407
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1408
+
1409
+ if inpaint:
1410
+ # make a simple center square
1411
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1412
+ mask = torch.ones(N, h, w).to(self.device)
1413
+ # zeros will be filled in
1414
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1415
+ mask = mask[:, None, ...]
1416
+ with ema_scope("Plotting Inpaint"):
1417
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1418
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1419
+ x_samples = self.decode_first_stage(samples.to(self.device))
1420
+ log["samples_inpainting"] = x_samples
1421
+ log["mask"] = mask
1422
+
1423
+ # outpaint
1424
+ mask = 1. - mask
1425
+ with ema_scope("Plotting Outpaint"):
1426
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1427
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1428
+ x_samples = self.decode_first_stage(samples.to(self.device))
1429
+ log["samples_outpainting"] = x_samples
1430
+
1431
+ if plot_progressive_rows:
1432
+ with ema_scope("Plotting Progressives"):
1433
+ img, progressives = self.progressive_denoising(c,
1434
+ shape=(self.channels, self.image_size, self.image_size),
1435
+ batch_size=N)
1436
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1437
+ log["progressive_row"] = prog_row
1438
+
1439
+ if return_keys:
1440
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1441
+ return log
1442
+ else:
1443
+ return {key: log[key] for key in return_keys}
1444
+ return log
1445
+
1446
+ def configure_optimizers(self):
1447
+ lr = self.learning_rate
1448
+ params = []
1449
+ if self.unet_trainable == "attn":
1450
+ print("Training only unet attention layers")
1451
+ for n, m in self.model.named_modules():
1452
+ if isinstance(m, CrossAttention) and n.endswith('attn2'):
1453
+ params.extend(m.parameters())
1454
+ elif self.unet_trainable is True or self.unet_trainable == "all":
1455
+ print("Training the full unet")
1456
+ params = list(self.model.parameters())
1457
+ else:
1458
+ raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}")
1459
+
1460
+ if self.cond_stage_trainable:
1461
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1462
+ params = params + list(self.cond_stage_model.parameters())
1463
+ if self.learn_logvar:
1464
+ print('Diffusion model optimizing logvar')
1465
+ params.append(self.logvar)
1466
+ opt = torch.optim.AdamW(params, lr=lr)
1467
+ if self.use_scheduler:
1468
+ assert 'target' in self.scheduler_config
1469
+ scheduler = instantiate_from_config(self.scheduler_config)
1470
+
1471
+ print("Setting up LambdaLR scheduler...")
1472
+ scheduler = [
1473
+ {
1474
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1475
+ 'interval': 'step',
1476
+ 'frequency': 1
1477
+ }]
1478
+ return [opt], scheduler
1479
+ return opt
1480
+
1481
+ @torch.no_grad()
1482
+ def to_rgb(self, x):
1483
+ x = x.float()
1484
+ if not hasattr(self, "colorize"):
1485
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1486
+ x = nn.functional.conv2d(x, weight=self.colorize)
1487
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1488
+ return x
1489
+
1490
+
1491
+ class DiffusionWrapper(pl.LightningModule):
1492
+ def __init__(self, diff_model_config, conditioning_key):
1493
+ super().__init__()
1494
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1495
+ self.conditioning_key = conditioning_key
1496
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
1497
+
1498
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1499
+ if self.conditioning_key is None:
1500
+ out = self.diffusion_model(x, t)
1501
+ elif self.conditioning_key == 'concat':
1502
+ xc = torch.cat([x] + c_concat, dim=1)
1503
+ out = self.diffusion_model(xc, t)
1504
+ elif self.conditioning_key == 'crossattn':
1505
+ cc = torch.cat(c_crossattn, 1)
1506
+ out = self.diffusion_model(x, t, context=cc)
1507
+ elif self.conditioning_key == 'hybrid':
1508
+ xc = torch.cat([x] + c_concat, dim=1)
1509
+ cc = torch.cat(c_crossattn, 1)
1510
+ out = self.diffusion_model(xc, t, context=cc)
1511
+ elif self.conditioning_key == 'hybrid-adm':
1512
+ assert c_adm is not None
1513
+ xc = torch.cat([x] + c_concat, dim=1)
1514
+ cc = torch.cat(c_crossattn, 1)
1515
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1516
+ elif self.conditioning_key == 'adm':
1517
+ cc = c_crossattn[0]
1518
+ out = self.diffusion_model(x, t, y=cc)
1519
+ else:
1520
+ raise NotImplementedError()
1521
+
1522
+ return out
1523
+
1524
+
1525
+ class LatentUpscaleDiffusion(LatentDiffusion):
1526
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
1527
+ super().__init__(*args, **kwargs)
1528
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1529
+ assert not self.cond_stage_trainable
1530
+ self.instantiate_low_stage(low_scale_config)
1531
+ self.low_scale_key = low_scale_key
1532
+
1533
+ def instantiate_low_stage(self, config):
1534
+ model = instantiate_from_config(config)
1535
+ self.low_scale_model = model.eval()
1536
+ self.low_scale_model.train = disabled_train
1537
+ for param in self.low_scale_model.parameters():
1538
+ param.requires_grad = False
1539
+
1540
+ @torch.no_grad()
1541
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1542
+ if not log_mode:
1543
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1544
+ else:
1545
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1546
+ force_c_encode=True, return_original_cond=True, bs=bs)
1547
+ x_low = batch[self.low_scale_key][:bs]
1548
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
1549
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1550
+ zx, noise_level = self.low_scale_model(x_low)
1551
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1552
+ #import pudb; pu.db
1553
+ if log_mode:
1554
+ # TODO: maybe disable if too expensive
1555
+ interpretability = False
1556
+ if interpretability:
1557
+ zx = zx[:, :, ::2, ::2]
1558
+ x_low_rec = self.low_scale_model.decode(zx)
1559
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1560
+ return z, all_conds
1561
+
1562
+ @torch.no_grad()
1563
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1564
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1565
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1566
+ **kwargs):
1567
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1568
+ use_ddim = ddim_steps is not None
1569
+
1570
+ log = dict()
1571
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1572
+ log_mode=True)
1573
+ N = min(x.shape[0], N)
1574
+ n_row = min(x.shape[0], n_row)
1575
+ log["inputs"] = x
1576
+ log["reconstruction"] = xrec
1577
+ log["x_lr"] = x_low
1578
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1579
+ if self.model.conditioning_key is not None:
1580
+ if hasattr(self.cond_stage_model, "decode"):
1581
+ xc = self.cond_stage_model.decode(c)
1582
+ log["conditioning"] = xc
1583
+ elif self.cond_stage_key in ["caption", "txt"]:
1584
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
1585
+ log["conditioning"] = xc
1586
+ elif self.cond_stage_key == 'class_label':
1587
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
1588
+ log['conditioning'] = xc
1589
+ elif isimage(xc):
1590
+ log["conditioning"] = xc
1591
+ if ismap(xc):
1592
+ log["original_conditioning"] = self.to_rgb(xc)
1593
+
1594
+ if plot_diffusion_rows:
1595
+ # get diffusion row
1596
+ diffusion_row = list()
1597
+ z_start = z[:n_row]
1598
+ for t in range(self.num_timesteps):
1599
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1600
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1601
+ t = t.to(self.device).long()
1602
+ noise = torch.randn_like(z_start)
1603
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1604
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1605
+
1606
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1607
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1608
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1609
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1610
+ log["diffusion_row"] = diffusion_grid
1611
+
1612
+ if sample:
1613
+ # get denoise row
1614
+ with ema_scope("Sampling"):
1615
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1616
+ ddim_steps=ddim_steps, eta=ddim_eta)
1617
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1618
+ x_samples = self.decode_first_stage(samples)
1619
+ log["samples"] = x_samples
1620
+ if plot_denoise_rows:
1621
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1622
+ log["denoise_row"] = denoise_grid
1623
+
1624
+ if unconditional_guidance_scale > 1.0:
1625
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1626
+ # TODO explore better "unconditional" choices for the other keys
1627
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1628
+ uc = dict()
1629
+ for k in c:
1630
+ if k == "c_crossattn":
1631
+ assert isinstance(c[k], list) and len(c[k]) == 1
1632
+ uc[k] = [uc_tmp]
1633
+ elif k == "c_adm": # todo: only run with text-based guidance?
1634
+ assert isinstance(c[k], torch.Tensor)
1635
+ uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1636
+ elif isinstance(c[k], list):
1637
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1638
+ else:
1639
+ uc[k] = c[k]
1640
+
1641
+ with ema_scope("Sampling with classifier-free guidance"):
1642
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1643
+ ddim_steps=ddim_steps, eta=ddim_eta,
1644
+ unconditional_guidance_scale=unconditional_guidance_scale,
1645
+ unconditional_conditioning=uc,
1646
+ )
1647
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1648
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1649
+
1650
+ if plot_progressive_rows:
1651
+ with ema_scope("Plotting Progressives"):
1652
+ img, progressives = self.progressive_denoising(c,
1653
+ shape=(self.channels, self.image_size, self.image_size),
1654
+ batch_size=N)
1655
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1656
+ log["progressive_row"] = prog_row
1657
+
1658
+ return log
1659
+
1660
+
1661
+ class LatentInpaintDiffusion(LatentDiffusion):
1662
+ """
1663
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1664
+ e.g. mask as concat and text via cross-attn.
1665
+ To disable finetuning mode, set finetune_keys to None
1666
+ """
1667
+ def __init__(self,
1668
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1669
+ "model_ema.diffusion_modelinput_blocks00weight"
1670
+ ),
1671
+ concat_keys=("mask", "masked_image"),
1672
+ masked_image_key="masked_image",
1673
+ keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
1674
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1675
+ c_concat_log_end=None,
1676
+ *args, **kwargs
1677
+ ):
1678
+ ckpt_path = kwargs.pop("ckpt_path", None)
1679
+ ignore_keys = kwargs.pop("ignore_keys", list())
1680
+ super().__init__(*args, **kwargs)
1681
+ self.masked_image_key = masked_image_key
1682
+ assert self.masked_image_key in concat_keys
1683
+ self.finetune_keys = finetune_keys
1684
+ self.concat_keys = concat_keys
1685
+ self.keep_dims = keep_finetune_dims
1686
+ self.c_concat_log_start = c_concat_log_start
1687
+ self.c_concat_log_end = c_concat_log_end
1688
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1689
+ if exists(ckpt_path):
1690
+ self.init_from_ckpt(ckpt_path, ignore_keys)
1691
+
1692
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1693
+ sd = torch.load(path, map_location="cpu")
1694
+ if "state_dict" in list(sd.keys()):
1695
+ sd = sd["state_dict"]
1696
+ keys = list(sd.keys())
1697
+ for k in keys:
1698
+ for ik in ignore_keys:
1699
+ if k.startswith(ik):
1700
+ print("Deleting key {} from state_dict.".format(k))
1701
+ del sd[k]
1702
+
1703
+ # make it explicit, finetune by including extra input channels
1704
+ if exists(self.finetune_keys) and k in self.finetune_keys:
1705
+ new_entry = None
1706
+ for name, param in self.named_parameters():
1707
+ if name in self.finetune_keys:
1708
+ print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1709
+ new_entry = torch.zeros_like(param) # zero init
1710
+ assert exists(new_entry), 'did not find matching parameter to modify'
1711
+ new_entry[:, :self.keep_dims, ...] = sd[k]
1712
+ sd[k] = new_entry
1713
+
1714
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
1715
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1716
+ if len(missing) > 0:
1717
+ print(f"Missing Keys: {missing}")
1718
+ if len(unexpected) > 0:
1719
+ print(f"Unexpected Keys: {unexpected}")
1720
+
1721
+ @torch.no_grad()
1722
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1723
+ # note: restricted to non-trainable encoders currently
1724
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1725
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1726
+ force_c_encode=True, return_original_cond=True, bs=bs)
1727
+
1728
+ assert exists(self.concat_keys)
1729
+ c_cat = list()
1730
+ for ck in self.concat_keys:
1731
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1732
+ if bs is not None:
1733
+ cc = cc[:bs]
1734
+ cc = cc.to(self.device)
1735
+ bchw = z.shape
1736
+ if ck != self.masked_image_key:
1737
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1738
+ else:
1739
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1740
+ c_cat.append(cc)
1741
+ c_cat = torch.cat(c_cat, dim=1)
1742
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1743
+ if return_first_stage_outputs:
1744
+ return z, all_conds, x, xrec, xc
1745
+ return z, all_conds
1746
+
1747
+ @torch.no_grad()
1748
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1749
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1750
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1751
+ use_ema_scope=True,
1752
+ **kwargs):
1753
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1754
+ use_ddim = ddim_steps is not None
1755
+
1756
+ log = dict()
1757
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1758
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1759
+ N = min(x.shape[0], N)
1760
+ n_row = min(x.shape[0], n_row)
1761
+ log["inputs"] = x
1762
+ log["reconstruction"] = xrec
1763
+ if self.model.conditioning_key is not None:
1764
+ if hasattr(self.cond_stage_model, "decode"):
1765
+ xc = self.cond_stage_model.decode(c)
1766
+ log["conditioning"] = xc
1767
+ elif self.cond_stage_key in ["caption", "txt"]:
1768
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1769
+ log["conditioning"] = xc
1770
+ elif self.cond_stage_key == 'class_label':
1771
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1772
+ log['conditioning'] = xc
1773
+ elif isimage(xc):
1774
+ log["conditioning"] = xc
1775
+ if ismap(xc):
1776
+ log["original_conditioning"] = self.to_rgb(xc)
1777
+
1778
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1779
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end])
1780
+
1781
+ if plot_diffusion_rows:
1782
+ # get diffusion row
1783
+ diffusion_row = list()
1784
+ z_start = z[:n_row]
1785
+ for t in range(self.num_timesteps):
1786
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1787
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1788
+ t = t.to(self.device).long()
1789
+ noise = torch.randn_like(z_start)
1790
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1791
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1792
+
1793
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1794
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1795
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1796
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1797
+ log["diffusion_row"] = diffusion_grid
1798
+
1799
+ if sample:
1800
+ # get denoise row
1801
+ with ema_scope("Sampling"):
1802
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1803
+ batch_size=N, ddim=use_ddim,
1804
+ ddim_steps=ddim_steps, eta=ddim_eta)
1805
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1806
+ x_samples = self.decode_first_stage(samples)
1807
+ log["samples"] = x_samples
1808
+ if plot_denoise_rows:
1809
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1810
+ log["denoise_row"] = denoise_grid
1811
+
1812
+ if unconditional_guidance_scale > 1.0:
1813
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1814
+ uc_cat = c_cat
1815
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1816
+ with ema_scope("Sampling with classifier-free guidance"):
1817
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1818
+ batch_size=N, ddim=use_ddim,
1819
+ ddim_steps=ddim_steps, eta=ddim_eta,
1820
+ unconditional_guidance_scale=unconditional_guidance_scale,
1821
+ unconditional_conditioning=uc_full,
1822
+ )
1823
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1824
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1825
+
1826
+ log["masked_image"] = rearrange(batch["masked_image"],
1827
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1828
+ return log
1829
+
1830
+
1831
+ class Layout2ImgDiffusion(LatentDiffusion):
1832
+ # TODO: move all layout-specific hacks to this class
1833
+ def __init__(self, cond_stage_key, *args, **kwargs):
1834
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1835
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1836
+
1837
+ def log_images(self, batch, N=8, *args, **kwargs):
1838
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1839
+
1840
+ key = 'train' if self.training else 'validation'
1841
+ dset = self.trainer.datamodule.datasets[key]
1842
+ mapper = dset.conditional_builders[self.cond_stage_key]
1843
+
1844
+ bbox_imgs = []
1845
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1846
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1847
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1848
+ bbox_imgs.append(bboximg)
1849
+
1850
+ cond_img = torch.stack(bbox_imgs, dim=0)
1851
+ logs['bbox_image'] = cond_img
1852
+ return logs
1853
+
1854
+ class SimpleUpscaleDiffusion(LatentDiffusion):
1855
+ def __init__(self, *args, low_scale_key="LR", **kwargs):
1856
+ super().__init__(*args, **kwargs)
1857
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1858
+ assert not self.cond_stage_trainable
1859
+ self.low_scale_key = low_scale_key
1860
+
1861
+ @torch.no_grad()
1862
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1863
+ if not log_mode:
1864
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1865
+ else:
1866
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1867
+ force_c_encode=True, return_original_cond=True, bs=bs)
1868
+ x_low = batch[self.low_scale_key][:bs]
1869
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
1870
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1871
+
1872
+ encoder_posterior = self.encode_first_stage(x_low)
1873
+ zx = self.get_first_stage_encoding(encoder_posterior).detach()
1874
+ all_conds = {"c_concat": [zx], "c_crossattn": [c]}
1875
+
1876
+ if log_mode:
1877
+ # TODO: maybe disable if too expensive
1878
+ interpretability = False
1879
+ if interpretability:
1880
+ zx = zx[:, :, ::2, ::2]
1881
+ return z, all_conds, x, xrec, xc, x_low
1882
+ return z, all_conds
1883
+
1884
+ @torch.no_grad()
1885
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1886
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1887
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1888
+ **kwargs):
1889
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1890
+ use_ddim = ddim_steps is not None
1891
+
1892
+ log = dict()
1893
+ z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)
1894
+ N = min(x.shape[0], N)
1895
+ n_row = min(x.shape[0], n_row)
1896
+ log["inputs"] = x
1897
+ log["reconstruction"] = xrec
1898
+ log["x_lr"] = x_low
1899
+
1900
+ if self.model.conditioning_key is not None:
1901
+ if hasattr(self.cond_stage_model, "decode"):
1902
+ xc = self.cond_stage_model.decode(c)
1903
+ log["conditioning"] = xc
1904
+ elif self.cond_stage_key in ["caption", "txt"]:
1905
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
1906
+ log["conditioning"] = xc
1907
+ elif self.cond_stage_key == 'class_label':
1908
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
1909
+ log['conditioning'] = xc
1910
+ elif isimage(xc):
1911
+ log["conditioning"] = xc
1912
+ if ismap(xc):
1913
+ log["original_conditioning"] = self.to_rgb(xc)
1914
+
1915
+ if sample:
1916
+ # get denoise row
1917
+ with ema_scope("Sampling"):
1918
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1919
+ ddim_steps=ddim_steps, eta=ddim_eta)
1920
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1921
+ x_samples = self.decode_first_stage(samples)
1922
+ log["samples"] = x_samples
1923
+
1924
+ if unconditional_guidance_scale > 1.0:
1925
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1926
+ uc = dict()
1927
+ for k in c:
1928
+ if k == "c_crossattn":
1929
+ assert isinstance(c[k], list) and len(c[k]) == 1
1930
+ uc[k] = [uc_tmp]
1931
+ elif isinstance(c[k], list):
1932
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1933
+ else:
1934
+ uc[k] = c[k]
1935
+
1936
+ with ema_scope("Sampling with classifier-free guidance"):
1937
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1938
+ ddim_steps=ddim_steps, eta=ddim_eta,
1939
+ unconditional_guidance_scale=unconditional_guidance_scale,
1940
+ unconditional_conditioning=uc,
1941
+ )
1942
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1943
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1944
+
1945
+
1946
+ return log
stable_diffusion/ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import sys
7
+ sys.path.append('.')
8
+ from stable_diffusion.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+ from stable_diffusion.ldm.models.diffusion.sampling_util import norm_thresholding
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ dynamic_threshold=None,
81
+ **kwargs
82
+ ):
83
+ if conditioning is not None:
84
+ if isinstance(conditioning, dict):
85
+ ctmp = conditioning[list(conditioning.keys())[0]]
86
+ while isinstance(ctmp, list): ctmp = ctmp[0]
87
+ cbs = ctmp.shape[0]
88
+ if cbs != batch_size:
89
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
90
+ else:
91
+ if conditioning.shape[0] != batch_size:
92
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
93
+
94
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
95
+ # sampling
96
+ C, H, W = shape
97
+ size = (batch_size, C, H, W)
98
+ print(f'Data shape for PLMS sampling is {size}')
99
+
100
+ samples, intermediates = self.plms_sampling(conditioning, size,
101
+ callback=callback,
102
+ img_callback=img_callback,
103
+ quantize_denoised=quantize_x0,
104
+ mask=mask, x0=x0,
105
+ ddim_use_original_steps=False,
106
+ noise_dropout=noise_dropout,
107
+ temperature=temperature,
108
+ score_corrector=score_corrector,
109
+ corrector_kwargs=corrector_kwargs,
110
+ x_T=x_T,
111
+ log_every_t=log_every_t,
112
+ unconditional_guidance_scale=unconditional_guidance_scale,
113
+ unconditional_conditioning=unconditional_conditioning,
114
+ dynamic_threshold=dynamic_threshold,
115
+ )
116
+ return samples, intermediates
117
+
118
+ @torch.no_grad()
119
+ def plms_sampling(self, cond, shape,
120
+ x_T=None, ddim_use_original_steps=False,
121
+ callback=None, timesteps=None, quantize_denoised=False,
122
+ mask=None, x0=None, img_callback=None, log_every_t=100,
123
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
124
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
125
+ dynamic_threshold=None):
126
+ device = self.model.betas.device
127
+ b = shape[0]
128
+ if x_T is None:
129
+ img = torch.randn(shape, device=device)
130
+ else:
131
+ img = x_T
132
+
133
+ if timesteps is None:
134
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
135
+ elif timesteps is not None and not ddim_use_original_steps:
136
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
137
+ timesteps = self.ddim_timesteps[:subset_end]
138
+
139
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
140
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
141
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
142
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
143
+
144
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
145
+ old_eps = []
146
+
147
+ for i, step in enumerate(iterator):
148
+ index = total_steps - i - 1
149
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
150
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
151
+
152
+ if mask is not None:
153
+ assert x0 is not None
154
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
155
+ img = img_orig * mask + (1. - mask) * img
156
+
157
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
158
+ quantize_denoised=quantize_denoised, temperature=temperature,
159
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
160
+ corrector_kwargs=corrector_kwargs,
161
+ unconditional_guidance_scale=unconditional_guidance_scale,
162
+ unconditional_conditioning=unconditional_conditioning,
163
+ old_eps=old_eps, t_next=ts_next,
164
+ dynamic_threshold=dynamic_threshold)
165
+ img, pred_x0, e_t = outs
166
+ old_eps.append(e_t)
167
+ if len(old_eps) >= 4:
168
+ old_eps.pop(0)
169
+ if callback: callback(i)
170
+ if img_callback: img_callback(pred_x0, i)
171
+
172
+ if index % log_every_t == 0 or index == total_steps - 1:
173
+ intermediates['x_inter'].append(img)
174
+ intermediates['pred_x0'].append(pred_x0)
175
+
176
+ return img, intermediates
177
+
178
+ @torch.no_grad()
179
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
180
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
181
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
182
+ dynamic_threshold=None):
183
+ b, *_, device = *x.shape, x.device
184
+
185
+ def get_model_output(x, t):
186
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
187
+ e_t = self.model.apply_model(x, t, c)
188
+ else:
189
+ x_in = torch.cat([x] * 2)
190
+ t_in = torch.cat([t] * 2)
191
+ if isinstance(c, dict):
192
+ assert isinstance(unconditional_conditioning, dict)
193
+ c_in = dict()
194
+ for k in c:
195
+ if isinstance(c[k], list):
196
+ c_in[k] = [torch.cat([
197
+ unconditional_conditioning[k][i],
198
+ c[k][i]]) for i in range(len(c[k]))]
199
+ else:
200
+ c_in[k] = torch.cat([
201
+ unconditional_conditioning[k],
202
+ c[k]])
203
+ else:
204
+ c_in = torch.cat([unconditional_conditioning, c])
205
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
206
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
207
+
208
+ if score_corrector is not None:
209
+ assert self.model.parameterization == "eps"
210
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
211
+
212
+ return e_t
213
+
214
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
215
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
216
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
217
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
218
+
219
+ def get_x_prev_and_pred_x0(e_t, index):
220
+ # select parameters corresponding to the currently considered timestep
221
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
222
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
223
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
224
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
225
+
226
+ # current prediction for x_0
227
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
228
+ if quantize_denoised:
229
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
230
+ if dynamic_threshold is not None:
231
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
232
+ # direction pointing to x_t
233
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
234
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
235
+ if noise_dropout > 0.:
236
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
237
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
238
+ return x_prev, pred_x0
239
+
240
+ e_t = get_model_output(x, t)
241
+ if len(old_eps) == 0:
242
+ # Pseudo Improved Euler (2nd order)
243
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
244
+ e_t_next = get_model_output(x_prev, t_next)
245
+ e_t_prime = (e_t + e_t_next) / 2
246
+ elif len(old_eps) == 1:
247
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
248
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
249
+ elif len(old_eps) == 2:
250
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
251
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
252
+ elif len(old_eps) >= 3:
253
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
254
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
255
+
256
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
257
+
258
+ return x_prev, pred_x0, e_t
stable_diffusion/ldm/models/diffusion/sampling_util.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from einops import rearrange
4
+
5
+
6
+ def append_dims(x, target_dims):
7
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
8
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
9
+ dims_to_append = target_dims - x.ndim
10
+ if dims_to_append < 0:
11
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
12
+ return x[(...,) + (None,) * dims_to_append]
13
+
14
+
15
+ def norm_thresholding(x0, value):
16
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
17
+ return x0 * (value / s)
18
+
19
+
20
+ def spatial_norm_thresholding(x0, value):
21
+ # b c h w
22
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
23
+ return x0 * (value / s)
stable_diffusion/ldm/modules/image_degradation/utils/test.png ADDED

Git LFS Details

  • SHA256: 92e516278f0d3e85e84cfb55b43338e12d5896a0ee3833aafdf378025457d753
  • Pointer size: 131 Bytes
  • Size of remote file: 441 kB
stable_diffusion/ldm/thirdp/psp/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (4.11 kB). View file
 
stable_diffusion/ldm/thirdp/psp/__pycache__/id_loss.cpython-38.pyc ADDED
Binary file (1.28 kB). View file
 
stable_diffusion/ldm/thirdp/psp/__pycache__/model_irse.cpython-38.pyc ADDED
Binary file (2.93 kB). View file
 
stable_diffusion/ldm/thirdp/psp/helpers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/eladrich/pixel2style2pixel
2
+
3
+ from collections import namedtuple
4
+ import torch
5
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
6
+
7
+ """
8
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
9
+ """
10
+
11
+
12
+ class Flatten(Module):
13
+ def forward(self, input):
14
+ return input.view(input.size(0), -1)
15
+
16
+
17
+ def l2_norm(input, axis=1):
18
+ norm = torch.norm(input, 2, axis, True)
19
+ output = torch.div(input, norm)
20
+ return output
21
+
22
+
23
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
24
+ """ A named tuple describing a ResNet block. """
25
+
26
+
27
+ def get_block(in_channel, depth, num_units, stride=2):
28
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
29
+
30
+
31
+ def get_blocks(num_layers):
32
+ if num_layers == 50:
33
+ blocks = [
34
+ get_block(in_channel=64, depth=64, num_units=3),
35
+ get_block(in_channel=64, depth=128, num_units=4),
36
+ get_block(in_channel=128, depth=256, num_units=14),
37
+ get_block(in_channel=256, depth=512, num_units=3)
38
+ ]
39
+ elif num_layers == 100:
40
+ blocks = [
41
+ get_block(in_channel=64, depth=64, num_units=3),
42
+ get_block(in_channel=64, depth=128, num_units=13),
43
+ get_block(in_channel=128, depth=256, num_units=30),
44
+ get_block(in_channel=256, depth=512, num_units=3)
45
+ ]
46
+ elif num_layers == 152:
47
+ blocks = [
48
+ get_block(in_channel=64, depth=64, num_units=3),
49
+ get_block(in_channel=64, depth=128, num_units=8),
50
+ get_block(in_channel=128, depth=256, num_units=36),
51
+ get_block(in_channel=256, depth=512, num_units=3)
52
+ ]
53
+ else:
54
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
55
+ return blocks
56
+
57
+
58
+ class SEModule(Module):
59
+ def __init__(self, channels, reduction):
60
+ super(SEModule, self).__init__()
61
+ self.avg_pool = AdaptiveAvgPool2d(1)
62
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
63
+ self.relu = ReLU(inplace=True)
64
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
65
+ self.sigmoid = Sigmoid()
66
+
67
+ def forward(self, x):
68
+ module_input = x
69
+ x = self.avg_pool(x)
70
+ x = self.fc1(x)
71
+ x = self.relu(x)
72
+ x = self.fc2(x)
73
+ x = self.sigmoid(x)
74
+ return module_input * x
75
+
76
+
77
+ class bottleneck_IR(Module):
78
+ def __init__(self, in_channel, depth, stride):
79
+ super(bottleneck_IR, self).__init__()
80
+ if in_channel == depth:
81
+ self.shortcut_layer = MaxPool2d(1, stride)
82
+ else:
83
+ self.shortcut_layer = Sequential(
84
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
85
+ BatchNorm2d(depth)
86
+ )
87
+ self.res_layer = Sequential(
88
+ BatchNorm2d(in_channel),
89
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
90
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
91
+ )
92
+
93
+ def forward(self, x):
94
+ shortcut = self.shortcut_layer(x)
95
+ res = self.res_layer(x)
96
+ return res + shortcut
97
+
98
+
99
+ class bottleneck_IR_SE(Module):
100
+ def __init__(self, in_channel, depth, stride):
101
+ super(bottleneck_IR_SE, self).__init__()
102
+ if in_channel == depth:
103
+ self.shortcut_layer = MaxPool2d(1, stride)
104
+ else:
105
+ self.shortcut_layer = Sequential(
106
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
107
+ BatchNorm2d(depth)
108
+ )
109
+ self.res_layer = Sequential(
110
+ BatchNorm2d(in_channel),
111
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
112
+ PReLU(depth),
113
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
114
+ BatchNorm2d(depth),
115
+ SEModule(depth, 16)
116
+ )
117
+
118
+ def forward(self, x):
119
+ shortcut = self.shortcut_layer(x)
120
+ res = self.res_layer(x)
121
+ return res + shortcut
stable_diffusion/ldm/thirdp/psp/id_loss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/eladrich/pixel2style2pixel
2
+ import torch
3
+ from torch import nn
4
+ import sys
5
+ sys.path.append('.')
6
+ from stable_diffusion.ldm.thirdp.psp.model_irse import Backbone
7
+
8
+
9
+ class IDFeatures(nn.Module):
10
+ def __init__(self, model_path):
11
+ super(IDFeatures, self).__init__()
12
+ print('Loading ResNet ArcFace')
13
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
14
+ self.facenet.load_state_dict(torch.load(model_path))
15
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
16
+ self.facenet.eval()
17
+
18
+ def forward(self, x, crop=False):
19
+ # Not sure of the image range here
20
+ if crop:
21
+ x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
22
+ x = x[:, :, 35:223, 32:220]
23
+ x = self.face_pool(x)
24
+ x_feats = self.facenet(x)
25
+ return x_feats
stable_diffusion/ldm/thirdp/psp/model_irse.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/eladrich/pixel2style2pixel
2
+ import sys
3
+ sys.path.append(".")
4
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
5
+ from stable_diffusion.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
6
+
7
+ """
8
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
9
+ """
10
+
11
+
12
+ class Backbone(Module):
13
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
14
+ super(Backbone, self).__init__()
15
+ assert input_size in [112, 224], "input_size should be 112 or 224"
16
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
17
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
18
+ blocks = get_blocks(num_layers)
19
+ if mode == 'ir':
20
+ unit_module = bottleneck_IR
21
+ elif mode == 'ir_se':
22
+ unit_module = bottleneck_IR_SE
23
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
24
+ BatchNorm2d(64),
25
+ PReLU(64))
26
+ if input_size == 112:
27
+ self.output_layer = Sequential(BatchNorm2d(512),
28
+ Dropout(drop_ratio),
29
+ Flatten(),
30
+ Linear(512 * 7 * 7, 512),
31
+ BatchNorm1d(512, affine=affine))
32
+ else:
33
+ self.output_layer = Sequential(BatchNorm2d(512),
34
+ Dropout(drop_ratio),
35
+ Flatten(),
36
+ Linear(512 * 14 * 14, 512),
37
+ BatchNorm1d(512, affine=affine))
38
+
39
+ modules = []
40
+ for block in blocks:
41
+ for bottleneck in block:
42
+ modules.append(unit_module(bottleneck.in_channel,
43
+ bottleneck.depth,
44
+ bottleneck.stride))
45
+ self.body = Sequential(*modules)
46
+
47
+ def forward(self, x):
48
+ x = self.input_layer(x)
49
+ x = self.body(x)
50
+ x = self.output_layer(x)
51
+ return l2_norm(x)
52
+
53
+
54
+ def IR_50(input_size):
55
+ """Constructs a ir-50 model."""
56
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
57
+ return model
58
+
59
+
60
+ def IR_101(input_size):
61
+ """Constructs a ir-101 model."""
62
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
63
+ return model
64
+
65
+
66
+ def IR_152(input_size):
67
+ """Constructs a ir-152 model."""
68
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
69
+ return model
70
+
71
+
72
+ def IR_SE_50(input_size):
73
+ """Constructs a ir_se-50 model."""
74
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
75
+ return model
76
+
77
+
78
+ def IR_SE_101(input_size):
79
+ """Constructs a ir_se-101 model."""
80
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
81
+ return model
82
+
83
+
84
+ def IR_SE_152(input_size):
85
+ """Constructs a ir_se-152 model."""
86
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
87
+ return model
weights/Abstractionism.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68e87b6b779a8eebf79ee88f15ad397abf589ad021bc87834285f76466bae279
3
+ size 4265501162
weights/Artist_Sketch.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6370f52d6c6923d52474c309b3e40f74ff71f7436c8981569a21c4e96b930438
3
+ size 4265500016
weights/Blossom_Season.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5db2a5f7b0c9e2c4bf56290eacadc59cff828cfc7d31684a6f931833bb6d0424
3
+ size 4265501162
weights/Bricks.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30c61896482c62889fa27955b56fa4675418adfbf1da14db8f3ec442d9617495
3
+ size 4265486170
weights/Cats.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:603a1f4ffe491466b4e12e80efee4cda2049f240aa4ef8bc9ba8980863d58420
3
+ size 4265417638
weights/Color_Fantasy.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441ea11c216298f7e95678c9d6194a2b138d5f562d2cb4ffc2b1e7f3c152c85c
3
+ size 4265500016