BiliSakura commited on
Commit
cd99f4e
·
verified ·
1 Parent(s): 744c773

Add files using upload-large-folder tool

Browse files
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: diffusers
4
+ pipeline_tag: image-to-image
5
+ tags:
6
+ - zoomldm
7
+ - histopathology
8
+ - brca
9
+ - latent-diffusion
10
+ - custom-pipeline
11
+ - arxiv:2411.16969
12
+ ---
13
+
14
+ # BiliSakura/ZoomLDM-brca
15
+
16
+ Diffusers-format **BRCA** variant of ZoomLDM with a bundled custom pipeline and local `ldm` modules.
17
+
18
+ ## Model Description
19
+
20
+ - **Architecture:** ZoomLDM latent diffusion pipeline (`UNet + VAE + conditioning encoder`)
21
+ - **Domain:** Histopathology (BRCA)
22
+ - **Conditioning:** UNI-style SSL feature maps + magnification level (`0..4`)
23
+ - **Format:** Self-contained local folder for `DiffusionPipeline.from_pretrained(...)`
24
+
25
+ ## Intended Use
26
+
27
+ Use this model for conditional multi-scale BRCA patch generation when you have compatible pre-extracted SSL features.
28
+
29
+ ## Out-of-Scope Use
30
+
31
+ - Not intended for diagnosis, treatment planning, or other clinical decisions.
32
+ - Not a general-purpose text-to-image model.
33
+ - Not validated for data outside the expected acquisition/distribution range.
34
+
35
+ ## Files
36
+
37
+ - `unet/`, `vae/`, `conditioning_encoder/`, `scheduler/`
38
+ - `model_index.json`
39
+ - `pipeline_zoomldm.py`
40
+ - `ldm/` (bundled dependency modules)
41
+
42
+ ## Usage
43
+
44
+ ```python
45
+ import torch
46
+ from diffusers import DiffusionPipeline
47
+
48
+ pipe = DiffusionPipeline.from_pretrained(
49
+ "BiliSakura/ZoomLDM-brca",
50
+ custom_pipeline="pipeline_zoomldm.py",
51
+ trust_remote_code=True,
52
+ ).to("cuda")
53
+
54
+ out = pipe(
55
+ ssl_features=ssl_feat_tensor.to("cuda"), # BRCA UNI-style SSL embeddings
56
+ magnification=torch.tensor([0]).to("cuda"), # 0..4
57
+ num_inference_steps=50,
58
+ guidance_scale=2.0,
59
+ )
60
+ images = out.images
61
+ ```
62
+
63
+ ## Limitations
64
+
65
+ - Requires correctly precomputed BRCA conditioning features.
66
+ - Magnification conditioning must match expected integer codes.
67
+ - Generated content may reflect biases and artifacts from training data.
68
+
69
+ ## Citation
70
+
71
+ ```bibtex
72
+ @InProceedings{Yellapragada_2025_CVPR,
73
+ author = {Yellapragada, Srikar and Graikos, Alexandros and Triaridis, Kostas and Prasanna, Prateek and Gupta, Rajarsi and Saltz, Joel and Samaras, Dimitris},
74
+ title = {ZoomLDM: Latent Diffusion Model for Multi-scale Image Generation},
75
+ booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
76
+ month = {June},
77
+ year = {2025},
78
+ pages = {23453-23463}
79
+ }
80
+ ```
__pycache__/pipeline_zoomldm.cpython-312.pyc ADDED
Binary file (24.3 kB). View file
 
conditioning_encoder/config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_key": "ssl_feat",
3
+ "mag_key": "mag",
4
+ "num_layers": 12,
5
+ "input_channels": 1024,
6
+ "hidden_channels": 512,
7
+ "vit_mlp_dim": 2048,
8
+ "p_uncond": 0.1
9
+ }
conditioning_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1f294255b4b26a33fc72eb1819cf98032e9bcc0cfb2e308cf8376d58921bed8
3
+ size 154641952
ldm/data/__init__.py ADDED
File without changes
ldm/data/brca.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+ import h5py
6
+ import io
7
+ import torch.nn.functional as F
8
+ import torch
9
+
10
+
11
+ MAG_DICT = {
12
+ "20x": 0,
13
+ "10x": 1,
14
+ "5x": 2,
15
+ "2_5x": 3,
16
+ "1_25x": 4,
17
+ "0_625x": 5,
18
+ "0_3125x": 6,
19
+ "0_15625x": 7,
20
+ }
21
+
22
+ MAG_NUM_IMGS = {
23
+ "20x": 12_509_760,
24
+ "10x": 3_036_288,
25
+ "5x": 752_000,
26
+ "2_5x": 187_280,
27
+ "1_25x": 57_090,
28
+ "0_625x": 20_679,
29
+ "0_3125x": 7_923,
30
+ "0_15625x": 2489,
31
+ }
32
+
33
+
34
+ class TCGADataset(Dataset):
35
+ def __init__(self, config=None):
36
+ self.root = Path(config.get("root"))
37
+ self.mag = config.get("mag", None)
38
+
39
+ self.keys = list(MAG_DICT.keys())
40
+ self.feat_target_size = config.get("feat_target_size", -1)
41
+ self.return_image = config.get("return_image", False)
42
+ self.normalize_ssl = config.get("normalize_ssl", False)
43
+
44
+ def __len__(self):
45
+ if self.mag:
46
+ return MAG_NUM_IMGS[self.mag]
47
+ return MAG_NUM_IMGS["20x"]
48
+
49
+ def __getitem__(self, idx):
50
+ if self.mag:
51
+ mag_choice = self.mag
52
+ else:
53
+ mag_choice = np.random.choice(self.keys)
54
+ # pick a random index
55
+ idx = np.random.randint(0, MAG_NUM_IMGS[mag_choice])
56
+
57
+ ##### load VAE feat
58
+ folder = str(idx // 1_000_000)
59
+ folder_path = self.root / f"{mag_choice}/{folder}"
60
+
61
+ try:
62
+ vae_feat = np.load(folder_path / f"{idx}_vae.npy").astype(np.float16)
63
+ if vae_feat.shape != (3, 64, 64):
64
+ ### TEMPORARY FIX ###
65
+ raise Exception(f"vae shape {vae_feat.shape} for idx {idx}")
66
+
67
+ except:
68
+ idx = np.random.randint(len(self))
69
+ return self.__getitem__(idx)
70
+
71
+ ###### load SSL feature
72
+ ssl_feat = np.load(folder_path / f"{idx}_uni_grid.npy").astype(np.float16)
73
+
74
+ if len(ssl_feat.shape) == 1:
75
+ ssl_feat = ssl_feat[:, None]
76
+ h = np.sqrt(ssl_feat.shape[1]).astype(int)
77
+
78
+ ssl_feat = torch.tensor(ssl_feat.reshape((-1, h, h)))
79
+
80
+ # resize ssl_feat
81
+ if self.feat_target_size != -1 and h > self.feat_target_size:
82
+ shape = (self.feat_target_size, self.feat_target_size)
83
+ ssl_feat = F.adaptive_avg_pool2d(ssl_feat, shape)
84
+
85
+ # normalize ssl_feat
86
+ if self.normalize_ssl:
87
+ mean = ssl_feat.mean(axis=0, keepdims=True)
88
+ std = ssl_feat.std(axis=0, keepdims=True)
89
+ ssl_feat = (ssl_feat - mean) / (std + 1e-8)
90
+
91
+
92
+ #### load image
93
+ if self.return_image:
94
+ image = np.load(folder_path / f"{idx}_img.npy")
95
+ image = Image.open(io.BytesIO(image))
96
+ image = np.array(image).astype(np.uint8)
97
+
98
+ else:
99
+ image = np.ones((1, 1, 1, 3), dtype=np.float16)
100
+
101
+ return {
102
+ "image": image,
103
+ "vae_feat": vae_feat,
104
+ "ssl_feat": ssl_feat,
105
+ "idx": idx,
106
+ "mag": MAG_DICT[mag_choice],
107
+ }
ldm/data/naip.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+ import h5py
6
+ import io
7
+ import torch.nn.functional as F
8
+ import torch
9
+ from einops import rearrange
10
+
11
+ MAG_DICT = {
12
+ "1x": 0,
13
+ "2x": 1,
14
+ "3x": 2,
15
+ "4x": 3,
16
+ }
17
+
18
+ MAG_NUM_IMGS = {
19
+ "1x": 365119,
20
+ "2x": 94263,
21
+ "3x": 25690,
22
+ "4x": 8772,
23
+ }
24
+
25
+
26
+
27
+ class NAIPDataset(Dataset):
28
+ def __init__(self, config=None):
29
+ self.root = Path(config.get("root"))
30
+ self.mag = config.get("mag", None)
31
+
32
+ self.keys = list(MAG_DICT.keys())
33
+ self.feat_target_size = config.get("feat_target_size", -1)
34
+ self.return_image = config.get("return_image", False)
35
+ self.normalize_ssl = config.get("normalize_ssl", False)
36
+
37
+
38
+ def __len__(self):
39
+ if self.mag:
40
+ return MAG_NUM_IMGS[self.mag]
41
+ return sum(MAG_NUM_IMGS.values())
42
+
43
+ def __getitem__(self, idx):
44
+ if self.mag:
45
+ mag_choice = self.mag
46
+ else:
47
+ mag_choice = np.random.choice(self.keys)
48
+ # pick a random index
49
+ idx = np.random.randint(0, MAG_NUM_IMGS[mag_choice])
50
+
51
+ folder_path = self.root / f"{mag_choice}/"
52
+
53
+ vae_feat = np.load(folder_path / f"{idx}_vae.npy").astype(np.float16)
54
+
55
+ ssl_feat = np.load(folder_path / f"{idx}_dino_grid.npy").astype(np.float16)
56
+
57
+ h = np.sqrt(ssl_feat.shape[0]).astype(int)
58
+
59
+ ssl_feat = torch.tensor(rearrange(ssl_feat, "(h1 h2) dim -> dim h1 h2", h1 = h))
60
+
61
+ # resize ssl_feat
62
+ if self.feat_target_size != -1 and h > self.feat_target_size:
63
+ shape = (self.feat_target_size, self.feat_target_size)
64
+ ssl_feat = F.adaptive_avg_pool2d(ssl_feat, shape)
65
+
66
+ # normalize ssl_feat
67
+ if self.normalize_ssl:
68
+ mean = ssl_feat.mean(axis=0, keepdims=True)
69
+ std = ssl_feat.std(axis=0, keepdims=True)
70
+ ssl_feat = (ssl_feat - mean) / (std + 1e-8)
71
+
72
+
73
+ #### load image
74
+ if self.return_image:
75
+ image = Image.open(folder_path / f"{idx}.jpg")
76
+ image = np.array(image).astype(np.uint8)
77
+
78
+ else:
79
+ image = np.ones((1, 1, 1, 3), dtype=np.float16)
80
+
81
+ return {
82
+ "image": image,
83
+ "vae_feat": vae_feat,
84
+ "ssl_feat": ssl_feat,
85
+ "idx": idx,
86
+ "mag": MAG_DICT[mag_choice],
87
+ "img_path": str(folder_path / f"{idx}.jpg"),
88
+ }
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
10
+ self.lr_warm_up_steps = warm_up_steps
11
+ self.lr_start = lr_start
12
+ self.lr_min = lr_min
13
+ self.lr_max = lr_max
14
+ self.lr_max_decay_steps = max_decay_steps
15
+ self.last_lr = 0.0
16
+ self.verbosity_interval = verbosity_interval
17
+
18
+ def schedule(self, n, **kwargs):
19
+ if self.verbosity_interval > 0:
20
+ if n % self.verbosity_interval == 0:
21
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
22
+ if n < self.lr_warm_up_steps:
23
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
24
+ self.last_lr = lr
25
+ return lr
26
+ else:
27
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
28
+ t = min(t, 1.0)
29
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
30
+ self.last_lr = lr
31
+ return lr
32
+
33
+ def __call__(self, n, **kwargs):
34
+ return self.schedule(n, **kwargs)
35
+
36
+
37
+ class LambdaWarmUpCosineScheduler2:
38
+ """
39
+ supports repeated iterations, configurable via lists
40
+ note: use with a base_lr of 1.0.
41
+ """
42
+
43
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
44
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
45
+ self.lr_warm_up_steps = warm_up_steps
46
+ self.f_start = f_start
47
+ self.f_min = f_min
48
+ self.f_max = f_max
49
+ self.cycle_lengths = cycle_lengths
50
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
51
+ self.last_f = 0.0
52
+ self.verbosity_interval = verbosity_interval
53
+
54
+ def find_in_interval(self, n):
55
+ interval = 0
56
+ for cl in self.cum_cycles[1:]:
57
+ if n <= cl:
58
+ return interval
59
+ interval += 1
60
+
61
+ def schedule(self, n, **kwargs):
62
+ cycle = self.find_in_interval(n)
63
+ n = n - self.cum_cycles[cycle]
64
+ if self.verbosity_interval > 0:
65
+ if n % self.verbosity_interval == 0:
66
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
67
+ if n < self.lr_warm_up_steps[cycle]:
68
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
69
+ self.last_f = f
70
+ return f
71
+ else:
72
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
73
+ t = min(t, 1.0)
74
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
75
+ self.last_f = f
76
+ return f
77
+
78
+ def __call__(self, n, **kwargs):
79
+ return self.schedule(n, **kwargs)
80
+
81
+
82
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (
96
+ self.cycle_lengths[cycle]
97
+ )
98
+ self.last_f = f
99
+ return f
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pytorch_lightning as pl
4
+ import torch.nn.functional as F
5
+ from contextlib import contextmanager
6
+
7
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
8
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
9
+
10
+ from ldm.util import instantiate_from_config
11
+ from ldm.modules.ema import LitEma
12
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
13
+
14
+
15
+ class VQModel(pl.LightningModule):
16
+ def __init__(
17
+ self,
18
+ ddconfig,
19
+ lossconfig,
20
+ n_embed,
21
+ embed_dim,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ image_key="image",
25
+ colorize_nlabels=None,
26
+ monitor=None,
27
+ batch_resize_range=None,
28
+ scheduler_config=None,
29
+ lr_g_factor=1.0,
30
+ remap=None,
31
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
32
+ use_ema=False,
33
+ ):
34
+ super().__init__()
35
+ self.embed_dim = embed_dim
36
+ self.n_embed = n_embed
37
+ self.image_key = image_key
38
+ self.encoder = Encoder(**ddconfig)
39
+ self.decoder = Decoder(**ddconfig)
40
+ self.loss = instantiate_from_config(lossconfig)
41
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels) == int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_, _, ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size + 16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(
151
+ qloss,
152
+ x,
153
+ xrec,
154
+ optimizer_idx,
155
+ self.global_step,
156
+ last_layer=self.get_last_layer(),
157
+ split="train",
158
+ )
159
+
160
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
161
+ return aeloss
162
+
163
+ if optimizer_idx == 1:
164
+ # discriminator
165
+ discloss, log_dict_disc = self.loss(
166
+ qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train"
167
+ )
168
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
169
+ return discloss
170
+
171
+ def validation_step(self, batch, batch_idx):
172
+ log_dict = self._validation_step(batch, batch_idx)
173
+ with self.ema_scope():
174
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
175
+ return log_dict
176
+
177
+ def _validation_step(self, batch, batch_idx, suffix=""):
178
+ x = self.get_input(batch, self.image_key)
179
+ xrec, qloss, ind = self(x, return_pred_indices=True)
180
+ aeloss, log_dict_ae = self.loss(
181
+ qloss,
182
+ x,
183
+ xrec,
184
+ 0,
185
+ self.global_step,
186
+ last_layer=self.get_last_layer(),
187
+ split="val" + suffix,
188
+ )
189
+
190
+ discloss, log_dict_disc = self.loss(
191
+ qloss,
192
+ x,
193
+ xrec,
194
+ 1,
195
+ self.global_step,
196
+ last_layer=self.get_last_layer(),
197
+ split="val" + suffix,
198
+ )
199
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
200
+ self.log(
201
+ f"val{suffix}/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True
202
+ )
203
+ self.log(
204
+ f"val{suffix}/aeloss", aeloss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True
205
+ )
206
+ del log_dict_ae[f"val{suffix}/rec_loss"]
207
+ self.log_dict(log_dict_ae)
208
+ self.log_dict(log_dict_disc)
209
+ return self.log_dict
210
+
211
+ def configure_optimizers(self):
212
+ lr_d = self.learning_rate
213
+ lr_g = self.lr_g_factor * self.learning_rate
214
+ print("lr_d", lr_d)
215
+ print("lr_g", lr_g)
216
+ opt_ae = torch.optim.Adam(
217
+ list(self.encoder.parameters())
218
+ + list(self.decoder.parameters())
219
+ + list(self.quantize.parameters())
220
+ + list(self.quant_conv.parameters())
221
+ + list(self.post_quant_conv.parameters()),
222
+ lr=lr_g,
223
+ betas=(0.5, 0.9),
224
+ )
225
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
226
+
227
+ if self.scheduler_config is not None:
228
+ scheduler = instantiate_from_config(self.scheduler_config)
229
+
230
+ print("Setting up LambdaLR scheduler...")
231
+ scheduler = [
232
+ {"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1},
233
+ {"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1},
234
+ ]
235
+ return [opt_ae, opt_disc], scheduler
236
+ return [opt_ae, opt_disc], []
237
+
238
+ def get_last_layer(self):
239
+ return self.decoder.conv_out.weight
240
+
241
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
242
+ log = dict()
243
+ x = self.get_input(batch, self.image_key)
244
+ x = x.to(self.device)
245
+ if only_inputs:
246
+ log["inputs"] = x
247
+ return log
248
+ xrec, _ = self(x)
249
+ if x.shape[1] > 3:
250
+ # colorize with random projection
251
+ assert xrec.shape[1] > 3
252
+ x = self.to_rgb(x)
253
+ xrec = self.to_rgb(xrec)
254
+ log["inputs"] = x
255
+ log["reconstructions"] = xrec
256
+ if plot_ema:
257
+ with self.ema_scope():
258
+ xrec_ema, _ = self(x)
259
+ if x.shape[1] > 3:
260
+ xrec_ema = self.to_rgb(xrec_ema)
261
+ log["reconstructions_ema"] = xrec_ema
262
+ return log
263
+
264
+ def to_rgb(self, x):
265
+ assert self.image_key == "segmentation"
266
+ if not hasattr(self, "colorize"):
267
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
268
+ x = F.conv2d(x, weight=self.colorize)
269
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
270
+ return x
271
+
272
+
273
+ class VQModelInterface(VQModel):
274
+ def __init__(self, embed_dim, *args, **kwargs):
275
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
276
+ self.embed_dim = embed_dim
277
+
278
+ def encode(self, x):
279
+ h = self.encoder(x)
280
+ h = self.quant_conv(h)
281
+ return h
282
+
283
+ def decode(self, h, force_not_quantize=False):
284
+ # also go through quantization layer
285
+ if not force_not_quantize:
286
+ quant, emb_loss, info = self.quantize(h)
287
+ else:
288
+ quant = h
289
+ quant = self.post_quant_conv(quant)
290
+ dec = self.decoder(quant)
291
+ return dec
292
+
293
+
294
+ class AutoencoderKL(pl.LightningModule):
295
+ def __init__(self,
296
+ ddconfig,
297
+ lossconfig,
298
+ embed_dim,
299
+ ckpt_path=None,
300
+ ignore_keys=[],
301
+ image_key="image",
302
+ colorize_nlabels=None,
303
+ monitor=None,
304
+ ema_decay=None,
305
+ learn_logvar=False
306
+ ):
307
+ super().__init__()
308
+ self.learn_logvar = learn_logvar
309
+ self.image_key = image_key
310
+ self.encoder = Encoder(**ddconfig)
311
+ self.decoder = Decoder(**ddconfig)
312
+ self.loss = instantiate_from_config(lossconfig)
313
+ assert ddconfig["double_z"]
314
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
315
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
316
+ self.embed_dim = embed_dim
317
+ if colorize_nlabels is not None:
318
+ assert type(colorize_nlabels)==int
319
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
320
+ if monitor is not None:
321
+ self.monitor = monitor
322
+
323
+ self.use_ema = ema_decay is not None
324
+ if self.use_ema:
325
+ self.ema_decay = ema_decay
326
+ assert 0. < ema_decay < 1.
327
+ self.model_ema = LitEma(self, decay=ema_decay)
328
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
329
+
330
+ if ckpt_path is not None:
331
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
332
+
333
+ def init_from_ckpt(self, path, ignore_keys=list()):
334
+ sd = torch.load(path, map_location="cpu")["state_dict"]
335
+ keys = list(sd.keys())
336
+ for k in keys:
337
+ for ik in ignore_keys:
338
+ if k.startswith(ik):
339
+ print("Deleting key {} from state_dict.".format(k))
340
+ del sd[k]
341
+ self.load_state_dict(sd, strict=False)
342
+ print(f"Restored from {path}")
343
+
344
+ @contextmanager
345
+ def ema_scope(self, context=None):
346
+ if self.use_ema:
347
+ self.model_ema.store(self.parameters())
348
+ self.model_ema.copy_to(self)
349
+ if context is not None:
350
+ print(f"{context}: Switched to EMA weights")
351
+ try:
352
+ yield None
353
+ finally:
354
+ if self.use_ema:
355
+ self.model_ema.restore(self.parameters())
356
+ if context is not None:
357
+ print(f"{context}: Restored training weights")
358
+
359
+ def on_train_batch_end(self, *args, **kwargs):
360
+ if self.use_ema:
361
+ self.model_ema(self)
362
+
363
+ def encode(self, x):
364
+ h = self.encoder(x)
365
+ moments = self.quant_conv(h)
366
+ posterior = DiagonalGaussianDistribution(moments)
367
+ return posterior
368
+
369
+ def decode(self, z):
370
+ z = self.post_quant_conv(z)
371
+ dec = self.decoder(z)
372
+ return dec
373
+
374
+ def forward(self, input, sample_posterior=True):
375
+ posterior = self.encode(input)
376
+ if sample_posterior:
377
+ z = posterior.sample()
378
+ else:
379
+ z = posterior.mode()
380
+ dec = self.decode(z)
381
+ return dec, posterior
382
+
383
+ def get_input(self, batch, k):
384
+ x = batch[k]
385
+ if len(x.shape) == 3:
386
+ x = x[..., None]
387
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
388
+ return x
389
+
390
+ def training_step(self, batch, batch_idx, optimizer_idx):
391
+ inputs = self.get_input(batch, self.image_key)
392
+ reconstructions, posterior = self(inputs)
393
+
394
+ if optimizer_idx == 0:
395
+ # train encoder+decoder+logvar
396
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
397
+ last_layer=self.get_last_layer(), split="train")
398
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
399
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
400
+ return aeloss
401
+
402
+ if optimizer_idx == 1:
403
+ # train the discriminator
404
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
405
+ last_layer=self.get_last_layer(), split="train")
406
+
407
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
408
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
409
+ return discloss
410
+
411
+ def validation_step(self, batch, batch_idx):
412
+ log_dict = self._validation_step(batch, batch_idx)
413
+ with self.ema_scope():
414
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
415
+ return log_dict
416
+
417
+ def _validation_step(self, batch, batch_idx, postfix=""):
418
+ inputs = self.get_input(batch, self.image_key)
419
+ reconstructions, posterior = self(inputs)
420
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
421
+ last_layer=self.get_last_layer(), split="val"+postfix)
422
+
423
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
424
+ last_layer=self.get_last_layer(), split="val"+postfix)
425
+
426
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
427
+ self.log_dict(log_dict_ae)
428
+ self.log_dict(log_dict_disc)
429
+ return self.log_dict
430
+
431
+ def configure_optimizers(self):
432
+ lr = self.learning_rate
433
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
434
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
435
+ if self.learn_logvar:
436
+ print(f"{self.__class__.__name__}: Learning logvar")
437
+ ae_params_list.append(self.loss.logvar)
438
+ opt_ae = torch.optim.Adam(ae_params_list,
439
+ lr=lr, betas=(0.5, 0.9))
440
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
441
+ lr=lr, betas=(0.5, 0.9))
442
+ return [opt_ae, opt_disc], []
443
+
444
+ def get_last_layer(self):
445
+ return self.decoder.conv_out.weight
446
+
447
+ @torch.no_grad()
448
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
449
+ log = dict()
450
+ x = self.get_input(batch, self.image_key)
451
+ x = x.to(self.device)
452
+ if not only_inputs:
453
+ xrec, posterior = self(x)
454
+ if x.shape[1] > 3:
455
+ # colorize with random projection
456
+ assert xrec.shape[1] > 3
457
+ x = self.to_rgb(x)
458
+ xrec = self.to_rgb(xrec)
459
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
460
+ log["reconstructions"] = xrec
461
+ if log_ema or self.use_ema:
462
+ with self.ema_scope():
463
+ xrec_ema, posterior_ema = self(x)
464
+ if x.shape[1] > 3:
465
+ # colorize with random projection
466
+ assert xrec_ema.shape[1] > 3
467
+ xrec_ema = self.to_rgb(xrec_ema)
468
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
469
+ log["reconstructions_ema"] = xrec_ema
470
+ log["inputs"] = x
471
+ return log
472
+
473
+ def to_rgb(self, x):
474
+ assert self.image_key == "segmentation"
475
+ if not hasattr(self, "colorize"):
476
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
477
+ x = F.conv2d(x, weight=self.colorize)
478
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
479
+ return x
480
+
481
+
482
+ class IdentityFirstStage(torch.nn.Module):
483
+ def __init__(self, *args, vq_interface=False, **kwargs):
484
+ self.vq_interface = vq_interface
485
+ super().__init__()
486
+
487
+ def encode(self, x, *args, **kwargs):
488
+ return x
489
+
490
+ def decode(self, x, *args, **kwargs):
491
+ return x
492
+
493
+ def quantize(self, x, *args, **kwargs):
494
+ if self.vq_interface:
495
+ return x, None, [None, None, None]
496
+ return x
497
+
498
+ def forward(self, x, *args, **kwargs):
499
+ return x
500
+
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import (
8
+ make_ddim_sampling_parameters,
9
+ make_ddim_timesteps,
10
+ noise_like,
11
+ extract_into_tensor,
12
+ )
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
17
+ super().__init__()
18
+ self.model = model
19
+ self.ddpm_num_timesteps = model.num_timesteps
20
+ self.schedule = schedule
21
+ self.device = device
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != self.device:
26
+ attr = attr.to(self.device)
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
112
+ dynamic_threshold=None,
113
+ ucg_schedule=None,
114
+ **kwargs,
115
+ ):
116
+ if conditioning is not None:
117
+ if isinstance(conditioning, dict):
118
+ ctmp = conditioning[list(conditioning.keys())[0]]
119
+ while isinstance(ctmp, list):
120
+ ctmp = ctmp[0]
121
+ cbs = ctmp.shape[0]
122
+ if cbs != batch_size:
123
+ print(
124
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
125
+ )
126
+
127
+ elif isinstance(conditioning, list):
128
+ for ctmp in conditioning:
129
+ if ctmp.shape[0] != batch_size:
130
+ print(
131
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
132
+ )
133
+
134
+ else:
135
+ if conditioning.shape[0] != batch_size:
136
+ print(
137
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
138
+ )
139
+
140
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
141
+ # sampling
142
+ size = (batch_size, *shape)
143
+ print(f"Data shape for DDIM sampling is {size}, eta {eta}")
144
+
145
+ samples, intermediates = self.ddim_sampling(
146
+ conditioning,
147
+ size,
148
+ callback=callback,
149
+ img_callback=img_callback,
150
+ quantize_denoised=quantize_x0,
151
+ mask=mask,
152
+ x0=x0,
153
+ ddim_use_original_steps=False,
154
+ noise_dropout=noise_dropout,
155
+ temperature=temperature,
156
+ score_corrector=score_corrector,
157
+ corrector_kwargs=corrector_kwargs,
158
+ x_T=x_T,
159
+ log_every_t=log_every_t,
160
+ unconditional_guidance_scale=unconditional_guidance_scale,
161
+ unconditional_conditioning=unconditional_conditioning,
162
+ dynamic_threshold=dynamic_threshold,
163
+ ucg_schedule=ucg_schedule,
164
+ )
165
+ return samples, intermediates
166
+
167
+ @torch.no_grad()
168
+ def ddim_sampling(
169
+ self,
170
+ cond,
171
+ shape,
172
+ x_T=None,
173
+ ddim_use_original_steps=False,
174
+ callback=None,
175
+ timesteps=None,
176
+ quantize_denoised=False,
177
+ mask=None,
178
+ x0=None,
179
+ img_callback=None,
180
+ log_every_t=100,
181
+ temperature=1.0,
182
+ noise_dropout=0.0,
183
+ score_corrector=None,
184
+ corrector_kwargs=None,
185
+ unconditional_guidance_scale=1.0,
186
+ unconditional_conditioning=None,
187
+ dynamic_threshold=None,
188
+ ucg_schedule=None,
189
+ ):
190
+ device = self.model.betas.device
191
+ b = shape[0]
192
+ if x_T is None:
193
+ img = torch.randn(shape, device=device)
194
+ else:
195
+ img = x_T
196
+
197
+ if timesteps is None:
198
+ timesteps = (
199
+ self.ddpm_num_timesteps
200
+ if ddim_use_original_steps
201
+ else self.ddim_timesteps
202
+ )
203
+ elif timesteps is not None and not ddim_use_original_steps:
204
+ subset_end = (
205
+ int(
206
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
207
+ * self.ddim_timesteps.shape[0]
208
+ )
209
+ - 1
210
+ )
211
+ timesteps = self.ddim_timesteps[:subset_end]
212
+
213
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
214
+ time_range = (
215
+ reversed(range(0, timesteps))
216
+ if ddim_use_original_steps
217
+ else np.flip(timesteps)
218
+ )
219
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
220
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
221
+
222
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
223
+
224
+ for i, step in enumerate(iterator):
225
+ index = total_steps - i - 1
226
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
227
+
228
+ if mask is not None:
229
+ assert x0 is not None
230
+ img_orig = self.model.q_sample(
231
+ x0, ts
232
+ ) # TODO: deterministic forward pass?
233
+ img = img_orig * mask + (1.0 - mask) * img
234
+
235
+ if ucg_schedule is not None:
236
+ assert len(ucg_schedule) == len(time_range)
237
+ unconditional_guidance_scale = ucg_schedule[i]
238
+
239
+ with torch.cuda.amp.autocast():
240
+ outs = self.p_sample_ddim(
241
+ img,
242
+ cond,
243
+ ts,
244
+ index=index,
245
+ use_original_steps=ddim_use_original_steps,
246
+ quantize_denoised=quantize_denoised,
247
+ temperature=temperature,
248
+ noise_dropout=noise_dropout,
249
+ score_corrector=score_corrector,
250
+ corrector_kwargs=corrector_kwargs,
251
+ unconditional_guidance_scale=unconditional_guidance_scale,
252
+ unconditional_conditioning=unconditional_conditioning,
253
+ dynamic_threshold=dynamic_threshold,
254
+ )
255
+ img, pred_x0 = outs
256
+ if callback:
257
+ callback(i)
258
+ if img_callback:
259
+ img_callback(pred_x0, i)
260
+
261
+ if index % log_every_t == 0 or index == total_steps - 1:
262
+ intermediates["x_inter"].append(img)
263
+ intermediates["pred_x0"].append(pred_x0)
264
+
265
+ return img, intermediates
266
+
267
+ @torch.no_grad()
268
+ def p_sample_ddim(
269
+ self,
270
+ x,
271
+ c,
272
+ t,
273
+ index,
274
+ repeat_noise=False,
275
+ use_original_steps=False,
276
+ quantize_denoised=False,
277
+ temperature=1.0,
278
+ noise_dropout=0.0,
279
+ score_corrector=None,
280
+ corrector_kwargs=None,
281
+ unconditional_guidance_scale=1.0,
282
+ unconditional_conditioning=None,
283
+ dynamic_threshold=None,
284
+ ):
285
+ b, *_, device = *x.shape, x.device
286
+
287
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
288
+ model_output = self.model.apply_model(x, t, c)
289
+ else:
290
+ x_in = torch.cat([x] * 2)
291
+ t_in = torch.cat([t] * 2)
292
+ if isinstance(c, dict):
293
+ assert isinstance(unconditional_conditioning, dict)
294
+ c_in = dict()
295
+ for k in c:
296
+ if isinstance(c[k], list):
297
+ c_in[k] = [
298
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
299
+ for i in range(len(c[k]))
300
+ ]
301
+ else:
302
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
303
+ elif isinstance(c, list):
304
+ c_in = list()
305
+ assert isinstance(unconditional_conditioning, list)
306
+ for i in range(len(c)):
307
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
308
+ else:
309
+ c_in = torch.cat([unconditional_conditioning, c])
310
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
311
+ model_output = model_uncond + unconditional_guidance_scale * (
312
+ model_t - model_uncond
313
+ )
314
+
315
+ if self.model.parameterization == "v":
316
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
317
+ else:
318
+ e_t = model_output
319
+
320
+ if score_corrector is not None:
321
+ assert self.model.parameterization == "eps", "not implemented"
322
+ e_t = score_corrector.modify_score(
323
+ self.model, e_t, x, t, c, **corrector_kwargs
324
+ )
325
+
326
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
327
+ alphas_prev = (
328
+ self.model.alphas_cumprod_prev
329
+ if use_original_steps
330
+ else self.ddim_alphas_prev
331
+ )
332
+ sqrt_one_minus_alphas = (
333
+ self.model.sqrt_one_minus_alphas_cumprod
334
+ if use_original_steps
335
+ else self.ddim_sqrt_one_minus_alphas
336
+ )
337
+ sigmas = (
338
+ self.model.ddim_sigmas_for_original_num_steps
339
+ if use_original_steps
340
+ else self.ddim_sigmas
341
+ )
342
+ # select parameters corresponding to the currently considered timestep
343
+ # a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
344
+ # a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
345
+ # sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
346
+ # sqrt_one_minus_at = torch.full(
347
+ # (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
348
+ # )
349
+
350
+ # x can be 3 or 4 dimensional
351
+ a_t = torch.full((b, *([1] * (len(x.shape) - 1))), alphas[index], device=device)
352
+ a_prev = torch.full(
353
+ (b, *([1] * (len(x.shape) - 1))), alphas_prev[index], device=device
354
+ )
355
+ sigma_t = torch.full(
356
+ (b, *([1] * (len(x.shape) - 1))), sigmas[index], device=device
357
+ )
358
+ sqrt_one_minus_at = torch.full(
359
+ (b, *([1] * (len(x.shape) - 1))),
360
+ sqrt_one_minus_alphas[index],
361
+ device=device,
362
+ )
363
+
364
+ # current prediction for x_0
365
+ if self.model.parameterization != "v":
366
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
367
+ else:
368
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
369
+
370
+ if quantize_denoised:
371
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
372
+
373
+ if dynamic_threshold is not None:
374
+ raise NotImplementedError()
375
+
376
+ # direction pointing to x_t
377
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
378
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
379
+ if noise_dropout > 0.0:
380
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
381
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
382
+ return x_prev, pred_x0
383
+
384
+ @torch.no_grad()
385
+ def encode(
386
+ self,
387
+ x0,
388
+ c,
389
+ t_enc,
390
+ use_original_steps=False,
391
+ return_intermediates=None,
392
+ unconditional_guidance_scale=1.0,
393
+ unconditional_conditioning=None,
394
+ callback=None,
395
+ ):
396
+ num_reference_steps = (
397
+ self.ddpm_num_timesteps
398
+ if use_original_steps
399
+ else self.ddim_timesteps.shape[0]
400
+ )
401
+
402
+ assert t_enc <= num_reference_steps
403
+ num_steps = t_enc
404
+
405
+ if use_original_steps:
406
+ alphas_next = self.alphas_cumprod[:num_steps]
407
+ alphas = self.alphas_cumprod_prev[:num_steps]
408
+ else:
409
+ alphas_next = self.ddim_alphas[:num_steps]
410
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
411
+
412
+ x_next = x0
413
+ intermediates = []
414
+ inter_steps = []
415
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
416
+ t = torch.full(
417
+ (x0.shape[0],), i, device=self.model.device, dtype=torch.long
418
+ )
419
+ if unconditional_guidance_scale == 1.0:
420
+ noise_pred = self.model.apply_model(x_next, t, c)
421
+ else:
422
+ assert unconditional_conditioning is not None
423
+ e_t_uncond, noise_pred = torch.chunk(
424
+ self.model.apply_model(
425
+ torch.cat((x_next, x_next)),
426
+ torch.cat((t, t)),
427
+ torch.cat((unconditional_conditioning, c)),
428
+ ),
429
+ 2,
430
+ )
431
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
432
+ noise_pred - e_t_uncond
433
+ )
434
+
435
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
436
+ weighted_noise_pred = (
437
+ alphas_next[i].sqrt()
438
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
439
+ * noise_pred
440
+ )
441
+ x_next = xt_weighted + weighted_noise_pred
442
+ if (
443
+ return_intermediates
444
+ and i % (num_steps // return_intermediates) == 0
445
+ and i < num_steps - 1
446
+ ):
447
+ intermediates.append(x_next)
448
+ inter_steps.append(i)
449
+ elif return_intermediates and i >= num_steps - 2:
450
+ intermediates.append(x_next)
451
+ inter_steps.append(i)
452
+ if callback:
453
+ callback(i)
454
+
455
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
456
+ if return_intermediates:
457
+ out.update({"intermediates": intermediates})
458
+ return x_next, out
459
+
460
+ @torch.no_grad()
461
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
462
+ # fast, but does not allow for exact reconstruction
463
+ # t serves as an index to gather the correct alphas
464
+ if use_original_steps:
465
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
466
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
467
+ else:
468
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
469
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
470
+
471
+ if noise is None:
472
+ noise = torch.randn_like(x0)
473
+ return (
474
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
475
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
476
+ )
477
+
478
+ @torch.no_grad()
479
+ def decode(
480
+ self,
481
+ x_latent,
482
+ cond,
483
+ t_start,
484
+ unconditional_guidance_scale=1.0,
485
+ unconditional_conditioning=None,
486
+ use_original_steps=False,
487
+ callback=None,
488
+ ):
489
+ timesteps = (
490
+ np.arange(self.ddpm_num_timesteps)
491
+ if use_original_steps
492
+ else self.ddim_timesteps
493
+ )
494
+ timesteps = timesteps[:t_start]
495
+
496
+ time_range = np.flip(timesteps)
497
+ total_steps = timesteps.shape[0]
498
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
499
+
500
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
501
+ x_dec = x_latent
502
+ for i, step in enumerate(iterator):
503
+ index = total_steps - i - 1
504
+ ts = torch.full(
505
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
506
+ )
507
+ x_dec, _ = self.p_sample_ddim(
508
+ x_dec,
509
+ cond,
510
+ ts,
511
+ index=index,
512
+ use_original_steps=use_original_steps,
513
+ unconditional_guidance_scale=unconditional_guidance_scale,
514
+ unconditional_conditioning=unconditional_conditioning,
515
+ )
516
+ if callback:
517
+ callback(i)
518
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+
20
+ try:
21
+ from pytorch_lightning.utilities.distributed import rank_zero_only
22
+ except:
23
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
24
+
25
+ import bitsandbytes as bnb
26
+
27
+
28
+ from ldm.util import (
29
+ log_txt_as_img,
30
+ exists,
31
+ default,
32
+ ismap,
33
+ isimage,
34
+ mean_flat,
35
+ count_params,
36
+ instantiate_from_config,
37
+ )
38
+ from ldm.modules.ema import LitEma
39
+ from ldm.modules.distributions.distributions import (
40
+ normal_kl,
41
+ DiagonalGaussianDistribution,
42
+ )
43
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
44
+ from ldm.modules.diffusionmodules.util import (
45
+ make_beta_schedule,
46
+ extract_into_tensor,
47
+ noise_like,
48
+ )
49
+ from ldm.models.diffusion.ddim import DDIMSampler
50
+ # from pytorch_fid.inception import InceptionV3
51
+ # from pytorch_fid.fid_score import calculate_frechet_distance
52
+ from torchvision import transforms
53
+
54
+
55
+ __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
56
+
57
+
58
+ def disabled_train(self, mode=True):
59
+ """Overwrite model.train with this function to make sure train/eval mode
60
+ does not change anymore."""
61
+ return self
62
+
63
+
64
+ def uniform_on_device(r1, r2, shape, device):
65
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
66
+
67
+
68
+ class DDPM(pl.LightningModule):
69
+ # classic DDPM with Gaussian diffusion, in image space
70
+ def __init__(
71
+ self,
72
+ unet_config,
73
+ timesteps=1000,
74
+ beta_schedule="linear",
75
+ loss_type="l2",
76
+ ckpt_path=None,
77
+ ignore_keys=[],
78
+ load_only_unet=False,
79
+ monitor="val/loss",
80
+ use_ema=True,
81
+ first_stage_key="image",
82
+ image_size=256,
83
+ channels=3,
84
+ log_every_t=100,
85
+ clip_denoised=True,
86
+ linear_start=1e-4,
87
+ linear_end=2e-2,
88
+ cosine_s=8e-3,
89
+ given_betas=None,
90
+ original_elbo_weight=0.0,
91
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
92
+ l_simple_weight=1.0,
93
+ conditioning_key=None,
94
+ parameterization="eps", # all assuming fixed variance schedules
95
+ scheduler_config=None,
96
+ use_positional_encodings=False,
97
+ learn_logvar=False,
98
+ logvar_init=0.0,
99
+ ):
100
+ super().__init__()
101
+ assert parameterization in [
102
+ "eps",
103
+ "x0",
104
+ ], 'currently only supporting "eps" and "x0"'
105
+ self.parameterization = parameterization
106
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
107
+ self.cond_stage_model = None
108
+ self.clip_denoised = clip_denoised
109
+ self.log_every_t = log_every_t
110
+ self.first_stage_key = first_stage_key
111
+ self.image_size = image_size # try conv?
112
+ self.channels = channels
113
+ self.use_positional_encodings = use_positional_encodings
114
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
115
+ count_params(self.model, verbose=True)
116
+ self.use_ema = use_ema
117
+ if self.use_ema:
118
+ self.model_ema = LitEma(self.model)
119
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
120
+
121
+ self.use_scheduler = scheduler_config is not None
122
+ if self.use_scheduler:
123
+ self.scheduler_config = scheduler_config
124
+
125
+ self.v_posterior = v_posterior
126
+ self.original_elbo_weight = original_elbo_weight
127
+ self.l_simple_weight = l_simple_weight
128
+
129
+ if monitor is not None:
130
+ self.monitor = monitor
131
+ if ckpt_path is not None:
132
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
133
+
134
+ self.register_schedule(
135
+ given_betas=given_betas,
136
+ beta_schedule=beta_schedule,
137
+ timesteps=timesteps,
138
+ linear_start=linear_start,
139
+ linear_end=linear_end,
140
+ cosine_s=cosine_s,
141
+ )
142
+
143
+ self.loss_type = loss_type
144
+
145
+ self.learn_logvar = learn_logvar
146
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
147
+ if self.learn_logvar:
148
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
149
+
150
+ def register_schedule(
151
+ self,
152
+ given_betas=None,
153
+ beta_schedule="linear",
154
+ timesteps=1000,
155
+ linear_start=1e-4,
156
+ linear_end=2e-2,
157
+ cosine_s=8e-3,
158
+ ):
159
+ if exists(given_betas):
160
+ betas = given_betas
161
+ else:
162
+ betas = make_beta_schedule(
163
+ beta_schedule,
164
+ timesteps,
165
+ linear_start=linear_start,
166
+ linear_end=linear_end,
167
+ cosine_s=cosine_s,
168
+ )
169
+ alphas = 1.0 - betas
170
+ alphas_cumprod = np.cumprod(alphas, axis=0)
171
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
172
+
173
+ (timesteps,) = betas.shape
174
+ self.num_timesteps = int(timesteps)
175
+ self.linear_start = linear_start
176
+ self.linear_end = linear_end
177
+ assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep"
178
+
179
+ to_torch = partial(torch.tensor, dtype=torch.float32)
180
+
181
+ self.register_buffer("betas", to_torch(betas))
182
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
183
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
184
+
185
+ # calculations for diffusion q(x_t | x_{t-1}) and others
186
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
187
+ self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
188
+ self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
189
+ self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
190
+ self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
191
+
192
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
193
+ posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / (
194
+ 1.0 - alphas_cumprod
195
+ ) + self.v_posterior * betas
196
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
197
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
198
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
199
+ self.register_buffer(
200
+ "posterior_log_variance_clipped",
201
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
202
+ )
203
+ self.register_buffer(
204
+ "posterior_mean_coef1",
205
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
206
+ )
207
+ self.register_buffer(
208
+ "posterior_mean_coef2",
209
+ to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)),
210
+ )
211
+
212
+ if self.parameterization == "eps":
213
+ lvlb_weights = self.betas**2 / (
214
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
215
+ )
216
+ elif self.parameterization == "x0":
217
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
218
+ else:
219
+ raise NotImplementedError("mu not supported")
220
+ # TODO how to choose this term
221
+ lvlb_weights[0] = lvlb_weights[1]
222
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
223
+ assert not torch.isnan(self.lvlb_weights).all()
224
+
225
+ @contextmanager
226
+ def ema_scope(self, context=None):
227
+ if self.use_ema:
228
+ self.model_ema.store(self.model.parameters())
229
+ self.model_ema.copy_to(self.model)
230
+ if context is not None:
231
+ print(f"{context}: Switched to EMA weights")
232
+ try:
233
+ yield None
234
+ finally:
235
+ if self.use_ema:
236
+ self.model_ema.restore(self.model.parameters())
237
+ if context is not None:
238
+ print(f"{context}: Restored training weights")
239
+
240
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
241
+ sd = torch.load(path, map_location="cpu")
242
+ if "state_dict" in list(sd.keys()):
243
+ sd = sd["state_dict"]
244
+ keys = list(sd.keys())
245
+ for k in keys:
246
+ for ik in ignore_keys:
247
+ if k.startswith(ik):
248
+ print("Deleting key {} from state_dict.".format(k))
249
+ del sd[k]
250
+ missing, unexpected = (
251
+ self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
252
+ )
253
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
254
+ if len(missing) > 0:
255
+ print(f"Missing Keys: {missing}")
256
+ if len(unexpected) > 0:
257
+ print(f"Unexpected Keys: {unexpected}")
258
+
259
+ def q_mean_variance(self, x_start, t):
260
+ """
261
+ Get the distribution q(x_t | x_0).
262
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
263
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
264
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
265
+ """
266
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
267
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
268
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
269
+ return mean, variance, log_variance
270
+
271
+ def predict_start_from_noise(self, x_t, t, noise):
272
+ return (
273
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
274
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
275
+ )
276
+
277
+ def q_posterior(self, x_start, x_t, t):
278
+ posterior_mean = (
279
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
280
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
281
+ )
282
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
283
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
284
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
285
+
286
+ def p_mean_variance(self, x, t, clip_denoised: bool):
287
+ model_out = self.model(x, t)
288
+ if self.parameterization == "eps":
289
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
290
+ elif self.parameterization == "x0":
291
+ x_recon = model_out
292
+ if clip_denoised:
293
+ x_recon.clamp_(-1.0, 1.0)
294
+
295
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
296
+ return model_mean, posterior_variance, posterior_log_variance
297
+
298
+ @torch.no_grad()
299
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
300
+ b, *_, device = *x.shape, x.device
301
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
302
+ noise = noise_like(x.shape, device, repeat_noise)
303
+ # no noise when t == 0
304
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
305
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
306
+
307
+ @torch.no_grad()
308
+ def p_sample_loop(self, shape, return_intermediates=False):
309
+ device = self.betas.device
310
+ b = shape[0]
311
+ img = torch.randn(shape, device=device)
312
+ intermediates = [img]
313
+ for i in tqdm(
314
+ reversed(range(0, self.num_timesteps)),
315
+ desc="Sampling t",
316
+ total=self.num_timesteps,
317
+ ):
318
+ img = self.p_sample(
319
+ img,
320
+ torch.full((b,), i, device=device, dtype=torch.long),
321
+ clip_denoised=self.clip_denoised,
322
+ )
323
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
324
+ intermediates.append(img)
325
+ if return_intermediates:
326
+ return img, intermediates
327
+ return img
328
+
329
+ @torch.no_grad()
330
+ def sample(self, batch_size=16, return_intermediates=False):
331
+ image_size = self.image_size
332
+ channels = self.channels
333
+ return self.p_sample_loop(
334
+ (batch_size, channels, image_size, image_size),
335
+ return_intermediates=return_intermediates,
336
+ )
337
+
338
+ def q_sample(self, x_start, t, noise=None):
339
+ noise = default(noise, lambda: torch.randn_like(x_start))
340
+ return (
341
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
342
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
343
+ )
344
+
345
+ def get_loss(self, pred, target, mean=True):
346
+ if self.loss_type == "l1":
347
+ loss = (target - pred).abs()
348
+ if mean:
349
+ loss = loss.mean()
350
+ elif self.loss_type == "l2":
351
+ if mean:
352
+ loss = torch.nn.functional.mse_loss(target, pred)
353
+ else:
354
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
355
+ else:
356
+ raise NotImplementedError("unknown loss type '{loss_type}'")
357
+
358
+ return loss
359
+
360
+ def p_losses(self, x_start, t, noise=None):
361
+ noise = default(noise, lambda: torch.randn_like(x_start))
362
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
363
+ model_out = self.model(x_noisy, t)
364
+
365
+ loss_dict = {}
366
+ if self.parameterization == "eps":
367
+ target = noise
368
+ elif self.parameterization == "x0":
369
+ target = x_start
370
+ else:
371
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
372
+
373
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
374
+
375
+ log_prefix = "train" if self.training else "val"
376
+
377
+ loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
378
+ loss_simple = loss.mean() * self.l_simple_weight
379
+
380
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
381
+ loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
382
+
383
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
384
+
385
+ loss_dict.update({f"{log_prefix}/loss": loss})
386
+
387
+ return loss, loss_dict
388
+
389
+ def forward(self, x, *args, **kwargs):
390
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
391
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
392
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
393
+ return self.p_losses(x, t, *args, **kwargs)
394
+
395
+ def get_input(self, batch, k):
396
+ x = batch[k]
397
+ if len(x.shape) == 3:
398
+ x = x[..., None]
399
+ x = rearrange(x, "b h w c -> b c h w")
400
+ x = x.to(memory_format=torch.contiguous_format).float()
401
+ return x
402
+
403
+ def shared_step(self, batch):
404
+ x = self.get_input(batch, self.first_stage_key)
405
+ loss, loss_dict = self(x)
406
+ return loss, loss_dict
407
+
408
+ def training_step(self, batch, batch_idx):
409
+ loss, loss_dict = self.shared_step(batch)
410
+
411
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
412
+
413
+ self.log(
414
+ "global_step",
415
+ self.global_step,
416
+ prog_bar=True,
417
+ logger=True,
418
+ on_step=True,
419
+ on_epoch=False,
420
+ )
421
+
422
+ if self.use_scheduler:
423
+ lr = self.optimizers().param_groups[0]["lr"]
424
+ self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
425
+
426
+ return loss
427
+
428
+ @torch.no_grad()
429
+ def validation_step(self, batch, batch_idx):
430
+ _, loss_dict_no_ema = self.shared_step(batch)
431
+ with self.ema_scope():
432
+ _, loss_dict_ema = self.shared_step(batch)
433
+ loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
434
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
435
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
436
+
437
+ def on_train_batch_end(self, *args, **kwargs):
438
+ if self.use_ema:
439
+ self.model_ema(self.model)
440
+
441
+ def _get_rows_from_list(self, samples):
442
+ n_imgs_per_row = len(samples)
443
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
444
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
445
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
446
+ return denoise_grid
447
+
448
+ @torch.no_grad()
449
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
450
+ log = dict()
451
+ x = self.get_input(batch, self.first_stage_key)
452
+ N = min(x.shape[0], N)
453
+ n_row = min(x.shape[0], n_row)
454
+ x = x.to(self.device)[:N]
455
+ log["inputs"] = x
456
+
457
+ # get diffusion row
458
+ diffusion_row = list()
459
+ x_start = x[:n_row]
460
+
461
+ for t in range(self.num_timesteps):
462
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
463
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
464
+ t = t.to(self.device).long()
465
+ noise = torch.randn_like(x_start)
466
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
467
+ diffusion_row.append(x_noisy)
468
+
469
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
470
+
471
+ if sample:
472
+ # get denoise row
473
+ with self.ema_scope("Plotting"):
474
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
475
+
476
+ log["samples"] = samples
477
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
478
+
479
+ if return_keys:
480
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
481
+ return log
482
+ else:
483
+ return {key: log[key] for key in return_keys}
484
+ return log
485
+
486
+ def configure_optimizers(self):
487
+ lr = self.learning_rate
488
+ params = list(self.model.parameters())
489
+ if self.learn_logvar:
490
+ params = params + [self.logvar]
491
+ opt = torch.optim.AdamW(params, lr=lr)
492
+ return opt
493
+
494
+
495
+ class LatentDiffusion(DDPM):
496
+ """main class"""
497
+
498
+ def __init__(
499
+ self,
500
+ first_stage_config,
501
+ cond_stage_config,
502
+ num_timesteps_cond=None,
503
+ cond_stage_key="image",
504
+ cond_stage_trainable=False,
505
+ concat_mode=True,
506
+ cond_stage_forward=None,
507
+ conditioning_key=None,
508
+ scale_factor=1.0,
509
+ scale_by_std=False,
510
+ x_feat_extracted=False,
511
+ x_feat_key = "vae_feat",
512
+ *args,
513
+ **kwargs,
514
+ ):
515
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
516
+ self.scale_by_std = scale_by_std
517
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
518
+ # for backwards compatibility after implementation of DiffusionWrapper
519
+ if conditioning_key is None:
520
+ conditioning_key = "concat" if concat_mode else "crossattn"
521
+ # if cond_stage_config == "__is_unconditional__":
522
+ # conditioning_key = None
523
+ ckpt_path = kwargs.pop("ckpt_path", None)
524
+ ignore_keys = kwargs.pop("ignore_keys", [])
525
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
526
+ self.concat_mode = concat_mode
527
+ self.cond_stage_trainable = cond_stage_trainable
528
+ self.cond_stage_key = cond_stage_key
529
+ try:
530
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
531
+ except:
532
+ self.num_downs = 0
533
+ if not scale_by_std:
534
+ self.scale_factor = scale_factor
535
+ else:
536
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
537
+ self.instantiate_first_stage(first_stage_config)
538
+ self.instantiate_cond_stage(cond_stage_config)
539
+ self.cond_stage_forward = cond_stage_forward
540
+ self.clip_denoised = False
541
+ self.bbox_tokenizer = None
542
+
543
+ self.restarted_from_ckpt = False
544
+ if ckpt_path is not None:
545
+ self.init_from_ckpt(ckpt_path, ignore_keys)
546
+ self.restarted_from_ckpt = True
547
+
548
+ # if using preextracted vae features
549
+ self.x_feat_extracted=x_feat_extracted
550
+ self.x_feat_key = x_feat_key
551
+
552
+
553
+
554
+ def make_cond_schedule(
555
+ self,
556
+ ):
557
+ self.cond_ids = torch.full(
558
+ size=(self.num_timesteps,),
559
+ fill_value=self.num_timesteps - 1,
560
+ dtype=torch.long,
561
+ )
562
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
563
+ self.cond_ids[: self.num_timesteps_cond] = ids
564
+
565
+ @rank_zero_only
566
+ @torch.no_grad()
567
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
568
+ # only for very first batch
569
+ if (
570
+ self.scale_by_std
571
+ and self.current_epoch == 0
572
+ and self.global_step == 0
573
+ and batch_idx == 0
574
+ and not self.restarted_from_ckpt
575
+ ):
576
+ assert self.scale_factor == 1.0, "rather not use custom rescaling and std-rescaling simultaneously"
577
+ # set rescale weight to 1./std of encodings
578
+ print("### USING STD-RESCALING ###")
579
+ x = super().get_input(batch, self.first_stage_key)
580
+ x = x.to(self.device)
581
+ encoder_posterior = self.encode_first_stage(x)
582
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
583
+ del self.scale_factor
584
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
585
+ print(f"setting self.scale_factor to {self.scale_factor}")
586
+ print("### USING STD-RESCALING ###")
587
+
588
+ def register_schedule(
589
+ self,
590
+ given_betas=None,
591
+ beta_schedule="linear",
592
+ timesteps=1000,
593
+ linear_start=1e-4,
594
+ linear_end=2e-2,
595
+ cosine_s=8e-3,
596
+ ):
597
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
598
+
599
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
600
+ if self.shorten_cond_schedule:
601
+ self.make_cond_schedule()
602
+
603
+ def instantiate_first_stage(self, config):
604
+ model = instantiate_from_config(config)
605
+ self.first_stage_model = model.eval()
606
+ self.first_stage_model.train = disabled_train
607
+ for param in self.first_stage_model.parameters():
608
+ param.requires_grad = False
609
+
610
+ def instantiate_cond_stage(self, config):
611
+ if not self.cond_stage_trainable:
612
+ if config == "__is_first_stage__":
613
+ print("Using first stage also as cond stage.")
614
+ self.cond_stage_model = self.first_stage_model
615
+ elif config == "__is_unconditional__":
616
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
617
+ self.cond_stage_model = None
618
+ # self.be_unconditional = True
619
+ else:
620
+ model = instantiate_from_config(config)
621
+ self.cond_stage_model = model.eval()
622
+ self.cond_stage_model.train = disabled_train
623
+ for param in self.cond_stage_model.parameters():
624
+ param.requires_grad = False
625
+ else:
626
+ assert config != "__is_first_stage__"
627
+ assert config != "__is_unconditional__"
628
+ model = instantiate_from_config(config)
629
+ self.cond_stage_model = model
630
+
631
+ def _get_denoise_row_from_list(self, samples, desc="", force_no_decoder_quantization=False):
632
+ denoise_row = []
633
+ for zd in tqdm(samples, desc=desc):
634
+ denoise_row.append(
635
+ self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)
636
+ )
637
+ n_imgs_per_row = len(denoise_row)
638
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
639
+ denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
640
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
641
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
642
+ return denoise_grid
643
+
644
+ def get_first_stage_encoding(self, encoder_posterior):
645
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
646
+ z = encoder_posterior.sample()
647
+ elif isinstance(encoder_posterior, torch.Tensor):
648
+ z = encoder_posterior
649
+ else:
650
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
651
+ return self.scale_factor * z
652
+
653
+ def get_learned_conditioning(self, c):
654
+ if self.cond_stage_forward is None:
655
+ if hasattr(self.cond_stage_model, "encode") and callable(self.cond_stage_model.encode):
656
+ c = self.cond_stage_model.encode(c)
657
+ if isinstance(c, DiagonalGaussianDistribution):
658
+ c = c.mode()
659
+ else:
660
+ c = self.cond_stage_model(c)
661
+ else:
662
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
663
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
664
+ return c
665
+
666
+ def meshgrid(self, h, w):
667
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
668
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
669
+
670
+ arr = torch.cat([y, x], dim=-1)
671
+ return arr
672
+
673
+ def delta_border(self, h, w):
674
+ """
675
+ :param h: height
676
+ :param w: width
677
+ :return: normalized distance to image border,
678
+ wtith min distance = 0 at border and max dist = 0.5 at image center
679
+ """
680
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
681
+ arr = self.meshgrid(h, w) / lower_right_corner
682
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
683
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
684
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
685
+ return edge_dist
686
+
687
+ def get_weighting(self, h, w, Ly, Lx, device):
688
+ weighting = self.delta_border(h, w)
689
+ weighting = torch.clip(
690
+ weighting,
691
+ self.split_input_params["clip_min_weight"],
692
+ self.split_input_params["clip_max_weight"],
693
+ )
694
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
695
+
696
+ if self.split_input_params["tie_braker"]:
697
+ L_weighting = self.delta_border(Ly, Lx)
698
+ L_weighting = torch.clip(
699
+ L_weighting,
700
+ self.split_input_params["clip_min_tie_weight"],
701
+ self.split_input_params["clip_max_tie_weight"],
702
+ )
703
+
704
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
705
+ weighting = weighting * L_weighting
706
+ return weighting
707
+
708
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
709
+ """
710
+ :param x: img of size (bs, c, h, w)
711
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
712
+ """
713
+ bs, nc, h, w = x.shape
714
+
715
+ # number of crops in image
716
+ Ly = (h - kernel_size[0]) // stride[0] + 1
717
+ Lx = (w - kernel_size[1]) // stride[1] + 1
718
+
719
+ if uf == 1 and df == 1:
720
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
721
+ unfold = torch.nn.Unfold(**fold_params)
722
+
723
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
724
+
725
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
726
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
727
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
728
+
729
+ elif uf > 1 and df == 1:
730
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
731
+ unfold = torch.nn.Unfold(**fold_params)
732
+
733
+ fold_params2 = dict(
734
+ kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
735
+ dilation=1,
736
+ padding=0,
737
+ stride=(stride[0] * uf, stride[1] * uf),
738
+ )
739
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
740
+
741
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
742
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
743
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
744
+
745
+ elif df > 1 and uf == 1:
746
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
747
+ unfold = torch.nn.Unfold(**fold_params)
748
+
749
+ fold_params2 = dict(
750
+ kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
751
+ dilation=1,
752
+ padding=0,
753
+ stride=(stride[0] // df, stride[1] // df),
754
+ )
755
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
756
+
757
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
758
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
759
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
760
+
761
+ else:
762
+ raise NotImplementedError
763
+
764
+ return fold, unfold, normalization, weighting
765
+
766
+ @torch.no_grad()
767
+ def get_input(
768
+ self,
769
+ batch,
770
+ k,
771
+ return_first_stage_outputs=False,
772
+ force_c_encode=False,
773
+ cond_key=None,
774
+ return_original_cond=False,
775
+ bs=None,
776
+ ):
777
+
778
+ if self.x_feat_extracted and self.x_feat_key == "vae_feat":
779
+ z = batch[self.x_feat_key].to(self.device)
780
+ if bs is not None:
781
+ z = z[:bs]
782
+ x = None
783
+
784
+ elif self.x_feat_key == "ssl_feat":
785
+ with torch.no_grad():
786
+ z = self.first_stage_model(batch)
787
+
788
+ z *= self.scale_factor
789
+
790
+ if bs is not None:
791
+ z = z[:bs]
792
+ x = None
793
+
794
+ else:
795
+
796
+ x = super().get_input(batch, k)
797
+ if bs is not None:
798
+ x = x[:bs]
799
+ x = x.to(self.device)
800
+ encoder_posterior = self.encode_first_stage(x)
801
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
802
+
803
+ if self.model.conditioning_key is not None:
804
+ if cond_key is None:
805
+ cond_key = self.cond_stage_key
806
+ if cond_key != self.first_stage_key:
807
+ if cond_key in ["caption", "coordinates_bbox", "mag"]:
808
+ xc = batch[cond_key]
809
+ elif cond_key in ["class_label", "hybrid"]:
810
+ xc = batch
811
+ else:
812
+ xc = super().get_input(batch, cond_key).to(self.device)
813
+ else:
814
+ xc = x
815
+ if cond_key != "mag" and (not self.cond_stage_trainable or force_c_encode):
816
+ if isinstance(xc, dict) or isinstance(xc, list):
817
+ c = self.get_learned_conditioning(xc)
818
+ else:
819
+ c = self.get_learned_conditioning(xc.to(self.device))
820
+ else:
821
+ c = xc
822
+ if bs is not None:
823
+ if isinstance(c, list):
824
+ c[0] = c[0][:bs]
825
+ c[1] = c[1][:bs]
826
+
827
+ c = c[:bs]
828
+
829
+ if self.use_positional_encodings:
830
+ pos_x, pos_y = self.compute_latent_shifts(batch)
831
+ ckey = __conditioning_keys__[self.model.conditioning_key]
832
+ c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
833
+
834
+ else:
835
+ c = None
836
+ xc = None
837
+ if self.use_positional_encodings:
838
+ pos_x, pos_y = self.compute_latent_shifts(batch)
839
+ c = {"pos_x": pos_x, "pos_y": pos_y}
840
+ out = [z, c]
841
+ if return_first_stage_outputs:
842
+ xrec = self.decode_first_stage(z)
843
+ out.extend([x, xrec])
844
+ if return_original_cond:
845
+ out.append(xc)
846
+ return out
847
+
848
+ @torch.no_grad()
849
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
850
+ if predict_cids:
851
+ if z.dim() == 4:
852
+ z = torch.argmax(z.exp(), dim=1).long()
853
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
854
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
855
+
856
+ z = 1.0 / self.scale_factor * z
857
+
858
+ if hasattr(self, "split_input_params"):
859
+ if self.split_input_params["patch_distributed_vq"]:
860
+ ks = self.split_input_params["ks"] # eg. (128, 128)
861
+ stride = self.split_input_params["stride"] # eg. (64, 64)
862
+ uf = self.split_input_params["vqf"]
863
+ bs, nc, h, w = z.shape
864
+ if ks[0] > h or ks[1] > w:
865
+ ks = (min(ks[0], h), min(ks[1], w))
866
+ print("reducing Kernel")
867
+
868
+ if stride[0] > h or stride[1] > w:
869
+ stride = (min(stride[0], h), min(stride[1], w))
870
+ print("reducing stride")
871
+
872
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
873
+
874
+ z = unfold(z) # (bn, nc * prod(**ks), L)
875
+ # 1. Reshape to img shape
876
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
877
+
878
+ # 2. apply model loop over last dim
879
+ if isinstance(self.first_stage_model, VQModelInterface):
880
+ output_list = [
881
+ self.first_stage_model.decode(
882
+ z[:, :, :, :, i],
883
+ force_not_quantize=predict_cids or force_not_quantize,
884
+ )
885
+ for i in range(z.shape[-1])
886
+ ]
887
+ else:
888
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1])]
889
+
890
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
891
+ o = o * weighting
892
+ # Reverse 1. reshape to img shape
893
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
894
+ # stitch crops together
895
+ decoded = fold(o)
896
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
897
+ return decoded
898
+ else:
899
+ if isinstance(self.first_stage_model, VQModelInterface):
900
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
901
+ else:
902
+ return self.first_stage_model.decode(z)
903
+
904
+ else:
905
+ if isinstance(self.first_stage_model, VQModelInterface):
906
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
907
+ else:
908
+ return self.first_stage_model.decode(z)
909
+
910
+ # same as above but without decorator
911
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
912
+ if predict_cids:
913
+ if z.dim() == 4:
914
+ z = torch.argmax(z.exp(), dim=1).long()
915
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
916
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
917
+
918
+ z = 1.0 / self.scale_factor * z
919
+
920
+ if hasattr(self, "split_input_params"):
921
+ if self.split_input_params["patch_distributed_vq"]:
922
+ ks = self.split_input_params["ks"] # eg. (128, 128)
923
+ stride = self.split_input_params["stride"] # eg. (64, 64)
924
+ uf = self.split_input_params["vqf"]
925
+ bs, nc, h, w = z.shape
926
+ if ks[0] > h or ks[1] > w:
927
+ ks = (min(ks[0], h), min(ks[1], w))
928
+ print("reducing Kernel")
929
+
930
+ if stride[0] > h or stride[1] > w:
931
+ stride = (min(stride[0], h), min(stride[1], w))
932
+ print("reducing stride")
933
+
934
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
935
+
936
+ z = unfold(z) # (bn, nc * prod(**ks), L)
937
+ # 1. Reshape to img shape
938
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
939
+
940
+ # 2. apply model loop over last dim
941
+ if isinstance(self.first_stage_model, VQModelInterface):
942
+ output_list = [
943
+ self.first_stage_model.decode(
944
+ z[:, :, :, :, i],
945
+ force_not_quantize=predict_cids or force_not_quantize,
946
+ )
947
+ for i in range(z.shape[-1])
948
+ ]
949
+ else:
950
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1])]
951
+
952
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
953
+ o = o * weighting
954
+ # Reverse 1. reshape to img shape
955
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
956
+ # stitch crops together
957
+ decoded = fold(o)
958
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
959
+ return decoded
960
+ else:
961
+ if isinstance(self.first_stage_model, VQModelInterface):
962
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
963
+ else:
964
+ return self.first_stage_model.decode(z)
965
+
966
+ else:
967
+ if isinstance(self.first_stage_model, VQModelInterface):
968
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
969
+ else:
970
+ return self.first_stage_model.decode(z)
971
+
972
+ @torch.no_grad()
973
+ def encode_first_stage(self, x):
974
+ if hasattr(self, "split_input_params"):
975
+ if self.split_input_params["patch_distributed_vq"]:
976
+ ks = self.split_input_params["ks"] # eg. (128, 128)
977
+ stride = self.split_input_params["stride"] # eg. (64, 64)
978
+ df = self.split_input_params["vqf"]
979
+ self.split_input_params["original_image_size"] = x.shape[-2:]
980
+ bs, nc, h, w = x.shape
981
+ if ks[0] > h or ks[1] > w:
982
+ ks = (min(ks[0], h), min(ks[1], w))
983
+ print("reducing Kernel")
984
+
985
+ if stride[0] > h or stride[1] > w:
986
+ stride = (min(stride[0], h), min(stride[1], w))
987
+ print("reducing stride")
988
+
989
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
990
+ z = unfold(x) # (bn, nc * prod(**ks), L)
991
+ # Reshape to img shape
992
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
993
+
994
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) for i in range(z.shape[-1])]
995
+
996
+ o = torch.stack(output_list, axis=-1)
997
+ o = o * weighting
998
+
999
+ # Reverse reshape to img shape
1000
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1001
+ # stitch crops together
1002
+ decoded = fold(o)
1003
+ decoded = decoded / normalization
1004
+ return decoded
1005
+
1006
+ else:
1007
+ return self.first_stage_model.encode(x)
1008
+ else:
1009
+ return self.first_stage_model.encode(x)
1010
+
1011
+ def shared_step(self, batch, **kwargs):
1012
+ x, c = self.get_input(batch, self.first_stage_key)
1013
+ if self.model.conditioning_key == 'hybrid':
1014
+ c_concat = rearrange(batch["LR_image"], 'n h w c -> n c h w')
1015
+ kwargs["c_concat"] = [c_concat]
1016
+ loss = self(x, c, **kwargs)
1017
+ return loss
1018
+
1019
+ def forward(self, x, c, *args, **kwargs):
1020
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
1021
+ if self.model.conditioning_key is not None:
1022
+ assert c is not None
1023
+ if self.cond_stage_trainable:
1024
+ c = self.get_learned_conditioning(c)
1025
+ if self.shorten_cond_schedule: # TODO: drop this option
1026
+ tc = self.cond_ids[t].to(self.device)
1027
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
1028
+ return self.p_losses(x, c, t, *args, **kwargs)
1029
+
1030
+
1031
+ def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs):
1032
+ if isinstance(cond, dict):
1033
+ # hybrid case, cond is exptected to be a dict
1034
+ pass
1035
+ else:
1036
+ if not isinstance(cond, list):
1037
+ cond = [cond]
1038
+ key = "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
1039
+ cond = {key: cond}
1040
+
1041
+ if hasattr(self, "split_input_params"):
1042
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
1043
+ assert not return_ids
1044
+ ks = self.split_input_params["ks"] # eg. (128, 128)
1045
+ stride = self.split_input_params["stride"] # eg. (64, 64)
1046
+
1047
+ h, w = x_noisy.shape[-2:]
1048
+
1049
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
1050
+
1051
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
1052
+ # Reshape to img shape
1053
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1054
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
1055
+
1056
+ if (
1057
+ self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"] and self.model.conditioning_key
1058
+ ): # todo check for completeness
1059
+ c_key = next(iter(cond.keys())) # get key
1060
+ c = next(iter(cond.values())) # get value
1061
+ assert len(c) == 1 # todo extend to list with more than one elem
1062
+ c = c[0] # get element
1063
+
1064
+ c = unfold(c)
1065
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1066
+
1067
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
1068
+
1069
+
1070
+ else:
1071
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
1072
+
1073
+ # apply model by loop over crops
1074
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1075
+ assert not isinstance(
1076
+ output_list[0], tuple
1077
+ ) # todo cant deal with multiple model outputs check this never happens
1078
+
1079
+ o = torch.stack(output_list, axis=-1)
1080
+ o = o * weighting
1081
+ # Reverse reshape to img shape
1082
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1083
+ # stitch crops together
1084
+ x_recon = fold(o) / normalization
1085
+
1086
+ else:
1087
+ with torch.cuda.amp.autocast():
1088
+ x_recon = self.model(x_noisy, t, **cond, **kwargs)
1089
+
1090
+ if isinstance(x_recon, tuple) and not return_ids:
1091
+ return x_recon[0]
1092
+ else:
1093
+ return x_recon
1094
+
1095
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1096
+ return (
1097
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
1098
+ ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1099
+
1100
+ def _prior_bpd(self, x_start):
1101
+ """
1102
+ Get the prior KL term for the variational lower-bound, measured in
1103
+ bits-per-dim.
1104
+ This term can't be optimized, as it only depends on the encoder.
1105
+ :param x_start: the [N x C x ...] tensor of inputs.
1106
+ :return: a batch of [N] KL values (in bits), one per batch element.
1107
+ """
1108
+ batch_size = x_start.shape[0]
1109
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1110
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1111
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1112
+ return mean_flat(kl_prior) / np.log(2.0)
1113
+
1114
+ def p_losses(self, x_start, cond, t, noise=None, **kwargs):
1115
+ noise = default(noise, lambda: torch.randn_like(x_start))
1116
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1117
+ model_output = self.apply_model(x_noisy, t, cond, **kwargs)
1118
+
1119
+ loss_dict = {}
1120
+ prefix = "train" if self.training else "val"
1121
+
1122
+ if self.parameterization == "x0":
1123
+ target = x_start
1124
+ elif self.parameterization == "eps":
1125
+ target = noise
1126
+ else:
1127
+ raise NotImplementedError()
1128
+
1129
+ dims_non_bs = tuple(range(1, target.dim()))
1130
+
1131
+ loss_simple = self.get_loss(model_output, target, mean=False).mean(dims_non_bs)
1132
+ loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
1133
+
1134
+ self.logvar = self.logvar.to(self.device)
1135
+ logvar_t = self.logvar[t].to(self.device)
1136
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1137
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1138
+ if self.learn_logvar:
1139
+ loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
1140
+ loss_dict.update({"logvar": self.logvar.data.mean()})
1141
+
1142
+ loss = self.l_simple_weight * loss.mean()
1143
+
1144
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=dims_non_bs)
1145
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1146
+ loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
1147
+ loss += self.original_elbo_weight * loss_vlb
1148
+ loss_dict.update({f"{prefix}/loss": loss})
1149
+
1150
+ return loss, loss_dict
1151
+
1152
+ def p_mean_variance(
1153
+ self,
1154
+ x,
1155
+ c,
1156
+ t,
1157
+ clip_denoised: bool,
1158
+ return_codebook_ids=False,
1159
+ quantize_denoised=False,
1160
+ return_x0=False,
1161
+ score_corrector=None,
1162
+ corrector_kwargs=None,
1163
+ ):
1164
+ t_in = t
1165
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1166
+
1167
+ if score_corrector is not None:
1168
+ assert self.parameterization == "eps"
1169
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1170
+
1171
+ if return_codebook_ids:
1172
+ model_out, logits = model_out
1173
+
1174
+ if self.parameterization == "eps":
1175
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1176
+ elif self.parameterization == "x0":
1177
+ x_recon = model_out
1178
+ else:
1179
+ raise NotImplementedError()
1180
+
1181
+ if clip_denoised:
1182
+ x_recon.clamp_(-1.0, 1.0)
1183
+ if quantize_denoised:
1184
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1185
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1186
+ if return_codebook_ids:
1187
+ return model_mean, posterior_variance, posterior_log_variance, logits
1188
+ elif return_x0:
1189
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1190
+ else:
1191
+ return model_mean, posterior_variance, posterior_log_variance
1192
+
1193
+ @torch.no_grad()
1194
+ def p_sample(
1195
+ self,
1196
+ x,
1197
+ c,
1198
+ t,
1199
+ clip_denoised=False,
1200
+ repeat_noise=False,
1201
+ return_codebook_ids=False,
1202
+ quantize_denoised=False,
1203
+ return_x0=False,
1204
+ temperature=1.0,
1205
+ noise_dropout=0.0,
1206
+ score_corrector=None,
1207
+ corrector_kwargs=None,
1208
+ ):
1209
+ b, *_, device = *x.shape, x.device
1210
+ outputs = self.p_mean_variance(
1211
+ x=x,
1212
+ c=c,
1213
+ t=t,
1214
+ clip_denoised=clip_denoised,
1215
+ return_codebook_ids=return_codebook_ids,
1216
+ quantize_denoised=quantize_denoised,
1217
+ return_x0=return_x0,
1218
+ score_corrector=score_corrector,
1219
+ corrector_kwargs=corrector_kwargs,
1220
+ )
1221
+ if return_codebook_ids:
1222
+ raise DeprecationWarning("Support dropped.")
1223
+ model_mean, _, model_log_variance, logits = outputs
1224
+ elif return_x0:
1225
+ model_mean, _, model_log_variance, x0 = outputs
1226
+ else:
1227
+ model_mean, _, model_log_variance = outputs
1228
+
1229
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1230
+ if noise_dropout > 0.0:
1231
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1232
+ # no noise when t == 0
1233
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1234
+
1235
+ if return_codebook_ids:
1236
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1237
+ if return_x0:
1238
+ return (
1239
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
1240
+ x0,
1241
+ )
1242
+ else:
1243
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1244
+
1245
+ @torch.no_grad()
1246
+ def progressive_denoising(
1247
+ self,
1248
+ cond,
1249
+ shape,
1250
+ verbose=True,
1251
+ callback=None,
1252
+ quantize_denoised=False,
1253
+ img_callback=None,
1254
+ mask=None,
1255
+ x0=None,
1256
+ temperature=1.0,
1257
+ noise_dropout=0.0,
1258
+ score_corrector=None,
1259
+ corrector_kwargs=None,
1260
+ batch_size=None,
1261
+ x_T=None,
1262
+ start_T=None,
1263
+ log_every_t=None,
1264
+ ):
1265
+ if not log_every_t:
1266
+ log_every_t = self.log_every_t
1267
+ timesteps = self.num_timesteps
1268
+ if batch_size is not None:
1269
+ b = batch_size if batch_size is not None else shape[0]
1270
+ shape = [batch_size] + list(shape)
1271
+ else:
1272
+ b = batch_size = shape[0]
1273
+ if x_T is None:
1274
+ img = torch.randn(shape, device=self.device)
1275
+ else:
1276
+ img = x_T
1277
+ intermediates = []
1278
+ if cond is not None:
1279
+ if isinstance(cond, dict):
1280
+ cond = {
1281
+ key: cond[key][:batch_size]
1282
+ if not isinstance(cond[key], list)
1283
+ else list(map(lambda x: x[:batch_size], cond[key]))
1284
+ for key in cond
1285
+ }
1286
+ else:
1287
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1288
+
1289
+ if start_T is not None:
1290
+ timesteps = min(timesteps, start_T)
1291
+ iterator = (
1292
+ tqdm(
1293
+ reversed(range(0, timesteps)),
1294
+ desc="Progressive Generation",
1295
+ total=timesteps,
1296
+ )
1297
+ if verbose
1298
+ else reversed(range(0, timesteps))
1299
+ )
1300
+ if type(temperature) == float:
1301
+ temperature = [temperature] * timesteps
1302
+
1303
+ for i in iterator:
1304
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1305
+ if self.shorten_cond_schedule:
1306
+ assert self.model.conditioning_key != "hybrid"
1307
+ tc = self.cond_ids[ts].to(cond.device)
1308
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1309
+
1310
+ img, x0_partial = self.p_sample(
1311
+ img,
1312
+ cond,
1313
+ ts,
1314
+ clip_denoised=self.clip_denoised,
1315
+ quantize_denoised=quantize_denoised,
1316
+ return_x0=True,
1317
+ temperature=temperature[i],
1318
+ noise_dropout=noise_dropout,
1319
+ score_corrector=score_corrector,
1320
+ corrector_kwargs=corrector_kwargs,
1321
+ )
1322
+ if mask is not None:
1323
+ assert x0 is not None
1324
+ img_orig = self.q_sample(x0, ts)
1325
+ img = img_orig * mask + (1.0 - mask) * img
1326
+
1327
+ if i % log_every_t == 0 or i == timesteps - 1:
1328
+ intermediates.append(x0_partial)
1329
+ if callback:
1330
+ callback(i)
1331
+ if img_callback:
1332
+ img_callback(img, i)
1333
+ return img, intermediates
1334
+
1335
+ @torch.no_grad()
1336
+ def p_sample_loop(
1337
+ self,
1338
+ cond,
1339
+ shape,
1340
+ return_intermediates=False,
1341
+ x_T=None,
1342
+ verbose=True,
1343
+ callback=None,
1344
+ timesteps=None,
1345
+ quantize_denoised=False,
1346
+ mask=None,
1347
+ x0=None,
1348
+ img_callback=None,
1349
+ start_T=None,
1350
+ log_every_t=None,
1351
+ ):
1352
+ if not log_every_t:
1353
+ log_every_t = self.log_every_t
1354
+ device = self.betas.device
1355
+ b = shape[0]
1356
+ if x_T is None:
1357
+ img = torch.randn(shape, device=device)
1358
+ else:
1359
+ img = x_T
1360
+
1361
+ intermediates = [img]
1362
+ if timesteps is None:
1363
+ timesteps = self.num_timesteps
1364
+
1365
+ if start_T is not None:
1366
+ timesteps = min(timesteps, start_T)
1367
+ iterator = (
1368
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
1369
+ if verbose
1370
+ else reversed(range(0, timesteps))
1371
+ )
1372
+
1373
+ if mask is not None:
1374
+ assert x0 is not None
1375
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1376
+
1377
+ for i in iterator:
1378
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1379
+ if self.shorten_cond_schedule:
1380
+ assert self.model.conditioning_key != "hybrid"
1381
+ tc = self.cond_ids[ts].to(cond.device)
1382
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1383
+
1384
+ img = self.p_sample(
1385
+ img,
1386
+ cond,
1387
+ ts,
1388
+ clip_denoised=self.clip_denoised,
1389
+ quantize_denoised=quantize_denoised,
1390
+ )
1391
+ if mask is not None:
1392
+ img_orig = self.q_sample(x0, ts)
1393
+ img = img_orig * mask + (1.0 - mask) * img
1394
+
1395
+ if i % log_every_t == 0 or i == timesteps - 1:
1396
+ intermediates.append(img)
1397
+ if callback:
1398
+ callback(i)
1399
+ if img_callback:
1400
+ img_callback(img, i)
1401
+
1402
+ if return_intermediates:
1403
+ return img, intermediates
1404
+ return img
1405
+
1406
+ @torch.no_grad()
1407
+ def sample(
1408
+ self,
1409
+ cond,
1410
+ batch_size=16,
1411
+ return_intermediates=False,
1412
+ x_T=None,
1413
+ verbose=True,
1414
+ timesteps=None,
1415
+ quantize_denoised=False,
1416
+ mask=None,
1417
+ x0=None,
1418
+ shape=None,
1419
+ **kwargs,
1420
+ ):
1421
+ if shape is None:
1422
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1423
+ if cond is not None:
1424
+ if isinstance(cond, dict):
1425
+ cond = {
1426
+ key: cond[key][:batch_size]
1427
+ if not isinstance(cond[key], list)
1428
+ else list(map(lambda x: x[:batch_size], cond[key]))
1429
+ for key in cond
1430
+ }
1431
+ else:
1432
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1433
+ return self.p_sample_loop(
1434
+ cond,
1435
+ shape,
1436
+ return_intermediates=return_intermediates,
1437
+ x_T=x_T,
1438
+ verbose=verbose,
1439
+ timesteps=timesteps,
1440
+ quantize_denoised=quantize_denoised,
1441
+ mask=mask,
1442
+ x0=x0,
1443
+ )
1444
+
1445
+ @torch.no_grad()
1446
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1447
+ if ddim:
1448
+ ddim_sampler = DDIMSampler(self)
1449
+ shape = (self.channels, self.image_size, self.image_size)
1450
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
1451
+
1452
+ else:
1453
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs)
1454
+
1455
+ return samples, intermediates
1456
+
1457
+
1458
+ @torch.no_grad()
1459
+ def get_images_and_latents(self, batch, **ddim_kwargs):
1460
+ """Returns input images, denoised images and latents for clustering"""
1461
+
1462
+ z, c, x, xrec, xc = self.get_input(
1463
+ batch,
1464
+ self.first_stage_key,
1465
+ return_first_stage_outputs=True,
1466
+ force_c_encode=True,
1467
+ return_original_cond=True,
1468
+ )
1469
+
1470
+ with self.ema_scope("Plotting"):
1471
+ samples_latent, _ = self.sample_log(cond=c, batch_size=x.shape[0], ddim=True, **ddim_kwargs)
1472
+
1473
+ convert_to_numpy = lambda x: x.detach().cpu().numpy()
1474
+
1475
+ x_samples = self.decode_first_stage(samples_latent)
1476
+ x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
1477
+ x_samples_ddim = (x_samples_ddim * 255).to(torch.uint8)
1478
+ x_samples_ddim = convert_to_numpy(x_samples_ddim)
1479
+
1480
+ input_arr = (127.5 * (x + 1)).detach().cpu().numpy().astype(np.uint8)
1481
+ samples_latent = convert_to_numpy(samples_latent)
1482
+
1483
+ return input_arr, x_samples_ddim, samples_latent
1484
+
1485
+ @torch.no_grad()
1486
+ def log_images(
1487
+ self,
1488
+ batch,
1489
+ N=8,
1490
+ n_row=4,
1491
+ sample=True,
1492
+ ddim_steps=200,
1493
+ ddim_eta=1.0,
1494
+ return_keys=None,
1495
+ quantize_denoised=True,
1496
+ inpaint=True,
1497
+ plot_denoise_rows=False,
1498
+ plot_progressive_rows=True,
1499
+ plot_diffusion_rows=True,
1500
+ **kwargs,
1501
+ ):
1502
+ use_ddim = ddim_steps is not None
1503
+
1504
+ log = dict()
1505
+ z, c, x, xrec, xc = self.get_input(
1506
+ batch,
1507
+ self.first_stage_key,
1508
+ return_first_stage_outputs=True,
1509
+ force_c_encode=True,
1510
+ return_original_cond=True,
1511
+ bs=N,
1512
+ )
1513
+ N = min(x.shape[0], N)
1514
+ n_row = min(x.shape[0], n_row)
1515
+ log["inputs"] = x
1516
+ log["reconstruction"] = xrec
1517
+ if self.model.conditioning_key is not None:
1518
+ if hasattr(self.cond_stage_model, "decode"):
1519
+ xc = self.cond_stage_model.decode(c)
1520
+ log["conditioning"] = xc
1521
+ elif self.cond_stage_key in ["caption"]:
1522
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1523
+ log["conditioning"] = xc
1524
+ elif self.cond_stage_key in ["class_label", "hybrid"]:
1525
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1526
+ log["conditioning"] = xc
1527
+ elif isimage(xc):
1528
+ log["conditioning"] = xc
1529
+ if ismap(xc):
1530
+ log["original_conditioning"] = self.to_rgb(xc)
1531
+
1532
+ if plot_diffusion_rows:
1533
+ # get diffusion row
1534
+ diffusion_row = list()
1535
+ z_start = z[:n_row]
1536
+ for t in range(self.num_timesteps):
1537
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1538
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
1539
+ t = t.to(self.device).long()
1540
+ noise = torch.randn_like(z_start)
1541
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1542
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1543
+
1544
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1545
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
1546
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
1547
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1548
+ log["diffusion_row"] = diffusion_grid
1549
+
1550
+ if sample:
1551
+ # get denoise row
1552
+ with self.ema_scope("Plotting"):
1553
+ samples, z_denoise_row = self.sample_log(
1554
+ cond=c,
1555
+ batch_size=N,
1556
+ ddim=use_ddim,
1557
+ ddim_steps=ddim_steps,
1558
+ eta=ddim_eta,
1559
+ )
1560
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1561
+ x_samples = self.decode_first_stage(samples)
1562
+ log["samples"] = x_samples
1563
+ if plot_denoise_rows:
1564
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1565
+ log["denoise_row"] = denoise_grid
1566
+
1567
+ if (
1568
+ quantize_denoised
1569
+ and not isinstance(self.first_stage_model, AutoencoderKL)
1570
+ and not isinstance(self.first_stage_model, IdentityFirstStage)
1571
+ ):
1572
+ # also display when quantizing x0 while sampling
1573
+ with self.ema_scope("Plotting Quantized Denoised"):
1574
+ samples, z_denoise_row = self.sample_log(
1575
+ cond=c,
1576
+ batch_size=N,
1577
+ ddim=use_ddim,
1578
+ ddim_steps=ddim_steps,
1579
+ eta=ddim_eta,
1580
+ quantize_denoised=True,
1581
+ )
1582
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1583
+ # quantize_denoised=True)
1584
+ x_samples = self.decode_first_stage(samples.to(self.device))
1585
+ log["samples_x0_quantized"] = x_samples
1586
+
1587
+ if inpaint:
1588
+ # make a simple center square
1589
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1590
+ mask = torch.ones(N, h, w).to(self.device)
1591
+ # zeros will be filled in
1592
+ mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
1593
+ mask = mask[:, None, ...]
1594
+ with self.ema_scope("Plotting Inpaint"):
1595
+ samples, _ = self.sample_log(
1596
+ cond=c,
1597
+ batch_size=N,
1598
+ ddim=use_ddim,
1599
+ eta=ddim_eta,
1600
+ ddim_steps=ddim_steps,
1601
+ x0=z[:N],
1602
+ mask=mask,
1603
+ )
1604
+ x_samples = self.decode_first_stage(samples.to(self.device))
1605
+ log["samples_inpainting"] = x_samples
1606
+ log["mask"] = mask
1607
+
1608
+ # outpaint
1609
+ with self.ema_scope("Plotting Outpaint"):
1610
+ samples, _ = self.sample_log(
1611
+ cond=c,
1612
+ batch_size=N,
1613
+ ddim=use_ddim,
1614
+ eta=ddim_eta,
1615
+ ddim_steps=ddim_steps,
1616
+ x0=z[:N],
1617
+ mask=mask,
1618
+ )
1619
+ x_samples = self.decode_first_stage(samples.to(self.device))
1620
+ log["samples_outpainting"] = x_samples
1621
+
1622
+ if plot_progressive_rows:
1623
+ with self.ema_scope("Plotting Progressives"):
1624
+ img, progressives = self.progressive_denoising(
1625
+ c,
1626
+ shape=(self.channels, self.image_size, self.image_size),
1627
+ batch_size=N,
1628
+ )
1629
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1630
+ log["progressive_row"] = prog_row
1631
+
1632
+ if return_keys:
1633
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1634
+ return log
1635
+ else:
1636
+ return {key: log[key] for key in return_keys}
1637
+ return log
1638
+
1639
+ def configure_optimizers(self):
1640
+ lr = self.learning_rate
1641
+ params = list(self.model.parameters())
1642
+ if self.cond_stage_trainable:
1643
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1644
+ params = params + list(self.cond_stage_model.parameters())
1645
+ if self.learn_logvar:
1646
+ print("Diffusion model optimizing logvar")
1647
+ params.append(self.logvar)
1648
+ opt = torch.optim.AdamW(params, lr=lr)
1649
+ # opt = bnb.optim.AdamW8bit(params, lr=lr)
1650
+ if self.use_scheduler:
1651
+ assert "target" in self.scheduler_config
1652
+ scheduler = instantiate_from_config(self.scheduler_config)
1653
+
1654
+ print("Setting up LambdaLR scheduler...")
1655
+ scheduler = [
1656
+ {
1657
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
1658
+ "interval": "step",
1659
+ "frequency": 1,
1660
+ }
1661
+ ]
1662
+ return [opt], scheduler
1663
+ return opt
1664
+
1665
+ @torch.no_grad()
1666
+ def to_rgb(self, x):
1667
+ x = x.float()
1668
+ if not hasattr(self, "colorize"):
1669
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1670
+ x = nn.functional.conv2d(x, weight=self.colorize)
1671
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
1672
+ return x
1673
+
1674
+
1675
+ class DiffusionWrapper(pl.LightningModule):
1676
+ def __init__(self, diff_model_config, conditioning_key):
1677
+ super().__init__()
1678
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1679
+ self.conditioning_key = conditioning_key
1680
+ assert self.conditioning_key in [
1681
+ None,
1682
+ "concat",
1683
+ "crossattn",
1684
+ "hybrid",
1685
+ "adm",
1686
+ ]
1687
+
1688
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1689
+ if self.conditioning_key is None:
1690
+ out = self.diffusion_model(x, t)
1691
+ elif self.conditioning_key == "concat":
1692
+ xc = torch.cat([x] + c_concat, dim=1)
1693
+ out = self.diffusion_model(xc, t)
1694
+ elif self.conditioning_key == "crossattn":
1695
+ cc = torch.cat(c_crossattn, 1)
1696
+ out = self.diffusion_model(x, t, context=cc)
1697
+ elif self.conditioning_key == "hybrid":
1698
+ xc = torch.cat([x] + c_concat, dim=1)
1699
+ cc = torch.cat(c_crossattn, 1)
1700
+ out = self.diffusion_model(xc, t, context=cc)
1701
+ elif self.conditioning_key == "adm":
1702
+ cc = c_crossattn[0]
1703
+ out = self.diffusion_model(x, t, y=cc)
1704
+ else:
1705
+ raise NotImplementedError()
1706
+
1707
+ return out
1708
+
ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError(
75
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
+ schedule))
77
+
78
+ self.schedule = schedule
79
+ if schedule == 'discrete':
80
+ if betas is not None:
81
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
+ else:
83
+ assert alphas_cumprod is not None
84
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
85
+ self.total_N = len(log_alphas)
86
+ self.T = 1.
87
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
89
+ else:
90
+ self.total_N = 1000
91
+ self.beta_0 = continuous_beta_0
92
+ self.beta_1 = continuous_beta_1
93
+ self.cosine_s = 0.008
94
+ self.cosine_beta_max = 999.
95
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
+ 1. + self.cosine_s) / math.pi - self.cosine_s
97
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
+ self.schedule = schedule
99
+ if schedule == 'cosine':
100
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
+ self.T = 0.9946
103
+ else:
104
+ self.T = 1.
105
+
106
+ def marginal_log_mean_coeff(self, t):
107
+ """
108
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
+ """
110
+ if self.schedule == 'discrete':
111
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
+ self.log_alpha_array.to(t.device)).reshape((-1))
113
+ elif self.schedule == 'linear':
114
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
+ elif self.schedule == 'cosine':
116
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
+ return log_alpha_t
119
+
120
+ def marginal_alpha(self, t):
121
+ """
122
+ Compute alpha_t of a given continuous-time label t in [0, T].
123
+ """
124
+ return torch.exp(self.marginal_log_mean_coeff(t))
125
+
126
+ def marginal_std(self, t):
127
+ """
128
+ Compute sigma_t of a given continuous-time label t in [0, T].
129
+ """
130
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
+
132
+ def marginal_lambda(self, t):
133
+ """
134
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
+ """
136
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
137
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
+ return log_mean_coeff - log_std
139
+
140
+ def inverse_lambda(self, lamb):
141
+ """
142
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
+ """
144
+ if self.schedule == 'linear':
145
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
+ Delta = self.beta_0 ** 2 + tmp
147
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
+ elif self.schedule == 'discrete':
149
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
+ torch.flip(self.t_array.to(lamb.device), [1]))
152
+ return t.reshape((-1,))
153
+ else:
154
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
+ 1. + self.cosine_s) / math.pi - self.cosine_s
157
+ t = t_fn(log_alpha)
158
+ return t
159
+
160
+
161
+ def model_wrapper(
162
+ model,
163
+ noise_schedule,
164
+ model_type="noise",
165
+ model_kwargs={},
166
+ guidance_type="uncond",
167
+ condition=None,
168
+ unconditional_condition=None,
169
+ guidance_scale=1.,
170
+ classifier_fn=None,
171
+ classifier_kwargs={},
172
+ ):
173
+ """Create a wrapper function for the noise prediction model.
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
+ We support four types of the diffusion model by setting `model_type`:
177
+ 1. "noise": noise prediction model. (Trained by predicting noise).
178
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
+ arXiv preprint arXiv:2202.00512 (2022).
183
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
+ arXiv preprint arXiv:2210.02303 (2022).
185
+
186
+ 4. "score": marginal score function. (Trained by denoising score matching).
187
+ Note that the score function and the noise prediction model follows a simple relationship:
188
+ ```
189
+ noise(x_t, t) = -sigma_t * score(x_t, t)
190
+ ```
191
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
192
+ 1. "uncond": unconditional sampling by DPMs.
193
+ The input `model` has the following format:
194
+ ``
195
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
+ ``
197
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
+ The input `model` has the following format:
199
+ ``
200
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
+ ``
202
+ The input `classifier_fn` has the following format:
203
+ ``
204
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
+ ``
206
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
+ arXiv preprint arXiv:2207.12598 (2022).
216
+
217
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
+ or continuous-time labels (i.e. epsilon to T).
219
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
+ ``
221
+ def model_fn(x, t_continuous) -> noise:
222
+ t_input = get_model_input_time(t_continuous)
223
+ return noise_pred(model, x, t_input, **model_kwargs)
224
+ ``
225
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
+ ===============================================================
227
+ Args:
228
+ model: A diffusion model with the corresponding format described above.
229
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
+ model_type: A `str`. The parameterization type of the diffusion model.
231
+ "noise" or "x_start" or "v" or "score".
232
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
+ guidance_type: A `str`. The type of the guidance for sampling.
234
+ "uncond" or "classifier" or "classifier-free".
235
+ condition: A pytorch tensor. The condition for the guided sampling.
236
+ Only used for "classifier" or "classifier-free" guidance type.
237
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
+ Only used for "classifier-free" guidance type.
239
+ guidance_scale: A `float`. The scale for the guided sampling.
240
+ classifier_fn: A classifier function. Only used for the classifier guidance.
241
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
+ Returns:
243
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
+ """
245
+
246
+ def get_model_input_time(t_continuous):
247
+ """
248
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
+ For continuous-time DPMs, we just use `t_continuous`.
251
+ """
252
+ if noise_schedule.schedule == 'discrete':
253
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
+ else:
255
+ return t_continuous
256
+
257
+ def noise_pred_fn(x, t_continuous, cond=None):
258
+ if t_continuous.reshape((-1,)).shape[0] == 1:
259
+ t_continuous = t_continuous.expand((x.shape[0]))
260
+ t_input = get_model_input_time(t_continuous)
261
+ if cond is None:
262
+ output = model(x, t_input, **model_kwargs)
263
+ else:
264
+ output = model(x, t_input, cond, **model_kwargs)
265
+ if model_type == "noise":
266
+ return output
267
+ elif model_type == "x_start":
268
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
+ dims = x.dim()
270
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
+ elif model_type == "v":
272
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
+ dims = x.dim()
274
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
+ elif model_type == "score":
276
+ sigma_t = noise_schedule.marginal_std(t_continuous)
277
+ dims = x.dim()
278
+ return -expand_dims(sigma_t, dims) * output
279
+
280
+ def cond_grad_fn(x, t_input):
281
+ """
282
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
+ """
284
+ with torch.enable_grad():
285
+ x_in = x.detach().requires_grad_(True)
286
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
+
289
+ def model_fn(x, t_continuous):
290
+ """
291
+ The noise predicition model function that is used for DPM-Solver.
292
+ """
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ if guidance_type == "uncond":
296
+ return noise_pred_fn(x, t_continuous)
297
+ elif guidance_type == "classifier":
298
+ assert classifier_fn is not None
299
+ t_input = get_model_input_time(t_continuous)
300
+ cond_grad = cond_grad_fn(x, t_input)
301
+ sigma_t = noise_schedule.marginal_std(t_continuous)
302
+ noise = noise_pred_fn(x, t_continuous)
303
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
+ elif guidance_type == "classifier-free":
305
+ if guidance_scale == 1. or unconditional_condition is None:
306
+ return noise_pred_fn(x, t_continuous, cond=condition)
307
+ else:
308
+ x_in = torch.cat([x] * 2)
309
+ t_in = torch.cat([t_continuous] * 2)
310
+ if isinstance(condition, dict):
311
+ assert isinstance(unconditional_condition, dict)
312
+ c_in = dict()
313
+ for k in condition:
314
+ if isinstance(condition[k], list):
315
+ c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
316
+ else:
317
+ c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
318
+ else:
319
+ c_in = torch.cat([unconditional_condition, condition])
320
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
321
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
322
+
323
+ assert model_type in ["noise", "x_start", "v"]
324
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
325
+ return model_fn
326
+
327
+
328
+ class DPM_Solver:
329
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
330
+ """Construct a DPM-Solver.
331
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
332
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
333
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
334
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
335
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
336
+ Args:
337
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
338
+ ``
339
+ def model_fn(x, t_continuous):
340
+ return noise
341
+ ``
342
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
343
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
344
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
345
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
346
+
347
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
348
+ """
349
+ self.model = model_fn
350
+ self.noise_schedule = noise_schedule
351
+ self.predict_x0 = predict_x0
352
+ self.thresholding = thresholding
353
+ self.max_val = max_val
354
+
355
+ def noise_prediction_fn(self, x, t):
356
+ """
357
+ Return the noise prediction model.
358
+ """
359
+ return self.model(x, t)
360
+
361
+ def data_prediction_fn(self, x, t):
362
+ """
363
+ Return the data prediction model (with thresholding).
364
+ """
365
+ noise = self.noise_prediction_fn(x, t)
366
+ dims = x.dim()
367
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
368
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
369
+ if self.thresholding:
370
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
371
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
372
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
373
+ x0 = torch.clamp(x0, -s, s) / s
374
+ return x0
375
+
376
+ def model_fn(self, x, t):
377
+ """
378
+ Convert the model to the noise prediction model or the data prediction model.
379
+ """
380
+ if self.predict_x0:
381
+ return self.data_prediction_fn(x, t)
382
+ else:
383
+ return self.noise_prediction_fn(x, t)
384
+
385
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
386
+ """Compute the intermediate time steps for sampling.
387
+ Args:
388
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
389
+ - 'logSNR': uniform logSNR for the time steps.
390
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
391
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
392
+ t_T: A `float`. The starting time of the sampling (default is T).
393
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
394
+ N: A `int`. The total number of the spacing of the time steps.
395
+ device: A torch device.
396
+ Returns:
397
+ A pytorch tensor of the time steps, with the shape (N + 1,).
398
+ """
399
+ if skip_type == 'logSNR':
400
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
401
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
402
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
403
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
404
+ elif skip_type == 'time_uniform':
405
+ return torch.linspace(t_T, t_0, N + 1).to(device)
406
+ elif skip_type == 'time_quadratic':
407
+ t_order = 2
408
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
409
+ return t
410
+ else:
411
+ raise ValueError(
412
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
413
+
414
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
415
+ """
416
+ Get the order of each step for sampling by the singlestep DPM-Solver.
417
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
418
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
419
+ - If order == 1:
420
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
421
+ - If order == 2:
422
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
423
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
424
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
425
+ - If order == 3:
426
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
427
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
428
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
429
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
430
+ ============================================
431
+ Args:
432
+ order: A `int`. The max order for the solver (2 or 3).
433
+ steps: A `int`. The total number of function evaluations (NFE).
434
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
435
+ - 'logSNR': uniform logSNR for the time steps.
436
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
437
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
438
+ t_T: A `float`. The starting time of the sampling (default is T).
439
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
440
+ device: A torch device.
441
+ Returns:
442
+ orders: A list of the solver order of each step.
443
+ """
444
+ if order == 3:
445
+ K = steps // 3 + 1
446
+ if steps % 3 == 0:
447
+ orders = [3, ] * (K - 2) + [2, 1]
448
+ elif steps % 3 == 1:
449
+ orders = [3, ] * (K - 1) + [1]
450
+ else:
451
+ orders = [3, ] * (K - 1) + [2]
452
+ elif order == 2:
453
+ if steps % 2 == 0:
454
+ K = steps // 2
455
+ orders = [2, ] * K
456
+ else:
457
+ K = steps // 2 + 1
458
+ orders = [2, ] * (K - 1) + [1]
459
+ elif order == 1:
460
+ K = 1
461
+ orders = [1, ] * steps
462
+ else:
463
+ raise ValueError("'order' must be '1' or '2' or '3'.")
464
+ if skip_type == 'logSNR':
465
+ # To reproduce the results in DPM-Solver paper
466
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
467
+ else:
468
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
469
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
470
+ return timesteps_outer, orders
471
+
472
+ def denoise_to_zero_fn(self, x, s):
473
+ """
474
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
475
+ """
476
+ return self.data_prediction_fn(x, s)
477
+
478
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
479
+ """
480
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
481
+ Args:
482
+ x: A pytorch tensor. The initial value at time `s`.
483
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
484
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
485
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
486
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
487
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
488
+ Returns:
489
+ x_t: A pytorch tensor. The approximated solution at time `t`.
490
+ """
491
+ ns = self.noise_schedule
492
+ dims = x.dim()
493
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
494
+ h = lambda_t - lambda_s
495
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
496
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
497
+ alpha_t = torch.exp(log_alpha_t)
498
+
499
+ if self.predict_x0:
500
+ phi_1 = torch.expm1(-h)
501
+ if model_s is None:
502
+ model_s = self.model_fn(x, s)
503
+ x_t = (
504
+ expand_dims(sigma_t / sigma_s, dims) * x
505
+ - expand_dims(alpha_t * phi_1, dims) * model_s
506
+ )
507
+ if return_intermediate:
508
+ return x_t, {'model_s': model_s}
509
+ else:
510
+ return x_t
511
+ else:
512
+ phi_1 = torch.expm1(h)
513
+ if model_s is None:
514
+ model_s = self.model_fn(x, s)
515
+ x_t = (
516
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
517
+ - expand_dims(sigma_t * phi_1, dims) * model_s
518
+ )
519
+ if return_intermediate:
520
+ return x_t, {'model_s': model_s}
521
+ else:
522
+ return x_t
523
+
524
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
525
+ solver_type='dpm_solver'):
526
+ """
527
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
528
+ Args:
529
+ x: A pytorch tensor. The initial value at time `s`.
530
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
531
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
532
+ r1: A `float`. The hyperparameter of the second-order solver.
533
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
534
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
535
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
536
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
537
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
538
+ Returns:
539
+ x_t: A pytorch tensor. The approximated solution at time `t`.
540
+ """
541
+ if solver_type not in ['dpm_solver', 'taylor']:
542
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
543
+ if r1 is None:
544
+ r1 = 0.5
545
+ ns = self.noise_schedule
546
+ dims = x.dim()
547
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
548
+ h = lambda_t - lambda_s
549
+ lambda_s1 = lambda_s + r1 * h
550
+ s1 = ns.inverse_lambda(lambda_s1)
551
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
552
+ s1), ns.marginal_log_mean_coeff(t)
553
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
554
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
555
+
556
+ if self.predict_x0:
557
+ phi_11 = torch.expm1(-r1 * h)
558
+ phi_1 = torch.expm1(-h)
559
+
560
+ if model_s is None:
561
+ model_s = self.model_fn(x, s)
562
+ x_s1 = (
563
+ expand_dims(sigma_s1 / sigma_s, dims) * x
564
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
565
+ )
566
+ model_s1 = self.model_fn(x_s1, s1)
567
+ if solver_type == 'dpm_solver':
568
+ x_t = (
569
+ expand_dims(sigma_t / sigma_s, dims) * x
570
+ - expand_dims(alpha_t * phi_1, dims) * model_s
571
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
572
+ )
573
+ elif solver_type == 'taylor':
574
+ x_t = (
575
+ expand_dims(sigma_t / sigma_s, dims) * x
576
+ - expand_dims(alpha_t * phi_1, dims) * model_s
577
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
578
+ model_s1 - model_s)
579
+ )
580
+ else:
581
+ phi_11 = torch.expm1(r1 * h)
582
+ phi_1 = torch.expm1(h)
583
+
584
+ if model_s is None:
585
+ model_s = self.model_fn(x, s)
586
+ x_s1 = (
587
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
588
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
589
+ )
590
+ model_s1 = self.model_fn(x_s1, s1)
591
+ if solver_type == 'dpm_solver':
592
+ x_t = (
593
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
594
+ - expand_dims(sigma_t * phi_1, dims) * model_s
595
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
596
+ )
597
+ elif solver_type == 'taylor':
598
+ x_t = (
599
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
600
+ - expand_dims(sigma_t * phi_1, dims) * model_s
601
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
602
+ )
603
+ if return_intermediate:
604
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
605
+ else:
606
+ return x_t
607
+
608
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
609
+ return_intermediate=False, solver_type='dpm_solver'):
610
+ """
611
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
612
+ Args:
613
+ x: A pytorch tensor. The initial value at time `s`.
614
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
615
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
616
+ r1: A `float`. The hyperparameter of the third-order solver.
617
+ r2: A `float`. The hyperparameter of the third-order solver.
618
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
619
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
620
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
621
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
622
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
623
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
624
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
625
+ Returns:
626
+ x_t: A pytorch tensor. The approximated solution at time `t`.
627
+ """
628
+ if solver_type not in ['dpm_solver', 'taylor']:
629
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
630
+ if r1 is None:
631
+ r1 = 1. / 3.
632
+ if r2 is None:
633
+ r2 = 2. / 3.
634
+ ns = self.noise_schedule
635
+ dims = x.dim()
636
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
637
+ h = lambda_t - lambda_s
638
+ lambda_s1 = lambda_s + r1 * h
639
+ lambda_s2 = lambda_s + r2 * h
640
+ s1 = ns.inverse_lambda(lambda_s1)
641
+ s2 = ns.inverse_lambda(lambda_s2)
642
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
643
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
644
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
645
+ s2), ns.marginal_std(t)
646
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
647
+
648
+ if self.predict_x0:
649
+ phi_11 = torch.expm1(-r1 * h)
650
+ phi_12 = torch.expm1(-r2 * h)
651
+ phi_1 = torch.expm1(-h)
652
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
653
+ phi_2 = phi_1 / h + 1.
654
+ phi_3 = phi_2 / h - 0.5
655
+
656
+ if model_s is None:
657
+ model_s = self.model_fn(x, s)
658
+ if model_s1 is None:
659
+ x_s1 = (
660
+ expand_dims(sigma_s1 / sigma_s, dims) * x
661
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
662
+ )
663
+ model_s1 = self.model_fn(x_s1, s1)
664
+ x_s2 = (
665
+ expand_dims(sigma_s2 / sigma_s, dims) * x
666
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
667
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
668
+ )
669
+ model_s2 = self.model_fn(x_s2, s2)
670
+ if solver_type == 'dpm_solver':
671
+ x_t = (
672
+ expand_dims(sigma_t / sigma_s, dims) * x
673
+ - expand_dims(alpha_t * phi_1, dims) * model_s
674
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
675
+ )
676
+ elif solver_type == 'taylor':
677
+ D1_0 = (1. / r1) * (model_s1 - model_s)
678
+ D1_1 = (1. / r2) * (model_s2 - model_s)
679
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
680
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
681
+ x_t = (
682
+ expand_dims(sigma_t / sigma_s, dims) * x
683
+ - expand_dims(alpha_t * phi_1, dims) * model_s
684
+ + expand_dims(alpha_t * phi_2, dims) * D1
685
+ - expand_dims(alpha_t * phi_3, dims) * D2
686
+ )
687
+ else:
688
+ phi_11 = torch.expm1(r1 * h)
689
+ phi_12 = torch.expm1(r2 * h)
690
+ phi_1 = torch.expm1(h)
691
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
692
+ phi_2 = phi_1 / h - 1.
693
+ phi_3 = phi_2 / h - 0.5
694
+
695
+ if model_s is None:
696
+ model_s = self.model_fn(x, s)
697
+ if model_s1 is None:
698
+ x_s1 = (
699
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
700
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
701
+ )
702
+ model_s1 = self.model_fn(x_s1, s1)
703
+ x_s2 = (
704
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
705
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
706
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
707
+ )
708
+ model_s2 = self.model_fn(x_s2, s2)
709
+ if solver_type == 'dpm_solver':
710
+ x_t = (
711
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
712
+ - expand_dims(sigma_t * phi_1, dims) * model_s
713
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
714
+ )
715
+ elif solver_type == 'taylor':
716
+ D1_0 = (1. / r1) * (model_s1 - model_s)
717
+ D1_1 = (1. / r2) * (model_s2 - model_s)
718
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
719
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
720
+ x_t = (
721
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
722
+ - expand_dims(sigma_t * phi_1, dims) * model_s
723
+ - expand_dims(sigma_t * phi_2, dims) * D1
724
+ - expand_dims(sigma_t * phi_3, dims) * D2
725
+ )
726
+
727
+ if return_intermediate:
728
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
729
+ else:
730
+ return x_t
731
+
732
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
733
+ """
734
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
735
+ Args:
736
+ x: A pytorch tensor. The initial value at time `s`.
737
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
738
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
739
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
740
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
741
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
742
+ Returns:
743
+ x_t: A pytorch tensor. The approximated solution at time `t`.
744
+ """
745
+ if solver_type not in ['dpm_solver', 'taylor']:
746
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
747
+ ns = self.noise_schedule
748
+ dims = x.dim()
749
+ model_prev_1, model_prev_0 = model_prev_list
750
+ t_prev_1, t_prev_0 = t_prev_list
751
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
752
+ t_prev_0), ns.marginal_lambda(t)
753
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
754
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
755
+ alpha_t = torch.exp(log_alpha_t)
756
+
757
+ h_0 = lambda_prev_0 - lambda_prev_1
758
+ h = lambda_t - lambda_prev_0
759
+ r0 = h_0 / h
760
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
761
+ if self.predict_x0:
762
+ if solver_type == 'dpm_solver':
763
+ x_t = (
764
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
765
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
766
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
767
+ )
768
+ elif solver_type == 'taylor':
769
+ x_t = (
770
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
771
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
772
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
773
+ )
774
+ else:
775
+ if solver_type == 'dpm_solver':
776
+ x_t = (
777
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
778
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
779
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
780
+ )
781
+ elif solver_type == 'taylor':
782
+ x_t = (
783
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
784
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
785
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
786
+ )
787
+ return x_t
788
+
789
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
790
+ """
791
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
792
+ Args:
793
+ x: A pytorch tensor. The initial value at time `s`.
794
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
795
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
796
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
797
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
798
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
799
+ Returns:
800
+ x_t: A pytorch tensor. The approximated solution at time `t`.
801
+ """
802
+ ns = self.noise_schedule
803
+ dims = x.dim()
804
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
805
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
806
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
807
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
808
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
809
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
810
+ alpha_t = torch.exp(log_alpha_t)
811
+
812
+ h_1 = lambda_prev_1 - lambda_prev_2
813
+ h_0 = lambda_prev_0 - lambda_prev_1
814
+ h = lambda_t - lambda_prev_0
815
+ r0, r1 = h_0 / h, h_1 / h
816
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
817
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
818
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
819
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
820
+ if self.predict_x0:
821
+ x_t = (
822
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
823
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
824
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
825
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
826
+ )
827
+ else:
828
+ x_t = (
829
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
830
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
831
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
832
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
833
+ )
834
+ return x_t
835
+
836
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
837
+ r2=None):
838
+ """
839
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
840
+ Args:
841
+ x: A pytorch tensor. The initial value at time `s`.
842
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
843
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
844
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
845
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
846
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
847
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
848
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
849
+ r2: A `float`. The hyperparameter of the third-order solver.
850
+ Returns:
851
+ x_t: A pytorch tensor. The approximated solution at time `t`.
852
+ """
853
+ if order == 1:
854
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
855
+ elif order == 2:
856
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
857
+ solver_type=solver_type, r1=r1)
858
+ elif order == 3:
859
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
860
+ solver_type=solver_type, r1=r1, r2=r2)
861
+ else:
862
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
863
+
864
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
865
+ """
866
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
867
+ Args:
868
+ x: A pytorch tensor. The initial value at time `s`.
869
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
870
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
871
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
872
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
873
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
874
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
875
+ Returns:
876
+ x_t: A pytorch tensor. The approximated solution at time `t`.
877
+ """
878
+ if order == 1:
879
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
880
+ elif order == 2:
881
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
882
+ elif order == 3:
883
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
884
+ else:
885
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
886
+
887
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
888
+ solver_type='dpm_solver'):
889
+ """
890
+ The adaptive step size solver based on singlestep DPM-Solver.
891
+ Args:
892
+ x: A pytorch tensor. The initial value at time `t_T`.
893
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
894
+ t_T: A `float`. The starting time of the sampling (default is T).
895
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
896
+ h_init: A `float`. The initial step size (for logSNR).
897
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
898
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
899
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
900
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
901
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
902
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
903
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
904
+ Returns:
905
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
906
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
907
+ """
908
+ ns = self.noise_schedule
909
+ s = t_T * torch.ones((x.shape[0],)).to(x)
910
+ lambda_s = ns.marginal_lambda(s)
911
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
912
+ h = h_init * torch.ones_like(s).to(x)
913
+ x_prev = x
914
+ nfe = 0
915
+ if order == 2:
916
+ r1 = 0.5
917
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
918
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
919
+ solver_type=solver_type,
920
+ **kwargs)
921
+ elif order == 3:
922
+ r1, r2 = 1. / 3., 2. / 3.
923
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
924
+ return_intermediate=True,
925
+ solver_type=solver_type)
926
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
927
+ solver_type=solver_type,
928
+ **kwargs)
929
+ else:
930
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
931
+ while torch.abs((s - t_0)).mean() > t_err:
932
+ t = ns.inverse_lambda(lambda_s + h)
933
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
934
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
935
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
936
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
937
+ E = norm_fn((x_higher - x_lower) / delta).max()
938
+ if torch.all(E <= 1.):
939
+ x = x_higher
940
+ s = t
941
+ x_prev = x_lower
942
+ lambda_s = ns.marginal_lambda(s)
943
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
944
+ nfe += order
945
+ print('adaptive solver nfe', nfe)
946
+ return x
947
+
948
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
949
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
950
+ atol=0.0078, rtol=0.05,
951
+ ):
952
+ """
953
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
954
+ =====================================================
955
+ We support the following algorithms for both noise prediction model and data prediction model:
956
+ - 'singlestep':
957
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
958
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
959
+ The total number of function evaluations (NFE) == `steps`.
960
+ Given a fixed NFE == `steps`, the sampling procedure is:
961
+ - If `order` == 1:
962
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
963
+ - If `order` == 2:
964
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
965
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
966
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
967
+ - If `order` == 3:
968
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
969
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
970
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
971
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
972
+ - 'multistep':
973
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
974
+ We initialize the first `order` values by lower order multistep solvers.
975
+ Given a fixed NFE == `steps`, the sampling procedure is:
976
+ Denote K = steps.
977
+ - If `order` == 1:
978
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
979
+ - If `order` == 2:
980
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
981
+ - If `order` == 3:
982
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
983
+ - 'singlestep_fixed':
984
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
985
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
986
+ - 'adaptive':
987
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
988
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
989
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
990
+ (NFE) and the sample quality.
991
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
992
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
993
+ =====================================================
994
+ Some advices for choosing the algorithm:
995
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
996
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
997
+ e.g.
998
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
999
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1000
+ skip_type='time_uniform', method='singlestep')
1001
+ - For **guided sampling with large guidance scale** by DPMs:
1002
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1003
+ e.g.
1004
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1005
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1006
+ skip_type='time_uniform', method='multistep')
1007
+ We support three types of `skip_type`:
1008
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1009
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1010
+ - 'time_quadratic': quadratic time for the time steps.
1011
+ =====================================================
1012
+ Args:
1013
+ x: A pytorch tensor. The initial value at time `t_start`
1014
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1015
+ steps: A `int`. The total number of function evaluations (NFE).
1016
+ t_start: A `float`. The starting time of the sampling.
1017
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1018
+ t_end: A `float`. The ending time of the sampling.
1019
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1020
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1021
+ For discrete-time DPMs:
1022
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1023
+ For continuous-time DPMs:
1024
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1025
+ order: A `int`. The order of DPM-Solver.
1026
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1027
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1028
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1029
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1030
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1031
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1032
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1033
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1034
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1035
+ it for high-resolutional images.
1036
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1037
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1038
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1039
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1040
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1041
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1042
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1043
+ Returns:
1044
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1045
+ """
1046
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1047
+ t_T = self.noise_schedule.T if t_start is None else t_start
1048
+ device = x.device
1049
+ if method == 'adaptive':
1050
+ with torch.no_grad():
1051
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1052
+ solver_type=solver_type)
1053
+ elif method == 'multistep':
1054
+ assert steps >= order
1055
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1056
+ assert timesteps.shape[0] - 1 == steps
1057
+ with torch.no_grad():
1058
+ vec_t = timesteps[0].expand((x.shape[0]))
1059
+ model_prev_list = [self.model_fn(x, vec_t)]
1060
+ t_prev_list = [vec_t]
1061
+ # Init the first `order` values by lower order multistep DPM-Solver.
1062
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1063
+ vec_t = timesteps[init_order].expand(x.shape[0])
1064
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1065
+ solver_type=solver_type)
1066
+ model_prev_list.append(self.model_fn(x, vec_t))
1067
+ t_prev_list.append(vec_t)
1068
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1069
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1070
+ vec_t = timesteps[step].expand(x.shape[0])
1071
+ if lower_order_final and steps < 15:
1072
+ step_order = min(order, steps + 1 - step)
1073
+ else:
1074
+ step_order = order
1075
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1076
+ solver_type=solver_type)
1077
+ for i in range(order - 1):
1078
+ t_prev_list[i] = t_prev_list[i + 1]
1079
+ model_prev_list[i] = model_prev_list[i + 1]
1080
+ t_prev_list[-1] = vec_t
1081
+ # We do not need to evaluate the final model value.
1082
+ if step < steps:
1083
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1084
+ elif method in ['singlestep', 'singlestep_fixed']:
1085
+ if method == 'singlestep':
1086
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1087
+ skip_type=skip_type,
1088
+ t_T=t_T, t_0=t_0,
1089
+ device=device)
1090
+ elif method == 'singlestep_fixed':
1091
+ K = steps // order
1092
+ orders = [order, ] * K
1093
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1094
+ for i, order in enumerate(orders):
1095
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1096
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1097
+ N=order, device=device)
1098
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1099
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1100
+ h = lambda_inner[-1] - lambda_inner[0]
1101
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1102
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1103
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1104
+ if denoise_to_zero:
1105
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1106
+ return x
1107
+
1108
+
1109
+ #############################################################
1110
+ # other utility functions
1111
+ #############################################################
1112
+
1113
+ def interpolate_fn(x, xp, yp):
1114
+ """
1115
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1116
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1117
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1118
+ Args:
1119
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1120
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1121
+ yp: PyTorch tensor with shape [C, K].
1122
+ Returns:
1123
+ The function values f(x), with shape [N, C].
1124
+ """
1125
+ N, K = x.shape[0], xp.shape[1]
1126
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1127
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1128
+ x_idx = torch.argmin(x_indices, dim=2)
1129
+ cand_start_idx = x_idx - 1
1130
+ start_idx = torch.where(
1131
+ torch.eq(x_idx, 0),
1132
+ torch.tensor(1, device=x.device),
1133
+ torch.where(
1134
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1135
+ ),
1136
+ )
1137
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1138
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1139
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1140
+ start_idx2 = torch.where(
1141
+ torch.eq(x_idx, 0),
1142
+ torch.tensor(0, device=x.device),
1143
+ torch.where(
1144
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1145
+ ),
1146
+ )
1147
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1148
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1149
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1150
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1151
+ return cand
1152
+
1153
+
1154
+ def expand_dims(v, dims):
1155
+ """
1156
+ Expand the tensor `v` to the dim `dims`.
1157
+ Args:
1158
+ `v`: a PyTorch tensor with shape [N].
1159
+ `dim`: a `int`.
1160
+ Returns:
1161
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1162
+ """
1163
+ return v[(...,) + (None,) * (dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+ import torch
3
+
4
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5
+
6
+ MODEL_TYPES = {
7
+ "eps": "noise",
8
+ "v": "v"
9
+ }
10
+
11
+
12
+ class DPMSolverSampler(object):
13
+ def __init__(self, model, device=torch.device("cuda"), **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.device = device
17
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ if attr.device != self.device:
23
+ attr = attr.to(self.device)
24
+ setattr(self, name, attr)
25
+
26
+ @torch.no_grad()
27
+ def sample(self,
28
+ S,
29
+ batch_size,
30
+ shape,
31
+ conditioning=None,
32
+ callback=None,
33
+ normals_sequence=None,
34
+ img_callback=None,
35
+ quantize_x0=False,
36
+ eta=0.,
37
+ mask=None,
38
+ x0=None,
39
+ temperature=1.,
40
+ noise_dropout=0.,
41
+ score_corrector=None,
42
+ corrector_kwargs=None,
43
+ verbose=True,
44
+ x_T=None,
45
+ log_every_t=100,
46
+ unconditional_guidance_scale=1.,
47
+ unconditional_conditioning=None,
48
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49
+ **kwargs
50
+ ):
51
+ if conditioning is not None:
52
+ if isinstance(conditioning, dict):
53
+ ctmp = conditioning[list(conditioning.keys())[0]]
54
+ while isinstance(ctmp, list): ctmp = ctmp[0]
55
+ if isinstance(ctmp, torch.Tensor):
56
+ cbs = ctmp.shape[0]
57
+ if cbs != batch_size:
58
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
59
+ elif isinstance(conditioning, list):
60
+ for ctmp in conditioning:
61
+ if ctmp.shape[0] != batch_size:
62
+ print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
63
+ else:
64
+ if isinstance(conditioning, torch.Tensor):
65
+ if conditioning.shape[0] != batch_size:
66
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
67
+
68
+ # sampling
69
+ C, H, W = shape
70
+ size = (batch_size, C, H, W)
71
+
72
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
73
+
74
+ device = self.model.betas.device
75
+ if x_T is None:
76
+ img = torch.randn(size, device=device)
77
+ else:
78
+ img = x_T
79
+
80
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
81
+
82
+ model_fn = model_wrapper(
83
+ lambda x, t, c: self.model.apply_model(x, t, c),
84
+ ns,
85
+ model_type=MODEL_TYPES[self.model.parameterization],
86
+ guidance_type="classifier-free",
87
+ condition=conditioning,
88
+ unconditional_condition=unconditional_conditioning,
89
+ guidance_scale=unconditional_guidance_scale,
90
+ )
91
+
92
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
93
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
94
+ lower_order_final=True)
95
+
96
+ return x.to(device), None
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+ from ldm.models.diffusion.sampling_util import norm_thresholding
10
+
11
+
12
+ class PLMSSampler(object):
13
+ def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+ self.device = device
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ if attr.device != self.device:
23
+ attr = attr.to(self.device)
24
+ setattr(self, name, attr)
25
+
26
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
27
+ if ddim_eta != 0:
28
+ raise ValueError('ddim_eta must be 0 for PLMS')
29
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
30
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
31
+ alphas_cumprod = self.model.alphas_cumprod
32
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
33
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
34
+
35
+ self.register_buffer('betas', to_torch(self.model.betas))
36
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
37
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
38
+
39
+ # calculations for diffusion q(x_t | x_{t-1}) and others
40
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
44
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
45
+
46
+ # ddim sampling parameters
47
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
48
+ ddim_timesteps=self.ddim_timesteps,
49
+ eta=ddim_eta,verbose=verbose)
50
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
51
+ self.register_buffer('ddim_alphas', ddim_alphas)
52
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
53
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
54
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
55
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
56
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
57
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
58
+
59
+ @torch.no_grad()
60
+ def sample(self,
61
+ S,
62
+ batch_size,
63
+ shape,
64
+ conditioning=None,
65
+ callback=None,
66
+ timesteps=None,
67
+ normals_sequence=None,
68
+ img_callback=None,
69
+ quantize_x0=False,
70
+ eta=0.,
71
+ mask=None,
72
+ x0=None,
73
+ temperature=1.,
74
+ noise_dropout=0.,
75
+ score_corrector=None,
76
+ corrector_kwargs=None,
77
+ verbose=True,
78
+ x_T=None,
79
+ log_every_t=100,
80
+ unconditional_guidance_scale=1.,
81
+ unconditional_conditioning=None,
82
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
83
+ dynamic_threshold=None,
84
+ **kwargs
85
+ ):
86
+ if conditioning is not None:
87
+ if isinstance(conditioning, dict):
88
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
89
+ if cbs != batch_size:
90
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
91
+ else:
92
+ if conditioning.shape[0] != batch_size:
93
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
94
+
95
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
96
+ # sampling
97
+ C, H, W = shape
98
+ size = (batch_size, C, H, W)
99
+ if verbose:
100
+ print(f'Data shape for PLMS sampling is {size}')
101
+
102
+ samples, intermediates = self.plms_sampling(conditioning, size,
103
+ callback=callback,
104
+ img_callback=img_callback,
105
+ quantize_denoised=quantize_x0,
106
+ timesteps=timesteps,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ verbose=verbose,
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def plms_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
129
+ dynamic_threshold=None, verbose=True):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ timesteps = timesteps
141
+
142
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
143
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
144
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
145
+ if verbose:
146
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
147
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
148
+ else:
149
+ iterator = time_range
150
+ old_eps = []
151
+
152
+ for i, step in enumerate(iterator):
153
+ index = total_steps - i - 1
154
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
155
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
156
+
157
+ if mask is not None:
158
+ assert x0 is not None
159
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
160
+ img = img_orig * mask + (1. - mask) * img
161
+
162
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
163
+ quantize_denoised=quantize_denoised, temperature=temperature,
164
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
165
+ corrector_kwargs=corrector_kwargs,
166
+ unconditional_guidance_scale=unconditional_guidance_scale,
167
+ unconditional_conditioning=unconditional_conditioning,
168
+ old_eps=old_eps, t_next=ts_next,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0, e_t = outs
171
+ old_eps.append(e_t)
172
+ if len(old_eps) >= 4:
173
+ old_eps.pop(0)
174
+ if callback: callback(i)
175
+ if img_callback: img_callback(pred_x0, i)
176
+
177
+ if index % log_every_t == 0 or index == total_steps - 1:
178
+ intermediates['x_inter'].append(img)
179
+ intermediates['pred_x0'].append(pred_x0)
180
+
181
+ return img, intermediates
182
+
183
+ @torch.no_grad()
184
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
185
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
186
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
187
+ dynamic_threshold=None):
188
+ b, *_, device = *x.shape, x.device
189
+
190
+ def get_model_output(x, t):
191
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
192
+ e_t = self.model.apply_model(x, t, c)
193
+ else:
194
+ x_in = torch.cat([x] * 2)
195
+ t_in = torch.cat([t] * 2)
196
+ c_in = torch.cat([unconditional_conditioning, c])
197
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
198
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
199
+
200
+ if score_corrector is not None:
201
+ assert self.model.parameterization == "eps"
202
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
203
+
204
+ return e_t
205
+
206
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
207
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
208
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
209
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
210
+
211
+ def get_x_prev_and_pred_x0(e_t, index):
212
+ # select parameters corresponding to the currently considered timestep
213
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
214
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
215
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
216
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
217
+
218
+ # current prediction for x_0
219
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
220
+ if quantize_denoised:
221
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
222
+ if dynamic_threshold is not None:
223
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
224
+ # direction pointing to x_t
225
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
226
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
227
+ if noise_dropout > 0.:
228
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
229
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
230
+ return x_prev, pred_x0
231
+
232
+ e_t = get_model_output(x, t)
233
+ if len(old_eps) == 0:
234
+ # Pseudo Improved Euler (2nd order)
235
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
236
+ e_t_next = get_model_output(x_prev, t_next)
237
+ e_t_prime = (e_t + e_t_next) / 2
238
+ elif len(old_eps) == 1:
239
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
240
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
241
+ elif len(old_eps) == 2:
242
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
243
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
244
+ elif len(old_eps) >= 3:
245
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
246
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
247
+
248
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
249
+
250
+ return x_prev, pred_x0, e_t
ldm/models/diffusion/sampling_util.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def append_dims(x, target_dims):
6
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8
+ dims_to_append = target_dims - x.ndim
9
+ if dims_to_append < 0:
10
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11
+ return x[(...,) + (None,) * dims_to_append]
12
+
13
+
14
+ def norm_thresholding(x0, value):
15
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16
+ return x0 * (value / s)
17
+
18
+
19
+ def spatial_norm_thresholding(x0, value):
20
+ # b c h w
21
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22
+ return x0 * (value / s)
ldm/modules/attention.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILBLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+
23
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
68
+
69
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
70
+
71
+ def forward(self, x):
72
+ return self.net(x)
73
+
74
+
75
+ def zero_module(module):
76
+ """
77
+ Zero out the parameters of a module and return it.
78
+ """
79
+ for p in module.parameters():
80
+ p.detach().zero_()
81
+ return module
82
+
83
+
84
+ def Normalize(in_channels):
85
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
86
+
87
+
88
+ class SpatialSelfAttention(nn.Module):
89
+ def __init__(self, in_channels):
90
+ super().__init__()
91
+ self.in_channels = in_channels
92
+
93
+ self.norm = Normalize(in_channels)
94
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
95
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
96
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
97
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
98
+
99
+ def forward(self, x):
100
+ h_ = x
101
+ h_ = self.norm(h_)
102
+ q = self.q(h_)
103
+ k = self.k(h_)
104
+ v = self.v(h_)
105
+
106
+ # compute attention
107
+ b, c, h, w = q.shape
108
+ q = rearrange(q, "b c h w -> b (h w) c")
109
+ k = rearrange(k, "b c h w -> b c (h w)")
110
+ w_ = torch.einsum("bij,bjk->bik", q, k)
111
+
112
+ w_ = w_ * (int(c) ** (-0.5))
113
+ w_ = torch.nn.functional.softmax(w_, dim=2)
114
+
115
+ # attend to values
116
+ v = rearrange(v, "b c h w -> b c (h w)")
117
+ w_ = rearrange(w_, "b i j -> b j i")
118
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
119
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
120
+ h_ = self.proj_out(h_)
121
+
122
+ return x + h_
123
+
124
+
125
+ class CrossAttention(nn.Module):
126
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
127
+ super().__init__()
128
+ inner_dim = dim_head * heads
129
+ context_dim = default(context_dim, query_dim)
130
+
131
+ self.scale = dim_head**-0.5
132
+ self.heads = heads
133
+
134
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
135
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
136
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
137
+
138
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
139
+
140
+ def forward(self, x, context=None, mask=None):
141
+ h = self.heads
142
+
143
+ q = self.to_q(x)
144
+ context = default(context, x)
145
+ k = self.to_k(context)
146
+ v = self.to_v(context)
147
+
148
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
149
+
150
+ # force cast to fp32 to avoid overflowing
151
+ if _ATTN_PRECISION == "fp32":
152
+ with torch.autocast(enabled=False, device_type="cuda"):
153
+ q, k = q.float(), k.float()
154
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
155
+ else:
156
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
157
+
158
+ del q, k
159
+
160
+ if exists(mask):
161
+ mask = rearrange(mask, "b ... -> b (...)")
162
+ max_neg_value = -torch.finfo(sim.dtype).max
163
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
164
+ sim.masked_fill_(~mask, max_neg_value)
165
+
166
+ # attention, what we cannot get enough of
167
+ sim = sim.softmax(dim=-1)
168
+
169
+ out = einsum("b i j, b j d -> b i d", sim, v)
170
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
171
+ return self.to_out(out)
172
+
173
+
174
+ class MemoryEfficientCrossAttention(nn.Module):
175
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
176
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
177
+ super().__init__()
178
+ print(
179
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
180
+ f"{heads} heads."
181
+ )
182
+ inner_dim = dim_head * heads
183
+ context_dim = default(context_dim, query_dim)
184
+
185
+ self.heads = heads
186
+ self.dim_head = dim_head
187
+
188
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
189
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
190
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
191
+
192
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
193
+ self.attention_op: Optional[Any] = None
194
+
195
+ def forward(self, x, context=None, mask=None):
196
+ q = self.to_q(x)
197
+ context = default(context, x)
198
+ k = self.to_k(context)
199
+ v = self.to_v(context)
200
+
201
+ b, _, _ = q.shape
202
+ q, k, v = map(
203
+ lambda t: t.unsqueeze(3)
204
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
205
+ .permute(0, 2, 1, 3)
206
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
207
+ .contiguous(),
208
+ (q, k, v),
209
+ )
210
+
211
+ # actually compute the attention, what we cannot get enough of
212
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
213
+
214
+ if exists(mask):
215
+ raise NotImplementedError
216
+ out = (
217
+ out.unsqueeze(0)
218
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
219
+ .permute(0, 2, 1, 3)
220
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
221
+ )
222
+ return self.to_out(out)
223
+
224
+
225
+ class BasicTransformerBlock(nn.Module):
226
+ ATTENTION_MODES = {
227
+ "softmax": CrossAttention, # vanilla attention
228
+ "softmax-xformers": MemoryEfficientCrossAttention,
229
+ }
230
+
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ n_heads,
235
+ d_head,
236
+ dropout=0.0,
237
+ context_dim=None,
238
+ gated_ff=True,
239
+ checkpoint=True,
240
+ disable_self_attn=False,
241
+ ):
242
+ super().__init__()
243
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
244
+ assert attn_mode in self.ATTENTION_MODES
245
+ attn_cls = self.ATTENTION_MODES[attn_mode]
246
+ self.disable_self_attn = disable_self_attn
247
+ self.attn1 = attn_cls(
248
+ query_dim=dim,
249
+ heads=n_heads,
250
+ dim_head=d_head,
251
+ dropout=dropout,
252
+ context_dim=context_dim if self.disable_self_attn else None,
253
+ ) # is a self-attention if not self.disable_self_attn
254
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
255
+ self.attn2 = attn_cls(
256
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
257
+ ) # is self-attn if context is none
258
+ self.norm1 = nn.LayerNorm(dim)
259
+ self.norm2 = nn.LayerNorm(dim)
260
+ self.norm3 = nn.LayerNorm(dim)
261
+ self.checkpoint = checkpoint
262
+
263
+ def forward(self, x, context=None):
264
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
265
+
266
+ def _forward(self, x, context=None):
267
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
268
+ x = self.attn2(self.norm2(x), context=context) + x
269
+ x = self.ff(self.norm3(x)) + x
270
+ return x
271
+
272
+
273
+ class SpatialTransformer(nn.Module):
274
+ """
275
+ Transformer block for image-like data.
276
+ First, project the input (aka embedding)
277
+ and reshape to b, t, d.
278
+ Then apply standard transformer action.
279
+ Finally, reshape to image
280
+ NEW: use_linear for more efficiency instead of the 1x1 convs
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ in_channels,
286
+ n_heads,
287
+ d_head,
288
+ depth=1,
289
+ dropout=0.0,
290
+ context_dim=None,
291
+ disable_self_attn=False,
292
+ use_linear=False,
293
+ use_checkpoint=True,
294
+ ):
295
+ super().__init__()
296
+ if exists(context_dim) and not isinstance(context_dim, list):
297
+ context_dim = [context_dim]
298
+ self.in_channels = in_channels
299
+ inner_dim = n_heads * d_head
300
+ self.norm = Normalize(in_channels)
301
+ if not use_linear:
302
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
303
+ else:
304
+ self.proj_in = nn.Linear(in_channels, inner_dim)
305
+
306
+ self.transformer_blocks = nn.ModuleList(
307
+ [
308
+ BasicTransformerBlock(
309
+ inner_dim,
310
+ n_heads,
311
+ d_head,
312
+ dropout=dropout,
313
+ context_dim=context_dim[d],
314
+ disable_self_attn=disable_self_attn,
315
+ checkpoint=use_checkpoint,
316
+ )
317
+ for d in range(depth)
318
+ ]
319
+ )
320
+ if not use_linear:
321
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
322
+ else:
323
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
324
+ self.use_linear = use_linear
325
+
326
+ def forward(self, x, context=None):
327
+ # note: if no context is given, cross-attention defaults to self-attention
328
+ if not isinstance(context, list):
329
+ context = [context]
330
+ b, c, h, w = x.shape
331
+ x_in = x
332
+ x = self.norm(x)
333
+ if not self.use_linear:
334
+ x = self.proj_in(x)
335
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
336
+ if self.use_linear:
337
+ x = self.proj_in(x)
338
+ for i, block in enumerate(self.transformer_blocks):
339
+ x = block(x, context=context[i])
340
+ if self.use_linear:
341
+ x = self.proj_out(x)
342
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
343
+ if not self.use_linear:
344
+ x = self.proj_out(x)
345
+ return x + x_in
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+ print("No module 'xformers'. Proceeding without it.")
19
+
20
+
21
+ def get_timestep_embedding(timesteps, embedding_dim):
22
+ """
23
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
24
+ From Fairseq.
25
+ Build sinusoidal embeddings.
26
+ This matches the implementation in tensor2tensor, but differs slightly
27
+ from the description in Section 3.5 of "Attention Is All You Need".
28
+ """
29
+ assert len(timesteps.shape) == 1
30
+
31
+ half_dim = embedding_dim // 2
32
+ emb = math.log(10000) / (half_dim - 1)
33
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
34
+ emb = emb.to(device=timesteps.device)
35
+ emb = timesteps.float()[:, None] * emb[None, :]
36
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
37
+ if embedding_dim % 2 == 1: # zero pad
38
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
39
+ return emb
40
+
41
+
42
+ def nonlinearity(x):
43
+ # swish
44
+ return x * torch.sigmoid(x)
45
+
46
+
47
+ def Normalize(in_channels, num_groups=32):
48
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
49
+
50
+
51
+ class Upsample(nn.Module):
52
+ def __init__(self, in_channels, with_conv):
53
+ super().__init__()
54
+ self.with_conv = with_conv
55
+ if self.with_conv:
56
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
57
+
58
+ def forward(self, x):
59
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
60
+ if self.with_conv:
61
+ x = self.conv(x)
62
+ return x
63
+
64
+
65
+ class Downsample(nn.Module):
66
+ def __init__(self, in_channels, with_conv):
67
+ super().__init__()
68
+ self.with_conv = with_conv
69
+ if self.with_conv:
70
+ # no asymmetric padding in torch conv, must do it ourselves
71
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
72
+
73
+ def forward(self, x):
74
+ if self.with_conv:
75
+ pad = (0, 1, 0, 1)
76
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
77
+ x = self.conv(x)
78
+ else:
79
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
80
+ return x
81
+
82
+
83
+ class ResnetBlock(nn.Module):
84
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
95
+ self.norm2 = Normalize(out_channels)
96
+ self.dropout = torch.nn.Dropout(dropout)
97
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
98
+ if self.in_channels != self.out_channels:
99
+ if self.use_conv_shortcut:
100
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
101
+ else:
102
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
103
+
104
+ def forward(self, x, temb):
105
+ h = x
106
+ h = self.norm1(h)
107
+ h = nonlinearity(h)
108
+ h = self.conv1(h)
109
+
110
+ if temb is not None:
111
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
112
+
113
+ h = self.norm2(h)
114
+ h = nonlinearity(h)
115
+ h = self.dropout(h)
116
+ h = self.conv2(h)
117
+
118
+ if self.in_channels != self.out_channels:
119
+ if self.use_conv_shortcut:
120
+ x = self.conv_shortcut(x)
121
+ else:
122
+ x = self.nin_shortcut(x)
123
+
124
+ return x + h
125
+
126
+
127
+ class AttnBlock(nn.Module):
128
+ def __init__(self, in_channels):
129
+ super().__init__()
130
+ self.in_channels = in_channels
131
+
132
+ self.norm = Normalize(in_channels)
133
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
134
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
135
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
136
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
137
+
138
+ def forward(self, x):
139
+ h_ = x
140
+ h_ = self.norm(h_)
141
+ q = self.q(h_)
142
+ k = self.k(h_)
143
+ v = self.v(h_)
144
+
145
+ # compute attention
146
+ b, c, h, w = q.shape
147
+ q = q.reshape(b, c, h * w)
148
+ q = q.permute(0, 2, 1) # b,hw,c
149
+ k = k.reshape(b, c, h * w) # b,c,hw
150
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
151
+ w_ = w_ * (int(c) ** (-0.5))
152
+ w_ = torch.nn.functional.softmax(w_, dim=2)
153
+
154
+ # attend to values
155
+ v = v.reshape(b, c, h * w)
156
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
157
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
158
+ h_ = h_.reshape(b, c, h, w)
159
+
160
+ h_ = self.proj_out(h_)
161
+
162
+ return x + h_
163
+
164
+
165
+ class MemoryEfficientAttnBlock(nn.Module):
166
+ """
167
+ Uses xformers efficient implementation,
168
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
169
+ Note: this is a single-head self-attention operation
170
+ """
171
+
172
+ #
173
+ def __init__(self, in_channels):
174
+ super().__init__()
175
+ self.in_channels = in_channels
176
+
177
+ self.norm = Normalize(in_channels)
178
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
179
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
180
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
181
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
182
+ self.attention_op: Optional[Any] = None
183
+
184
+ def forward(self, x):
185
+ h_ = x
186
+ h_ = self.norm(h_)
187
+ q = self.q(h_)
188
+ k = self.k(h_)
189
+ v = self.v(h_)
190
+
191
+ # compute attention
192
+ B, C, H, W = q.shape
193
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
194
+
195
+ q, k, v = map(
196
+ lambda t: t.unsqueeze(3)
197
+ .reshape(B, t.shape[1], 1, C)
198
+ .permute(0, 2, 1, 3)
199
+ .reshape(B * 1, t.shape[1], C)
200
+ .contiguous(),
201
+ (q, k, v),
202
+ )
203
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
204
+
205
+ out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
206
+ out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
207
+ out = self.proj_out(out)
208
+ return x + out
209
+
210
+
211
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
212
+ def forward(self, x, context=None, mask=None):
213
+ b, c, h, w = x.shape
214
+ x = rearrange(x, "b c h w -> b (h w) c")
215
+ out = super().forward(x, context=context, mask=mask)
216
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
217
+ return x + out
218
+
219
+
220
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
221
+ assert attn_type in [
222
+ "vanilla",
223
+ "vanilla-xformers",
224
+ "memory-efficient-cross-attn",
225
+ "linear",
226
+ "none",
227
+ ], f"attn_type {attn_type} unknown"
228
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
229
+ attn_type = "vanilla-xformers"
230
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
231
+ if attn_type == "vanilla":
232
+ assert attn_kwargs is None
233
+ return AttnBlock(in_channels)
234
+ elif attn_type == "vanilla-xformers":
235
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
236
+ return MemoryEfficientAttnBlock(in_channels)
237
+ elif type == "memory-efficient-cross-attn":
238
+ attn_kwargs["query_dim"] = in_channels
239
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
240
+ elif attn_type == "none":
241
+ return nn.Identity(in_channels)
242
+ else:
243
+ raise NotImplementedError()
244
+
245
+
246
+ class Model(nn.Module):
247
+ def __init__(
248
+ self,
249
+ *,
250
+ ch,
251
+ out_ch,
252
+ ch_mult=(1, 2, 4, 8),
253
+ num_res_blocks,
254
+ attn_resolutions,
255
+ dropout=0.0,
256
+ resamp_with_conv=True,
257
+ in_channels,
258
+ resolution,
259
+ use_timestep=True,
260
+ use_linear_attn=False,
261
+ attn_type="vanilla",
262
+ ):
263
+ super().__init__()
264
+ if use_linear_attn:
265
+ attn_type = "linear"
266
+ self.ch = ch
267
+ self.temb_ch = self.ch * 4
268
+ self.num_resolutions = len(ch_mult)
269
+ self.num_res_blocks = num_res_blocks
270
+ self.resolution = resolution
271
+ self.in_channels = in_channels
272
+
273
+ self.use_timestep = use_timestep
274
+ if self.use_timestep:
275
+ # timestep embedding
276
+ self.temb = nn.Module()
277
+ self.temb.dense = nn.ModuleList(
278
+ [
279
+ torch.nn.Linear(self.ch, self.temb_ch),
280
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
281
+ ]
282
+ )
283
+
284
+ # downsampling
285
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
286
+
287
+ curr_res = resolution
288
+ in_ch_mult = (1,) + tuple(ch_mult)
289
+ self.down = nn.ModuleList()
290
+ for i_level in range(self.num_resolutions):
291
+ block = nn.ModuleList()
292
+ attn = nn.ModuleList()
293
+ block_in = ch * in_ch_mult[i_level]
294
+ block_out = ch * ch_mult[i_level]
295
+ for i_block in range(self.num_res_blocks):
296
+ block.append(
297
+ ResnetBlock(
298
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
299
+ )
300
+ )
301
+ block_in = block_out
302
+ if curr_res in attn_resolutions:
303
+ attn.append(make_attn(block_in, attn_type=attn_type))
304
+ down = nn.Module()
305
+ down.block = block
306
+ down.attn = attn
307
+ if i_level != self.num_resolutions - 1:
308
+ down.downsample = Downsample(block_in, resamp_with_conv)
309
+ curr_res = curr_res // 2
310
+ self.down.append(down)
311
+
312
+ # middle
313
+ self.mid = nn.Module()
314
+ self.mid.block_1 = ResnetBlock(
315
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
316
+ )
317
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
318
+ self.mid.block_2 = ResnetBlock(
319
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
320
+ )
321
+
322
+ # upsampling
323
+ self.up = nn.ModuleList()
324
+ for i_level in reversed(range(self.num_resolutions)):
325
+ block = nn.ModuleList()
326
+ attn = nn.ModuleList()
327
+ block_out = ch * ch_mult[i_level]
328
+ skip_in = ch * ch_mult[i_level]
329
+ for i_block in range(self.num_res_blocks + 1):
330
+ if i_block == self.num_res_blocks:
331
+ skip_in = ch * in_ch_mult[i_level]
332
+ block.append(
333
+ ResnetBlock(
334
+ in_channels=block_in + skip_in,
335
+ out_channels=block_out,
336
+ temb_channels=self.temb_ch,
337
+ dropout=dropout,
338
+ )
339
+ )
340
+ block_in = block_out
341
+ if curr_res in attn_resolutions:
342
+ attn.append(make_attn(block_in, attn_type=attn_type))
343
+ up = nn.Module()
344
+ up.block = block
345
+ up.attn = attn
346
+ if i_level != 0:
347
+ up.upsample = Upsample(block_in, resamp_with_conv)
348
+ curr_res = curr_res * 2
349
+ self.up.insert(0, up) # prepend to get consistent order
350
+
351
+ # end
352
+ self.norm_out = Normalize(block_in)
353
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
354
+
355
+ def forward(self, x, t=None, context=None):
356
+ # assert x.shape[2] == x.shape[3] == self.resolution
357
+ if context is not None:
358
+ # assume aligned context, cat along channel axis
359
+ x = torch.cat((x, context), dim=1)
360
+ if self.use_timestep:
361
+ # timestep embedding
362
+ assert t is not None
363
+ temb = get_timestep_embedding(t, self.ch)
364
+ temb = self.temb.dense[0](temb)
365
+ temb = nonlinearity(temb)
366
+ temb = self.temb.dense[1](temb)
367
+ else:
368
+ temb = None
369
+
370
+ # downsampling
371
+ hs = [self.conv_in(x)]
372
+ for i_level in range(self.num_resolutions):
373
+ for i_block in range(self.num_res_blocks):
374
+ h = self.down[i_level].block[i_block](hs[-1], temb)
375
+ if len(self.down[i_level].attn) > 0:
376
+ h = self.down[i_level].attn[i_block](h)
377
+ hs.append(h)
378
+ if i_level != self.num_resolutions - 1:
379
+ hs.append(self.down[i_level].downsample(hs[-1]))
380
+
381
+ # middle
382
+ h = hs[-1]
383
+ h = self.mid.block_1(h, temb)
384
+ h = self.mid.attn_1(h)
385
+ h = self.mid.block_2(h, temb)
386
+
387
+ # upsampling
388
+ for i_level in reversed(range(self.num_resolutions)):
389
+ for i_block in range(self.num_res_blocks + 1):
390
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
391
+ if len(self.up[i_level].attn) > 0:
392
+ h = self.up[i_level].attn[i_block](h)
393
+ if i_level != 0:
394
+ h = self.up[i_level].upsample(h)
395
+
396
+ # end
397
+ h = self.norm_out(h)
398
+ h = nonlinearity(h)
399
+ h = self.conv_out(h)
400
+ return h
401
+
402
+ def get_last_layer(self):
403
+ return self.conv_out.weight
404
+
405
+
406
+ class Encoder(nn.Module):
407
+ def __init__(
408
+ self,
409
+ *,
410
+ ch,
411
+ out_ch,
412
+ ch_mult=(1, 2, 4, 8),
413
+ num_res_blocks,
414
+ attn_resolutions,
415
+ dropout=0.0,
416
+ resamp_with_conv=True,
417
+ in_channels,
418
+ resolution,
419
+ z_channels,
420
+ double_z=True,
421
+ use_linear_attn=False,
422
+ attn_type="vanilla",
423
+ **ignore_kwargs,
424
+ ):
425
+ super().__init__()
426
+ if use_linear_attn:
427
+ attn_type = "linear"
428
+ self.ch = ch
429
+ self.temb_ch = 0
430
+ self.num_resolutions = len(ch_mult)
431
+ self.num_res_blocks = num_res_blocks
432
+ self.resolution = resolution
433
+ self.in_channels = in_channels
434
+
435
+ # downsampling
436
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
437
+
438
+ curr_res = resolution
439
+ in_ch_mult = (1,) + tuple(ch_mult)
440
+ self.in_ch_mult = in_ch_mult
441
+ self.down = nn.ModuleList()
442
+ for i_level in range(self.num_resolutions):
443
+ block = nn.ModuleList()
444
+ attn = nn.ModuleList()
445
+ block_in = ch * in_ch_mult[i_level]
446
+ block_out = ch * ch_mult[i_level]
447
+ for i_block in range(self.num_res_blocks):
448
+ block.append(
449
+ ResnetBlock(
450
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
451
+ )
452
+ )
453
+ block_in = block_out
454
+ if curr_res in attn_resolutions:
455
+ attn.append(make_attn(block_in, attn_type=attn_type))
456
+ down = nn.Module()
457
+ down.block = block
458
+ down.attn = attn
459
+ if i_level != self.num_resolutions - 1:
460
+ down.downsample = Downsample(block_in, resamp_with_conv)
461
+ curr_res = curr_res // 2
462
+ self.down.append(down)
463
+
464
+ # middle
465
+ self.mid = nn.Module()
466
+ self.mid.block_1 = ResnetBlock(
467
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
468
+ )
469
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
470
+ self.mid.block_2 = ResnetBlock(
471
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
472
+ )
473
+
474
+ # end
475
+ self.norm_out = Normalize(block_in)
476
+ self.conv_out = torch.nn.Conv2d(
477
+ block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
478
+ )
479
+
480
+ def forward(self, x):
481
+ # timestep embedding
482
+ temb = None
483
+
484
+ # downsampling
485
+ hs = [self.conv_in(x)]
486
+ for i_level in range(self.num_resolutions):
487
+ for i_block in range(self.num_res_blocks):
488
+ h = self.down[i_level].block[i_block](hs[-1], temb)
489
+ if len(self.down[i_level].attn) > 0:
490
+ h = self.down[i_level].attn[i_block](h)
491
+ hs.append(h)
492
+ if i_level != self.num_resolutions - 1:
493
+ hs.append(self.down[i_level].downsample(hs[-1]))
494
+
495
+ # middle
496
+ h = hs[-1]
497
+ h = self.mid.block_1(h, temb)
498
+ h = self.mid.attn_1(h)
499
+ h = self.mid.block_2(h, temb)
500
+
501
+ # end
502
+ h = self.norm_out(h)
503
+ h = nonlinearity(h)
504
+ h = self.conv_out(h)
505
+ return h
506
+
507
+
508
+ class Decoder(nn.Module):
509
+ def __init__(
510
+ self,
511
+ *,
512
+ ch,
513
+ out_ch,
514
+ ch_mult=(1, 2, 4, 8),
515
+ num_res_blocks,
516
+ attn_resolutions,
517
+ dropout=0.0,
518
+ resamp_with_conv=True,
519
+ in_channels,
520
+ resolution,
521
+ z_channels,
522
+ give_pre_end=False,
523
+ tanh_out=False,
524
+ use_linear_attn=False,
525
+ attn_type="vanilla",
526
+ **ignorekwargs,
527
+ ):
528
+ super().__init__()
529
+ if use_linear_attn:
530
+ attn_type = "linear"
531
+ self.ch = ch
532
+ self.temb_ch = 0
533
+ self.num_resolutions = len(ch_mult)
534
+ self.num_res_blocks = num_res_blocks
535
+ self.resolution = resolution
536
+ self.in_channels = in_channels
537
+ self.give_pre_end = give_pre_end
538
+ self.tanh_out = tanh_out
539
+
540
+ # compute in_ch_mult, block_in and curr_res at lowest res
541
+ in_ch_mult = (1,) + tuple(ch_mult)
542
+ block_in = ch * ch_mult[self.num_resolutions - 1]
543
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
544
+ self.z_shape = (1, z_channels, curr_res, curr_res)
545
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
546
+
547
+ # z to block_in
548
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
549
+
550
+ # middle
551
+ self.mid = nn.Module()
552
+ self.mid.block_1 = ResnetBlock(
553
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
554
+ )
555
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
556
+ self.mid.block_2 = ResnetBlock(
557
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
558
+ )
559
+
560
+ # upsampling
561
+ self.up = nn.ModuleList()
562
+ for i_level in reversed(range(self.num_resolutions)):
563
+ block = nn.ModuleList()
564
+ attn = nn.ModuleList()
565
+ block_out = ch * ch_mult[i_level]
566
+ for i_block in range(self.num_res_blocks + 1):
567
+ block.append(
568
+ ResnetBlock(
569
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
570
+ )
571
+ )
572
+ block_in = block_out
573
+ if curr_res in attn_resolutions:
574
+ attn.append(make_attn(block_in, attn_type=attn_type))
575
+ up = nn.Module()
576
+ up.block = block
577
+ up.attn = attn
578
+ if i_level != 0:
579
+ up.upsample = Upsample(block_in, resamp_with_conv)
580
+ curr_res = curr_res * 2
581
+ self.up.insert(0, up) # prepend to get consistent order
582
+
583
+ # end
584
+ self.norm_out = Normalize(block_in)
585
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
586
+
587
+ def forward(self, z):
588
+ # assert z.shape[1:] == self.z_shape[1:]
589
+ self.last_z_shape = z.shape
590
+
591
+ # timestep embedding
592
+ temb = None
593
+
594
+ # z to block_in
595
+ h = self.conv_in(z)
596
+
597
+ # middle
598
+ h = self.mid.block_1(h, temb)
599
+ h = self.mid.attn_1(h)
600
+ h = self.mid.block_2(h, temb)
601
+
602
+ # upsampling
603
+ for i_level in reversed(range(self.num_resolutions)):
604
+ for i_block in range(self.num_res_blocks + 1):
605
+ h = self.up[i_level].block[i_block](h, temb)
606
+ if len(self.up[i_level].attn) > 0:
607
+ h = self.up[i_level].attn[i_block](h)
608
+ if i_level != 0:
609
+ h = self.up[i_level].upsample(h)
610
+
611
+ # end
612
+ if self.give_pre_end:
613
+ return h
614
+
615
+ h = self.norm_out(h)
616
+ h = nonlinearity(h)
617
+ h = self.conv_out(h)
618
+ if self.tanh_out:
619
+ h = torch.tanh(h)
620
+ return h
621
+
622
+
623
+ class SimpleDecoder(nn.Module):
624
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
625
+ super().__init__()
626
+ self.model = nn.ModuleList(
627
+ [
628
+ nn.Conv2d(in_channels, in_channels, 1),
629
+ ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
630
+ ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
631
+ ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
632
+ nn.Conv2d(2 * in_channels, in_channels, 1),
633
+ Upsample(in_channels, with_conv=True),
634
+ ]
635
+ )
636
+ # end
637
+ self.norm_out = Normalize(in_channels)
638
+ self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
639
+
640
+ def forward(self, x):
641
+ for i, layer in enumerate(self.model):
642
+ if i in [1, 2, 3]:
643
+ x = layer(x, None)
644
+ else:
645
+ x = layer(x)
646
+
647
+ h = self.norm_out(x)
648
+ h = nonlinearity(h)
649
+ x = self.conv_out(h)
650
+ return x
651
+
652
+
653
+ class UpsampleDecoder(nn.Module):
654
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
655
+ super().__init__()
656
+ # upsampling
657
+ self.temb_ch = 0
658
+ self.num_resolutions = len(ch_mult)
659
+ self.num_res_blocks = num_res_blocks
660
+ block_in = in_channels
661
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
662
+ self.res_blocks = nn.ModuleList()
663
+ self.upsample_blocks = nn.ModuleList()
664
+ for i_level in range(self.num_resolutions):
665
+ res_block = []
666
+ block_out = ch * ch_mult[i_level]
667
+ for i_block in range(self.num_res_blocks + 1):
668
+ res_block.append(
669
+ ResnetBlock(
670
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
671
+ )
672
+ )
673
+ block_in = block_out
674
+ self.res_blocks.append(nn.ModuleList(res_block))
675
+ if i_level != self.num_resolutions - 1:
676
+ self.upsample_blocks.append(Upsample(block_in, True))
677
+ curr_res = curr_res * 2
678
+
679
+ # end
680
+ self.norm_out = Normalize(block_in)
681
+ self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
682
+
683
+ def forward(self, x):
684
+ # upsampling
685
+ h = x
686
+ for k, i_level in enumerate(range(self.num_resolutions)):
687
+ for i_block in range(self.num_res_blocks + 1):
688
+ h = self.res_blocks[i_level][i_block](h, None)
689
+ if i_level != self.num_resolutions - 1:
690
+ h = self.upsample_blocks[k](h)
691
+ h = self.norm_out(h)
692
+ h = nonlinearity(h)
693
+ h = self.conv_out(h)
694
+ return h
695
+
696
+
697
+ class LatentRescaler(nn.Module):
698
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
699
+ super().__init__()
700
+ # residual block, interpolate, residual block
701
+ self.factor = factor
702
+ self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
703
+ self.res_block1 = nn.ModuleList(
704
+ [
705
+ ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
706
+ for _ in range(depth)
707
+ ]
708
+ )
709
+ self.attn = AttnBlock(mid_channels)
710
+ self.res_block2 = nn.ModuleList(
711
+ [
712
+ ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
713
+ for _ in range(depth)
714
+ ]
715
+ )
716
+
717
+ self.conv_out = nn.Conv2d(
718
+ mid_channels,
719
+ out_channels,
720
+ kernel_size=1,
721
+ )
722
+
723
+ def forward(self, x):
724
+ x = self.conv_in(x)
725
+ for block in self.res_block1:
726
+ x = block(x, None)
727
+ x = torch.nn.functional.interpolate(
728
+ x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor)))
729
+ )
730
+ x = self.attn(x)
731
+ for block in self.res_block2:
732
+ x = block(x, None)
733
+ x = self.conv_out(x)
734
+ return x
735
+
736
+
737
+ class MergedRescaleEncoder(nn.Module):
738
+ def __init__(
739
+ self,
740
+ in_channels,
741
+ ch,
742
+ resolution,
743
+ out_ch,
744
+ num_res_blocks,
745
+ attn_resolutions,
746
+ dropout=0.0,
747
+ resamp_with_conv=True,
748
+ ch_mult=(1, 2, 4, 8),
749
+ rescale_factor=1.0,
750
+ rescale_module_depth=1,
751
+ ):
752
+ super().__init__()
753
+ intermediate_chn = ch * ch_mult[-1]
754
+ self.encoder = Encoder(
755
+ in_channels=in_channels,
756
+ num_res_blocks=num_res_blocks,
757
+ ch=ch,
758
+ ch_mult=ch_mult,
759
+ z_channels=intermediate_chn,
760
+ double_z=False,
761
+ resolution=resolution,
762
+ attn_resolutions=attn_resolutions,
763
+ dropout=dropout,
764
+ resamp_with_conv=resamp_with_conv,
765
+ out_ch=None,
766
+ )
767
+ self.rescaler = LatentRescaler(
768
+ factor=rescale_factor,
769
+ in_channels=intermediate_chn,
770
+ mid_channels=intermediate_chn,
771
+ out_channels=out_ch,
772
+ depth=rescale_module_depth,
773
+ )
774
+
775
+ def forward(self, x):
776
+ x = self.encoder(x)
777
+ x = self.rescaler(x)
778
+ return x
779
+
780
+
781
+ class MergedRescaleDecoder(nn.Module):
782
+ def __init__(
783
+ self,
784
+ z_channels,
785
+ out_ch,
786
+ resolution,
787
+ num_res_blocks,
788
+ attn_resolutions,
789
+ ch,
790
+ ch_mult=(1, 2, 4, 8),
791
+ dropout=0.0,
792
+ resamp_with_conv=True,
793
+ rescale_factor=1.0,
794
+ rescale_module_depth=1,
795
+ ):
796
+ super().__init__()
797
+ tmp_chn = z_channels * ch_mult[-1]
798
+ self.decoder = Decoder(
799
+ out_ch=out_ch,
800
+ z_channels=tmp_chn,
801
+ attn_resolutions=attn_resolutions,
802
+ dropout=dropout,
803
+ resamp_with_conv=resamp_with_conv,
804
+ in_channels=None,
805
+ num_res_blocks=num_res_blocks,
806
+ ch_mult=ch_mult,
807
+ resolution=resolution,
808
+ ch=ch,
809
+ )
810
+ self.rescaler = LatentRescaler(
811
+ factor=rescale_factor,
812
+ in_channels=z_channels,
813
+ mid_channels=tmp_chn,
814
+ out_channels=tmp_chn,
815
+ depth=rescale_module_depth,
816
+ )
817
+
818
+ def forward(self, x):
819
+ x = self.rescaler(x)
820
+ x = self.decoder(x)
821
+ return x
822
+
823
+
824
+ class Upsampler(nn.Module):
825
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
826
+ super().__init__()
827
+ assert out_size >= in_size
828
+ num_blocks = int(np.log2(out_size // in_size)) + 1
829
+ factor_up = 1.0 + (out_size % in_size)
830
+ print(
831
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
832
+ )
833
+ self.rescaler = LatentRescaler(
834
+ factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels
835
+ )
836
+ self.decoder = Decoder(
837
+ out_ch=out_channels,
838
+ resolution=out_size,
839
+ z_channels=in_channels,
840
+ num_res_blocks=2,
841
+ attn_resolutions=[],
842
+ in_channels=None,
843
+ ch=in_channels,
844
+ ch_mult=[ch_mult for _ in range(num_blocks)],
845
+ )
846
+
847
+ def forward(self, x):
848
+ x = self.rescaler(x)
849
+ x = self.decoder(x)
850
+ return x
851
+
852
+
853
+ class Resize(nn.Module):
854
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
855
+ super().__init__()
856
+ self.with_conv = learned
857
+ self.mode = mode
858
+ if self.with_conv:
859
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
860
+ raise NotImplementedError()
861
+ assert in_channels is not None
862
+ # no asymmetric padding in torch conv, must do it ourselves
863
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
864
+
865
+ def forward(self, x, scale_factor=1.0):
866
+ if scale_factor == 1.0:
867
+ return x
868
+ else:
869
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
870
+ return x
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import abstractmethod
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from ldm.modules.diffusionmodules.util import (
11
+ checkpoint,
12
+ conv_nd,
13
+ linear,
14
+ avg_pool_nd,
15
+ zero_module,
16
+ normalization,
17
+ timestep_embedding,
18
+ )
19
+ from ldm.modules.attention import SpatialTransformer
20
+ from ldm.util import exists
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+
28
+ def convert_module_to_f32(x):
29
+ pass
30
+
31
+
32
+ ## go
33
+ class AttentionPool2d(nn.Module):
34
+ """
35
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ spacial_dim: int,
41
+ embed_dim: int,
42
+ num_heads_channels: int,
43
+ output_dim: int = None,
44
+ ):
45
+ super().__init__()
46
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
47
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
48
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
49
+ self.num_heads = embed_dim // num_heads_channels
50
+ self.attention = QKVAttention(self.num_heads)
51
+
52
+ def forward(self, x):
53
+ b, c, *_spatial = x.shape
54
+ x = x.reshape(b, c, -1) # NC(HW)
55
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
56
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
57
+ x = self.qkv_proj(x)
58
+ x = self.attention(x)
59
+ x = self.c_proj(x)
60
+ return x[:, :, 0]
61
+
62
+
63
+ class TimestepBlock(nn.Module):
64
+ """
65
+ Any module where forward() takes timestep embeddings as a second argument.
66
+ """
67
+
68
+ @abstractmethod
69
+ def forward(self, x, emb):
70
+ """
71
+ Apply the module to `x` given `emb` timestep embeddings.
72
+ """
73
+
74
+
75
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
76
+ """
77
+ A sequential module that passes timestep embeddings to the children that
78
+ support it as an extra input.
79
+ """
80
+
81
+ def forward(self, x, emb, context=None):
82
+ for layer in self:
83
+ if isinstance(layer, TimestepBlock):
84
+ x = layer(x, emb)
85
+ elif isinstance(layer, SpatialTransformer):
86
+ x = layer(x, context)
87
+ else:
88
+ x = layer(x)
89
+ return x
90
+
91
+
92
+ class Upsample(nn.Module):
93
+ """
94
+ An upsampling layer with an optional convolution.
95
+ :param channels: channels in the inputs and outputs.
96
+ :param use_conv: a bool determining if a convolution is applied.
97
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
98
+ upsampling occurs in the inner-two dimensions.
99
+ """
100
+
101
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
102
+ super().__init__()
103
+ self.channels = channels
104
+ self.out_channels = out_channels or channels
105
+ self.use_conv = use_conv
106
+ self.dims = dims
107
+ if use_conv:
108
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
109
+
110
+ def forward(self, x):
111
+ assert x.shape[1] == self.channels
112
+ if self.dims == 3:
113
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
114
+ else:
115
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
116
+ if self.use_conv:
117
+ x = self.conv(x)
118
+ return x
119
+
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ "Learned 2x upsampling without padding"
123
+
124
+ def __init__(self, channels, out_channels=None, ks=5):
125
+ super().__init__()
126
+ self.channels = channels
127
+ self.out_channels = out_channels or channels
128
+
129
+ self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2)
130
+
131
+ def forward(self, x):
132
+ return self.up(x)
133
+
134
+
135
+ class Downsample(nn.Module):
136
+ """
137
+ A downsampling layer with an optional convolution.
138
+ :param channels: channels in the inputs and outputs.
139
+ :param use_conv: a bool determining if a convolution is applied.
140
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
141
+ downsampling occurs in the inner-two dimensions.
142
+ """
143
+
144
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
145
+ super().__init__()
146
+ self.channels = channels
147
+ self.out_channels = out_channels or channels
148
+ self.use_conv = use_conv
149
+ self.dims = dims
150
+ stride = 2 if dims != 3 else (1, 2, 2)
151
+ if use_conv:
152
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
153
+ else:
154
+ assert self.channels == self.out_channels
155
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
156
+
157
+ def forward(self, x):
158
+ assert x.shape[1] == self.channels
159
+ return self.op(x)
160
+
161
+
162
+ class ResBlock(TimestepBlock):
163
+ """
164
+ A residual block that can optionally change the number of channels.
165
+ :param channels: the number of input channels.
166
+ :param emb_channels: the number of timestep embedding channels.
167
+ :param dropout: the rate of dropout.
168
+ :param out_channels: if specified, the number of out channels.
169
+ :param use_conv: if True and out_channels is specified, use a spatial
170
+ convolution instead of a smaller 1x1 convolution to change the
171
+ channels in the skip connection.
172
+ :param dims: determines if the signal is 1D, 2D, or 3D.
173
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
174
+ :param up: if True, use this block for upsampling.
175
+ :param down: if True, use this block for downsampling.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ emb_channels,
182
+ dropout,
183
+ out_channels=None,
184
+ use_conv=False,
185
+ use_scale_shift_norm=False,
186
+ dims=2,
187
+ use_checkpoint=False,
188
+ up=False,
189
+ down=False,
190
+ ):
191
+ super().__init__()
192
+ self.channels = channels
193
+ self.emb_channels = emb_channels
194
+ self.dropout = dropout
195
+ self.out_channels = out_channels or channels
196
+ self.use_conv = use_conv
197
+ self.use_checkpoint = use_checkpoint
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+
200
+ self.in_layers = nn.Sequential(
201
+ normalization(channels),
202
+ nn.SiLU(),
203
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
204
+ )
205
+
206
+ self.updown = up or down
207
+
208
+ if up:
209
+ self.h_upd = Upsample(channels, False, dims)
210
+ self.x_upd = Upsample(channels, False, dims)
211
+ elif down:
212
+ self.h_upd = Downsample(channels, False, dims)
213
+ self.x_upd = Downsample(channels, False, dims)
214
+ else:
215
+ self.h_upd = self.x_upd = nn.Identity()
216
+
217
+ self.emb_layers = nn.Sequential(
218
+ nn.SiLU(),
219
+ linear(
220
+ emb_channels,
221
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
222
+ ),
223
+ )
224
+ self.out_layers = nn.Sequential(
225
+ normalization(self.out_channels),
226
+ nn.SiLU(),
227
+ nn.Dropout(p=dropout),
228
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
229
+ )
230
+
231
+ if self.out_channels == channels:
232
+ self.skip_connection = nn.Identity()
233
+ elif use_conv:
234
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
235
+ else:
236
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
237
+
238
+ def forward(self, x, emb):
239
+ """
240
+ Apply the block to a Tensor, conditioned on a timestep embedding.
241
+ :param x: an [N x C x ...] Tensor of features.
242
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
243
+ :return: an [N x C x ...] Tensor of outputs.
244
+ """
245
+ return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
246
+
247
+ def _forward(self, x, emb):
248
+ if self.updown:
249
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
250
+ h = in_rest(x)
251
+ h = self.h_upd(h)
252
+ x = self.x_upd(x)
253
+ h = in_conv(h)
254
+ else:
255
+ h = self.in_layers(x)
256
+ emb_out = self.emb_layers(emb).type(h.dtype)
257
+ while len(emb_out.shape) < len(h.shape):
258
+ emb_out = emb_out[..., None]
259
+ if self.use_scale_shift_norm:
260
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
261
+ scale, shift = th.chunk(emb_out, 2, dim=1)
262
+ h = out_norm(h) * (1 + scale) + shift
263
+ h = out_rest(h)
264
+ else:
265
+ h = h + emb_out
266
+ h = self.out_layers(h)
267
+ return self.skip_connection(x) + h
268
+
269
+
270
+ class AttentionBlock(nn.Module):
271
+ """
272
+ An attention block that allows spatial positions to attend to each other.
273
+ Originally ported from here, but adapted to the N-d case.
274
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
275
+ """
276
+
277
+ def __init__(
278
+ self,
279
+ channels,
280
+ num_heads=1,
281
+ num_head_channels=-1,
282
+ use_checkpoint=False,
283
+ use_new_attention_order=False,
284
+ ):
285
+ super().__init__()
286
+ self.channels = channels
287
+ if num_head_channels == -1:
288
+ self.num_heads = num_heads
289
+ else:
290
+ assert (
291
+ channels % num_head_channels == 0
292
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
293
+ self.num_heads = channels // num_head_channels
294
+ self.use_checkpoint = use_checkpoint
295
+ self.norm = normalization(channels)
296
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
297
+ if use_new_attention_order:
298
+ # split qkv before split heads
299
+ self.attention = QKVAttention(self.num_heads)
300
+ else:
301
+ # split heads before split qkv
302
+ self.attention = QKVAttentionLegacy(self.num_heads)
303
+
304
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
305
+
306
+ def forward(self, x):
307
+ return checkpoint(
308
+ self._forward, (x,), self.parameters(), True
309
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
310
+ # return pt_checkpoint(self._forward, x) # pytorch
311
+
312
+ def _forward(self, x):
313
+ b, c, *spatial = x.shape
314
+ x = x.reshape(b, c, -1)
315
+ qkv = self.qkv(self.norm(x))
316
+ h = self.attention(qkv)
317
+ h = self.proj_out(h)
318
+ return (x + h).reshape(b, c, *spatial)
319
+
320
+
321
+ def count_flops_attn(model, _x, y):
322
+ """
323
+ A counter for the `thop` package to count the operations in an
324
+ attention operation.
325
+ Meant to be used like:
326
+ macs, params = thop.profile(
327
+ model,
328
+ inputs=(inputs, timestamps),
329
+ custom_ops={QKVAttention: QKVAttention.count_flops},
330
+ )
331
+ """
332
+ b, c, *spatial = y[0].shape
333
+ num_spatial = int(np.prod(spatial))
334
+ # We perform two matmuls with the same number of ops.
335
+ # The first computes the weight matrix, the second computes
336
+ # the combination of the value vectors.
337
+ matmul_ops = 2 * b * (num_spatial**2) * c
338
+ model.total_ops += th.DoubleTensor([matmul_ops])
339
+
340
+
341
+ class QKVAttentionLegacy(nn.Module):
342
+ """
343
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
344
+ """
345
+
346
+ def __init__(self, n_heads):
347
+ super().__init__()
348
+ self.n_heads = n_heads
349
+
350
+ def forward(self, qkv):
351
+ """
352
+ Apply QKV attention.
353
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
354
+ :return: an [N x (H * C) x T] tensor after attention.
355
+ """
356
+ bs, width, length = qkv.shape
357
+ assert width % (3 * self.n_heads) == 0
358
+ ch = width // (3 * self.n_heads)
359
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
360
+ scale = 1 / math.sqrt(math.sqrt(ch))
361
+ weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
362
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
363
+ a = th.einsum("bts,bcs->bct", weight, v)
364
+ return a.reshape(bs, -1, length)
365
+
366
+ @staticmethod
367
+ def count_flops(model, _x, y):
368
+ return count_flops_attn(model, _x, y)
369
+
370
+
371
+ class QKVAttention(nn.Module):
372
+ """
373
+ A module which performs QKV attention and splits in a different order.
374
+ """
375
+
376
+ def __init__(self, n_heads):
377
+ super().__init__()
378
+ self.n_heads = n_heads
379
+
380
+ def forward(self, qkv):
381
+ """
382
+ Apply QKV attention.
383
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
384
+ :return: an [N x (H * C) x T] tensor after attention.
385
+ """
386
+ bs, width, length = qkv.shape
387
+ assert width % (3 * self.n_heads) == 0
388
+ ch = width // (3 * self.n_heads)
389
+ q, k, v = qkv.chunk(3, dim=1)
390
+ scale = 1 / math.sqrt(math.sqrt(ch))
391
+ weight = th.einsum(
392
+ "bct,bcs->bts",
393
+ (q * scale).view(bs * self.n_heads, ch, length),
394
+ (k * scale).view(bs * self.n_heads, ch, length),
395
+ ) # More stable with f16 than dividing afterwards
396
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
397
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
398
+ return a.reshape(bs, -1, length)
399
+
400
+ @staticmethod
401
+ def count_flops(model, _x, y):
402
+ return count_flops_attn(model, _x, y)
403
+
404
+
405
+ class Timestep(nn.Module):
406
+ def __init__(self, dim):
407
+ super().__init__()
408
+ self.dim = dim
409
+
410
+ def forward(self, t):
411
+ return timestep_embedding(t, self.dim)
412
+
413
+
414
+ class UNetModel(nn.Module):
415
+ """
416
+ The full UNet model with attention and timestep embedding.
417
+ :param in_channels: channels in the input Tensor.
418
+ :param model_channels: base channel count for the model.
419
+ :param out_channels: channels in the output Tensor.
420
+ :param num_res_blocks: number of residual blocks per downsample.
421
+ :param attention_resolutions: a collection of downsample rates at which
422
+ attention will take place. May be a set, list, or tuple.
423
+ For example, if this contains 4, then at 4x downsampling, attention
424
+ will be used.
425
+ :param dropout: the dropout probability.
426
+ :param channel_mult: channel multiplier for each level of the UNet.
427
+ :param conv_resample: if True, use learned convolutions for upsampling and
428
+ downsampling.
429
+ :param dims: determines if the signal is 1D, 2D, or 3D.
430
+ :param num_classes: if specified (as an int), then this model will be
431
+ class-conditional with `num_classes` classes.
432
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
433
+ :param num_heads: the number of attention heads in each attention layer.
434
+ :param num_heads_channels: if specified, ignore num_heads and instead use
435
+ a fixed channel width per attention head.
436
+ :param num_heads_upsample: works with num_heads to set a different number
437
+ of heads for upsampling. Deprecated.
438
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
439
+ :param resblock_updown: use residual blocks for up/downsampling.
440
+ :param use_new_attention_order: use a different attention pattern for potentially
441
+ increased efficiency.
442
+ """
443
+
444
+ def __init__(
445
+ self,
446
+ image_size,
447
+ in_channels,
448
+ model_channels,
449
+ out_channels,
450
+ num_res_blocks,
451
+ attention_resolutions,
452
+ dropout=0,
453
+ channel_mult=(1, 2, 4, 8),
454
+ conv_resample=True,
455
+ dims=2,
456
+ num_classes=None,
457
+ use_checkpoint=False,
458
+ use_fp16=False,
459
+ use_bf16=False,
460
+ num_heads=-1,
461
+ num_head_channels=-1,
462
+ num_heads_upsample=-1,
463
+ use_scale_shift_norm=False,
464
+ resblock_updown=False,
465
+ use_new_attention_order=False,
466
+ use_spatial_transformer=False, # custom transformer support
467
+ transformer_depth=1, # custom transformer support
468
+ context_dim=None, # custom transformer support
469
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
470
+ legacy=True,
471
+ disable_self_attentions=None,
472
+ num_attention_blocks=None,
473
+ disable_middle_self_attn=False,
474
+ use_linear_in_transformer=False,
475
+ adm_in_channels=None,
476
+ ckpt_path=None,
477
+ ignore_keys=[], # ignore keys for loading checkpoint
478
+ ):
479
+ super().__init__()
480
+ if use_spatial_transformer:
481
+ assert (
482
+ context_dim is not None
483
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
484
+
485
+ if context_dim is not None:
486
+ assert (
487
+ use_spatial_transformer
488
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
489
+ from omegaconf.listconfig import ListConfig
490
+
491
+ if type(context_dim) == ListConfig:
492
+ context_dim = list(context_dim)
493
+
494
+ if num_heads_upsample == -1:
495
+ num_heads_upsample = num_heads
496
+
497
+ if num_heads == -1:
498
+ assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
499
+
500
+ if num_head_channels == -1:
501
+ assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
502
+
503
+ self.image_size = image_size
504
+ self.in_channels = in_channels
505
+ self.model_channels = model_channels
506
+ self.out_channels = out_channels
507
+ if isinstance(num_res_blocks, int):
508
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
509
+ else:
510
+ if len(num_res_blocks) != len(channel_mult):
511
+ raise ValueError(
512
+ "provide num_res_blocks either as an int (globally constant) or "
513
+ "as a list/tuple (per-level) with the same length as channel_mult"
514
+ )
515
+ self.num_res_blocks = num_res_blocks
516
+ if disable_self_attentions is not None:
517
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
518
+ assert len(disable_self_attentions) == len(channel_mult)
519
+ if num_attention_blocks is not None:
520
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
521
+ assert all(
522
+ map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))
523
+ )
524
+ print(
525
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
526
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
527
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
528
+ f"attention will still not be set."
529
+ )
530
+
531
+ self.attention_resolutions = attention_resolutions
532
+ self.dropout = dropout
533
+ self.channel_mult = channel_mult
534
+ self.conv_resample = conv_resample
535
+ self.num_classes = num_classes
536
+ self.use_checkpoint = use_checkpoint
537
+ self.dtype = th.float16 if use_fp16 else th.float32
538
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
539
+ self.num_heads = num_heads
540
+ self.num_head_channels = num_head_channels
541
+ self.num_heads_upsample = num_heads_upsample
542
+ self.predict_codebook_ids = n_embed is not None
543
+
544
+ time_embed_dim = model_channels * 4
545
+ self.time_embed = nn.Sequential(
546
+ linear(model_channels, time_embed_dim),
547
+ nn.SiLU(),
548
+ linear(time_embed_dim, time_embed_dim),
549
+ )
550
+
551
+ if self.num_classes is not None:
552
+ if isinstance(self.num_classes, int):
553
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
554
+ elif self.num_classes == "continuous":
555
+ print("setting up linear c_adm embedding layer")
556
+ self.label_emb = nn.Linear(1, time_embed_dim)
557
+ elif self.num_classes == "sequential":
558
+ assert adm_in_channels is not None
559
+ self.label_emb = nn.Sequential(
560
+ nn.Sequential(
561
+ linear(adm_in_channels, time_embed_dim),
562
+ nn.SiLU(),
563
+ linear(time_embed_dim, time_embed_dim),
564
+ )
565
+ )
566
+ else:
567
+ raise ValueError()
568
+
569
+ self.input_blocks = nn.ModuleList(
570
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
571
+ )
572
+ self._feature_size = model_channels
573
+ input_block_chans = [model_channels]
574
+ ch = model_channels
575
+ ds = 1
576
+ for level, mult in enumerate(channel_mult):
577
+ for nr in range(self.num_res_blocks[level]):
578
+ layers = [
579
+ ResBlock(
580
+ ch,
581
+ time_embed_dim,
582
+ dropout,
583
+ out_channels=mult * model_channels,
584
+ dims=dims,
585
+ use_checkpoint=use_checkpoint,
586
+ use_scale_shift_norm=use_scale_shift_norm,
587
+ )
588
+ ]
589
+ ch = mult * model_channels
590
+ if ds in attention_resolutions:
591
+ if num_head_channels == -1:
592
+ dim_head = ch // num_heads
593
+ else:
594
+ num_heads = ch // num_head_channels
595
+ dim_head = num_head_channels
596
+ if legacy:
597
+ # num_heads = 1
598
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
599
+ if exists(disable_self_attentions):
600
+ disabled_sa = disable_self_attentions[level]
601
+ else:
602
+ disabled_sa = False
603
+
604
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
605
+ layers.append(
606
+ AttentionBlock(
607
+ ch,
608
+ use_checkpoint=use_checkpoint,
609
+ num_heads=num_heads,
610
+ num_head_channels=dim_head,
611
+ use_new_attention_order=use_new_attention_order,
612
+ )
613
+ if not use_spatial_transformer
614
+ else SpatialTransformer(
615
+ ch,
616
+ num_heads,
617
+ dim_head,
618
+ depth=transformer_depth,
619
+ context_dim=context_dim,
620
+ disable_self_attn=disabled_sa,
621
+ use_linear=use_linear_in_transformer,
622
+ use_checkpoint=use_checkpoint,
623
+ )
624
+ )
625
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
626
+ self._feature_size += ch
627
+ input_block_chans.append(ch)
628
+ if level != len(channel_mult) - 1:
629
+ out_ch = ch
630
+ self.input_blocks.append(
631
+ TimestepEmbedSequential(
632
+ ResBlock(
633
+ ch,
634
+ time_embed_dim,
635
+ dropout,
636
+ out_channels=out_ch,
637
+ dims=dims,
638
+ use_checkpoint=use_checkpoint,
639
+ use_scale_shift_norm=use_scale_shift_norm,
640
+ down=True,
641
+ )
642
+ if resblock_updown
643
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
644
+ )
645
+ )
646
+ ch = out_ch
647
+ input_block_chans.append(ch)
648
+ ds *= 2
649
+ self._feature_size += ch
650
+
651
+ if num_head_channels == -1:
652
+ dim_head = ch // num_heads
653
+ else:
654
+ num_heads = ch // num_head_channels
655
+ dim_head = num_head_channels
656
+ if legacy:
657
+ # num_heads = 1
658
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
659
+ self.middle_block = TimestepEmbedSequential(
660
+ ResBlock(
661
+ ch,
662
+ time_embed_dim,
663
+ dropout,
664
+ dims=dims,
665
+ use_checkpoint=use_checkpoint,
666
+ use_scale_shift_norm=use_scale_shift_norm,
667
+ ),
668
+ AttentionBlock(
669
+ ch,
670
+ use_checkpoint=use_checkpoint,
671
+ num_heads=num_heads,
672
+ num_head_channels=dim_head,
673
+ use_new_attention_order=use_new_attention_order,
674
+ )
675
+ if not use_spatial_transformer
676
+ else SpatialTransformer( # always uses a self-attn
677
+ ch,
678
+ num_heads,
679
+ dim_head,
680
+ depth=transformer_depth,
681
+ context_dim=context_dim,
682
+ disable_self_attn=disable_middle_self_attn,
683
+ use_linear=use_linear_in_transformer,
684
+ use_checkpoint=use_checkpoint,
685
+ ),
686
+ ResBlock(
687
+ ch,
688
+ time_embed_dim,
689
+ dropout,
690
+ dims=dims,
691
+ use_checkpoint=use_checkpoint,
692
+ use_scale_shift_norm=use_scale_shift_norm,
693
+ ),
694
+ )
695
+ self._feature_size += ch
696
+
697
+ self.output_blocks = nn.ModuleList([])
698
+ for level, mult in list(enumerate(channel_mult))[::-1]:
699
+ for i in range(self.num_res_blocks[level] + 1):
700
+ ich = input_block_chans.pop()
701
+ layers = [
702
+ ResBlock(
703
+ ch + ich,
704
+ time_embed_dim,
705
+ dropout,
706
+ out_channels=model_channels * mult,
707
+ dims=dims,
708
+ use_checkpoint=use_checkpoint,
709
+ use_scale_shift_norm=use_scale_shift_norm,
710
+ )
711
+ ]
712
+ ch = model_channels * mult
713
+ if ds in attention_resolutions:
714
+ if num_head_channels == -1:
715
+ dim_head = ch // num_heads
716
+ else:
717
+ num_heads = ch // num_head_channels
718
+ dim_head = num_head_channels
719
+ if legacy:
720
+ # num_heads = 1
721
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
722
+ if exists(disable_self_attentions):
723
+ disabled_sa = disable_self_attentions[level]
724
+ else:
725
+ disabled_sa = False
726
+
727
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
728
+ layers.append(
729
+ AttentionBlock(
730
+ ch,
731
+ use_checkpoint=use_checkpoint,
732
+ num_heads=num_heads_upsample,
733
+ num_head_channels=dim_head,
734
+ use_new_attention_order=use_new_attention_order,
735
+ )
736
+ if not use_spatial_transformer
737
+ else SpatialTransformer(
738
+ ch,
739
+ num_heads,
740
+ dim_head,
741
+ depth=transformer_depth,
742
+ context_dim=context_dim,
743
+ disable_self_attn=disabled_sa,
744
+ use_linear=use_linear_in_transformer,
745
+ use_checkpoint=use_checkpoint,
746
+ )
747
+ )
748
+ if level and i == self.num_res_blocks[level]:
749
+ out_ch = ch
750
+ layers.append(
751
+ ResBlock(
752
+ ch,
753
+ time_embed_dim,
754
+ dropout,
755
+ out_channels=out_ch,
756
+ dims=dims,
757
+ use_checkpoint=use_checkpoint,
758
+ use_scale_shift_norm=use_scale_shift_norm,
759
+ up=True,
760
+ )
761
+ if resblock_updown
762
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
763
+ )
764
+ ds //= 2
765
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
766
+ self._feature_size += ch
767
+
768
+ self.out = nn.Sequential(
769
+ normalization(ch),
770
+ nn.SiLU(),
771
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
772
+ )
773
+ if self.predict_codebook_ids:
774
+ self.id_predictor = nn.Sequential(
775
+ normalization(ch),
776
+ conv_nd(dims, model_channels, n_embed, 1),
777
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
778
+ )
779
+
780
+ if ckpt_path is not None:
781
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
782
+
783
+ def init_from_ckpt(self, path, ignore_keys=list()):
784
+ sd = th.load(path, map_location="cpu")["state_dict"]
785
+ keys = list(sd.keys())
786
+ for k in keys:
787
+ for ik in ignore_keys:
788
+ if k.startswith(ik):
789
+ print("Deleting key {} from state_dict.".format(k))
790
+ del sd[k]
791
+ self.load_state_dict(sd, strict=False)
792
+ print(f"Restored from {path}")
793
+
794
+ def convert_to_fp16(self):
795
+ """
796
+ Convert the torso of the model to float16.
797
+ """
798
+ self.input_blocks.apply(convert_module_to_f16)
799
+ self.middle_block.apply(convert_module_to_f16)
800
+ self.output_blocks.apply(convert_module_to_f16)
801
+
802
+ def convert_to_fp32(self):
803
+ """
804
+ Convert the torso of the model to float32.
805
+ """
806
+ self.input_blocks.apply(convert_module_to_f32)
807
+ self.middle_block.apply(convert_module_to_f32)
808
+ self.output_blocks.apply(convert_module_to_f32)
809
+
810
+ def forward(self, x, timesteps=None, context=None, y=None, return_intermediates=False, **kwargs):
811
+ """
812
+ Apply the model to an input batch.
813
+ :param x: an [N x C x ...] Tensor of inputs.
814
+ :param timesteps: a 1-D batch of timesteps.
815
+ :param context: conditioning plugged in via crossattn
816
+ :param y: an [N] Tensor of labels, if class-conditional.
817
+ :return: an [N x C x ...] Tensor of outputs.
818
+ """
819
+ assert (y is not None) == (
820
+ self.num_classes is not None
821
+ ), "must specify y if and only if the model is class-conditional"
822
+ hs = []
823
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
824
+ emb = self.time_embed(t_emb)
825
+
826
+ if self.num_classes is not None:
827
+ assert y.shape[0] == x.shape[0]
828
+ emb = emb + self.label_emb(y)
829
+
830
+ h = x.type(self.dtype)
831
+ for module in self.input_blocks:
832
+ h = module(h, emb, context)
833
+ hs.append(h)
834
+ h = self.middle_block(h, emb, context)
835
+
836
+ intermediates = [h]
837
+
838
+ for module in self.output_blocks:
839
+ h = th.cat([h, hs.pop()], dim=1)
840
+ h = module(h, emb, context)
841
+ if return_intermediates:
842
+ intermediates.append(h)
843
+ h = h.type(x.dtype)
844
+ if return_intermediates:
845
+ return intermediates
846
+ if self.predict_codebook_ids:
847
+ return self.id_predictor(h)
848
+ else:
849
+ return self.out(h)
ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
+ from ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(
18
+ self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
19
+ ):
20
+ betas = make_beta_schedule(
21
+ beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s
22
+ )
23
+ alphas = 1.0 - betas
24
+ alphas_cumprod = np.cumprod(alphas, axis=0)
25
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
26
+
27
+ (timesteps,) = betas.shape
28
+ self.num_timesteps = int(timesteps)
29
+ self.linear_start = linear_start
30
+ self.linear_end = linear_end
31
+ assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep"
32
+
33
+ to_torch = partial(torch.tensor, dtype=torch.float32)
34
+
35
+ self.register_buffer("betas", to_torch(betas))
36
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
37
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
38
+
39
+ # calculations for diffusion q(x_t | x_{t-1}) and others
40
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
41
+ self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)))
42
+ self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)))
43
+ self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)))
44
+ self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
45
+
46
+ def q_sample(self, x_start, t, noise=None):
47
+ noise = default(noise, lambda: torch.randn_like(x_start))
48
+ return (
49
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
50
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
51
+ )
52
+
53
+ def forward(self, x):
54
+ return x, None
55
+
56
+ def decode(self, x):
57
+ return x
58
+
59
+
60
+ class SimpleImageConcat(AbstractLowScaleModel):
61
+ # no noise level conditioning
62
+ def __init__(self):
63
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
64
+ self.max_noise_level = 0
65
+
66
+ def forward(self, x):
67
+ # fix to constant noise level
68
+ return x, torch.zeros(x.shape[0], device=x.device).long()
69
+
70
+
71
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
72
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
73
+ super().__init__(noise_schedule_config=noise_schedule_config)
74
+ self.max_noise_level = max_noise_level
75
+
76
+ def forward(self, x, noise_level=None):
77
+ if noise_level is None:
78
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
79
+ else:
80
+ assert isinstance(noise_level, torch.Tensor)
81
+ z = self.q_sample(x, noise_level)
82
+ return z, noise_level
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "squaredcos_cap_v2": # used for karlo prior
38
+ # return early
39
+ return betas_for_alpha_bar(
40
+ n_timestep,
41
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
42
+ )
43
+
44
+ elif schedule == "sqrt_linear":
45
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
46
+ elif schedule == "sqrt":
47
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
48
+ else:
49
+ raise ValueError(f"schedule '{schedule}' unknown.")
50
+ return betas.numpy()
51
+
52
+
53
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
54
+ if ddim_discr_method == 'uniform':
55
+ c = num_ddpm_timesteps // num_ddim_timesteps
56
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
57
+ elif ddim_discr_method == 'quad':
58
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
59
+ else:
60
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
61
+
62
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
63
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
64
+ steps_out = ddim_timesteps + 1
65
+ if verbose:
66
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
67
+ return steps_out
68
+
69
+
70
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
71
+ # select alphas for computing the variance schedule
72
+ alphas = alphacums[ddim_timesteps]
73
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
74
+
75
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
76
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
77
+ if verbose:
78
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
79
+ print(f'For the chosen value of eta, which is {eta}, '
80
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
81
+ return sigmas, alphas, alphas_prev
82
+
83
+
84
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
85
+ """
86
+ Create a beta schedule that discretizes the given alpha_t_bar function,
87
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
88
+ :param num_diffusion_timesteps: the number of betas to produce.
89
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
90
+ produces the cumulative product of (1-beta) up to that
91
+ part of the diffusion process.
92
+ :param max_beta: the maximum beta to use; use values lower than 1 to
93
+ prevent singularities.
94
+ """
95
+ betas = []
96
+ for i in range(num_diffusion_timesteps):
97
+ t1 = i / num_diffusion_timesteps
98
+ t2 = (i + 1) / num_diffusion_timesteps
99
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
100
+ return np.array(betas)
101
+
102
+
103
+ def extract_into_tensor(a, t, x_shape):
104
+ b, *_ = t.shape
105
+ out = a.gather(-1, t)
106
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
107
+
108
+
109
+ def checkpoint(func, inputs, params, flag):
110
+ """
111
+ Evaluate a function without caching intermediate activations, allowing for
112
+ reduced memory at the expense of extra compute in the backward pass.
113
+ :param func: the function to evaluate.
114
+ :param inputs: the argument sequence to pass to `func`.
115
+ :param params: a sequence of parameters `func` depends on but does not
116
+ explicitly take as arguments.
117
+ :param flag: if False, disable gradient checkpointing.
118
+ """
119
+ if flag:
120
+ args = tuple(inputs) + tuple(params)
121
+ return CheckpointFunction.apply(func, len(inputs), *args)
122
+ else:
123
+ return func(*inputs)
124
+
125
+
126
+ class CheckpointFunction(torch.autograd.Function):
127
+ @staticmethod
128
+ def forward(ctx, run_function, length, *args):
129
+ ctx.run_function = run_function
130
+ ctx.input_tensors = list(args[:length])
131
+ ctx.input_params = list(args[length:])
132
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
133
+ "dtype": torch.get_autocast_gpu_dtype(),
134
+ "cache_enabled": torch.is_autocast_cache_enabled()}
135
+ with torch.no_grad():
136
+ output_tensors = ctx.run_function(*ctx.input_tensors)
137
+ return output_tensors
138
+
139
+ @staticmethod
140
+ def backward(ctx, *output_grads):
141
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
142
+ with torch.enable_grad(), \
143
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
144
+ # Fixes a bug where the first op in run_function modifies the
145
+ # Tensor storage in place, which is not allowed for detach()'d
146
+ # Tensors.
147
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
148
+ output_tensors = ctx.run_function(*shallow_copies)
149
+ input_grads = torch.autograd.grad(
150
+ output_tensors,
151
+ ctx.input_tensors + ctx.input_params,
152
+ output_grads,
153
+ allow_unused=True,
154
+ )
155
+ del ctx.input_tensors
156
+ del ctx.input_params
157
+ del output_tensors
158
+ return (None, None) + input_grads
159
+
160
+
161
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
162
+ """
163
+ Create sinusoidal timestep embeddings.
164
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
165
+ These may be fractional.
166
+ :param dim: the dimension of the output.
167
+ :param max_period: controls the minimum frequency of the embeddings.
168
+ :return: an [N x dim] Tensor of positional embeddings.
169
+ """
170
+ if not repeat_only:
171
+ half = dim // 2
172
+ freqs = torch.exp(
173
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
174
+ ).to(device=timesteps.device)
175
+ args = timesteps[:, None].float() * freqs[None]
176
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
177
+ if dim % 2:
178
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
179
+ else:
180
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
181
+ return embedding
182
+
183
+
184
+ def zero_module(module):
185
+ """
186
+ Zero out the parameters of a module and return it.
187
+ """
188
+ for p in module.parameters():
189
+ p.detach().zero_()
190
+ return module
191
+
192
+
193
+ def scale_module(module, scale):
194
+ """
195
+ Scale the parameters of a module and return it.
196
+ """
197
+ for p in module.parameters():
198
+ p.detach().mul_(scale)
199
+ return module
200
+
201
+
202
+ def mean_flat(tensor):
203
+ """
204
+ Take the mean over all non-batch dimensions.
205
+ """
206
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
207
+
208
+
209
+ def normalization(channels):
210
+ """
211
+ Make a standard normalization layer.
212
+ :param channels: number of input channels.
213
+ :return: an nn.Module for normalization.
214
+ """
215
+ return GroupNorm32(32, channels)
216
+
217
+
218
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
219
+ class SiLU(nn.Module):
220
+ def forward(self, x):
221
+ return x * torch.sigmoid(x)
222
+
223
+
224
+ class GroupNorm32(nn.GroupNorm):
225
+ def forward(self, x):
226
+ return super().forward(x.float()).type(x.dtype)
227
+
228
+
229
+ def conv_nd(dims, *args, **kwargs):
230
+ """
231
+ Create a 1D, 2D, or 3D convolution module.
232
+ """
233
+ if dims == 1:
234
+ return nn.Conv1d(*args, **kwargs)
235
+ elif dims == 2:
236
+ return nn.Conv2d(*args, **kwargs)
237
+ elif dims == 3:
238
+ return nn.Conv3d(*args, **kwargs)
239
+ raise ValueError(f"unsupported dimensions: {dims}")
240
+
241
+
242
+ def linear(*args, **kwargs):
243
+ """
244
+ Create a linear module.
245
+ """
246
+ return nn.Linear(*args, **kwargs)
247
+
248
+
249
+ def avg_pool_nd(dims, *args, **kwargs):
250
+ """
251
+ Create a 1D, 2D, or 3D average pooling module.
252
+ """
253
+ if dims == 1:
254
+ return nn.AvgPool1d(*args, **kwargs)
255
+ elif dims == 2:
256
+ return nn.AvgPool2d(*args, **kwargs)
257
+ elif dims == 3:
258
+ return nn.AvgPool3d(*args, **kwargs)
259
+ raise ValueError(f"unsupported dimensions: {dims}")
260
+
261
+
262
+ class HybridConditioner(nn.Module):
263
+
264
+ def __init__(self, c_concat_config, c_crossattn_config):
265
+ super().__init__()
266
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
267
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
268
+
269
+ def forward(self, c_concat, c_crossattn):
270
+ c_concat = self.concat_conditioner(c_concat)
271
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
272
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
273
+
274
+
275
+ def noise_like(shape, device, repeat=False):
276
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
277
+ noise = lambda: torch.randn(shape, device=device)
278
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.0])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
45
+ else:
46
+ return 0.5 * torch.sum(
47
+ torch.pow(self.mean - other.mean, 2) / other.var
48
+ + self.var / other.var
49
+ - 1.0
50
+ - self.logvar
51
+ + other.logvar,
52
+ dim=[1, 2, 3],
53
+ )
54
+
55
+ def nll(self, sample, dims=[1, 2, 3]):
56
+ if self.deterministic:
57
+ return torch.Tensor([0.0])
58
+ logtwopi = np.log(2.0 * np.pi)
59
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
82
+
83
+ return 0.5 * (
84
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
85
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int)
15
+ )
16
+
17
+ for name, p in model.named_parameters():
18
+ if p.requires_grad:
19
+ # remove as '.'-character is not allowed in buffers
20
+ s_name = name.replace(".", "")
21
+ self.m_name2s_name.update({name: s_name})
22
+ self.register_buffer(s_name, p.clone().detach().data)
23
+
24
+ self.collected_params = []
25
+
26
+ def reset_num_updates(self):
27
+ del self.num_updates
28
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
29
+
30
+ def forward(self, model):
31
+ decay = self.decay
32
+
33
+ if self.num_updates >= 0:
34
+ self.num_updates += 1
35
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
36
+
37
+ one_minus_decay = 1.0 - decay
38
+
39
+ with torch.no_grad():
40
+ m_param = dict(model.named_parameters())
41
+ shadow_params = dict(self.named_buffers())
42
+
43
+ for key in m_param:
44
+ if m_param[key].requires_grad:
45
+ sname = self.m_name2s_name[key]
46
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
47
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
48
+ else:
49
+ assert not key in self.m_name2s_name
50
+
51
+ def copy_to(self, model):
52
+ m_param = dict(model.named_parameters())
53
+ shadow_params = dict(self.named_buffers())
54
+ for key in m_param:
55
+ if m_param[key].requires_grad:
56
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
57
+ else:
58
+ assert not key in self.m_name2s_name
59
+
60
+ def store(self, parameters):
61
+ """
62
+ Save the current parameters for restoring later.
63
+ Args:
64
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
65
+ temporarily stored.
66
+ """
67
+ self.collected_params = [param.clone() for param in parameters]
68
+
69
+ def restore(self, parameters):
70
+ """
71
+ Restore the parameters stored with the `store` method.
72
+ Useful to validate the model with EMA parameters without affecting the
73
+ original optimization process. Store the parameters before the
74
+ `copy_to` method. After validation (or model saving), use this to
75
+ restore the former parameters.
76
+ Args:
77
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
78
+ updated with the stored parameters.
79
+ """
80
+ for c_param, param in zip(self.collected_params, parameters):
81
+ param.data.copy_(c_param.data)
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ from torchvision.models.vision_transformer import Encoder
7
+
8
+ class AbstractEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def encode(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+
16
+ class IdentityEncoder(AbstractEncoder):
17
+ def encode(self, x):
18
+ return x
19
+
20
+
21
+ class ClassEmbedder(nn.Module):
22
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
23
+ super().__init__()
24
+ self.key = key
25
+ self.embedding = nn.Embedding(n_classes+1, embed_dim)
26
+ self.n_classes = n_classes
27
+ self.ucg_rate = ucg_rate
28
+
29
+ def forward(self, batch, key=None, disable_dropout=False):
30
+ if key is None:
31
+ key = self.key
32
+ # this is for use in crossattn
33
+ c = batch[key][:, None]
34
+ if self.ucg_rate > 0.0 and not disable_dropout:
35
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
36
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
37
+ c = c.long()
38
+ c = self.embedding(c)
39
+ return c
40
+
41
+ def get_unconditional_conditioning(self, bs, device="cuda"):
42
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
43
+ uc = torch.ones((bs,), device=device) * uc_class
44
+ uc = {self.key: uc}
45
+ return uc
46
+
47
+
48
+ def disabled_train(self, mode=True):
49
+ """Overwrite model.train with this function to make sure train/eval mode
50
+ does not change anymore."""
51
+ return self
52
+
53
+
54
+ class EmbeddingViT2(nn.Module):
55
+
56
+ """
57
+ 1. more transformer blocks
58
+ 2. correct padding : non zero embeddings at the center, instead of beginning
59
+ 3. classifier guidance null token replacement AFTER transformer instead of before
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ feat_key="feat",
65
+ mag_key="mag",
66
+ input_channels=1024,
67
+ hidden_channels=512,
68
+ vit_mlp_dim=2048,
69
+ output_channels=512,
70
+ seq_length=64,
71
+ mag_levels=8,
72
+ num_layers=12,
73
+ num_heads=8,
74
+ p_uncond=0,
75
+ ckpt_path=None,
76
+ ignore_keys=[],
77
+ ):
78
+ super(EmbeddingViT2, self).__init__()
79
+
80
+ self.mag_embedding = nn.Embedding(mag_levels, hidden_channels)
81
+ self.feat_key = feat_key
82
+ self.mag_key = mag_key
83
+ self.hidden_channels = hidden_channels
84
+
85
+ self.dim_reduce = nn.Linear(input_channels, hidden_channels)
86
+
87
+ self.pad_token = nn.Parameter(torch.randn(1, 1, hidden_channels))
88
+ self.encoder = Encoder(
89
+ seq_length=seq_length + 1,
90
+ num_layers=num_layers,
91
+ num_heads=num_heads,
92
+ hidden_dim=hidden_channels,
93
+ mlp_dim=vit_mlp_dim,
94
+ dropout=0,
95
+ attention_dropout=0,
96
+ )
97
+ self.final_proj = nn.Linear(hidden_channels, output_channels)
98
+ self.p_uncond = p_uncond
99
+
100
+ # if ckpt_path is not None:
101
+ # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
102
+
103
+ def forward(self, batch):
104
+ x = batch[self.feat_key]
105
+ int_mag = batch[self.mag_key]
106
+
107
+ # Process inputs
108
+ x = self.process_input_batch(x) # Shape: [batch_size, 64, hidden_channels]
109
+
110
+ mag_embed = self.mag_embedding(int_mag).unsqueeze(1) # Shape: [batch_size, 1, hidden_channels]
111
+ x = torch.cat((mag_embed, x), dim=1) # Shape: [batch_size, 65, hidden_channels]
112
+
113
+ x = self.encoder(x)
114
+
115
+ x = self.final_proj(x) # Shape: [batch_size, 65, output_channels]
116
+
117
+ return x
118
+
119
+ def process_input_batch(self, x):
120
+ if isinstance(x, torch.Tensor):
121
+ x = list(x)
122
+ if isinstance(x, list):
123
+ return torch.stack([self.process_single_input(item) for item in x])
124
+ else:
125
+ return self.process_single_input(x).unsqueeze(0)
126
+
127
+ def process_single_input(self, x):
128
+ # Ensure x is 3D: [channels, height, width]
129
+ if x.dim() == 2:
130
+ x = x.unsqueeze(0)
131
+
132
+ c, h, w = x.shape
133
+
134
+ n = h * w
135
+
136
+ x = x.view(c, -1).transpose(0, 1)
137
+ x = self.dim_reduce(x)
138
+
139
+ if h == w == 1:
140
+ # center the token
141
+ mask = torch.ones(64, device=x.device)
142
+ mask[32] = 0
143
+
144
+ elif h < 8 or w < 8:
145
+ # pad x to 64 tokens, keep the original tokens at the center
146
+
147
+ x = F.pad(x, (0, 0, 32 - n // 2, 32 - n // 2))
148
+ mask = torch.ones(64, device=x.device)
149
+ mask[32 - n // 2 : 32 + n // 2] = 0
150
+
151
+ else:
152
+ # we used avg pooling in the dataloader
153
+ return x
154
+
155
+ x = x * (1 - mask.unsqueeze(1)) + self.pad_token * mask.unsqueeze(1)
156
+ return x.squeeze() # Return as [64, hidden_channels]
157
+
158
+ def encode(self, batch):
159
+ c = self.forward(batch)
160
+ # replace features with zeros with probability p_uncond
161
+ if self.p_uncond > 0.0:
162
+ mask = 1.0 - torch.bernoulli(torch.ones(len(c)) * self.p_uncond)
163
+ mask = mask[:, None, None].to(c.device)
164
+ c = mask * c
165
+ return c
166
+
167
+
168
+
169
+ class EmbeddingViT2_5(EmbeddingViT2):
170
+
171
+ """
172
+ v2 but layer norm at the end
173
+ """
174
+
175
+ def __init__(self, *args, **kwargs):
176
+
177
+ super().__init__(*args, **kwargs)
178
+
179
+ hidden_channels = kwargs.get("hidden_channels")
180
+
181
+ self.layer_norm = nn.LayerNorm(hidden_channels)
182
+
183
+
184
+ def init_from_ckpt(self, path, ignore_keys=list()):
185
+ sd = torch.load(path, map_location="cpu")["state_dict"]
186
+
187
+ sd_cond_stage = {k.replace("cond_stage_model.", ""):v for k,v in sd.items() if "cond_stage_model" in k}
188
+
189
+ self.load_state_dict(sd_cond_stage, strict=True)
190
+ print(f"Restored from {path}")
191
+
192
+
193
+ def forward(self, batch):
194
+ x = batch[self.feat_key]
195
+ int_mag = batch[self.mag_key]
196
+
197
+ # Process inputs
198
+ x = self.process_input_batch(x) # Shape: [batch_size, 64, hidden_channels]
199
+
200
+ mag_embed = self.mag_embedding(int_mag).unsqueeze(1) # Shape: [batch_size, 1, hidden_channels]
201
+ x = torch.cat((mag_embed, x), dim=1) # Shape: [batch_size, 65, hidden_channels]
202
+
203
+ x = self.encoder(x)
204
+
205
+ x = self.final_proj(x)
206
+ x = self.layer_norm(x)
207
+
208
+ return x
ldm/util.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+
7
+ from inspect import isfunction
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ def autocast(f):
12
+ def do_autocast(*args, **kwargs):
13
+ with torch.cuda.amp.autocast(
14
+ enabled=True, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled()
15
+ ):
16
+ return f(*args, **kwargs)
17
+
18
+ return do_autocast
19
+
20
+
21
+ def log_txt_as_img(wh, xc, size=10):
22
+ # wh a tuple of (width, height)
23
+ # xc a list of captions to plot
24
+ b = len(xc)
25
+ txts = list()
26
+ for bi in range(b):
27
+ txt = Image.new("RGB", wh, color="white")
28
+ draw = ImageDraw.Draw(txt)
29
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
30
+ nc = int(40 * (wh[0] / 256))
31
+ lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
32
+
33
+ try:
34
+ draw.text((0, 0), lines, fill="black", font=font)
35
+ except UnicodeEncodeError:
36
+ print("Cant encode string for logging. Skipping.")
37
+
38
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
39
+ txts.append(txt)
40
+ txts = np.stack(txts)
41
+ txts = torch.tensor(txts)
42
+ return txts
43
+
44
+
45
+ def ismap(x):
46
+ if not isinstance(x, torch.Tensor):
47
+ return False
48
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
49
+
50
+
51
+ def isimage(x):
52
+ if not isinstance(x, torch.Tensor):
53
+ return False
54
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
55
+
56
+
57
+ def exists(x):
58
+ return x is not None
59
+
60
+
61
+ def default(val, d):
62
+ if exists(val):
63
+ return val
64
+ return d() if isfunction(d) else d
65
+
66
+
67
+ def mean_flat(tensor):
68
+ """
69
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
70
+ Take the mean over all non-batch dimensions.
71
+ """
72
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
73
+
74
+
75
+ def count_params(model, verbose=False):
76
+ total_params = sum(p.numel() for p in model.parameters())
77
+ if verbose:
78
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
79
+ return total_params
80
+
81
+
82
+ def instantiate_from_config(config):
83
+ if not "target" in config:
84
+ if config == "__is_first_stage__":
85
+ return None
86
+ elif config == "__is_unconditional__":
87
+ return None
88
+ raise KeyError("Expected key `target` to instantiate.")
89
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
90
+
91
+
92
+ def get_obj_from_str(string, reload=False):
93
+ module, cls = string.rsplit(".", 1)
94
+ if reload:
95
+ module_imp = importlib.import_module(module)
96
+ importlib.reload(module_imp)
97
+ return getattr(importlib.import_module(module, package=None), cls)
98
+
99
+
100
+ class AdamWwithEMAandWings(optim.Optimizer):
101
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
102
+ def __init__(
103
+ self,
104
+ params,
105
+ lr=1.0e-3,
106
+ betas=(0.9, 0.999),
107
+ eps=1.0e-8, # TODO: check hyperparameters before using
108
+ weight_decay=1.0e-2,
109
+ amsgrad=False,
110
+ ema_decay=0.9999, # ema decay to match previous code
111
+ ema_power=1.0,
112
+ param_names=(),
113
+ ):
114
+ """AdamW that saves EMA versions of the parameters."""
115
+ if not 0.0 <= lr:
116
+ raise ValueError("Invalid learning rate: {}".format(lr))
117
+ if not 0.0 <= eps:
118
+ raise ValueError("Invalid epsilon value: {}".format(eps))
119
+ if not 0.0 <= betas[0] < 1.0:
120
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
121
+ if not 0.0 <= betas[1] < 1.0:
122
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
123
+ if not 0.0 <= weight_decay:
124
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
125
+ if not 0.0 <= ema_decay <= 1.0:
126
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
127
+ defaults = dict(
128
+ lr=lr,
129
+ betas=betas,
130
+ eps=eps,
131
+ weight_decay=weight_decay,
132
+ amsgrad=amsgrad,
133
+ ema_decay=ema_decay,
134
+ ema_power=ema_power,
135
+ param_names=param_names,
136
+ )
137
+ super().__init__(params, defaults)
138
+
139
+ def __setstate__(self, state):
140
+ super().__setstate__(state)
141
+ for group in self.param_groups:
142
+ group.setdefault("amsgrad", False)
143
+
144
+ @torch.no_grad()
145
+ def step(self, closure=None):
146
+ """Performs a single optimization step.
147
+ Args:
148
+ closure (callable, optional): A closure that reevaluates the model
149
+ and returns the loss.
150
+ """
151
+ loss = None
152
+ if closure is not None:
153
+ with torch.enable_grad():
154
+ loss = closure()
155
+
156
+ for group in self.param_groups:
157
+ params_with_grad = []
158
+ grads = []
159
+ exp_avgs = []
160
+ exp_avg_sqs = []
161
+ ema_params_with_grad = []
162
+ state_sums = []
163
+ max_exp_avg_sqs = []
164
+ state_steps = []
165
+ amsgrad = group["amsgrad"]
166
+ beta1, beta2 = group["betas"]
167
+ ema_decay = group["ema_decay"]
168
+ ema_power = group["ema_power"]
169
+
170
+ for p in group["params"]:
171
+ if p.grad is None:
172
+ continue
173
+ params_with_grad.append(p)
174
+ if p.grad.is_sparse:
175
+ raise RuntimeError("AdamW does not support sparse gradients")
176
+ grads.append(p.grad)
177
+
178
+ state = self.state[p]
179
+
180
+ # State initialization
181
+ if len(state) == 0:
182
+ state["step"] = 0
183
+ # Exponential moving average of gradient values
184
+ state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
185
+ # Exponential moving average of squared gradient values
186
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
187
+ if amsgrad:
188
+ # Maintains max of all exp. moving avg. of sq. grad. values
189
+ state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
190
+ # Exponential moving average of parameter values
191
+ state["param_exp_avg"] = p.detach().float().clone()
192
+
193
+ exp_avgs.append(state["exp_avg"])
194
+ exp_avg_sqs.append(state["exp_avg_sq"])
195
+ ema_params_with_grad.append(state["param_exp_avg"])
196
+
197
+ if amsgrad:
198
+ max_exp_avg_sqs.append(state["max_exp_avg_sq"])
199
+
200
+ # update the steps for each param group update
201
+ state["step"] += 1
202
+ # record the step after step update
203
+ state_steps.append(state["step"])
204
+
205
+ optim._functional.adamw(
206
+ params_with_grad,
207
+ grads,
208
+ exp_avgs,
209
+ exp_avg_sqs,
210
+ max_exp_avg_sqs,
211
+ state_steps,
212
+ amsgrad=amsgrad,
213
+ beta1=beta1,
214
+ beta2=beta2,
215
+ lr=group["lr"],
216
+ weight_decay=group["weight_decay"],
217
+ eps=group["eps"],
218
+ maximize=False,
219
+ )
220
+
221
+ cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power)
222
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
223
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
224
+
225
+ return loss
model_index.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ZoomLDMPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "conditioning_encoder": [
5
+ "pipeline_zoomldm",
6
+ "ZoomLDMPipeline"
7
+ ],
8
+ "scheduler": [
9
+ "diffusers",
10
+ "DDIMScheduler"
11
+ ],
12
+ "unet": [
13
+ "pipeline_zoomldm",
14
+ "ZoomLDMPipeline"
15
+ ],
16
+ "vae": [
17
+ "pipeline_zoomldm",
18
+ "ZoomLDMPipeline"
19
+ ],
20
+ "scale_factor": 1.0,
21
+ "conditioning_key": "crossattn",
22
+ "variant": "brca"
23
+ }
pipeline_zoomldm.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom diffusers pipeline for ZoomLDM multi-scale image generation.
3
+
4
+ Dependencies: diffusers, torch; optional: safetensors, huggingface_hub, PyYAML.
5
+ Uses only stdlib (json, importlib) plus the above. No OmegaConf.
6
+ Model architectures (UNet, VAE, conditioning encoder) require ``ldm`` modules.
7
+ This pipeline auto-detects bundled local ``ldm`` folders when available.
8
+ """
9
+
10
+ import importlib
11
+ import importlib.util
12
+ import json
13
+ import sys
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from diffusers import DDIMScheduler, DiffusionPipeline
21
+ from diffusers.utils import BaseOutput
22
+ from PIL import Image
23
+
24
+
25
+ def _ensure_local_ldm_on_path():
26
+ """
27
+ Make local bundled ``ldm`` package importable without external repos.
28
+
29
+ Search near this pipeline file:
30
+ - <this_dir>/ldm
31
+ - <this_dir>/../ldm
32
+ """
33
+ if importlib.util.find_spec("ldm") is not None:
34
+ return
35
+
36
+ here = Path(__file__).resolve().parent
37
+ for candidate in (here / "ldm", here.parent / "ldm"):
38
+ if candidate.exists():
39
+ parent = str(candidate.parent)
40
+ if parent not in sys.path:
41
+ sys.path.insert(0, parent)
42
+ if importlib.util.find_spec("ldm") is not None:
43
+ return
44
+
45
+
46
+ _ensure_local_ldm_on_path()
47
+
48
+
49
+ def _get_class(target: str):
50
+ """Resolve a class from a dotted path (e.g. 'ldm.modules.xxx.UNetModel')."""
51
+ module_path, cls_name = target.rsplit(".", 1)
52
+ mod = importlib.import_module(module_path)
53
+ return getattr(mod, cls_name)
54
+
55
+
56
+ def _instantiate_from_config(config: dict):
57
+ """Instantiate from a dict with 'target' and optional 'params' (no OmegaConf)."""
58
+ if not isinstance(config, dict) or "target" not in config:
59
+ if config == "__is_first_stage__" or config == "__is_unconditional__":
60
+ return None
61
+ raise KeyError("Expected key 'target' to instantiate.")
62
+ cls = _get_class(config["target"])
63
+ params = config.get("params", {})
64
+ return cls(**params)
65
+
66
+
67
+ @dataclass
68
+ class ZoomLDMPipelineOutput(BaseOutput):
69
+ """
70
+ Output class for ZoomLDM pipeline.
71
+
72
+ Args:
73
+ images: List of PIL images or numpy array of generated images.
74
+ """
75
+
76
+ images: Union[List[Image.Image], np.ndarray, torch.Tensor]
77
+
78
+
79
+ class ZoomLDMPipeline(DiffusionPipeline):
80
+ """
81
+ Pipeline for multi-scale image generation with ZoomLDM.
82
+
83
+ This pipeline wraps the ZoomLDM model components using the native
84
+ huggingface/diffusers ``DiffusionPipeline`` interface, replacing custom
85
+ samplers with the diffusers ``DDIMScheduler``.
86
+
87
+ Args:
88
+ unet: The UNet denoising model (``UNetModel`` from openaimodel).
89
+ vae: The first-stage autoencoder (``VQModelInterface``).
90
+ conditioning_encoder: The conditioning encoder
91
+ (``EmbeddingViT2_5``).
92
+ scheduler: A diffusers noise scheduler (e.g. ``DDIMScheduler``).
93
+ scale_factor: Latent space scaling factor (default: 1.0).
94
+ conditioning_key: Type of conditioning ("crossattn", "concat",
95
+ "hybrid").
96
+ """
97
+
98
+ model_cpu_offload_seq = "conditioning_encoder->unet->vae"
99
+
100
+ def __init__(
101
+ self,
102
+ unet: torch.nn.Module,
103
+ vae: torch.nn.Module,
104
+ conditioning_encoder: torch.nn.Module,
105
+ scheduler: DDIMScheduler,
106
+ scale_factor: float = 1.0,
107
+ conditioning_key: str = "crossattn",
108
+ ):
109
+ super().__init__()
110
+ self.register_modules(
111
+ unet=unet,
112
+ vae=vae,
113
+ conditioning_encoder=conditioning_encoder,
114
+ scheduler=scheduler,
115
+ )
116
+ self.scale_factor = scale_factor
117
+ self.conditioning_key = conditioning_key
118
+
119
+ @property
120
+ def device(self) -> torch.device:
121
+ """Return the device of the pipeline's parameters."""
122
+ try:
123
+ return next(self.unet.parameters()).device
124
+ except StopIteration:
125
+ return torch.device("cpu")
126
+
127
+ def to(self, *args, **kwargs):
128
+ """
129
+ Move pipeline modules to a device/dtype.
130
+
131
+ Diffusers' default ``DiffusionPipeline.to`` expects each module to
132
+ expose a ``dtype`` attribute. ``EmbeddingViT2_5`` does not, which can
133
+ raise an ``AttributeError``. This override keeps standard ``pipe.to``
134
+ usage working for ZoomLDM custom components.
135
+ """
136
+ module_kwargs = {}
137
+ for key in ("dtype", "non_blocking", "memory_format"):
138
+ if key in kwargs:
139
+ module_kwargs[key] = kwargs[key]
140
+
141
+ # Ignore diffusers-only kwargs not accepted by torch.nn.Module.to.
142
+ device_or_dtype_args = args
143
+ if not device_or_dtype_args and "device" in kwargs:
144
+ device_or_dtype_args = (kwargs["device"],)
145
+
146
+ for name in ("unet", "vae", "conditioning_encoder"):
147
+ module = getattr(self, name, None)
148
+ if module is not None:
149
+ module.to(*device_or_dtype_args, **module_kwargs)
150
+
151
+ return self
152
+
153
+ @classmethod
154
+ def from_single_file(cls, config_path, ckpt_path, device=None, **kwargs):
155
+ """
156
+ Load a ``ZoomLDMPipeline`` from original ZoomLDM config and
157
+ checkpoint files.
158
+
159
+ Requires ``ldm`` modules. Bundled local ``ldm`` is auto-detected.
160
+
161
+ Args:
162
+ config_path: Path to the YAML config file.
163
+ ckpt_path: Path to the model checkpoint (``.ckpt`` or
164
+ ``.pt``).
165
+ device: Device to load the model onto.
166
+
167
+ Returns:
168
+ A ``ZoomLDMPipeline`` instance.
169
+
170
+ Example::
171
+
172
+ from huggingface_hub import hf_hub_download
173
+
174
+ ckpt = hf_hub_download(
175
+ "StonyBrook-CVLab/ZoomLDM", "brca/weights.ckpt"
176
+ )
177
+ cfg = hf_hub_download(
178
+ "StonyBrook-CVLab/ZoomLDM", "brca/config.yaml"
179
+ )
180
+ pipe = ZoomLDMPipeline.from_single_file(cfg, ckpt)
181
+ pipe = pipe.to("cuda")
182
+ """
183
+ import yaml
184
+
185
+ with open(config_path) as f:
186
+ config = yaml.safe_load(f)
187
+ model = _instantiate_from_config(config["model"])
188
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
189
+ if "state_dict" in state_dict:
190
+ state_dict = state_dict["state_dict"]
191
+ model.load_state_dict(state_dict, strict=False)
192
+ model.eval()
193
+
194
+ pipe = cls.from_ldm_model(model)
195
+
196
+ if device is not None:
197
+ pipe = pipe.to(device)
198
+
199
+ return pipe
200
+
201
+ @classmethod
202
+ def from_ldm_model(cls, model):
203
+ """
204
+ Create a ``ZoomLDMPipeline`` from an existing ``LatentDiffusion``
205
+ model instance.
206
+
207
+ Args:
208
+ model: A ``LatentDiffusion`` model.
209
+
210
+ Returns:
211
+ A ``ZoomLDMPipeline`` instance.
212
+ """
213
+ # Apply EMA weights if available
214
+ if hasattr(model, "use_ema") and model.use_ema:
215
+ model.model_ema.copy_to(model.model)
216
+
217
+ # Extract components
218
+ unet = model.model.diffusion_model
219
+ vae = model.first_stage_model
220
+ conditioning_encoder = model.cond_stage_model
221
+
222
+ # Disable classifier-free dropout in conditioning encoder
223
+ if hasattr(conditioning_encoder, "p_uncond"):
224
+ conditioning_encoder.p_uncond = 0
225
+
226
+ # Determine scale_factor
227
+ sf = model.scale_factor
228
+ if isinstance(sf, torch.Tensor):
229
+ sf = sf.item()
230
+
231
+ # Create a diffusers DDIMScheduler that matches the original
232
+ # noise schedule.
233
+ # - The original "linear" beta schedule uses:
234
+ # betas = linspace(sqrt(start), sqrt(end), T) ** 2
235
+ # which corresponds to "scaled_linear" in diffusers.
236
+ # - steps_offset=1 replicates the +1 shift used by the
237
+ # original DDIM sampler.
238
+ scheduler = DDIMScheduler(
239
+ num_train_timesteps=model.num_timesteps,
240
+ beta_start=model.linear_start,
241
+ beta_end=model.linear_end,
242
+ beta_schedule="scaled_linear",
243
+ clip_sample=False,
244
+ set_alpha_to_one=False,
245
+ prediction_type="epsilon",
246
+ steps_offset=1,
247
+ )
248
+
249
+ # Determine the conditioning key
250
+ conditioning_key = "crossattn"
251
+ if hasattr(model, "model") and hasattr(model.model, "conditioning_key"):
252
+ conditioning_key = model.model.conditioning_key or "crossattn"
253
+
254
+ return cls(
255
+ unet=unet,
256
+ vae=vae,
257
+ conditioning_encoder=conditioning_encoder,
258
+ scheduler=scheduler,
259
+ scale_factor=sf,
260
+ conditioning_key=conditioning_key,
261
+ )
262
+
263
+ @classmethod
264
+ def from_pretrained(
265
+ cls,
266
+ pretrained_model_name_or_path: Union[str, Path],
267
+ variant: Optional[str] = None,
268
+ device: Optional[Union[str, torch.device]] = None,
269
+ **kwargs,
270
+ ):
271
+ """
272
+ Load a ``ZoomLDMPipeline`` from a diffusers-format directory
273
+ (created by ``convert_to_diffusers.py``).
274
+
275
+ Args:
276
+ pretrained_model_name_or_path: Path to the diffusers-format
277
+ directory (or HuggingFace repo ID).
278
+ variant: Optional model variant to load when
279
+ ``pretrained_model_name_or_path`` points to a root directory
280
+ containing multiple self-contained subfolders (e.g.
281
+ ``"brca"``, ``"naip"``).
282
+ device: Device to load the model onto.
283
+
284
+ Returns:
285
+ A ``ZoomLDMPipeline`` instance.
286
+
287
+ Example::
288
+
289
+ pipe = ZoomLDMPipeline.from_pretrained(
290
+ "/root/worksapce/models/BiliSakura/ZoomLDM",
291
+ variant="brca",
292
+ )
293
+ pipe = pipe.to("cuda")
294
+ """
295
+ path = Path(pretrained_model_name_or_path)
296
+ if not path.exists():
297
+ from huggingface_hub import snapshot_download
298
+
299
+ path = Path(snapshot_download(pretrained_model_name_or_path))
300
+
301
+ path = path.resolve()
302
+
303
+ def _is_diffusers_model_dir(candidate: Path) -> bool:
304
+ required = [
305
+ candidate / "model_index.json",
306
+ candidate / "scheduler" / "scheduler_config.json",
307
+ candidate / "unet" / "config.json",
308
+ candidate / "vae" / "config.json",
309
+ candidate / "conditioning_encoder" / "config.json",
310
+ ]
311
+ return all(p.exists() for p in required)
312
+
313
+ if variant:
314
+ model_dir = path / variant
315
+ if not _is_diffusers_model_dir(model_dir):
316
+ raise FileNotFoundError(
317
+ f"Variant '{variant}' was requested, but '{model_dir}' is not a valid model directory."
318
+ )
319
+ elif _is_diffusers_model_dir(path):
320
+ model_dir = path
321
+ else:
322
+ candidate_dirs = [d for d in path.iterdir() if d.is_dir() and _is_diffusers_model_dir(d)]
323
+ if not candidate_dirs:
324
+ raise FileNotFoundError(
325
+ f"No diffusers model found at '{path}'. "
326
+ "Expected model files in this directory or in subfolders (e.g. brca/, naip/)."
327
+ )
328
+ if len(candidate_dirs) > 1:
329
+ variants = ", ".join(sorted(d.name for d in candidate_dirs))
330
+ raise ValueError(
331
+ f"Multiple model variants found at '{path}': {variants}. "
332
+ "Pass variant='<name>' to select one."
333
+ )
334
+ model_dir = candidate_dirs[0]
335
+
336
+ scheduler = DDIMScheduler.from_pretrained(model_dir / "scheduler")
337
+
338
+ _TARGETS = {
339
+ "unet": "ldm.modules.diffusionmodules.openaimodel.UNetModel",
340
+ "vae": "ldm.models.autoencoder.VQModelInterface",
341
+ "conditioning_encoder": "ldm.modules.encoders.modules.EmbeddingViT2_5",
342
+ }
343
+
344
+ def load_custom_component(name: str):
345
+ comp_path = model_dir / name
346
+ with open(comp_path / "config.json") as f:
347
+ cfg = json.load(f)
348
+
349
+ if "target" in cfg:
350
+ params = dict(cfg.get("params", {k: v for k, v in cfg.items() if k != "target"}))
351
+ params.pop("ckpt_path", None)
352
+ params.pop("ignore_keys", None)
353
+ component = _instantiate_from_config({"target": cfg["target"], "params": params})
354
+ else:
355
+ model_cls = _get_class(_TARGETS[name])
356
+ params = dict(cfg)
357
+ if name == "vae":
358
+ lc = params.get("lossconfig") or {}
359
+ if "target" not in lc:
360
+ params["lossconfig"] = {"target": "torch.nn.Identity", "params": {}}
361
+ component = model_cls(**params)
362
+
363
+ # Load weights
364
+ safetensors_path = comp_path / "diffusion_pytorch_model.safetensors"
365
+ bin_path = comp_path / "diffusion_pytorch_model.bin"
366
+ if safetensors_path.exists():
367
+ from safetensors.torch import load_file
368
+
369
+ state = load_file(str(safetensors_path))
370
+ elif bin_path.exists():
371
+ try:
372
+ state = torch.load(bin_path, map_location="cpu", weights_only=True)
373
+ except TypeError:
374
+ state = torch.load(bin_path, map_location="cpu")
375
+ else:
376
+ raise FileNotFoundError(
377
+ f"No weights found in {comp_path} "
378
+ "(expected diffusion_pytorch_model.safetensors or .bin)"
379
+ )
380
+ component.load_state_dict(state, strict=True)
381
+ component.eval()
382
+ return component
383
+
384
+ unet = load_custom_component("unet")
385
+ vae = load_custom_component("vae")
386
+ conditioning_encoder = load_custom_component("conditioning_encoder")
387
+
388
+ if hasattr(conditioning_encoder, "p_uncond"):
389
+ conditioning_encoder.p_uncond = 0
390
+
391
+ model_index_path = model_dir / "model_index.json"
392
+ if model_index_path.exists():
393
+ with open(model_index_path) as f:
394
+ model_index = json.load(f)
395
+ scale_factor = model_index.get("scale_factor", 1.0)
396
+ conditioning_key = model_index.get("conditioning_key", "crossattn")
397
+ else:
398
+ scale_factor = 1.0
399
+ conditioning_key = "crossattn"
400
+
401
+ pipe = cls(
402
+ unet=unet,
403
+ vae=vae,
404
+ conditioning_encoder=conditioning_encoder,
405
+ scheduler=scheduler,
406
+ scale_factor=scale_factor,
407
+ conditioning_key=conditioning_key,
408
+ )
409
+
410
+ if device is not None:
411
+ pipe = pipe.to(device)
412
+
413
+ return pipe
414
+
415
+ def encode_conditioning(self, ssl_features, magnification):
416
+ """
417
+ Encode conditioning inputs through the conditioning encoder.
418
+
419
+ Args:
420
+ ssl_features: SSL feature tensors (e.g. UNI or DINO-v2
421
+ embeddings).
422
+ magnification: Integer magnification level tensor.
423
+
424
+ Returns:
425
+ Encoded conditioning tensor.
426
+ """
427
+ device = self.device
428
+ cond_dict = {
429
+ self.conditioning_encoder.feat_key: ssl_features,
430
+ self.conditioning_encoder.mag_key: magnification.to(device),
431
+ }
432
+
433
+ if hasattr(self.conditioning_encoder, "encode"):
434
+ return self.conditioning_encoder.encode(cond_dict)
435
+ return self.conditioning_encoder(cond_dict)
436
+
437
+ def decode_latents(self, latents):
438
+ """
439
+ Decode latent representations to images using the VAE.
440
+
441
+ Args:
442
+ latents: Latent tensor from the diffusion process.
443
+
444
+ Returns:
445
+ Image tensor in ``[-1, 1]`` range.
446
+ """
447
+ latents = (1.0 / self.scale_factor) * latents
448
+ return self.vae.decode(latents)
449
+
450
+ @torch.no_grad()
451
+ def __call__(
452
+ self,
453
+ ssl_features: Union[torch.Tensor, list],
454
+ magnification: torch.Tensor,
455
+ num_inference_steps: int = 50,
456
+ guidance_scale: float = 2.0,
457
+ latent_shape: tuple = (3, 64, 64),
458
+ generator: Optional[torch.Generator] = None,
459
+ latents: Optional[torch.Tensor] = None,
460
+ output_type: str = "pil",
461
+ return_dict: bool = True,
462
+ ):
463
+ """
464
+ Generate images conditioned on SSL features and magnification
465
+ level.
466
+
467
+ Args:
468
+ ssl_features: SSL feature tensor(s) for conditioning.
469
+ Shape depends on the magnification level.
470
+ magnification: Integer magnification levels
471
+ (0=20x, 1=10x, 2=5x, 3=2.5x, 4=1.25x).
472
+ num_inference_steps: Number of denoising steps (default: 50).
473
+ guidance_scale: Classifier-free guidance scale (default: 2.0).
474
+ latent_shape: Shape of each latent sample
475
+ (default: ``(3, 64, 64)``).
476
+ generator: Optional random number generator for
477
+ reproducibility.
478
+ latents: Optional pre-initialized latent noise tensor.
479
+ output_type: Output format — ``"pil"``, ``"np"``, or
480
+ ``"pt"`` (default: ``"pil"``).
481
+ return_dict: Whether to return a ``ZoomLDMPipelineOutput``
482
+ or a tuple (default: ``True``).
483
+
484
+ Returns:
485
+ ``ZoomLDMPipelineOutput`` with generated images, or a tuple.
486
+
487
+ Example::
488
+
489
+ pipe = ZoomLDMPipeline.from_single_file(cfg, ckpt)
490
+ pipe = pipe.to("cuda")
491
+ output = pipe(
492
+ ssl_features=batch["ssl_feat"].to("cuda"),
493
+ magnification=batch["mag"].to("cuda"),
494
+ num_inference_steps=50,
495
+ guidance_scale=2.0,
496
+ )
497
+ images = output.images
498
+ """
499
+ device = self.device
500
+ dtype = next(self.unet.parameters()).dtype
501
+
502
+ # Determine batch size
503
+ if isinstance(ssl_features, list):
504
+ batch_size = len(ssl_features)
505
+ elif isinstance(ssl_features, torch.Tensor):
506
+ batch_size = ssl_features.shape[0]
507
+ else:
508
+ batch_size = 1
509
+
510
+ # 1. Encode conditioning
511
+ cc = self.encode_conditioning(ssl_features, magnification)
512
+ uc = torch.zeros_like(cc)
513
+
514
+ # 2. Prepare latents
515
+ if latents is None:
516
+ latents = torch.randn(
517
+ (batch_size, *latent_shape),
518
+ generator=generator,
519
+ device=device,
520
+ dtype=dtype,
521
+ )
522
+ else:
523
+ latents = latents.to(device=device, dtype=dtype)
524
+
525
+ # 3. Set up scheduler timesteps
526
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
527
+ timesteps = self.scheduler.timesteps
528
+
529
+ # 4. Denoising loop
530
+ for t in self.progress_bar(timesteps):
531
+ latent_model_input = torch.cat([latents, latents])
532
+ t_batch = t.expand(latent_model_input.shape[0])
533
+ cond_input = torch.cat([uc, cc])
534
+
535
+ # Predict noise with the UNet
536
+ with torch.amp.autocast(device_type=device.type, enabled=device.type != "cpu"):
537
+ if self.conditioning_key == "crossattn":
538
+ noise_pred = self.unet(
539
+ latent_model_input,
540
+ t_batch,
541
+ context=cond_input,
542
+ )
543
+ elif self.conditioning_key == "concat":
544
+ noise_pred = self.unet(
545
+ torch.cat(
546
+ [latent_model_input, cond_input], dim=1
547
+ ),
548
+ t_batch,
549
+ )
550
+ elif self.conditioning_key == "hybrid":
551
+ raise NotImplementedError(
552
+ "Hybrid conditioning requires c_concat and "
553
+ "c_crossattn to be passed separately. Use the "
554
+ "original LatentDiffusion model for hybrid "
555
+ "conditioning."
556
+ )
557
+ else:
558
+ noise_pred = self.unet(latent_model_input, t_batch)
559
+
560
+ # Classifier-free guidance
561
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
562
+ noise_pred = noise_pred_uncond + guidance_scale * (
563
+ noise_pred_cond - noise_pred_uncond
564
+ )
565
+
566
+ # Scheduler step
567
+ latents = self.scheduler.step(
568
+ noise_pred, t, latents, generator=generator
569
+ ).prev_sample
570
+
571
+ # 5. Decode latents to images
572
+ images = self.decode_latents(latents)
573
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
574
+
575
+ # 6. Convert output format
576
+ if output_type == "pt":
577
+ pass
578
+ elif output_type == "np":
579
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
580
+ elif output_type == "pil":
581
+ images_np = images.cpu().permute(0, 2, 3, 1).float().numpy()
582
+ images = [
583
+ Image.fromarray((img * 255).astype(np.uint8))
584
+ for img in images_np
585
+ ]
586
+ else:
587
+ raise ValueError(
588
+ f"Unknown output_type '{output_type}'. "
589
+ "Use 'pil', 'np', or 'pt'."
590
+ )
591
+
592
+ if not return_dict:
593
+ return (images,)
594
+
595
+ return ZoomLDMPipelineOutput(images=images)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "beta_end": 0.0195,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.0015,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
unet/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "use_checkpoint": true,
3
+ "use_fp16": true,
4
+ "image_size": 64,
5
+ "in_channels": 3,
6
+ "out_channels": 3,
7
+ "model_channels": 192,
8
+ "attention_resolutions": [
9
+ 8,
10
+ 4,
11
+ 2
12
+ ],
13
+ "num_res_blocks": 2,
14
+ "channel_mult": [
15
+ 1,
16
+ 2,
17
+ 3,
18
+ 5
19
+ ],
20
+ "num_heads": 1,
21
+ "use_spatial_transformer": true,
22
+ "transformer_depth": 1,
23
+ "context_dim": 512
24
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:986f315134782019941984a007527becf85c4ab9627be257451b37c3f69d90c8
3
+ size 1603762196
vae/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 3,
3
+ "n_embed": 8192,
4
+ "ddconfig": {
5
+ "double_z": false,
6
+ "z_channels": 3,
7
+ "resolution": 256,
8
+ "in_channels": 3,
9
+ "out_ch": 3,
10
+ "ch": 128,
11
+ "ch_mult": [
12
+ 1,
13
+ 2,
14
+ 4
15
+ ],
16
+ "num_res_blocks": 2,
17
+ "attn_resolutions": [],
18
+ "dropout": 0.0
19
+ },
20
+ "lossconfig": {}
21
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaaa896c36dba0715ecd41ce26bd8d981b256c32ee433804f3b1a90197560924
3
+ size 221312136