HadiZayer commited on
Commit
b988971
·
1 Parent(s): b2d4649

add MagicFixup space (code only, examples uploaded via API)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ examples/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.egg-info/
README.md CHANGED
@@ -1,14 +1,20 @@
1
  ---
2
- title: MagicFixup
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.21.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: spatially editing an image with generative models
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
  ---
2
+ title: Magic Fixup
3
+ emoji: 🪄
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.21.0
8
  app_file: app.py
9
  pinned: false
10
+ license: other
11
+ short_description: Streamlining photo editing by watching dynamic videos
12
  ---
13
 
14
+ # Magic Fixup
15
+
16
+ Demo for the paper [Magic Fixup: Streamlining Photo Editing by Watching Dynamic Videos](https://magic-fixup.github.io).
17
+
18
+ Upload your **original image** and a **coarse edit** (PNG with an alpha channel marking the edited region). Magic Fixup refines the edit to look photorealistic.
19
+
20
+ For more details, training code, and the full codebase, see the [GitHub repo](https://github.com/adobe-research/MagicFixup).
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ from huggingface_hub import hf_hub_download
4
+ from run_magicfu import MagicFixup
5
+ import torchvision
6
+ from torch import autocast
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import numpy as np
10
+
11
+
12
+ # Download checkpoint from HF Hub at startup
13
+ checkpoint_path = hf_hub_download(repo_id="HadiZayer/MagicFixup", filename="magicfu_weights")
14
+
15
+ magic_fixup = MagicFixup(model_path=checkpoint_path)
16
+
17
+
18
+ def sample(original_image, coarse_edit):
19
+ to_tensor = torchvision.transforms.ToTensor()
20
+ with autocast("cuda"):
21
+ w, h = coarse_edit.size
22
+ ref_image_t = to_tensor(original_image.resize((512, 512))).half().cuda()
23
+ coarse_edit_t = to_tensor(coarse_edit.resize((512, 512))).half().cuda()
24
+ coarse_edit_mask_t = to_tensor(coarse_edit.resize((512, 512))).half().cuda()
25
+ mask_t = (coarse_edit_mask_t[-1][None, None, ...]).half()
26
+ coarse_edit_t_rgb = coarse_edit_t[:-1]
27
+
28
+ out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50)
29
+ output = out_rgb.squeeze().cpu().detach().moveaxis(0, -1).float().numpy()
30
+ output = (output * 255.0).astype(np.uint8)
31
+ output_pil = Image.fromarray(output)
32
+ output_pil = output_pil.resize((w, h))
33
+ return output_pil
34
+
35
+
36
+ demo = gr.Interface(
37
+ fn=sample,
38
+ inputs=[
39
+ gr.Image(type="pil", image_mode="RGB", label="Original Image"),
40
+ gr.Image(type="pil", image_mode="RGBA", label="Coarse Edit (with alpha mask)"),
41
+ ],
42
+ outputs=gr.Image(label="Result"),
43
+ examples="examples",
44
+ title="Magic Fixup",
45
+ description="Upload your original image and a coarse edit (PNG with alpha channel marking the edited region). Magic Fixup will refine the edit to look photorealistic.",
46
+ )
47
+
48
+ demo.launch()
configs/collage_composite_train.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+ model:
3
+ base_learning_rate: 1.0e-05
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "inpaint"
12
+ cond_stage_key: "image"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: "rewarp"
17
+ monitor: val/loss_simple_ema
18
+ u_cond_percent: 0.2
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536
22
+
23
+
24
+ scheduler_config: # 10000 warmup steps
25
+ target: ldm.lr_scheduler.LambdaLinearScheduler
26
+ params:
27
+ warm_up_steps: [ 10000 ]
28
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
29
+ f_start: [ 1.e-6 ]
30
+ f_max: [ 1. ]
31
+ f_min: [ 1. ]
32
+
33
+ unet_config:
34
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
35
+ params:
36
+ image_size: 32 # unused
37
+ in_channels: 9
38
+ out_channels: 4
39
+ model_channels: 320
40
+ attention_resolutions: [ 4, 2, 1 ]
41
+ num_res_blocks: 2
42
+ channel_mult: [ 1, 2, 4, 4 ]
43
+ num_heads: 8
44
+ use_spatial_transformer: True
45
+ transformer_depth: 1
46
+ context_dim: 768
47
+ use_checkpoint: True
48
+ legacy: False
49
+ add_conv_in_front_of_unet: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding
76
+ params:
77
+ dino_version: "big" # [small, big, large, huge]
78
+
79
+ data:
80
+ target: main.DataModuleFromConfig
81
+ params:
82
+ batch_size: 2
83
+ num_workers: 8
84
+ use_worker_init_fn: False
85
+ wrap: False
86
+ train:
87
+ target: ldm.data.collage_dataset.CollageDataset
88
+ params:
89
+ split_files: "<specify value train path>"
90
+ image_size: 512
91
+ embedding_type: 'dino' # TODO embedding
92
+ warping_type: 'collage'
93
+ validation:
94
+ target: ldm.data.collage_dataset.CollageDataset
95
+ params:
96
+ split_files: "<specify value val path>"
97
+ image_size: 512
98
+ embedding_type: 'dino' # TODO embedding
99
+ warping_type: 'mix'
100
+ test:
101
+ target: ldm.data.collage_dataset.CollageDataset
102
+ params:
103
+ split_files: "<specify value val path>"
104
+ image_size: 512
105
+ embedding_type: 'dino' # TODO embedding
106
+ warping_type: 'mix'
107
+
108
+ lightning:
109
+ trainer:
110
+ max_epochs: 500
111
+ num_nodes: 1
112
+ num_sanity_val_steps: 0
113
+ accelerator: 'gpu'
114
+ gpus: "0,1,2,3,4,5,6,7"
configs/collage_flow_train.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+ model:
3
+ base_learning_rate: 1.0e-05
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "inpaint"
12
+ cond_stage_key: "image"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: "rewarp"
17
+ monitor: val/loss_simple_ema
18
+ u_cond_percent: 0.2
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536
22
+
23
+
24
+ scheduler_config: # 10000 warmup steps
25
+ target: ldm.lr_scheduler.LambdaLinearScheduler
26
+ params:
27
+ warm_up_steps: [ 10000 ]
28
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
29
+ f_start: [ 1.e-6 ]
30
+ f_max: [ 1. ]
31
+ f_min: [ 1. ]
32
+
33
+ unet_config:
34
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
35
+ params:
36
+ image_size: 32 # unused
37
+ in_channels: 9
38
+ out_channels: 4
39
+ model_channels: 320
40
+ attention_resolutions: [ 4, 2, 1 ]
41
+ num_res_blocks: 2
42
+ channel_mult: [ 1, 2, 4, 4 ]
43
+ num_heads: 8
44
+ use_spatial_transformer: True
45
+ transformer_depth: 1
46
+ context_dim: 768
47
+ use_checkpoint: True
48
+ legacy: False
49
+ add_conv_in_front_of_unet: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding
76
+ params:
77
+ dino_version: "big" # [small, big, large, huge]
78
+
79
+ data:
80
+ target: main.DataModuleFromConfig
81
+ params:
82
+ batch_size: 2
83
+ num_workers: 8
84
+ use_worker_init_fn: False
85
+ wrap: False
86
+ train:
87
+ target: ldm.data.collage_dataset.CollageDataset
88
+ params:
89
+ split_files: /mnt/localssd/new_train
90
+ image_size: 512
91
+ embedding_type: 'dino' # TODO embedding
92
+ warping_type: 'flow'
93
+ validation:
94
+ target: ldm.data.collage_dataset.CollageDataset
95
+ params:
96
+ split_files: /mnt/localssd/new_val
97
+ image_size: 512
98
+ embedding_type: 'dino' # TODO embedding
99
+ warping_type: 'mix'
100
+ test:
101
+ target: ldm.data.collage_dataset.CollageDataset
102
+ params:
103
+ split_files: /mnt/localssd/new_val
104
+ image_size: 512
105
+ embedding_type: 'dino' # TODO embedding
106
+ warping_type: 'mix'
107
+
108
+ lightning:
109
+ trainer:
110
+ max_epochs: 500
111
+ num_nodes: 1
112
+ num_sanity_val_steps: 0
113
+ accelerator: 'gpu'
114
+ gpus: "0,1,2,3,4,5,6,7"
configs/collage_mix_train.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+ model:
3
+ base_learning_rate: 1.0e-05
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "inpaint"
12
+ cond_stage_key: "image"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: "rewarp"
17
+ monitor: val/loss_simple_ema
18
+ u_cond_percent: 0.2
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ context_embedding_dim: 384 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536
22
+ dropping_warped_latent_prob: 0.2
23
+
24
+
25
+ scheduler_config: # 10000 warmup steps
26
+ target: ldm.lr_scheduler.LambdaLinearScheduler
27
+ params:
28
+ warm_up_steps: [ 10000 ]
29
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30
+ f_start: [ 1.e-6 ]
31
+ f_max: [ 1. ]
32
+ f_min: [ 1. ]
33
+
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32 # unused
38
+ in_channels: 9
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions: [ 4, 2, 1 ]
42
+ num_res_blocks: 2
43
+ channel_mult: [ 1, 2, 4, 4 ]
44
+ num_heads: 8
45
+ use_spatial_transformer: True
46
+ transformer_depth: 1
47
+ context_dim: 768
48
+ use_checkpoint: True
49
+ legacy: False
50
+ add_conv_in_front_of_unet: False
51
+
52
+ first_stage_config:
53
+ target: ldm.models.autoencoder.AutoencoderKL
54
+ params:
55
+ embed_dim: 4
56
+ monitor: val/rec_loss
57
+ ddconfig:
58
+ double_z: true
59
+ z_channels: 4
60
+ resolution: 256
61
+ in_channels: 3
62
+ out_ch: 3
63
+ ch: 128
64
+ ch_mult:
65
+ - 1
66
+ - 2
67
+ - 4
68
+ - 4
69
+ num_res_blocks: 2
70
+ attn_resolutions: []
71
+ dropout: 0.0
72
+ lossconfig:
73
+ target: torch.nn.Identity
74
+
75
+ cond_stage_config:
76
+ target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding
77
+ params:
78
+ dino_version: "small" # [small, big, large, huge]
79
+
80
+ data:
81
+ target: main.DataModuleFromConfig
82
+ params:
83
+ batch_size: 4
84
+ num_workers: 8
85
+ use_worker_init_fn: False
86
+ wrap: False
87
+ train:
88
+ target: ldm.data.collage_dataset.CollageDataset
89
+ params:
90
+ split_files: /mnt/localssd/new_train
91
+ image_size: 512
92
+ embedding_type: 'dino' # TODO embedding
93
+ warping_type: 'mix'
94
+ validation:
95
+ target: ldm.data.collage_dataset.CollageDataset
96
+ params:
97
+ split_files: /mnt/localssd/new_val
98
+ image_size: 512
99
+ embedding_type: 'dino' # TODO embedding
100
+ warping_type: 'mix'
101
+ test:
102
+ target: ldm.data.collage_dataset.CollageDataset
103
+ params:
104
+ split_files: /mnt/localssd/new_val
105
+ image_size: 512
106
+ embedding_type: 'dino' # TODO embedding
107
+ warping_type: 'mix'
108
+
109
+ lightning:
110
+ trainer:
111
+ max_epochs: 500
112
+ num_nodes: 1
113
+ num_sanity_val_steps: 0
114
+ accelerator: 'gpu'
115
+ gpus: "0,1,2,3,4,5,6,7"
ldm/data/__init__.py ADDED
File without changes
ldm/data/collage_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import torchvision.transforms.functional as F
7
+ import glob
8
+ import torchvision
9
+ from PIL import Image
10
+ import time
11
+ import os
12
+ import tqdm
13
+ from torch.utils.data import Dataset
14
+ import pathlib
15
+ import cv2
16
+ from PIL import Image
17
+ import os
18
+ import json
19
+ import albumentations as A
20
+
21
+ def get_tensor(normalize=True, toTensor=True):
22
+ transform_list = []
23
+ if toTensor:
24
+ transform_list += [torchvision.transforms.ToTensor()]
25
+
26
+ if normalize:
27
+ # transform_list += [torchvision.transforms.Normalize((0.0, 0.0, 0.0),
28
+ # (10.0, 10.0, 10.0))]
29
+ transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
30
+ (0.5, 0.5, 0.5))]
31
+ return torchvision.transforms.Compose(transform_list)
32
+
33
+ def get_tensor_clip(normalize=True, toTensor=True):
34
+ transform_list = [torchvision.transforms.Resize((224,224))]
35
+ if toTensor:
36
+ transform_list += [torchvision.transforms.ToTensor()]
37
+
38
+ if normalize:
39
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
40
+ (0.26862954, 0.26130258, 0.27577711))]
41
+ return torchvision.transforms.Compose(transform_list)
42
+
43
+ def get_tensor_dino(normalize=True, toTensor=True):
44
+ transform_list = [torchvision.transforms.Resize((224,224))]
45
+ if toTensor:
46
+ transform_list += [torchvision.transforms.ToTensor()]
47
+
48
+ if normalize:
49
+ transform_list += [lambda x: 255.0 * x[:3],
50
+ torchvision.transforms.Normalize(
51
+ mean=(123.675, 116.28, 103.53),
52
+ std=(58.395, 57.12, 57.375),
53
+ )]
54
+ return torchvision.transforms.Compose(transform_list)
55
+
56
+ def crawl_folders(folder_path):
57
+ # glob crawl
58
+ all_files = []
59
+ folders = glob.glob(f'{folder_path}/*')
60
+
61
+ for folder in folders:
62
+ src_paths = glob.glob(f'{folder}/src_*png')
63
+ all_files.extend(src_paths)
64
+ return all_files
65
+
66
+ def get_grid(size):
67
+ y = np.repeat(np.arange(size)[None, ...], size)
68
+ y = y.reshape(size, size)
69
+ x = y.transpose()
70
+ out = np.stack([y,x], -1)
71
+ return out
72
+
73
+
74
+ class CollageDataset(Dataset):
75
+ def __init__(self, split_files, image_size, embedding_type, warping_type, blur_warped=False):
76
+ self.size = image_size
77
+ # depends on the embedding type
78
+ if embedding_type == 'clip':
79
+ self.get_embedding_vector = get_tensor_clip()
80
+ elif embedding_type == 'dino':
81
+ self.get_embedding_vector = get_tensor_dino()
82
+ self.get_tensor = get_tensor()
83
+ self.resize = torchvision.transforms.Resize(size=(image_size, image_size))
84
+ self.to_mask_tensor = get_tensor(normalize=False)
85
+
86
+ self.src_paths = crawl_folders(split_files)
87
+ print('current split size', len(self.src_paths))
88
+ print('for dir', split_files)
89
+
90
+ assert warping_type in ['collage', 'flow', 'mix']
91
+ self.warping_type = warping_type
92
+
93
+ self.mask_threshold = 0.85
94
+
95
+ self.blur_t = torchvision.transforms.GaussianBlur(kernel_size=51, sigma=20.0)
96
+ self.blur_warped = blur_warped
97
+
98
+ # self.save_folder = '/mnt/localssd/collage_out'
99
+ # os.makedirs(self.save_folder, exist_ok=True)
100
+ self.save_counter = 0
101
+ self.save_subfolder = None
102
+
103
+ def __len__(self):
104
+ return len(self.src_paths)
105
+
106
+
107
+ def __getitem__(self, idx, depth=0):
108
+
109
+ if self.warping_type == 'mix':
110
+ # randomly sample
111
+ warping_type = np.random.choice(['collage', 'flow'])
112
+ else:
113
+ warping_type = self.warping_type
114
+
115
+ src_path = self.src_paths[idx]
116
+ tgt_path = src_path.replace('src_', 'tgt_')
117
+
118
+ if warping_type == 'collage':
119
+ warped_path = src_path.replace('src_', 'composite_')
120
+ mask_path = src_path.replace('src_', 'composite_mask_')
121
+ corresp_path = src_path.replace('src_', 'composite_grid_')
122
+ corresp_path = corresp_path.split('.')[0]
123
+ corresp_path += '.npy'
124
+ elif warping_type == 'flow':
125
+ warped_path = src_path.replace('src_', 'flow_warped_')
126
+ mask_path = src_path.replace('src_', 'flow_mask_')
127
+ corresp_path = src_path.replace('src_', 'flow_warped_grid_')
128
+ corresp_path = corresp_path.split('.')[0]
129
+ corresp_path += '.npy'
130
+ else:
131
+ raise ValueError
132
+
133
+ # load reference image, warped image, and target GT image
134
+ reference_img = Image.open(src_path).convert('RGB')
135
+ gt_img = Image.open(tgt_path).convert('RGB')
136
+ warped_img = Image.open(warped_path).convert('RGB')
137
+ warping_mask = Image.open(mask_path).convert('RGB')
138
+
139
+ # resize all
140
+ reference_img = self.resize(reference_img)
141
+ gt_img = self.resize(gt_img)
142
+ warped_img = self.resize(warped_img)
143
+ warping_mask = self.resize(warping_mask)
144
+
145
+
146
+ # NO CROPPING PLEASE. ALL INPUTS ARE 512X512
147
+ # Random crop
148
+ # i, j, h, w = torchvision.transforms.RandomCrop.get_params(
149
+ # reference_img, output_size=(512, 512))
150
+
151
+ # reference_img = torchvision.transforms.functional.crop(reference_img, i, j, h, w)
152
+ # gt_img = torchvision.transforms.functional.crop(gt_img, i, j, h, w)
153
+ # warped_img = torchvision.transforms.functional.crop(warped_img, i, j, h, w)
154
+ # # TODO start using the warping mask
155
+ # warping_mask = torchvision.transforms.functional.crop(warping_mask, i, j, h, w)
156
+
157
+ grid_transformed = torch.tensor(np.load(corresp_path))
158
+ # grid_transformed = torchvision.transforms.functional.crop(grid_transformed, i, j, h, w)
159
+
160
+
161
+
162
+ # reference_t = to_tensor(reference_img)
163
+ gt_t = self.get_tensor(gt_img)
164
+ warped_t = self.get_tensor(warped_img)
165
+ warping_mask_t = self.to_mask_tensor(warping_mask)
166
+ clean_reference_t = self.get_tensor(reference_img)
167
+ # compute error to generate mask
168
+ blur_t = torchvision.transforms.GaussianBlur(kernel_size=(11,11), sigma=5.0)
169
+
170
+ reference_clip_img = self.get_embedding_vector(reference_img)
171
+
172
+ mask = torch.ones_like(gt_t)[:1]
173
+ warping_mask_t = warping_mask_t[:1]
174
+
175
+ good_region = torch.mean(warping_mask_t)
176
+ # print('good region', good_region)
177
+ # print('good region frac', good_region)
178
+ if good_region < 0.4 and depth < 3:
179
+ # example too hard, sample something else
180
+ # print('bad image, resampling..')
181
+ rand_idx = np.random.randint(len(self.src_paths))
182
+ return self.__getitem__(rand_idx, depth+1)
183
+
184
+ # if mask is too large then ignore
185
+
186
+ # #gaussian inpainting now
187
+ missing_mask = warping_mask_t[0] < 0.5
188
+
189
+
190
+ reference = (warped_t.clone() + 1) / 2.0
191
+ ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy()
192
+ ref_cv = (ref_cv * 255).astype(np.uint8)
193
+ cv_mask = missing_mask.int().squeeze().cpu().numpy().astype(np.uint8)
194
+ kernel = np.ones((7,7))
195
+ dilated_mask = cv2.dilate(cv_mask, kernel)
196
+ # cv_mask = np.stack([cv_mask]*3, axis=-1)
197
+ dst = cv2.inpaint(ref_cv,dilated_mask,5,cv2.INPAINT_NS)
198
+
199
+ mask_resized = torchvision.transforms.functional.resize(warping_mask_t, (64,64))
200
+ # print(mask_resized)
201
+ size=512
202
+ grid_np = (get_grid(size) / size).astype(np.float16)# 512 x 512 x 2
203
+ grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2
204
+ grid_resized = torchvision.transforms.functional.resize(grid_t, (64,64)).to(torch.float16)
205
+ changed_pixels = torch.logical_or((torch.abs(grid_resized - grid_transformed)[0] > 0.04) , (torch.abs(grid_resized - grid_transformed)[1] > 0.04))
206
+ changed_pixels = changed_pixels.unsqueeze(0)
207
+ # changed_pixels = torch.logical_and(changed_pixels, (mask_resized >= 0.3))
208
+ changed_pixels = changed_pixels.float()
209
+
210
+ inpainted_warped = (torch.tensor(dst).moveaxis(-1, 0).float() / 255.0) * 2.0 - 1.0
211
+
212
+ if self.blur_warped:
213
+ inpainted_warped= self.blur_t(inpainted_warped)
214
+
215
+ out = {"GT": gt_t,"inpaint_image": inpainted_warped,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels}
216
+ # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": torch.ones_like(warping_mask_t), "ref_imgs": reference_clip_img * 0.0, "clean_reference": gt_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels}
217
+ # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img * 0.0, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels}
218
+
219
+ # out = {"GT": gt_t,"inpaint_image": warped_t,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, 'inpainted': inpainted_warped}
220
+ # out_half = {key: out[key].half() for key in out}
221
+ # if self.save_counter < 50:
222
+ # save_path = f'{self.save_folder}/output_{time.time()}.pt'
223
+ # torch.save(out, save_path)
224
+ # self.save_counter += 1
225
+
226
+ return out
227
+
228
+
229
+
230
+
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import numpy as np
15
+
16
+
17
+ class LambdaWarmUpCosineScheduler:
18
+ """
19
+ note: use with a base_lr of 1.0
20
+ """
21
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
22
+ self.lr_warm_up_steps = warm_up_steps
23
+ self.lr_start = lr_start
24
+ self.lr_min = lr_min
25
+ self.lr_max = lr_max
26
+ self.lr_max_decay_steps = max_decay_steps
27
+ self.last_lr = 0.
28
+ self.verbosity_interval = verbosity_interval
29
+
30
+ def schedule(self, n, **kwargs):
31
+ if self.verbosity_interval > 0:
32
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
33
+ if n < self.lr_warm_up_steps:
34
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
35
+ self.last_lr = lr
36
+ return lr
37
+ else:
38
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
39
+ t = min(t, 1.0)
40
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
41
+ 1 + np.cos(t * np.pi))
42
+ self.last_lr = lr
43
+ return lr
44
+
45
+ def __call__(self, n, **kwargs):
46
+ return self.schedule(n,**kwargs)
47
+
48
+
49
+ class LambdaWarmUpCosineScheduler2:
50
+ """
51
+ supports repeated iterations, configurable via lists
52
+ note: use with a base_lr of 1.0.
53
+ """
54
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
55
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
56
+ self.lr_warm_up_steps = warm_up_steps
57
+ self.f_start = f_start
58
+ self.f_min = f_min
59
+ self.f_max = f_max
60
+ self.cycle_lengths = cycle_lengths
61
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
62
+ self.last_f = 0.
63
+ self.verbosity_interval = verbosity_interval
64
+
65
+ def find_in_interval(self, n):
66
+ interval = 0
67
+ for cl in self.cum_cycles[1:]:
68
+ if n <= cl:
69
+ return interval
70
+ interval += 1
71
+
72
+ def schedule(self, n, **kwargs):
73
+ cycle = self.find_in_interval(n)
74
+ n = n - self.cum_cycles[cycle]
75
+ if self.verbosity_interval > 0:
76
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
77
+ f"current cycle {cycle}")
78
+ if n < self.lr_warm_up_steps[cycle]:
79
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
80
+ self.last_f = f
81
+ return f
82
+ else:
83
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
84
+ t = min(t, 1.0)
85
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
86
+ 1 + np.cos(t * np.pi))
87
+ self.last_f = f
88
+ return f
89
+
90
+ def __call__(self, n, **kwargs):
91
+ return self.schedule(n, **kwargs)
92
+
93
+
94
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
95
+
96
+ def schedule(self, n, **kwargs):
97
+ cycle = self.find_in_interval(n)
98
+ n = n - self.cum_cycles[cycle]
99
+ if self.verbosity_interval > 0:
100
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
101
+ f"current cycle {cycle}")
102
+
103
+ if n < self.lr_warm_up_steps[cycle]:
104
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
105
+ self.last_f = f
106
+ return f
107
+ else:
108
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
109
+ self.last_f = f
110
+ return f
111
+
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ import pytorch_lightning as pl
16
+ import torch.nn.functional as F
17
+ from contextlib import contextmanager
18
+
19
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
20
+
21
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
22
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
23
+
24
+ from ldm.util import instantiate_from_config
25
+
26
+
27
+ class VQModel(pl.LightningModule):
28
+ def __init__(self,
29
+ ddconfig,
30
+ lossconfig,
31
+ n_embed,
32
+ embed_dim,
33
+ ckpt_path=None,
34
+ ignore_keys=[],
35
+ image_key="image",
36
+ colorize_nlabels=None,
37
+ monitor=None,
38
+ batch_resize_range=None,
39
+ scheduler_config=None,
40
+ lr_g_factor=1.0,
41
+ remap=None,
42
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
43
+ use_ema=False
44
+ ):
45
+ super().__init__()
46
+ self.embed_dim = embed_dim
47
+ self.n_embed = n_embed
48
+ self.image_key = image_key
49
+ self.encoder = Encoder(**ddconfig)
50
+ self.decoder = Decoder(**ddconfig)
51
+ self.loss = instantiate_from_config(lossconfig)
52
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
53
+ remap=remap,
54
+ sane_index_shape=sane_index_shape)
55
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
56
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
57
+ if colorize_nlabels is not None:
58
+ assert type(colorize_nlabels)==int
59
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
60
+ if monitor is not None:
61
+ self.monitor = monitor
62
+ self.batch_resize_range = batch_resize_range
63
+ if self.batch_resize_range is not None:
64
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
65
+
66
+ self.use_ema = use_ema
67
+ if self.use_ema:
68
+ self.model_ema = LitEma(self)
69
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
70
+
71
+ if ckpt_path is not None:
72
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
73
+ self.scheduler_config = scheduler_config
74
+ self.lr_g_factor = lr_g_factor
75
+
76
+ @contextmanager
77
+ def ema_scope(self, context=None):
78
+ if self.use_ema:
79
+ self.model_ema.store(self.parameters())
80
+ self.model_ema.copy_to(self)
81
+ if context is not None:
82
+ print(f"{context}: Switched to EMA weights")
83
+ try:
84
+ yield None
85
+ finally:
86
+ if self.use_ema:
87
+ self.model_ema.restore(self.parameters())
88
+ if context is not None:
89
+ print(f"{context}: Restored training weights")
90
+
91
+ def init_from_ckpt(self, path, ignore_keys=list()):
92
+ sd = torch.load(path, map_location="cpu")["state_dict"]
93
+ keys = list(sd.keys())
94
+ for k in keys:
95
+ for ik in ignore_keys:
96
+ if k.startswith(ik):
97
+ print("Deleting key {} from state_dict.".format(k))
98
+ del sd[k]
99
+ missing, unexpected = self.load_state_dict(sd, strict=False)
100
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
101
+ if len(missing) > 0:
102
+ print(f"Missing Keys: {missing}")
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def on_train_batch_end(self, *args, **kwargs):
106
+ if self.use_ema:
107
+ self.model_ema(self)
108
+
109
+ def encode(self, x):
110
+ h = self.encoder(x)
111
+ h = self.quant_conv(h)
112
+ quant, emb_loss, info = self.quantize(h)
113
+ return quant, emb_loss, info
114
+
115
+ def encode_to_prequant(self, x):
116
+ h = self.encoder(x)
117
+ h = self.quant_conv(h)
118
+ return h
119
+
120
+ def decode(self, quant):
121
+ quant = self.post_quant_conv(quant)
122
+ dec = self.decoder(quant)
123
+ return dec
124
+
125
+ def decode_code(self, code_b):
126
+ quant_b = self.quantize.embed_code(code_b)
127
+ dec = self.decode(quant_b)
128
+ return dec
129
+
130
+ def forward(self, input, return_pred_indices=False):
131
+ quant, diff, (_,_,ind) = self.encode(input)
132
+ dec = self.decode(quant)
133
+ if return_pred_indices:
134
+ return dec, diff, ind
135
+ return dec, diff
136
+
137
+ def get_input(self, batch, k):
138
+ x = batch[k]
139
+ if len(x.shape) == 3:
140
+ x = x[..., None]
141
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
142
+ if self.batch_resize_range is not None:
143
+ lower_size = self.batch_resize_range[0]
144
+ upper_size = self.batch_resize_range[1]
145
+ if self.global_step <= 4:
146
+ # do the first few batches with max size to avoid later oom
147
+ new_resize = upper_size
148
+ else:
149
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
150
+ if new_resize != x.shape[2]:
151
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
152
+ x = x.detach()
153
+ return x
154
+
155
+ def training_step(self, batch, batch_idx, optimizer_idx):
156
+ # https://github.com/pytorch/pytorch/issues/37142
157
+ # try not to fool the heuristics
158
+ x = self.get_input(batch, self.image_key)
159
+ xrec, qloss, ind = self(x, return_pred_indices=True)
160
+
161
+ if optimizer_idx == 0:
162
+ # autoencode
163
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
164
+ last_layer=self.get_last_layer(), split="train",
165
+ predicted_indices=ind)
166
+
167
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
168
+ return aeloss
169
+
170
+ if optimizer_idx == 1:
171
+ # discriminator
172
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
173
+ last_layer=self.get_last_layer(), split="train")
174
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
175
+ return discloss
176
+
177
+ def validation_step(self, batch, batch_idx):
178
+ log_dict = self._validation_step(batch, batch_idx)
179
+ with self.ema_scope():
180
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
181
+ return log_dict
182
+
183
+ def _validation_step(self, batch, batch_idx, suffix=""):
184
+ x = self.get_input(batch, self.image_key)
185
+ xrec, qloss, ind = self(x, return_pred_indices=True)
186
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
187
+ self.global_step,
188
+ last_layer=self.get_last_layer(),
189
+ split="val"+suffix,
190
+ predicted_indices=ind
191
+ )
192
+
193
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
194
+ self.global_step,
195
+ last_layer=self.get_last_layer(),
196
+ split="val"+suffix,
197
+ predicted_indices=ind
198
+ )
199
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
200
+ self.log(f"val{suffix}/rec_loss", rec_loss,
201
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
202
+ self.log(f"val{suffix}/aeloss", aeloss,
203
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
204
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
205
+ del log_dict_ae[f"val{suffix}/rec_loss"]
206
+ self.log_dict(log_dict_ae)
207
+ self.log_dict(log_dict_disc)
208
+ return self.log_dict
209
+
210
+ def configure_optimizers(self):
211
+ lr_d = self.learning_rate
212
+ lr_g = self.lr_g_factor*self.learning_rate
213
+ print("lr_d", lr_d)
214
+ print("lr_g", lr_g)
215
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
216
+ list(self.decoder.parameters())+
217
+ list(self.quantize.parameters())+
218
+ list(self.quant_conv.parameters())+
219
+ list(self.post_quant_conv.parameters()),
220
+ lr=lr_g, betas=(0.5, 0.9))
221
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
222
+ lr=lr_d, betas=(0.5, 0.9))
223
+
224
+ if self.scheduler_config is not None:
225
+ scheduler = instantiate_from_config(self.scheduler_config)
226
+
227
+ print("Setting up LambdaLR scheduler...")
228
+ scheduler = [
229
+ {
230
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
231
+ 'interval': 'step',
232
+ 'frequency': 1
233
+ },
234
+ {
235
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
236
+ 'interval': 'step',
237
+ 'frequency': 1
238
+ },
239
+ ]
240
+ return [opt_ae, opt_disc], scheduler
241
+ return [opt_ae, opt_disc], []
242
+
243
+ def get_last_layer(self):
244
+ return self.decoder.conv_out.weight
245
+
246
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
247
+ log = dict()
248
+ x = self.get_input(batch, self.image_key)
249
+ x = x.to(self.device)
250
+ if only_inputs:
251
+ log["inputs"] = x
252
+ return log
253
+ xrec, _ = self(x)
254
+ if x.shape[1] > 3:
255
+ # colorize with random projection
256
+ assert xrec.shape[1] > 3
257
+ x = self.to_rgb(x)
258
+ xrec = self.to_rgb(xrec)
259
+ log["inputs"] = x
260
+ log["reconstructions"] = xrec
261
+ if plot_ema:
262
+ with self.ema_scope():
263
+ xrec_ema, _ = self(x)
264
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
265
+ log["reconstructions_ema"] = xrec_ema
266
+ return log
267
+
268
+ def to_rgb(self, x):
269
+ assert self.image_key == "segmentation"
270
+ if not hasattr(self, "colorize"):
271
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
272
+ x = F.conv2d(x, weight=self.colorize)
273
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
274
+ return x
275
+
276
+
277
+ class VQModelInterface(VQModel):
278
+ def __init__(self, embed_dim, *args, **kwargs):
279
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
280
+ self.embed_dim = embed_dim
281
+
282
+ def encode(self, x):
283
+ h = self.encoder(x)
284
+ h = self.quant_conv(h)
285
+ return h
286
+
287
+ def decode(self, h, force_not_quantize=False):
288
+ # also go through quantization layer
289
+ if not force_not_quantize:
290
+ quant, emb_loss, info = self.quantize(h)
291
+ else:
292
+ quant = h
293
+ quant = self.post_quant_conv(quant)
294
+ dec = self.decoder(quant)
295
+ return dec
296
+
297
+
298
+ class AutoencoderKL(pl.LightningModule):
299
+ def __init__(self,
300
+ ddconfig,
301
+ lossconfig,
302
+ embed_dim,
303
+ ckpt_path=None,
304
+ ignore_keys=[],
305
+ image_key="image",
306
+ colorize_nlabels=None,
307
+ monitor=None,
308
+ ):
309
+ super().__init__()
310
+ self.image_key = image_key
311
+ self.encoder = Encoder(**ddconfig)
312
+ self.decoder = Decoder(**ddconfig)
313
+ self.loss = instantiate_from_config(lossconfig)
314
+ assert ddconfig["double_z"]
315
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
316
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
317
+ self.embed_dim = embed_dim
318
+ if colorize_nlabels is not None:
319
+ assert type(colorize_nlabels)==int
320
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
321
+ if monitor is not None:
322
+ self.monitor = monitor
323
+ if ckpt_path is not None:
324
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
325
+
326
+ def init_from_ckpt(self, path, ignore_keys=list()):
327
+ sd = torch.load(path, map_location="cpu")["state_dict"]
328
+ keys = list(sd.keys())
329
+ for k in keys:
330
+ for ik in ignore_keys:
331
+ if k.startswith(ik):
332
+ print("Deleting key {} from state_dict.".format(k))
333
+ del sd[k]
334
+ self.load_state_dict(sd, strict=False)
335
+ print(f"Restored from {path}")
336
+
337
+ def encode(self, x):
338
+ h = self.encoder(x)
339
+ moments = self.quant_conv(h)
340
+ posterior = DiagonalGaussianDistribution(moments)
341
+ return posterior
342
+
343
+ def decode(self, z):
344
+ z = self.post_quant_conv(z)
345
+ dec = self.decoder(z)
346
+ return dec
347
+
348
+ def forward(self, input, sample_posterior=True):
349
+ posterior = self.encode(input)
350
+ if sample_posterior:
351
+ z = posterior.sample()
352
+ else:
353
+ z = posterior.mode()
354
+ dec = self.decode(z)
355
+ return dec, posterior
356
+
357
+ def get_input(self, batch, k):
358
+ x = batch[k]
359
+ if len(x.shape) == 3:
360
+ x = x[..., None]
361
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
362
+ return x
363
+
364
+ def training_step(self, batch, batch_idx, optimizer_idx):
365
+ inputs = self.get_input(batch, self.image_key)
366
+ reconstructions, posterior = self(inputs)
367
+
368
+ if optimizer_idx == 0:
369
+ # train encoder+decoder+logvar
370
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
371
+ last_layer=self.get_last_layer(), split="train")
372
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
373
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
374
+ return aeloss
375
+
376
+ if optimizer_idx == 1:
377
+ # train the discriminator
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
379
+ last_layer=self.get_last_layer(), split="train")
380
+
381
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
382
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
383
+ return discloss
384
+
385
+ def validation_step(self, batch, batch_idx):
386
+ inputs = self.get_input(batch, self.image_key)
387
+ reconstructions, posterior = self(inputs)
388
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
389
+ last_layer=self.get_last_layer(), split="val")
390
+
391
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
392
+ last_layer=self.get_last_layer(), split="val")
393
+
394
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
395
+ self.log_dict(log_dict_ae)
396
+ self.log_dict(log_dict_disc)
397
+ return self.log_dict
398
+
399
+ def configure_optimizers(self):
400
+ lr = self.learning_rate
401
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
402
+ list(self.decoder.parameters())+
403
+ list(self.quant_conv.parameters())+
404
+ list(self.post_quant_conv.parameters()),
405
+ lr=lr, betas=(0.5, 0.9))
406
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
407
+ lr=lr, betas=(0.5, 0.9))
408
+ return [opt_ae, opt_disc], []
409
+
410
+ def get_last_layer(self):
411
+ return self.decoder.conv_out.weight
412
+
413
+ @torch.no_grad()
414
+ def log_images(self, batch, only_inputs=False, **kwargs):
415
+ log = dict()
416
+ x = self.get_input(batch, self.image_key)
417
+ x = x.to(self.device)
418
+ if not only_inputs:
419
+ xrec, posterior = self(x)
420
+ if x.shape[1] > 3:
421
+ # colorize with random projection
422
+ assert xrec.shape[1] > 3
423
+ x = self.to_rgb(x)
424
+ xrec = self.to_rgb(xrec)
425
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
426
+ log["reconstructions"] = xrec
427
+ log["inputs"] = x
428
+ return log
429
+
430
+ def to_rgb(self, x):
431
+ assert self.image_key == "segmentation"
432
+ if not hasattr(self, "colorize"):
433
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
434
+ x = F.conv2d(x, weight=self.colorize)
435
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
436
+ return x
437
+
438
+
439
+ class IdentityFirstStage(torch.nn.Module):
440
+ def __init__(self, *args, vq_interface=False, **kwargs):
441
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
442
+ super().__init__()
443
+
444
+ def encode(self, x, *args, **kwargs):
445
+ return x
446
+
447
+ def decode(self, x, *args, **kwargs):
448
+ return x
449
+
450
+ def quantize(self, x, *args, **kwargs):
451
+ if self.vq_interface:
452
+ return x, None, [None, None, None]
453
+ return x
454
+
455
+ def forward(self, x, *args, **kwargs):
456
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import os
15
+ import torch
16
+ import pytorch_lightning as pl
17
+ from omegaconf import OmegaConf
18
+ from torch.nn import functional as F
19
+ from torch.optim import AdamW
20
+ from torch.optim.lr_scheduler import LambdaLR
21
+ from copy import deepcopy
22
+ from einops import rearrange
23
+ from glob import glob
24
+ from natsort import natsorted
25
+
26
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
27
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
28
+
29
+ __models__ = {
30
+ 'class_label': EncoderUNetModel,
31
+ 'segmentation': UNetModel
32
+ }
33
+
34
+
35
+ def disabled_train(self, mode=True):
36
+ """Overwrite model.train with this function to make sure train/eval mode
37
+ does not change anymore."""
38
+ return self
39
+
40
+
41
+ class NoisyLatentImageClassifier(pl.LightningModule):
42
+
43
+ def __init__(self,
44
+ diffusion_path,
45
+ num_classes,
46
+ ckpt_path=None,
47
+ pool='attention',
48
+ label_key=None,
49
+ diffusion_ckpt_path=None,
50
+ scheduler_config=None,
51
+ weight_decay=1.e-2,
52
+ log_steps=10,
53
+ monitor='val/loss',
54
+ *args,
55
+ **kwargs):
56
+ super().__init__(*args, **kwargs)
57
+ self.num_classes = num_classes
58
+ # get latest config of diffusion model
59
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
60
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
61
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
62
+ self.load_diffusion()
63
+
64
+ self.monitor = monitor
65
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
66
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
67
+ self.log_steps = log_steps
68
+
69
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
70
+ else self.diffusion_model.cond_stage_key
71
+
72
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
73
+
74
+ if self.label_key not in __models__:
75
+ raise NotImplementedError()
76
+
77
+ self.load_classifier(ckpt_path, pool)
78
+
79
+ self.scheduler_config = scheduler_config
80
+ self.use_scheduler = self.scheduler_config is not None
81
+ self.weight_decay = weight_decay
82
+
83
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
84
+ sd = torch.load(path, map_location="cpu")
85
+ if "state_dict" in list(sd.keys()):
86
+ sd = sd["state_dict"]
87
+ keys = list(sd.keys())
88
+ for k in keys:
89
+ for ik in ignore_keys:
90
+ if k.startswith(ik):
91
+ print("Deleting key {} from state_dict.".format(k))
92
+ del sd[k]
93
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
94
+ sd, strict=False)
95
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
96
+ if len(missing) > 0:
97
+ print(f"Missing Keys: {missing}")
98
+ if len(unexpected) > 0:
99
+ print(f"Unexpected Keys: {unexpected}")
100
+
101
+ def load_diffusion(self):
102
+ model = instantiate_from_config(self.diffusion_config)
103
+ self.diffusion_model = model.eval()
104
+ self.diffusion_model.train = disabled_train
105
+ for param in self.diffusion_model.parameters():
106
+ param.requires_grad = False
107
+
108
+ def load_classifier(self, ckpt_path, pool):
109
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
110
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
111
+ model_config.out_channels = self.num_classes
112
+ if self.label_key == 'class_label':
113
+ model_config.pool = pool
114
+
115
+ self.model = __models__[self.label_key](**model_config)
116
+ if ckpt_path is not None:
117
+ print('#####################################################################')
118
+ print(f'load from ckpt "{ckpt_path}"')
119
+ print('#####################################################################')
120
+ self.init_from_ckpt(ckpt_path)
121
+
122
+ @torch.no_grad()
123
+ def get_x_noisy(self, x, t, noise=None):
124
+ noise = default(noise, lambda: torch.randn_like(x))
125
+ continuous_sqrt_alpha_cumprod = None
126
+ if self.diffusion_model.use_continuous_noise:
127
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
128
+ # todo: make sure t+1 is correct here
129
+
130
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
131
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
132
+
133
+ def forward(self, x_noisy, t, *args, **kwargs):
134
+ return self.model(x_noisy, t)
135
+
136
+ @torch.no_grad()
137
+ def get_input(self, batch, k):
138
+ x = batch[k]
139
+ if len(x.shape) == 3:
140
+ x = x[..., None]
141
+ x = rearrange(x, 'b h w c -> b c h w')
142
+ x = x.to(memory_format=torch.contiguous_format).float()
143
+ return x
144
+
145
+ @torch.no_grad()
146
+ def get_conditioning(self, batch, k=None):
147
+ if k is None:
148
+ k = self.label_key
149
+ assert k is not None, 'Needs to provide label key'
150
+
151
+ targets = batch[k].to(self.device)
152
+
153
+ if self.label_key == 'segmentation':
154
+ targets = rearrange(targets, 'b h w c -> b c h w')
155
+ for down in range(self.numd):
156
+ h, w = targets.shape[-2:]
157
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
158
+
159
+ # targets = rearrange(targets,'b c h w -> b h w c')
160
+
161
+ return targets
162
+
163
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
164
+ _, top_ks = torch.topk(logits, k, dim=1)
165
+ if reduction == "mean":
166
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
167
+ elif reduction == "none":
168
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
169
+
170
+ def on_train_epoch_start(self):
171
+ # save some memory
172
+ self.diffusion_model.model.to('cpu')
173
+
174
+ @torch.no_grad()
175
+ def write_logs(self, loss, logits, targets):
176
+ log_prefix = 'train' if self.training else 'val'
177
+ log = {}
178
+ log[f"{log_prefix}/loss"] = loss.mean()
179
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
180
+ logits, targets, k=1, reduction="mean"
181
+ )
182
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
183
+ logits, targets, k=5, reduction="mean"
184
+ )
185
+
186
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
187
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
188
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
189
+ lr = self.optimizers().param_groups[0]['lr']
190
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
191
+
192
+ def shared_step(self, batch, t=None):
193
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
194
+ targets = self.get_conditioning(batch)
195
+ if targets.dim() == 4:
196
+ targets = targets.argmax(dim=1)
197
+ if t is None:
198
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
199
+ else:
200
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
201
+ x_noisy = self.get_x_noisy(x, t)
202
+ logits = self(x_noisy, t)
203
+
204
+ loss = F.cross_entropy(logits, targets, reduction='none')
205
+
206
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
207
+
208
+ loss = loss.mean()
209
+ return loss, logits, x_noisy, targets
210
+
211
+ def training_step(self, batch, batch_idx):
212
+ loss, *_ = self.shared_step(batch)
213
+ return loss
214
+
215
+ def reset_noise_accs(self):
216
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
217
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
218
+
219
+ def on_validation_start(self):
220
+ self.reset_noise_accs()
221
+
222
+ @torch.no_grad()
223
+ def validation_step(self, batch, batch_idx):
224
+ loss, *_ = self.shared_step(batch)
225
+
226
+ for t in self.noisy_acc:
227
+ _, logits, _, targets = self.shared_step(batch, t)
228
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
229
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
230
+
231
+ return loss
232
+
233
+ def configure_optimizers(self):
234
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
235
+
236
+ if self.use_scheduler:
237
+ scheduler = instantiate_from_config(self.scheduler_config)
238
+
239
+ print("Setting up LambdaLR scheduler...")
240
+ scheduler = [
241
+ {
242
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
243
+ 'interval': 'step',
244
+ 'frequency': 1
245
+ }]
246
+ return [optimizer], scheduler
247
+
248
+ return optimizer
249
+
250
+ @torch.no_grad()
251
+ def log_images(self, batch, N=8, *args, **kwargs):
252
+ log = dict()
253
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
254
+ log['inputs'] = x
255
+
256
+ y = self.get_conditioning(batch)
257
+
258
+ if self.label_key == 'class_label':
259
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
260
+ log['labels'] = y
261
+
262
+ if ismap(y):
263
+ log['labels'] = self.diffusion_model.to_rgb(y)
264
+
265
+ for step in range(self.log_steps):
266
+ current_time = step * self.log_time_interval
267
+
268
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
269
+
270
+ log[f'inputs@t{current_time}'] = x_noisy
271
+
272
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
273
+ pred = rearrange(pred, 'b h w c -> b c h w')
274
+
275
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
276
+
277
+ for key in log:
278
+ log[key] = log[key][:N]
279
+
280
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """SAMPLING ONLY."""
15
+
16
+ import torch
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from functools import partial
20
+
21
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
22
+ extract_into_tensor
23
+
24
+
25
+ class DDIMSampler(object):
26
+ def __init__(self, model, schedule="linear", **kwargs):
27
+ super().__init__()
28
+ self.model = model
29
+ self.ddpm_num_timesteps = model.num_timesteps
30
+ self.schedule = schedule
31
+
32
+ def register_buffer(self, name, attr):
33
+ if type(attr) == torch.Tensor:
34
+ if attr.device != torch.device("cuda"):
35
+ attr = attr.to(torch.device("cuda"))
36
+ setattr(self, name, attr)
37
+
38
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True, steps=None):
39
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
40
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose, steps=steps)
41
+ alphas_cumprod = self.model.alphas_cumprod
42
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
43
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
44
+
45
+ self.register_buffer('betas', to_torch(self.model.betas))
46
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
47
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
48
+
49
+ # calculations for diffusion q(x_t | x_{t-1}) and others
50
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
51
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
52
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
53
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
54
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
55
+
56
+ # ddim sampling parameters
57
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
58
+ ddim_timesteps=self.ddim_timesteps,
59
+ eta=ddim_eta,verbose=verbose)
60
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
61
+ self.register_buffer('ddim_alphas', ddim_alphas)
62
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
63
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
64
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
65
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
66
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
67
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
68
+
69
+ @torch.no_grad()
70
+ def sample(self,
71
+ S,
72
+ batch_size,
73
+ shape,
74
+ conditioning=None,
75
+ callback=None,
76
+ normals_sequence=None,
77
+ img_callback=None,
78
+ quantize_x0=False,
79
+ eta=0.,
80
+ mask=None,
81
+ x0=None,
82
+ temperature=1.,
83
+ noise_dropout=0.,
84
+ score_corrector=None,
85
+ corrector_kwargs=None,
86
+ verbose=True,
87
+ x_T=None,
88
+ log_every_t=100,
89
+ unconditional_guidance_scale=1.,
90
+ unconditional_conditioning=None,
91
+ z_ref=None,
92
+ ddim_discretize='uniform',
93
+ schedule_steps=None,
94
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
95
+ **kwargs
96
+ ):
97
+ if conditioning is not None:
98
+ if isinstance(conditioning, dict):
99
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
100
+ if cbs != batch_size:
101
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
102
+ else:
103
+ if conditioning.shape[0] != batch_size:
104
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
105
+
106
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose, ddim_discretize=ddim_discretize, steps=schedule_steps)
107
+ # sampling
108
+ C, H, W = shape
109
+ size = (batch_size, C, H, W)
110
+
111
+ samples, intermediates = self.ddim_sampling(conditioning, size,
112
+ callback=callback,
113
+ img_callback=img_callback,
114
+ quantize_denoised=quantize_x0,
115
+ mask=mask, x0=x0,
116
+ ddim_use_original_steps=False,
117
+ noise_dropout=noise_dropout,
118
+ temperature=temperature,
119
+ score_corrector=score_corrector,
120
+ corrector_kwargs=corrector_kwargs,
121
+ x_T=x_T,
122
+ log_every_t=log_every_t,
123
+ unconditional_guidance_scale=unconditional_guidance_scale,
124
+ unconditional_conditioning=unconditional_conditioning,
125
+ z_ref=z_ref,
126
+ **kwargs
127
+ )
128
+ return samples, intermediates
129
+
130
+ @torch.no_grad()
131
+ def ddim_sampling(self, cond, shape,
132
+ x_T=None, ddim_use_original_steps=False,
133
+ callback=None, timesteps=None, quantize_denoised=False,
134
+ mask=None, x0=None, x0_step=None, img_callback=None, log_every_t=100,
135
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
136
+ unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None,**kwargs):
137
+ device = self.model.betas.device
138
+ b = shape[0]
139
+ if x_T is None:
140
+ img = torch.randn(shape, device=device)
141
+ else:
142
+ img = x_T
143
+
144
+ if timesteps is None:
145
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
146
+ elif timesteps is not None and not ddim_use_original_steps:
147
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
148
+ timesteps = self.ddim_timesteps[:subset_end]
149
+
150
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
151
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
152
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
153
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
154
+
155
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
156
+
157
+ for i, step in enumerate(iterator):
158
+ index = total_steps - i - 1
159
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
160
+
161
+ if x0_step is not None and i < x0_step:
162
+ assert x0 is not None
163
+ img = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
164
+ # img = img_orig * mask + (1. - mask) * img
165
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
166
+ quantize_denoised=quantize_denoised, temperature=temperature,
167
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
168
+ corrector_kwargs=corrector_kwargs,
169
+ unconditional_guidance_scale=unconditional_guidance_scale,
170
+ z_ref=z_ref,
171
+ unconditional_conditioning=unconditional_conditioning,**kwargs)
172
+ img, pred_x0 = outs
173
+ if callback: callback(i)
174
+ if img_callback: img_callback(pred_x0, i)
175
+
176
+ if index % log_every_t == 0 or index == total_steps - 1:
177
+ intermediates['x_inter'].append(img)
178
+ intermediates['pred_x0'].append(pred_x0)
179
+
180
+ return img, intermediates
181
+
182
+ @torch.no_grad()
183
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
184
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
185
+ unconditional_guidance_scale=1., unconditional_conditioning=None, z_ref=None, drop_latent_guidance=1.0,**kwargs):
186
+ b, *_, device = *x.shape, x.device
187
+ if 'test_model_kwargs' in kwargs:
188
+ kwargs=kwargs['test_model_kwargs']
189
+ if f'inpaint_mask_{index}' in kwargs:
190
+ x = torch.cat([x, kwargs['inpaint_image'], kwargs[f'inpaint_mask_{index}']],dim=1)
191
+ print('using proxy mask', index)
192
+ else:
193
+ x = torch.cat([x, kwargs['inpaint_image'], kwargs[f'inpaint_mask']],dim=1)
194
+ if 'changed_pixels' in kwargs:
195
+ x = torch.cat([x, kwargs['changed_pixels']],dim=1)
196
+ elif 'rest' in kwargs:
197
+ x = torch.cat((x, kwargs['rest']), dim=1)
198
+ else:
199
+ raise Exception("kwargs must contain either 'test_model_kwargs' or 'rest' key")
200
+
201
+ # maybe should assert not both of these are true
202
+ # print('index', index)
203
+ if isinstance(drop_latent_guidance, list):
204
+ cur_drop_latent_guidance = drop_latent_guidance[index]
205
+ else:
206
+ cur_drop_latent_guidance = drop_latent_guidance
207
+ # print('cur drop guidance', cur_drop_latent_guidance)
208
+
209
+ if (unconditional_conditioning is None or unconditional_guidance_scale == 1.) and cur_drop_latent_guidance == 1.:
210
+ e_t = self.model.apply_model(x, t, c, z_ref=z_ref)
211
+ elif cur_drop_latent_guidance != 1.:
212
+ assert (unconditional_conditioning is None or unconditional_guidance_scale == 1.)
213
+ x_dropped = x.clone()
214
+ # print('x dropped shape', x_dropped.shape)
215
+ x_dropped[:,4:9] *= 0.0
216
+ x_in = torch.cat([x_dropped, x])
217
+ t_in = torch.cat([t] * 2)
218
+ z_ref_in = torch.cat([z_ref] * 2)
219
+ c_in = torch.cat([c] * 2)
220
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in).chunk(2)
221
+ e_t = e_t_uncond + cur_drop_latent_guidance * (e_t - e_t_uncond)
222
+
223
+ else:
224
+ x_in = torch.cat([x] * 2)
225
+ t_in = torch.cat([t] * 2)
226
+ z_ref_in = torch.cat([z_ref] * 2)
227
+ # print('uncond shape', unconditional_conditioning.shape, 'c shape', c.shape)
228
+ c_in = torch.cat([unconditional_conditioning, c])
229
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, z_ref=z_ref_in).chunk(2)
230
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
231
+
232
+ if score_corrector is not None:
233
+ assert self.model.parameterization == "eps"
234
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
235
+
236
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
237
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
238
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
239
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
240
+ # select parameters corresponding to the currently considered timestep
241
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
242
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
243
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
244
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
245
+
246
+ # current prediction for x_0
247
+ if x.shape[1]!=4:
248
+ pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt()
249
+ else:
250
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
251
+ if quantize_denoised:
252
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
253
+ # direction pointing to x_t
254
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
255
+ noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature
256
+ if noise_dropout > 0.:
257
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
258
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
259
+ return x_prev, pred_x0
260
+
261
+ @torch.no_grad()
262
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
263
+ # fast, but does not allow for exact reconstruction
264
+ # t serves as an index to gather the correct alphas
265
+ if use_original_steps:
266
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
267
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
268
+ else:
269
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
270
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
271
+
272
+ if noise is None:
273
+ noise = torch.randn_like(x0)
274
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
275
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
276
+
277
+ @torch.no_grad()
278
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
279
+ use_original_steps=False):
280
+
281
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
282
+ timesteps = timesteps[:t_start]
283
+
284
+ time_range = np.flip(timesteps)
285
+ total_steps = timesteps.shape[0]
286
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
287
+
288
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
289
+ x_dec = x_latent
290
+ for i, step in enumerate(iterator):
291
+ index = total_steps - i - 1
292
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
293
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
294
+ unconditional_guidance_scale=unconditional_guidance_scale,
295
+ unconditional_conditioning=unconditional_conditioning)
296
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """
15
+ wild mixture of
16
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
17
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
18
+ https://github.com/CompVis/taming-transformers
19
+ -- merci
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision
25
+ import numpy as np
26
+ import pytorch_lightning as pl
27
+ from torch.optim.lr_scheduler import LambdaLR
28
+ from einops import rearrange, repeat
29
+ from contextlib import contextmanager
30
+ from functools import partial
31
+ from tqdm import tqdm
32
+ from torchvision.utils import make_grid
33
+ # from pytorch_lightning.utilities.distributed import rank_zero_only
34
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
35
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
36
+ from ldm.modules.ema import LitEma
37
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
38
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
39
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
40
+ from ldm.models.diffusion.ddim import DDIMSampler
41
+ from torchvision.transforms import Resize
42
+ import math
43
+ import time
44
+ import random
45
+ from torch.autograd import Variable
46
+ import copy
47
+ import os
48
+
49
+ __conditioning_keys__ = {'concat': 'c_concat',
50
+ 'crossattn': 'c_crossattn',
51
+ 'adm': 'y'}
52
+
53
+
54
+ def disabled_train(self, mode=True):
55
+ """Overwrite model.train with this function to make sure train/eval mode
56
+ does not change anymore."""
57
+ return self
58
+
59
+
60
+ def uniform_on_device(r1, r2, shape, device):
61
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
62
+
63
+
64
+ def rescale_zero_terminal_snr(betas):
65
+ """
66
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
67
+
68
+
69
+ Args:
70
+ betas (`torch.FloatTensor`):
71
+ the betas that the scheduler is being initialized with.
72
+
73
+ Returns:
74
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
75
+ """
76
+ # Convert betas to alphas_bar_sqrt
77
+ alphas = 1.0 - betas
78
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
79
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
80
+
81
+ # Store old values.
82
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
83
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
84
+
85
+ # Shift so the last timestep is zero.
86
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
87
+
88
+ # Scale so the first timestep is back to the old value.
89
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
90
+
91
+ # Convert alphas_bar_sqrt to betas
92
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
93
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
94
+ alphas = torch.cat([alphas_bar[0:1], alphas])
95
+ betas = 1 - alphas
96
+
97
+ return betas
98
+
99
+
100
+ class DDPM(pl.LightningModule):
101
+ # classic DDPM with Gaussian diffusion, in image space
102
+ def __init__(self,
103
+ unet_config,
104
+ timesteps=1000,
105
+ beta_schedule="linear",
106
+ loss_type="l2",
107
+ ckpt_path=None,
108
+ ignore_keys=[],
109
+ load_only_unet=False,
110
+ monitor="val/loss",
111
+ use_ema=True,
112
+ first_stage_key="image",
113
+ image_size=256,
114
+ channels=3,
115
+ log_every_t=100,
116
+ clip_denoised=True,
117
+ linear_start=1e-4,
118
+ linear_end=2e-2,
119
+ cosine_s=8e-3,
120
+ given_betas=None,
121
+ original_elbo_weight=0.,
122
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
123
+ l_simple_weight=1.,
124
+ conditioning_key=None,
125
+ parameterization="eps", # all assuming fixed variance schedules
126
+ scheduler_config=None,
127
+ use_positional_encodings=False,
128
+ learn_logvar=False,
129
+ logvar_init=0.,
130
+ u_cond_percent=0,
131
+ dropping_warped_latent_prob=0.,
132
+ remove_warped_latent=False,
133
+ gt_flag='GT',
134
+ sd_edit_step=850
135
+ ):
136
+ super().__init__()
137
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
138
+ self.parameterization = parameterization
139
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
140
+ self.cond_stage_model = None
141
+ self.clip_denoised = clip_denoised
142
+ self.log_every_t = log_every_t
143
+ self.first_stage_key = first_stage_key
144
+ self.image_size = image_size
145
+ self.channels = channels
146
+ self.u_cond_percent=u_cond_percent
147
+ self.use_positional_encodings = use_positional_encodings
148
+ self.gt_flag = gt_flag
149
+ self.sd_edit_step = sd_edit_step
150
+
151
+ self.remove_warped_latent = remove_warped_latent
152
+ self.dropping_warped_latent_prob = dropping_warped_latent_prob
153
+
154
+ if dropping_warped_latent_prob > 0.0:
155
+ assert not self.remove_warped_latent
156
+
157
+
158
+ self.use_ema = use_ema
159
+ if self.use_ema:
160
+ self.model_ema = LitEma(self.model)
161
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
162
+
163
+ self.use_scheduler = scheduler_config is not None
164
+ if self.use_scheduler:
165
+ self.scheduler_config = scheduler_config
166
+
167
+ self.v_posterior = v_posterior
168
+ self.original_elbo_weight = original_elbo_weight
169
+ self.l_simple_weight = l_simple_weight
170
+
171
+ if monitor is not None:
172
+ self.monitor = monitor
173
+ if ckpt_path is not None:
174
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
175
+
176
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
177
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
178
+
179
+ self.model = DiffusionWrapper(unet_config, conditioning_key, ddpm_parent=self,
180
+ sqrt_alphas_cumprod=self.sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=self.sqrt_one_minus_alphas_cumprod)
181
+ count_params(self.model, verbose=True)
182
+
183
+ self.loss_type = loss_type
184
+
185
+ self.learn_logvar = learn_logvar
186
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
187
+ if self.learn_logvar:
188
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
189
+
190
+
191
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
192
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
193
+ if exists(given_betas):
194
+ betas = given_betas
195
+ else:
196
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
197
+ cosine_s=cosine_s)
198
+
199
+ # rescale beta
200
+ rescale_beta = True
201
+ if rescale_beta:
202
+ betas = rescale_zero_terminal_snr(torch.tensor(betas)).numpy()
203
+
204
+ alphas = 1. - betas
205
+ alphas_cumprod = np.cumprod(alphas, axis=0)
206
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
207
+
208
+ timesteps, = betas.shape
209
+ self.num_timesteps = int(timesteps)
210
+ self.linear_start = linear_start
211
+ self.linear_end = linear_end
212
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
213
+
214
+ to_torch = partial(torch.tensor, dtype=torch.float32)
215
+
216
+ self.register_buffer('betas', to_torch(betas))
217
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
218
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
219
+
220
+ # calculations for diffusion q(x_t | x_{t-1}) and others
221
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
222
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
223
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
224
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
225
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
226
+
227
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
228
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
229
+ 1. - alphas_cumprod) + self.v_posterior * betas
230
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
231
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
232
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
233
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
234
+ self.register_buffer('posterior_mean_coef1', to_torch(
235
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
236
+ self.register_buffer('posterior_mean_coef2', to_torch(
237
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
238
+
239
+ if self.parameterization == "eps":
240
+ lvlb_weights = self.betas ** 2 / (
241
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
242
+ elif self.parameterization == "x0":
243
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
244
+ else:
245
+ raise NotImplementedError("mu not supported")
246
+ # pr_odo how to choose this term
247
+ lvlb_weights[0] = lvlb_weights[1]
248
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
249
+ assert not torch.isnan(self.lvlb_weights).all()
250
+
251
+ @contextmanager
252
+ def ema_scope(self, context=None):
253
+ if self.use_ema:
254
+ self.model_ema.store(self.model.parameters())
255
+ self.model_ema.copy_to(self.model)
256
+ if context is not None:
257
+ print(f"{context}: Switched to EMA weights")
258
+ try:
259
+ yield None
260
+ finally:
261
+ if self.use_ema:
262
+ self.model_ema.restore(self.model.parameters())
263
+ if context is not None:
264
+ print(f"{context}: Restored training weights")
265
+
266
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
267
+ sd = torch.load(path, map_location="cpu")
268
+ if "state_dict" in list(sd.keys()):
269
+ sd = sd["state_dict"]
270
+ keys = list(sd.keys())
271
+ for k in keys:
272
+ for ik in ignore_keys:
273
+ if k.startswith(ik):
274
+ print("Deleting key {} from state_dict.".format(k))
275
+ del sd[k]
276
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
277
+ sd, strict=False)
278
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
279
+ if len(missing) > 0:
280
+ print(f"Missing Keys: {missing}")
281
+ if len(unexpected) > 0:
282
+ print(f"Unexpected Keys: {unexpected}")
283
+
284
+ def q_mean_variance(self, x_start, t):
285
+ """
286
+ Get the distribution q(x_t | x_0).
287
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
288
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
289
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
290
+ """
291
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
292
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
293
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
294
+ return mean, variance, log_variance
295
+
296
+ def predict_start_from_noise(self, x_t, t, noise):
297
+ return (
298
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
299
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
300
+ )
301
+
302
+ def q_posterior(self, x_start, x_t, t):
303
+ posterior_mean = (
304
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
305
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
306
+ )
307
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
308
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
309
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
310
+
311
+ def p_mean_variance(self, x, t, clip_denoised: bool):
312
+ model_out = self.model(x, t)
313
+ if self.parameterization == "eps":
314
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
315
+ elif self.parameterization == "x0":
316
+ x_recon = model_out
317
+ if clip_denoised:
318
+ x_recon.clamp_(-1., 1.)
319
+
320
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
321
+ return model_mean, posterior_variance, posterior_log_variance
322
+
323
+ @torch.no_grad()
324
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
325
+ b, *_, device = *x.shape, x.device
326
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
327
+ noise = noise_like(x.shape, device, repeat_noise)
328
+ # no noise when t == 0
329
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
330
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
331
+
332
+ @torch.no_grad()
333
+ def p_sample_loop(self, shape, return_intermediates=False):
334
+ device = self.betas.device
335
+ b = shape[0]
336
+ img = torch.randn(shape, device=device)
337
+ intermediates = [img]
338
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
339
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
340
+ clip_denoised=self.clip_denoised)
341
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
342
+ intermediates.append(img)
343
+ if return_intermediates:
344
+ return img, intermediates
345
+ return img
346
+
347
+ @torch.no_grad()
348
+ def sample(self, batch_size=16, return_intermediates=False):
349
+ image_size = self.image_size
350
+ channels = self.channels
351
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
352
+ return_intermediates=return_intermediates)
353
+
354
+ def q_sample(self, x_start, t, noise=None):
355
+ noise = default(noise, lambda: torch.randn_like(x_start))
356
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
357
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
358
+
359
+ def get_loss(self, pred, target, mean=True):
360
+ if self.loss_type == 'l1':
361
+ loss = (target - pred).abs()
362
+ if mean:
363
+ loss = loss.mean()
364
+ elif self.loss_type == 'l2':
365
+ if mean:
366
+ loss = torch.nn.functional.mse_loss(target, pred)
367
+ else:
368
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
369
+ else:
370
+ raise NotImplementedError("unknown loss type '{loss_type}'")
371
+
372
+ return loss
373
+
374
+ def p_losses(self, x_start, t, noise=None):
375
+ noise = default(noise, lambda: torch.randn_like(x_start))
376
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
377
+ model_out = self.model(x_noisy, t)
378
+
379
+ loss_dict = {}
380
+ if self.parameterization == "eps":
381
+ target = noise
382
+ elif self.parameterization == "x0":
383
+ target = x_start
384
+ else:
385
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
386
+
387
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
388
+
389
+ log_prefix = 'train' if self.training else 'val'
390
+
391
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
392
+ loss_simple = loss.mean() * self.l_simple_weight
393
+
394
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
395
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
396
+
397
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
398
+
399
+ loss_dict.update({f'{log_prefix}/loss': loss})
400
+
401
+ return loss, loss_dict
402
+
403
+ def forward(self, x, *args, **kwargs):
404
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
405
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
406
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
407
+ return self.p_losses(x, t, *args, **kwargs)
408
+
409
+ def get_input(self, batch, k):
410
+ if k == "inpaint":
411
+ x = batch[self.gt_flag]
412
+ mask = batch['inpaint_mask']
413
+ inpaint = batch['inpaint_image']
414
+ reference = batch['ref_imgs']
415
+ clean_reference = batch['clean_reference']
416
+ grid_transformed = batch['grid_transformed']
417
+ changed_pixels = batch['changed_pixels']
418
+ else:
419
+ x = batch[k]
420
+ if len(x.shape) == 3:
421
+ x = x[..., None]
422
+ # x = rearrange(x, 'b h w c -> b c h w')
423
+ x = x.to(memory_format=torch.contiguous_format).float()
424
+ mask = mask.to(memory_format=torch.contiguous_format).float()
425
+ inpaint = inpaint.to(memory_format=torch.contiguous_format).float()
426
+ reference = reference.to(memory_format=torch.contiguous_format).float()
427
+ clean_reference = clean_reference.to(memory_format=torch.contiguous_format).float()
428
+ grid_transformed = grid_transformed.to(memory_format=torch.contiguous_format).float()
429
+ return x,inpaint,mask,reference, clean_reference, grid_transformed, changed_pixels
430
+
431
+ def shared_step(self, batch):
432
+ x = self.get_input(batch, self.first_stage_key)
433
+ loss, loss_dict = self(x)
434
+ return loss, loss_dict
435
+
436
+ def training_step(self, batch, batch_idx):
437
+ loss, loss_dict = self.shared_step(batch)
438
+
439
+ self.log_dict(loss_dict, prog_bar=True,
440
+ logger=True, on_step=True, on_epoch=True)
441
+
442
+ self.log("global_step", self.global_step,
443
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
444
+
445
+ if self.use_scheduler:
446
+ lr = self.optimizers().param_groups[0]['lr']
447
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
448
+
449
+ return loss
450
+
451
+ @torch.no_grad()
452
+ def validation_step(self, batch, batch_idx):
453
+ _, loss_dict_no_ema = self.shared_step(batch)
454
+ with self.ema_scope():
455
+ _, loss_dict_ema = self.shared_step(batch)
456
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
457
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
458
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
459
+
460
+ def on_train_batch_end(self, *args, **kwargs):
461
+ if self.use_ema:
462
+ self.model_ema(self.model)
463
+
464
+ def _get_rows_from_list(self, samples):
465
+ n_imgs_per_row = len(samples)
466
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
467
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
468
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
469
+ return denoise_grid
470
+
471
+ @torch.no_grad()
472
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
473
+ log = dict()
474
+ x = self.get_input(batch, self.first_stage_key)
475
+ N = min(x.shape[0], N)
476
+ n_row = min(x.shape[0], n_row)
477
+ x = x.to(self.device)[:N]
478
+ log["inputs"] = x
479
+
480
+ # get diffusion row
481
+ diffusion_row = list()
482
+ x_start = x[:n_row]
483
+
484
+ for t in range(self.num_timesteps):
485
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
486
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
487
+ t = t.to(self.device).long()
488
+ noise = torch.randn_like(x_start)
489
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
490
+ diffusion_row.append(x_noisy)
491
+
492
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
493
+
494
+ if sample:
495
+ # get denoise row
496
+ with self.ema_scope("Plotting"):
497
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
498
+
499
+ log["samples"] = samples
500
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
501
+
502
+ if return_keys:
503
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
504
+ return log
505
+ else:
506
+ return {key: log[key] for key in return_keys}
507
+ return log
508
+
509
+ def configure_optimizers(self):
510
+ lr = self.learning_rate
511
+ params = list(self.model.parameters())
512
+ if self.learn_logvar:
513
+ params = params + [self.logvar]
514
+ opt = torch.optim.AdamW(params, lr=lr)
515
+ return opt
516
+
517
+
518
+ class LatentDiffusion(DDPM):
519
+ """main class"""
520
+ def __init__(self,
521
+ first_stage_config,
522
+ cond_stage_config,
523
+ num_timesteps_cond=None,
524
+ cond_stage_key="image",
525
+ cond_stage_trainable=False,
526
+ concat_mode=True,
527
+ cond_stage_forward=None,
528
+ conditioning_key=None,
529
+ scale_factor=1.0,
530
+ scale_by_std=False,
531
+ context_embedding_dim=1024, # dim used for clip image encoder
532
+ *args, **kwargs):
533
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
534
+ self.scale_by_std = scale_by_std
535
+ assert self.num_timesteps_cond <= kwargs['timesteps']
536
+ # for backwards compatibility after implementation of DiffusionWrapper
537
+ if conditioning_key is None:
538
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
539
+ if cond_stage_config == '__is_unconditional__':
540
+ conditioning_key = None
541
+ ckpt_path = kwargs.pop("ckpt_path", None)
542
+ ignore_keys = kwargs.pop("ignore_keys", [])
543
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
544
+ self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
545
+ self.proj_out=nn.Linear(context_embedding_dim, 768)
546
+ self.concat_mode = concat_mode
547
+ self.cond_stage_trainable = cond_stage_trainable
548
+ self.cond_stage_key = cond_stage_key
549
+ try:
550
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
551
+ except:
552
+ self.num_downs = 0
553
+ if not scale_by_std:
554
+ self.scale_factor = scale_factor
555
+ else:
556
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
557
+ self.instantiate_first_stage(first_stage_config)
558
+ self.instantiate_cond_stage(cond_stage_config)
559
+ self.cond_stage_forward = cond_stage_forward
560
+ self.clip_denoised = False
561
+ self.bbox_tokenizer = None
562
+
563
+ self.restarted_from_ckpt = False
564
+ if ckpt_path is not None:
565
+ self.init_from_ckpt(ckpt_path, ignore_keys)
566
+ self.restarted_from_ckpt = True
567
+
568
+ def make_cond_schedule(self, ):
569
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
570
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
571
+ self.cond_ids[:self.num_timesteps_cond] = ids
572
+
573
+ @rank_zero_only
574
+ @torch.no_grad()
575
+ def on_train_batch_start(self, batch, batch_idx):
576
+ # only for very first batch
577
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
578
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
579
+ # set rescale weight to 1./std of encodings
580
+ print("### USING STD-RESCALING ###")
581
+ x = super().get_input(batch, self.first_stage_key)
582
+ x = x.to(self.device)
583
+ encoder_posterior = self.encode_first_stage(x)
584
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
585
+ del self.scale_factor
586
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
587
+ print(f"setting self.scale_factor to {self.scale_factor}")
588
+ print("### USING STD-RESCALING ###")
589
+
590
+ def register_schedule(self,
591
+ given_betas=None, beta_schedule="linear", timesteps=1000,
592
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
593
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
594
+
595
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
596
+ if self.shorten_cond_schedule:
597
+ self.make_cond_schedule()
598
+
599
+ def instantiate_first_stage(self, config):
600
+ model = instantiate_from_config(config)
601
+ self.first_stage_model = model.eval()
602
+ self.first_stage_model.train = disabled_train
603
+ for param in self.first_stage_model.parameters():
604
+ param.requires_grad = False
605
+
606
+ def instantiate_cond_stage(self, config):
607
+ if not self.cond_stage_trainable:
608
+ if config == "__is_first_stage__":
609
+ print("Using first stage also as cond stage.")
610
+ self.cond_stage_model = self.first_stage_model
611
+ elif config == "__is_unconditional__":
612
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
613
+ self.cond_stage_model = None
614
+ # self.be_unconditional = True
615
+ else:
616
+ model = instantiate_from_config(config)
617
+ self.cond_stage_model = model.eval()
618
+ self.cond_stage_model.train = disabled_train
619
+ for param in self.cond_stage_model.parameters():
620
+ param.requires_grad = False
621
+ else:
622
+ assert config != '__is_first_stage__'
623
+ assert config != '__is_unconditional__'
624
+ model = instantiate_from_config(config)
625
+ self.cond_stage_model = model
626
+
627
+
628
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
629
+ denoise_row = []
630
+ for zd in tqdm(samples, desc=desc):
631
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
632
+ force_not_quantize=force_no_decoder_quantization))
633
+ n_imgs_per_row = len(denoise_row)
634
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
635
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
636
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
637
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
638
+ return denoise_grid
639
+
640
+ def get_first_stage_encoding(self, encoder_posterior):
641
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
642
+ z = encoder_posterior.sample()
643
+ elif isinstance(encoder_posterior, torch.Tensor):
644
+ z = encoder_posterior
645
+ else:
646
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
647
+ return self.scale_factor * z
648
+
649
+ def get_learned_conditioning(self, c):
650
+ if self.cond_stage_forward is None:
651
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
652
+ c = self.cond_stage_model.encode(c)
653
+ if isinstance(c, DiagonalGaussianDistribution):
654
+ c = c.mode()
655
+ else:
656
+ c = self.cond_stage_model(c)
657
+ else:
658
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
659
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
660
+ return c
661
+
662
+
663
+ def meshgrid(self, h, w):
664
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
665
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
666
+
667
+ arr = torch.cat([y, x], dim=-1)
668
+ return arr
669
+
670
+ def delta_border(self, h, w):
671
+ """
672
+ :param h: height
673
+ :param w: width
674
+ :return: normalized distance to image border,
675
+ wtith min distance = 0 at border and max dist = 0.5 at image center
676
+ """
677
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
678
+ arr = self.meshgrid(h, w) / lower_right_corner
679
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
680
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
681
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
682
+ return edge_dist
683
+
684
+ def get_weighting(self, h, w, Ly, Lx, device):
685
+ weighting = self.delta_border(h, w)
686
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
687
+ self.split_input_params["clip_max_weight"], )
688
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
689
+
690
+ if self.split_input_params["tie_braker"]:
691
+ L_weighting = self.delta_border(Ly, Lx)
692
+ L_weighting = torch.clip(L_weighting,
693
+ self.split_input_params["clip_min_tie_weight"],
694
+ self.split_input_params["clip_max_tie_weight"])
695
+
696
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
697
+ weighting = weighting * L_weighting
698
+ return weighting
699
+
700
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # pr_odo load once not every time, shorten code
701
+ """
702
+ :param x: img of size (bs, c, h, w)
703
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
704
+ """
705
+ bs, nc, h, w = x.shape
706
+
707
+ # number of crops in image
708
+ Ly = (h - kernel_size[0]) // stride[0] + 1
709
+ Lx = (w - kernel_size[1]) // stride[1] + 1
710
+
711
+ if uf == 1 and df == 1:
712
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
713
+ unfold = torch.nn.Unfold(**fold_params)
714
+
715
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
716
+
717
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
718
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
719
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
720
+
721
+ elif uf > 1 and df == 1:
722
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
723
+ unfold = torch.nn.Unfold(**fold_params)
724
+
725
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
726
+ dilation=1, padding=0,
727
+ stride=(stride[0] * uf, stride[1] * uf))
728
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
729
+
730
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
731
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
732
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
733
+
734
+ elif df > 1 and uf == 1:
735
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
736
+ unfold = torch.nn.Unfold(**fold_params)
737
+
738
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
739
+ dilation=1, padding=0,
740
+ stride=(stride[0] // df, stride[1] // df))
741
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
742
+
743
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
744
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
745
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
746
+
747
+ else:
748
+ raise NotImplementedError
749
+
750
+ return fold, unfold, normalization, weighting
751
+
752
+ @torch.no_grad()
753
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
754
+ cond_key=None, return_original_cond=False, bs=None,get_mask=False,get_reference=False,get_inpaint=False, get_clean_ref=False, get_ref_rec=False,
755
+ get_changed_pixels=False):
756
+
757
+ x,inpaint,mask,reference, clean_reference, grid_transformed, changed_pixels = super().get_input(batch, k)
758
+ if bs is not None:
759
+ x = x[:bs]
760
+ inpaint = inpaint[:bs]
761
+ mask = mask[:bs]
762
+ reference = reference[:bs]
763
+ clean_reference = clean_reference[:bs]
764
+ grid_transformed = grid_transformed[:bs]
765
+ changed_pixels = changed_pixels[:bs]
766
+ x = x.to(self.device)
767
+ encoder_posterior = self.encode_first_stage(x)
768
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
769
+ encoder_posterior_inpaint = self.encode_first_stage(inpaint)
770
+ z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach()
771
+
772
+ encoder_posterior_inpaint = self.encode_first_stage(clean_reference)
773
+ z_reference = self.get_first_stage_encoding(encoder_posterior_inpaint).detach()
774
+ # breakpoint()
775
+ mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask)
776
+ grid_resized = Resize([z.shape[-1],z.shape[-1]])(grid_transformed)
777
+ z_new = torch.cat((z,z_inpaint,mask_resize, grid_resized),dim=1)
778
+ # z_new = torch.cat((z,z_inpaint,mask_resize, changed_pixels, grid_resized),dim=1)
779
+ # z_new = torch.cat((z,z_inpaint,mask_resize, grid_resized),dim=1)
780
+
781
+ if self.model.conditioning_key is not None:
782
+ if cond_key is None:
783
+ cond_key = self.cond_stage_key
784
+ if cond_key != self.first_stage_key:
785
+ if cond_key in ['txt','caption', 'coordinates_bbox']:
786
+ xc = batch[cond_key]
787
+ elif cond_key == 'image':
788
+ xc = reference
789
+ elif cond_key == 'class_label':
790
+ xc = batch
791
+ else:
792
+ xc = super().get_input(batch, cond_key).to(self.device)
793
+ else:
794
+ xc = x
795
+ if not self.cond_stage_trainable or force_c_encode:
796
+ if isinstance(xc, dict) or isinstance(xc, list):
797
+ # import pudb; pudb.set_trace()
798
+ c = self.get_learned_conditioning(xc)
799
+ else:
800
+ c = self.get_learned_conditioning(xc.to(self.device))
801
+ c = self.proj_out(c)
802
+ c = c.float()
803
+ else:
804
+ c = xc
805
+ if bs is not None:
806
+ c = c[:bs]
807
+
808
+ if self.use_positional_encodings:
809
+ pos_x, pos_y = self.compute_latent_shifts(batch)
810
+ ckey = __conditioning_keys__[self.model.conditioning_key]
811
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
812
+
813
+ else:
814
+ c = None
815
+ xc = None
816
+ if self.use_positional_encodings:
817
+ pos_x, pos_y = self.compute_latent_shifts(batch)
818
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
819
+
820
+ # embed reference latent into cond
821
+ # c = [c, z_reference]
822
+ out = [z_new, c, z_reference]
823
+ if return_first_stage_outputs:
824
+ if self.first_stage_key=='inpaint':
825
+ xrec = self.decode_first_stage(z[:,:4,:,:])
826
+ else:
827
+ xrec = self.decode_first_stage(z)
828
+ out.extend([x, xrec])
829
+ if return_original_cond:
830
+ out.append(xc)
831
+ if get_mask:
832
+ out.append(mask)
833
+ if get_reference:
834
+ out.append(reference)
835
+ if get_inpaint:
836
+ out.append(inpaint)
837
+ if get_clean_ref:
838
+ out.append(clean_reference)
839
+ if get_ref_rec:
840
+ ref_rec = self.decode_first_stage(z_reference)
841
+ out.append(ref_rec)
842
+ if get_changed_pixels:
843
+ out.append(changed_pixels)
844
+ return out
845
+
846
+ @torch.no_grad()
847
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
848
+ if predict_cids:
849
+ if z.dim() == 4:
850
+ z = torch.argmax(z.exp(), dim=1).long()
851
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
852
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
853
+
854
+ z = 1. / self.scale_factor * z
855
+
856
+ if hasattr(self, "split_input_params"):
857
+ if self.split_input_params["patch_distributed_vq"]:
858
+ ks = self.split_input_params["ks"] # eg. (128, 128)
859
+ stride = self.split_input_params["stride"] # eg. (64, 64)
860
+ uf = self.split_input_params["vqf"]
861
+ bs, nc, h, w = z.shape
862
+ if ks[0] > h or ks[1] > w:
863
+ ks = (min(ks[0], h), min(ks[1], w))
864
+ print("reducing Kernel")
865
+
866
+ if stride[0] > h or stride[1] > w:
867
+ stride = (min(stride[0], h), min(stride[1], w))
868
+ print("reducing stride")
869
+
870
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
871
+
872
+ z = unfold(z) # (bn, nc * prod(**ks), L)
873
+ # 1. Reshape to img shape
874
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
875
+
876
+ # 2. apply model loop over last dim
877
+ if isinstance(self.first_stage_model, VQModelInterface):
878
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
879
+ force_not_quantize=predict_cids or force_not_quantize)
880
+ for i in range(z.shape[-1])]
881
+ else:
882
+
883
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
884
+ for i in range(z.shape[-1])]
885
+
886
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
887
+ o = o * weighting
888
+ # Reverse 1. reshape to img shape
889
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
890
+ # stitch crops together
891
+ decoded = fold(o)
892
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
893
+ return decoded
894
+ else:
895
+ if isinstance(self.first_stage_model, VQModelInterface):
896
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
897
+ else:
898
+ return self.first_stage_model.decode(z)
899
+
900
+ else:
901
+ if isinstance(self.first_stage_model, VQModelInterface):
902
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
903
+ else:
904
+ if self.first_stage_key=='inpaint':
905
+ return self.first_stage_model.decode(z[:,:4,:,:])
906
+ else:
907
+ return self.first_stage_model.decode(z)
908
+
909
+ # same as above but without decorator
910
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
911
+ if predict_cids:
912
+ if z.dim() == 4:
913
+ z = torch.argmax(z.exp(), dim=1).long()
914
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
915
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
916
+
917
+ z = 1. / self.scale_factor * z
918
+
919
+ if hasattr(self, "split_input_params"):
920
+ if self.split_input_params["patch_distributed_vq"]:
921
+ ks = self.split_input_params["ks"] # eg. (128, 128)
922
+ stride = self.split_input_params["stride"] # eg. (64, 64)
923
+ uf = self.split_input_params["vqf"]
924
+ bs, nc, h, w = z.shape
925
+ if ks[0] > h or ks[1] > w:
926
+ ks = (min(ks[0], h), min(ks[1], w))
927
+ print("reducing Kernel")
928
+
929
+ if stride[0] > h or stride[1] > w:
930
+ stride = (min(stride[0], h), min(stride[1], w))
931
+ print("reducing stride")
932
+
933
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
934
+
935
+ z = unfold(z) # (bn, nc * prod(**ks), L)
936
+ # 1. Reshape to img shape
937
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
938
+
939
+ # 2. apply model loop over last dim
940
+ if isinstance(self.first_stage_model, VQModelInterface):
941
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
942
+ force_not_quantize=predict_cids or force_not_quantize)
943
+ for i in range(z.shape[-1])]
944
+ else:
945
+
946
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
947
+ for i in range(z.shape[-1])]
948
+
949
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
950
+ o = o * weighting
951
+ # Reverse 1. reshape to img shape
952
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
953
+ # stitch crops together
954
+ decoded = fold(o)
955
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
956
+ return decoded
957
+ else:
958
+ if isinstance(self.first_stage_model, VQModelInterface):
959
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
960
+ else:
961
+ return self.first_stage_model.decode(z)
962
+
963
+ else:
964
+ if isinstance(self.first_stage_model, VQModelInterface):
965
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
966
+ else:
967
+ return self.first_stage_model.decode(z)
968
+
969
+ @torch.no_grad()
970
+ def encode_first_stage(self, x):
971
+ if hasattr(self, "split_input_params"):
972
+ if self.split_input_params["patch_distributed_vq"]:
973
+ ks = self.split_input_params["ks"] # eg. (128, 128)
974
+ stride = self.split_input_params["stride"] # eg. (64, 64)
975
+ df = self.split_input_params["vqf"]
976
+ self.split_input_params['original_image_size'] = x.shape[-2:]
977
+ bs, nc, h, w = x.shape
978
+ if ks[0] > h or ks[1] > w:
979
+ ks = (min(ks[0], h), min(ks[1], w))
980
+ print("reducing Kernel")
981
+
982
+ if stride[0] > h or stride[1] > w:
983
+ stride = (min(stride[0], h), min(stride[1], w))
984
+ print("reducing stride")
985
+
986
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
987
+ z = unfold(x) # (bn, nc * prod(**ks), L)
988
+ # Reshape to img shape
989
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
990
+
991
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
992
+ for i in range(z.shape[-1])]
993
+
994
+ o = torch.stack(output_list, axis=-1)
995
+ o = o * weighting
996
+
997
+ # Reverse reshape to img shape
998
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
999
+ # stitch crops together
1000
+ decoded = fold(o)
1001
+ decoded = decoded / normalization
1002
+ return decoded
1003
+
1004
+ else:
1005
+ return self.first_stage_model.encode(x)
1006
+ else:
1007
+ return self.first_stage_model.encode(x)
1008
+
1009
+ def shared_step(self, batch, **kwargs):
1010
+ x, c, z_reference = self.get_input(batch, self.first_stage_key)
1011
+ loss = self(x, c, z_reference)
1012
+ return loss
1013
+
1014
+ def forward(self, x, c, z_reference, *args, **kwargs):
1015
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
1016
+ self.u_cond_prop=random.uniform(0, 1)
1017
+ if self.model.conditioning_key is not None:
1018
+ assert c is not None
1019
+ if self.cond_stage_trainable:
1020
+ c = self.get_learned_conditioning(c)
1021
+ c = self.proj_out(c)
1022
+
1023
+ if self.shorten_cond_schedule: # pr_odo: drop this option
1024
+ tc = self.cond_ids[t].to(self.device)
1025
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
1026
+
1027
+ if self.u_cond_prop<self.u_cond_percent:
1028
+ return self.p_losses(x, self.learnable_vector.repeat(x.shape[0],1,1), t, z_ref=z_reference, *args, **kwargs)
1029
+ else:
1030
+ return self.p_losses(x, c, t, z_ref=z_reference, *args, **kwargs)
1031
+
1032
+ def _rescale_annotations(self, bboxes, crop_coordinates): # pr_odo: move to dataset
1033
+ def rescale_bbox(bbox):
1034
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
1035
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
1036
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
1037
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
1038
+ return x0, y0, w, h
1039
+
1040
+ return [rescale_bbox(b) for b in bboxes]
1041
+
1042
+ def apply_model(self, x_noisy, t, cond, z_ref, return_ids=False):
1043
+
1044
+ if isinstance(cond, dict):
1045
+ # hybrid case, cond is exptected to be a dict
1046
+ pass
1047
+ else:
1048
+ if not isinstance(cond, list):
1049
+ cond = [cond]
1050
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
1051
+ cond = {key: cond}
1052
+
1053
+ if hasattr(self, "split_input_params"):
1054
+ raise ValueError('attempting to split input')
1055
+ # assert len(cond) == 1 # pr_odo can only deal with one conditioning atm
1056
+ # assert not return_ids
1057
+ # ks = self.split_input_params["ks"] # eg. (128, 128)
1058
+ # stride = self.split_input_params["stride"] # eg. (64, 64)
1059
+
1060
+ # h, w = x_noisy.shape[-2:]
1061
+
1062
+ # fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
1063
+
1064
+ # z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
1065
+ # # Reshape to img shape
1066
+ # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1067
+ # z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
1068
+
1069
+ # if self.cond_stage_key in ["image", "LR_image", "segmentation",
1070
+ # 'bbox_img'] and self.model.conditioning_key: # pr_odo check for completeness
1071
+ # c_key = next(iter(cond.keys())) # get key
1072
+ # c = next(iter(cond.values())) # get value
1073
+ # assert (len(c) == 1) # pr_odo extend to list with more than one elem
1074
+ # c = c[0] # get element
1075
+
1076
+ # c = unfold(c)
1077
+ # c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1078
+
1079
+ # cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
1080
+
1081
+ # elif self.cond_stage_key == 'coordinates_bbox':
1082
+ # assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
1083
+
1084
+ # # assuming padding of unfold is always 0 and its dilation is always 1
1085
+ # n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
1086
+ # full_img_h, full_img_w = self.split_input_params['original_image_size']
1087
+ # # as we are operating on latents, we need the factor from the original image size to the
1088
+ # # spatial latent size to properly rescale the crops for regenerating the bbox annotations
1089
+ # num_downs = self.first_stage_model.encoder.num_resolutions - 1
1090
+ # rescale_latent = 2 ** (num_downs)
1091
+
1092
+ # # get top left positions of patches as conforming for the bbbox tokenizer, therefore we
1093
+ # # need to rescale the tl patch coordinates to be in between (0,1)
1094
+ # tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
1095
+ # rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
1096
+ # for patch_nr in range(z.shape[-1])]
1097
+
1098
+ # # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
1099
+ # patch_limits = [(x_tl, y_tl,
1100
+ # rescale_latent * ks[0] / full_img_w,
1101
+ # rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
1102
+ # # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
1103
+
1104
+ # # tokenize crop coordinates for the bounding boxes of the respective patches
1105
+ # patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
1106
+ # for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
1107
+ # print(patch_limits_tknzd[0].shape)
1108
+ # # cut tknzd crop position from conditioning
1109
+ # assert isinstance(cond, dict), 'cond must be dict to be fed into model'
1110
+ # cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
1111
+
1112
+ # adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
1113
+ # adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
1114
+ # adapted_cond = self.get_learned_conditioning(adapted_cond)
1115
+ # adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
1116
+
1117
+ # cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
1118
+
1119
+ # else:
1120
+ # cond_list = [cond for i in range(z.shape[-1])] # pr_odo make this more efficient
1121
+
1122
+ # # apply model by loop over crops
1123
+ # output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1124
+ # assert not isinstance(output_list[0],
1125
+ # tuple) # pr_odo cant deal with multiple model outputs check this never happens
1126
+
1127
+ # o = torch.stack(output_list, axis=-1)
1128
+ # o = o * weighting
1129
+ # # Reverse reshape to img shape
1130
+ # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1131
+ # # stitch crops together
1132
+ # x_recon = fold(o) / normalization
1133
+
1134
+ else:
1135
+ # TODO address passing ref
1136
+ zeroed_out_warped_latent = x_noisy.clone()
1137
+ if self.remove_warped_latent:
1138
+ zeroed_out_warped_latent[:,4:8] *= 0.0
1139
+ x_recon = self.model(zeroed_out_warped_latent, t, z_ref=z_ref, **cond)
1140
+
1141
+ if isinstance(x_recon, tuple) and not return_ids:
1142
+ return x_recon[0]
1143
+ else:
1144
+ return x_recon
1145
+
1146
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1147
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1148
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1149
+
1150
+ def _prior_bpd(self, x_start):
1151
+ """
1152
+ Get the prior KL term for the variational lower-bound, measured in
1153
+ bits-per-dim.
1154
+ This term can't be optimized, as it only depends on the encoder.
1155
+ :param x_start: the [N x C x ...] tensor of inputs.
1156
+ :return: a batch of [N] KL values (in bits), one per batch element.
1157
+ """
1158
+ batch_size = x_start.shape[0]
1159
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1160
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1161
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1162
+ return mean_flat(kl_prior) / np.log(2.0)
1163
+
1164
+ def p_losses(self, x_start, cond, t, z_ref, noise=None):
1165
+ if self.first_stage_key == 'inpaint':
1166
+ # x_start=x_start[:,:4,:,:]
1167
+ latents = x_start[:,:4,:,:]
1168
+ latents_warped = x_start[:,4:8,:,:]
1169
+ noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
1170
+ # offset noise
1171
+ # noise += 0.05 * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
1172
+ # TODO address the reference latent
1173
+ # warped_mask = t > self.sd_edit_step
1174
+
1175
+ x_noisy = self.q_sample(x_start=latents, t=t, noise=noise)
1176
+ # warped_noisy = self.q_sample(x_start=latents_warped, t=t, noise=noise)
1177
+ # x_noisy[warped_mask] = warped_noisy[warped_mask]
1178
+
1179
+ # TODO add here
1180
+ remove_latent_prob=random.uniform(0, 1)
1181
+
1182
+ if remove_latent_prob < self.dropping_warped_latent_prob:
1183
+ modified_x_start = x_start.clone()
1184
+ # dropping warped latent and mask
1185
+ modified_x_start[:, 4:9] *= 0.0
1186
+
1187
+ # print('using modified x start')
1188
+ x_noisy = torch.cat((x_noisy,modified_x_start[:,4:,:,:]),dim=1)
1189
+ else:
1190
+ x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)
1191
+ else:
1192
+ noise = default(noise, lambda: torch.randn_like(x_start))
1193
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1194
+ model_output = self.apply_model(x_noisy, t, cond, z_ref)
1195
+
1196
+ loss_dict = {}
1197
+ prefix = 'train' if self.training else 'val'
1198
+
1199
+ if self.parameterization == "x0":
1200
+ target = x_start
1201
+ elif self.parameterization == "eps":
1202
+ target = noise
1203
+ else:
1204
+ raise NotImplementedError()
1205
+
1206
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1207
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1208
+
1209
+ self.logvar = self.logvar.to(self.device)
1210
+ logvar_t = self.logvar[t].to(self.device)
1211
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1212
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1213
+ if self.learn_logvar:
1214
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1215
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1216
+
1217
+ loss = self.l_simple_weight * loss.mean()
1218
+
1219
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1220
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1221
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1222
+ loss += (self.original_elbo_weight * loss_vlb)
1223
+ loss_dict.update({f'{prefix}/loss': loss})
1224
+
1225
+ return loss, loss_dict
1226
+
1227
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1228
+ return_x0=False, score_corrector=None, corrector_kwargs=None, z_ref=None):
1229
+ t_in = t
1230
+ #TODO pass reference
1231
+ model_out = self.apply_model(x, t_in, c, z_ref=z_ref, return_ids=return_codebook_ids)
1232
+
1233
+ if score_corrector is not None:
1234
+ assert self.parameterization == "eps"
1235
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1236
+
1237
+ if return_codebook_ids:
1238
+ model_out, logits = model_out
1239
+
1240
+ if self.parameterization == "eps":
1241
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1242
+ elif self.parameterization == "x0":
1243
+ x_recon = model_out
1244
+ else:
1245
+ raise NotImplementedError()
1246
+
1247
+ if clip_denoised:
1248
+ x_recon.clamp_(-1., 1.)
1249
+ if quantize_denoised:
1250
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1251
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1252
+ if return_codebook_ids:
1253
+ return model_mean, posterior_variance, posterior_log_variance, logits
1254
+ elif return_x0:
1255
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1256
+ else:
1257
+ return model_mean, posterior_variance, posterior_log_variance
1258
+
1259
+ @torch.no_grad()
1260
+ def p_sample(self, x, c, t, z_ref=None, clip_denoised=False, repeat_noise=False,
1261
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1262
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1263
+ b, *_, device = *x.shape, x.device
1264
+ outputs = self.p_mean_variance(x=x, c=c, t=t, z_ref=z_ref, clip_denoised=clip_denoised,
1265
+ return_codebook_ids=return_codebook_ids,
1266
+ quantize_denoised=quantize_denoised,
1267
+ return_x0=return_x0,
1268
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1269
+ if return_codebook_ids:
1270
+ raise DeprecationWarning("Support dropped.")
1271
+ model_mean, _, model_log_variance, logits = outputs
1272
+ elif return_x0:
1273
+ model_mean, _, model_log_variance, x0 = outputs
1274
+ else:
1275
+ model_mean, _, model_log_variance = outputs
1276
+
1277
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1278
+ if noise_dropout > 0.:
1279
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1280
+ # no noise when t == 0
1281
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1282
+
1283
+ if return_codebook_ids:
1284
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1285
+ if return_x0:
1286
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1287
+ else:
1288
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1289
+
1290
+ @torch.no_grad()
1291
+ def progressive_denoising(self, cond, shape, z_ref=None, verbose=True, callback=None, quantize_denoised=False,
1292
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1293
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1294
+ log_every_t=None):
1295
+ if not log_every_t:
1296
+ log_every_t = self.log_every_t
1297
+ timesteps = self.num_timesteps
1298
+ if batch_size is not None:
1299
+ b = batch_size if batch_size is not None else shape[0]
1300
+ shape = [batch_size] + list(shape)
1301
+ else:
1302
+ b = batch_size = shape[0]
1303
+ if x_T is None:
1304
+ img = torch.randn(shape, device=self.device)
1305
+ else:
1306
+ img = x_T
1307
+ intermediates = []
1308
+ if cond is not None:
1309
+ if isinstance(cond, dict):
1310
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1311
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1312
+ else:
1313
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1314
+
1315
+ if start_T is not None:
1316
+ timesteps = min(timesteps, start_T)
1317
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1318
+ total=timesteps) if verbose else reversed(
1319
+ range(0, timesteps))
1320
+ if type(temperature) == float:
1321
+ temperature = [temperature] * timesteps
1322
+
1323
+ for i in iterator:
1324
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1325
+ if self.shorten_cond_schedule:
1326
+ assert self.model.conditioning_key != 'hybrid'
1327
+ tc = self.cond_ids[ts].to(cond.device)
1328
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1329
+
1330
+ img, x0_partial = self.p_sample(img, cond, ts, z_ref=z_ref,
1331
+ clip_denoised=self.clip_denoised,
1332
+ quantize_denoised=quantize_denoised, return_x0=True,
1333
+ temperature=temperature[i], noise_dropout=noise_dropout,
1334
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1335
+ if mask is not None:
1336
+ assert x0 is not None
1337
+ img_orig = self.q_sample(x0, ts)
1338
+ img = img_orig * mask + (1. - mask) * img
1339
+
1340
+ if i % log_every_t == 0 or i == timesteps - 1:
1341
+ intermediates.append(x0_partial)
1342
+ if callback: callback(i)
1343
+ if img_callback: img_callback(img, i)
1344
+ return img, intermediates
1345
+
1346
+ @torch.no_grad()
1347
+ def p_sample_loop(self, cond, shape, z_ref=None, return_intermediates=False,
1348
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1349
+ mask=None, x0=None, img_callback=None, start_T=None,
1350
+ log_every_t=None):
1351
+
1352
+ if not log_every_t:
1353
+ log_every_t = self.log_every_t
1354
+ device = self.betas.device
1355
+ b = shape[0]
1356
+ if x_T is None:
1357
+ img = torch.randn(shape, device=device)
1358
+ else:
1359
+ img = x_T
1360
+
1361
+ intermediates = [img]
1362
+ if timesteps is None:
1363
+ timesteps = self.num_timesteps
1364
+
1365
+ if start_T is not None:
1366
+ timesteps = min(timesteps, start_T)
1367
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1368
+ range(0, timesteps))
1369
+
1370
+ if mask is not None:
1371
+ assert x0 is not None
1372
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1373
+
1374
+ for i in iterator:
1375
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1376
+ if self.shorten_cond_schedule:
1377
+ assert self.model.conditioning_key != 'hybrid'
1378
+ tc = self.cond_ids[ts].to(cond.device)
1379
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1380
+
1381
+ img = self.p_sample(img, cond, ts, z_ref=z_ref,
1382
+ clip_denoised=self.clip_denoised,
1383
+ quantize_denoised=quantize_denoised)
1384
+ if mask is not None:
1385
+ img_orig = self.q_sample(x0, ts)
1386
+ img = img_orig * mask + (1. - mask) * img
1387
+
1388
+ if i % log_every_t == 0 or i == timesteps - 1:
1389
+ intermediates.append(img)
1390
+ if callback: callback(i)
1391
+ if img_callback: img_callback(img, i)
1392
+
1393
+ if return_intermediates:
1394
+ return img, intermediates
1395
+ return img
1396
+
1397
+ @torch.no_grad()
1398
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1399
+ verbose=True, timesteps=None, quantize_denoised=False,
1400
+ mask=None, x0=None, shape=None,**kwargs):
1401
+ if shape is None:
1402
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1403
+ if cond is not None:
1404
+ if isinstance(cond, dict):
1405
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1406
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1407
+ else:
1408
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1409
+ return self.p_sample_loop(cond,
1410
+ shape,
1411
+ return_intermediates=return_intermediates, x_T=x_T,
1412
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1413
+ mask=mask, x0=x0)
1414
+
1415
+ @torch.no_grad()
1416
+ def sample_log(self,cond,batch_size,ddim, ddim_steps, z_ref=None, full_z=None,**kwargs):
1417
+
1418
+ if ddim:
1419
+ ddim_sampler = DDIMSampler(self)
1420
+ shape = (self.channels, self.image_size, self.image_size)
1421
+ z_inpaint = full_z[:,4:8]
1422
+ step=1
1423
+
1424
+
1425
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1426
+ shape,cond, z_ref=z_ref,verbose=False, x0=z_inpaint,
1427
+ x0_step=step,**kwargs)
1428
+
1429
+ else:
1430
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1431
+ return_intermediates=True,**kwargs)
1432
+
1433
+ return samples, intermediates
1434
+
1435
+
1436
+ @torch.no_grad()
1437
+ def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1438
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1439
+ plot_diffusion_rows=True, **kwargs):
1440
+
1441
+ use_ddim = ddim_steps is not None
1442
+
1443
+ log = dict()
1444
+
1445
+ z, c, z_ref, x, xrec, xc, mask, reference, inpaint_img, clean_ref, ref_rec, changed_pixels = self.get_input(batch, self.first_stage_key,
1446
+ return_first_stage_outputs=True,
1447
+ force_c_encode=True,
1448
+ return_original_cond=True,
1449
+ get_mask=True,
1450
+ get_reference=True,
1451
+ get_inpaint=True,
1452
+ bs=N,
1453
+ get_clean_ref=True,
1454
+ get_ref_rec=True,
1455
+ get_changed_pixels=True)
1456
+
1457
+ N = min(x.shape[0], N)
1458
+ n_row = min(x.shape[0], n_row)
1459
+ log["inputs"] = x
1460
+ log["reconstruction"] = xrec
1461
+ log["mask"]=mask
1462
+ log['changed_pixels'] = changed_pixels
1463
+ log["warped"]=inpaint_img
1464
+ log["original"] = clean_ref
1465
+ log["ref_rec"] = ref_rec
1466
+ # log["reference"]=reference
1467
+ if self.model.conditioning_key is not None:
1468
+ if hasattr(self.cond_stage_model, "decode"):
1469
+ xc = self.cond_stage_model.decode(c)
1470
+ log["conditioning"] = xc
1471
+ elif self.cond_stage_key in ["caption","txt"]:
1472
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key])
1473
+ log["conditioning"] = xc
1474
+ elif self.cond_stage_key == 'class_label':
1475
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1476
+ log['conditioning'] = xc
1477
+ elif isimage(xc):
1478
+ log["conditioning"] = xc
1479
+ if ismap(xc):
1480
+ log["original_conditioning"] = self.to_rgb(xc)
1481
+
1482
+ if plot_diffusion_rows:
1483
+ # get diffusion row
1484
+ diffusion_row = list()
1485
+ z_start = z[:n_row]
1486
+ for t in range(self.num_timesteps):
1487
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1488
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1489
+ t = t.to(self.device).long()
1490
+ noise = torch.randn_like(z_start)
1491
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1492
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1493
+
1494
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1495
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1496
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1497
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1498
+ log["diffusion_row"] = diffusion_grid
1499
+
1500
+ if sample:
1501
+ # get denoise row
1502
+ with self.ema_scope("Plotting"):
1503
+ if self.first_stage_key=='inpaint':
1504
+ samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, full_z=z,
1505
+ ddim_steps=ddim_steps,eta=ddim_eta,rest=z[:,4:,:,:])
1506
+ else:
1507
+ samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim,
1508
+ ddim_steps=ddim_steps,eta=ddim_eta)
1509
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1510
+ x_samples = self.decode_first_stage(samples)
1511
+ log["samples"] = x_samples
1512
+ if plot_denoise_rows:
1513
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1514
+ log["denoise_row"] = denoise_grid
1515
+
1516
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1517
+ self.first_stage_model, IdentityFirstStage):
1518
+ # also display when quantizing x0 while sampling
1519
+ with self.ema_scope("Plotting Quantized Denoised"):
1520
+ samples, z_denoise_row = self.sample_log(cond=c, z_ref=z_ref, batch_size=N,ddim=use_ddim,
1521
+ ddim_steps=ddim_steps,eta=ddim_eta,
1522
+ quantize_denoised=True)
1523
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1524
+ # quantize_denoised=True)
1525
+ x_samples = self.decode_first_stage(samples.to(self.device))
1526
+ log["samples_x0_quantized"] = x_samples
1527
+
1528
+ if inpaint:
1529
+ # make a simple center square
1530
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1531
+ mask = torch.ones(N, h, w).to(self.device)
1532
+ # zeros will be filled in
1533
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1534
+ mask = mask[:, None, ...]
1535
+ with self.ema_scope("Plotting Inpaint"):
1536
+
1537
+ samples, _ = self.sample_log(cond=c, z_ref=z_ref,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1538
+ ddim_steps=ddim_steps, x0=z[:N,:4], mask=mask)
1539
+ x_samples = self.decode_first_stage(samples.to(self.device))
1540
+ log["samples_inpainting"] = x_samples
1541
+ log["mask"] = mask
1542
+
1543
+ # outpaint
1544
+ with self.ema_scope("Plotting Outpaint"):
1545
+ samples, _ = self.sample_log(cond=c, z_ref=z_ref, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1546
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1547
+ x_samples = self.decode_first_stage(samples.to(self.device))
1548
+ log["samples_outpainting"] = x_samples
1549
+
1550
+ if plot_progressive_rows:
1551
+ with self.ema_scope("Plotting Progressives"):
1552
+ img, progressives = self.progressive_denoising(c,
1553
+ z_ref=z_ref,
1554
+ shape=(self.channels, self.image_size, self.image_size),
1555
+ batch_size=N)
1556
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1557
+ log["progressive_row"] = prog_row
1558
+
1559
+ if return_keys:
1560
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1561
+ return log
1562
+ else:
1563
+ return {key: log[key] for key in return_keys}
1564
+ return log
1565
+
1566
+ def configure_optimizers(self):
1567
+ lr = self.learning_rate
1568
+ params = list(self.model.parameters())
1569
+
1570
+
1571
+
1572
+ if self.cond_stage_trainable:
1573
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1574
+ # need to add final_ln.parameters() TODO
1575
+ params = params + list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters())
1576
+ if self.learn_logvar:
1577
+ print('Diffusion model optimizing logvar')
1578
+ params.append(self.logvar)
1579
+ params.append(self.learnable_vector)
1580
+ opt = torch.optim.AdamW(params, lr=lr)
1581
+ if self.use_scheduler:
1582
+ assert 'target' in self.scheduler_config
1583
+ scheduler = instantiate_from_config(self.scheduler_config)
1584
+
1585
+ print("Setting up LambdaLR scheduler...")
1586
+ scheduler = [
1587
+ {
1588
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1589
+ 'interval': 'step',
1590
+ 'frequency': 1
1591
+ }]
1592
+ return [opt], scheduler
1593
+ return opt
1594
+
1595
+ @torch.no_grad()
1596
+ def to_rgb(self, x):
1597
+ x = x.float()
1598
+ if not hasattr(self, "colorize"):
1599
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1600
+ x = nn.functional.conv2d(x, weight=self.colorize)
1601
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1602
+ return x
1603
+
1604
+
1605
+ class DiffusionWrapper(pl.LightningModule):
1606
+ def __init__(self, diff_model_config, conditioning_key, sqrt_alphas_cumprod=None, sqrt_one_minus_alphas_cumprod=None, ddpm_parent=None):
1607
+ super().__init__()
1608
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1609
+ self.conditioning_key = conditioning_key
1610
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'crossref', 'rewarp', 'rewarp_grid']
1611
+ # self.save_folder = '/mnt/localssd/collage_latents_lovely_new_data'
1612
+ # self.save_counter = 0
1613
+ # self.save_subfolder = None
1614
+
1615
+ # os.makedirs(self.save_folder, exist_ok=True)
1616
+ self.sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod
1617
+ self.sqrt_alphas_cumprod = sqrt_alphas_cumprod
1618
+ self.og_grid = None
1619
+ self.transformed_grid = None
1620
+ if self.conditioning_key == 'crossref' or 'rewarp' in self.conditioning_key:
1621
+ self.reference_model = copy.deepcopy(self.diffusion_model)
1622
+
1623
+
1624
+ def get_grid(self, size, batch_size):
1625
+ # raise ValueError TODO Fix
1626
+ y = np.repeat(np.arange(size)[None, ...], size)
1627
+ y = y.reshape(size, size)
1628
+ x = y.transpose()
1629
+ out = np.stack([y,x], 0)
1630
+ out = torch.tensor(out)
1631
+ out = out.unsqueeze(0)
1632
+ out = out.repeat(batch_size, 1, 1, 1)
1633
+ return out
1634
+
1635
+ def compute_correspondences(self, grid_transformed, masks, original_size=512, add_grids=False):
1636
+ # create the correspondence map for all the needed sizes
1637
+ corresp_indices = {}
1638
+ batch_size = grid_transformed.shape[0]
1639
+
1640
+ if self.og_grid is None:
1641
+ grid_og = self.get_grid(original_size, batch_size).to(grid_transformed.device) / float(original_size)
1642
+ else:
1643
+ grid_og = self.og_grid
1644
+
1645
+
1646
+ for d in [8, 16, 32, 64]:
1647
+ resized_grid_1 = torchvision.transforms.functional.resize(grid_og, size=(d,d))
1648
+ resized_grid_2 = torchvision.transforms.functional.resize(grid_transformed, size=(d,d))
1649
+ # the mask is at 64x64. 1 means exist in image. 0 is missing (needs inpainting)
1650
+ resized_mask = torchvision.transforms.functional.resize(masks, size=(d,d))
1651
+
1652
+ missing_mask = resized_mask.squeeze(1) < 0.7 #torch.sum(resized_grid_2, dim=1) < 0.1
1653
+
1654
+ src_grid = resized_grid_1.permute(0,2,3,1) # B x 2 x d x d
1655
+ guide_grid = resized_grid_2.permute(0,2,3,1)
1656
+
1657
+ src1_flat = src_grid.reshape(batch_size, d**2, 2)
1658
+ src2_flat = guide_grid.reshape(batch_size, d**2, 2)
1659
+ missing_flat = missing_mask.reshape(batch_size, d**2)
1660
+
1661
+ torch_dist = torch.cdist(src2_flat.float(), src1_flat.float())
1662
+ # print('torch dist shape for d', d, torch_dist.shape)
1663
+
1664
+ # missing_masks[d] = missing_flat
1665
+ min_indices = torch.argmin(torch_dist, dim=-1)
1666
+ # min_indices.requires_grad = False
1667
+ # missing_flat.requires_grad = False
1668
+ if add_grids:
1669
+ corresp_indices[d] = (min_indices, missing_flat, resized_grid_1, resized_grid_2)
1670
+ else:
1671
+ corresp_indices[d] = (min_indices, missing_flat)
1672
+ return corresp_indices #, missing_masks
1673
+
1674
+ def q_sample(self, x_start, t, noise=None):
1675
+ noise = default(noise, lambda: torch.randn_like(x_start))
1676
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(x_start.device)
1677
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(x_start.device)
1678
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
1679
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
1680
+
1681
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, z_ref = None):
1682
+ num_ch = x.shape[1]
1683
+ # print(num_ch)
1684
+ if num_ch >= 11:
1685
+ self.transformed_grid = x[:, -2:]
1686
+ x = x[:, :-2]
1687
+ # else:
1688
+ # grid_transformed = None
1689
+
1690
+ if self.conditioning_key is None:
1691
+ out = self.diffusion_model(x, t)
1692
+ elif self.conditioning_key == 'concat':
1693
+ xc = torch.cat([x] + c_concat, dim=1)
1694
+ out = self.diffusion_model(xc, t)
1695
+ elif self.conditioning_key == 'crossattn':
1696
+ cc = torch.cat(c_crossattn, 1)
1697
+ out = self.diffusion_model(x, t, context=cc)
1698
+
1699
+ # self.save_subfolder = f'{self.save_folder}/saved_{time.time()}'
1700
+ # os.makedirs(self.save_subfolder, exist_ok=True)
1701
+ # # just for saving purposes
1702
+ # assert z_ref is not None
1703
+ # noisy_z_ref = self.q_sample(z_ref, t)
1704
+ # # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1705
+
1706
+ # mask = x[:, -1:]
1707
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1)
1708
+
1709
+ # correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512, add_grids=True)
1710
+
1711
+ # if self.save_counter < 50:
1712
+ # torch.save(x.cpu(), f'{self.save_subfolder}/z_collage_concat.pt' )
1713
+ # torch.save(z_ref_concat.cpu(), f'{self.save_subfolder}/z_ref_concat.pt')
1714
+ # torch.save(correspondeces, f'{self.save_subfolder}/corresps.pt')
1715
+ # self.save_counter += 1
1716
+
1717
+
1718
+ elif self.conditioning_key == 'hybrid':
1719
+ xc = torch.cat([x] + c_concat, dim=1)
1720
+ cc = torch.cat(c_crossattn, 1)
1721
+ out = self.diffusion_model(xc, t, context=cc)
1722
+ elif self.conditioning_key == 'adm':
1723
+ cc = c_crossattn[0]
1724
+ out = self.diffusion_model(x, t, y=cc)
1725
+ # elif self.conditioning_key == 'crossref':
1726
+ # cc = torch.cat(c_crossattn, 1)
1727
+ # # qsample z_ref by t to add noise
1728
+ # # so have noisy z_ref + z_ref + mask
1729
+ # # compute contexts
1730
+ # assert z_ref is not None
1731
+ # noisy_z_ref = self.q_sample(z_ref, t)
1732
+ # # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1733
+ # mask = x[:, -1:]
1734
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1)
1735
+
1736
+
1737
+ # # compute contexts
1738
+ # _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True)
1739
+
1740
+ # # input diffusion model with contexts
1741
+ # out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts)
1742
+
1743
+ elif self.conditioning_key == 'rewarp' or self.conditioning_key == 'crossref': # also include the crossref for now
1744
+ cc = torch.cat(c_crossattn, 1)
1745
+ # qsample z_ref by t to add noise
1746
+ # so have noisy z_ref + z_ref + mask
1747
+ # compute contexts
1748
+ if self.conditioning_key == 'crossref':
1749
+ raise ValueError('currently not implemented properly. please fix attention')
1750
+ assert z_ref is not None
1751
+ noisy_z_ref = self.q_sample(z_ref, t)
1752
+ # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1753
+
1754
+ # mask = x[:, -2:-1] # mask and new regions
1755
+ # changed_pixels = x[:, -1:]
1756
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, changed_pixels], dim=1)
1757
+ mask = x[:, -1:] # mask and new regions
1758
+ z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask], dim=1)
1759
+
1760
+
1761
+ init_corresp_time = time.time()
1762
+ correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512) ## TODO make input dependent
1763
+ final_corresp_time = time.time()
1764
+
1765
+ # compute contexts
1766
+ _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True)
1767
+ # input diffusion model with contexts
1768
+ out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts, corresp=correspondeces)
1769
+
1770
+ elif self.conditioning_key == 'rewarp_grid':
1771
+ grid_og = self.get_grid(64, batch_size=x.shape[0]).to(x.device) / 64.0
1772
+ cc = torch.cat(c_crossattn, 1)
1773
+ # qsample z_ref by t to add noise
1774
+ # so have noisy z_ref + z_ref + mask
1775
+ # compute contexts
1776
+
1777
+ assert z_ref is not None
1778
+ noisy_z_ref = self.q_sample(z_ref, t)
1779
+ # z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
1780
+
1781
+ # mask = x[:, -2:-1] # mask and new regions
1782
+ # changed_pixels = x[:, -1:]
1783
+ # z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, changed_pixels], dim=1)
1784
+ mask = x[:, -1:] # mask and new regions
1785
+ z_ref_concat = torch.cat([noisy_z_ref, z_ref, mask, grid_og], dim=1)
1786
+ x = torch.cat([x, grid_og], dim=1)
1787
+
1788
+ correspondeces = self.compute_correspondences(self.transformed_grid, mask, original_size=512) ## TODO make input dependent
1789
+
1790
+ # compute contexts
1791
+ _, contexts = self.reference_model(z_ref_concat, t, context=cc, get_contexts=True)
1792
+ # input diffusion model with contexts
1793
+ out = self.diffusion_model(x, t, context=cc, passed_contexts=contexts, corresp=correspondeces)
1794
+
1795
+ else:
1796
+ raise NotImplementedError()
1797
+
1798
+ return out
1799
+
1800
+
1801
+ class Layout2ImgDiffusion(LatentDiffusion):
1802
+ # pr_odo: move all layout-specific hacks to this class
1803
+ def __init__(self, cond_stage_key, *args, **kwargs):
1804
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1805
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1806
+
1807
+ def log_images(self, batch, N=8, *args, **kwargs):
1808
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1809
+
1810
+ key = 'train' if self.training else 'validation'
1811
+ dset = self.trainer.datamodule.datasets[key]
1812
+ mapper = dset.conditional_builders[self.cond_stage_key]
1813
+
1814
+ bbox_imgs = []
1815
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1816
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1817
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1818
+ bbox_imgs.append(bboximg)
1819
+
1820
+ cond_img = torch.stack(bbox_imgs, dim=0)
1821
+ logs['bbox_image'] = cond_img
1822
+ return logs
1823
+
1824
+ class LatentInpaintDiffusion(LatentDiffusion):
1825
+ def __init__(
1826
+ self,
1827
+ concat_keys=("mask", "masked_image"),
1828
+ masked_image_key="masked_image",
1829
+ finetune_keys=None,
1830
+ *args,
1831
+ **kwargs,
1832
+ ):
1833
+ super().__init__(*args, **kwargs)
1834
+ self.masked_image_key = masked_image_key
1835
+ assert self.masked_image_key in concat_keys
1836
+ self.concat_keys = concat_keys
1837
+
1838
+
1839
+ @torch.no_grad()
1840
+ def get_input(
1841
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
1842
+ ):
1843
+ # note: restricted to non-trainable encoders currently
1844
+ assert (
1845
+ not self.cond_stage_trainable
1846
+ ), "trainable cond stages not yet supported for inpainting"
1847
+ z, c, x, xrec, xc = super().get_input(
1848
+ batch,
1849
+ self.first_stage_key,
1850
+ return_first_stage_outputs=True,
1851
+ force_c_encode=True,
1852
+ return_original_cond=True,
1853
+ bs=bs,
1854
+ )
1855
+
1856
+ assert exists(self.concat_keys)
1857
+ c_cat = list()
1858
+ for ck in self.concat_keys:
1859
+ cc = (
1860
+ rearrange(batch[ck], "b h w c -> b c h w")
1861
+ .to(memory_format=torch.contiguous_format)
1862
+ .float()
1863
+ )
1864
+ if bs is not None:
1865
+ cc = cc[:bs]
1866
+ cc = cc.to(self.device)
1867
+ bchw = z.shape
1868
+ if ck != self.masked_image_key:
1869
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1870
+ else:
1871
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1872
+ c_cat.append(cc)
1873
+ c_cat = torch.cat(c_cat, dim=1)
1874
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1875
+ if return_first_stage_outputs:
1876
+ return z, all_conds, x, xrec, xc
1877
+ return z, all_conds
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """SAMPLING ONLY."""
15
+
16
+ import torch
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from functools import partial
20
+
21
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
22
+
23
+
24
+ class PLMSSampler(object):
25
+ def __init__(self, model, schedule="linear", **kwargs):
26
+ super().__init__()
27
+ self.model = model
28
+ self.ddpm_num_timesteps = model.num_timesteps
29
+ self.schedule = schedule
30
+
31
+ def register_buffer(self, name, attr):
32
+ if type(attr) == torch.Tensor:
33
+ if attr.device != torch.device("cuda"):
34
+ attr = attr.to(torch.device("cuda"))
35
+ setattr(self, name, attr)
36
+
37
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
38
+ if ddim_eta != 0:
39
+ raise ValueError('ddim_eta must be 0 for PLMS')
40
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
41
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
42
+ alphas_cumprod = self.model.alphas_cumprod
43
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
44
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
45
+
46
+ self.register_buffer('betas', to_torch(self.model.betas))
47
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
48
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
52
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
53
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
54
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
55
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
56
+
57
+ # ddim sampling parameters
58
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
59
+ ddim_timesteps=self.ddim_timesteps,
60
+ eta=ddim_eta,verbose=verbose)
61
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
62
+ self.register_buffer('ddim_alphas', ddim_alphas)
63
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
64
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
65
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
66
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
67
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
68
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
69
+
70
+ @torch.no_grad()
71
+ def sample(self,
72
+ S,
73
+ batch_size,
74
+ shape,
75
+ conditioning=None,
76
+ callback=None,
77
+ normals_sequence=None,
78
+ img_callback=None,
79
+ quantize_x0=False,
80
+ eta=0.,
81
+ mask=None,
82
+ x0=None,
83
+ temperature=1.,
84
+ noise_dropout=0.,
85
+ score_corrector=None,
86
+ corrector_kwargs=None,
87
+ verbose=True,
88
+ x_T=None,
89
+ log_every_t=100,
90
+ unconditional_guidance_scale=1.,
91
+ unconditional_conditioning=None,
92
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
93
+ **kwargs
94
+ ):
95
+ if conditioning is not None:
96
+ if isinstance(conditioning, dict):
97
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
98
+ if cbs != batch_size:
99
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
100
+ else:
101
+ if conditioning.shape[0] != batch_size:
102
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
103
+
104
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
105
+ # sampling
106
+ C, H, W = shape
107
+ size = (batch_size, C, H, W)
108
+ print(f'Data shape for PLMS sampling is {size}')
109
+
110
+ samples, intermediates = self.plms_sampling(conditioning, size,
111
+ callback=callback,
112
+ img_callback=img_callback,
113
+ quantize_denoised=quantize_x0,
114
+ mask=mask, x0=x0,
115
+ ddim_use_original_steps=False,
116
+ noise_dropout=noise_dropout,
117
+ temperature=temperature,
118
+ score_corrector=score_corrector,
119
+ corrector_kwargs=corrector_kwargs,
120
+ x_T=x_T,
121
+ log_every_t=log_every_t,
122
+ unconditional_guidance_scale=unconditional_guidance_scale,
123
+ unconditional_conditioning=unconditional_conditioning,
124
+ **kwargs
125
+ )
126
+ return samples, intermediates
127
+
128
+ @torch.no_grad()
129
+ def plms_sampling(self, cond, shape,
130
+ x_T=None, ddim_use_original_steps=False,
131
+ callback=None, timesteps=None, quantize_denoised=False,
132
+ mask=None, x0=None, img_callback=None, log_every_t=100,
133
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
134
+ unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
135
+ device = self.model.betas.device
136
+ b = shape[0]
137
+ if x_T is None:
138
+ img = torch.randn(shape, device=device)
139
+ else:
140
+ img = x_T
141
+
142
+ if timesteps is None:
143
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
144
+ elif timesteps is not None and not ddim_use_original_steps:
145
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
146
+ timesteps = self.ddim_timesteps[:subset_end]
147
+
148
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
149
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
150
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
151
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
152
+
153
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
154
+ old_eps = []
155
+
156
+ for i, step in enumerate(iterator):
157
+ index = total_steps - i - 1
158
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
159
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
160
+
161
+ if mask is not None:
162
+ assert x0 is not None
163
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
164
+ img = img_orig * mask + (1. - mask) * img
165
+
166
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
167
+ quantize_denoised=quantize_denoised, temperature=temperature,
168
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
169
+ corrector_kwargs=corrector_kwargs,
170
+ unconditional_guidance_scale=unconditional_guidance_scale,
171
+ unconditional_conditioning=unconditional_conditioning,
172
+ old_eps=old_eps, t_next=ts_next,**kwargs)
173
+ img, pred_x0, e_t = outs
174
+ old_eps.append(e_t)
175
+ if len(old_eps) >= 4:
176
+ old_eps.pop(0)
177
+ if callback: callback(i)
178
+ if img_callback: img_callback(pred_x0, i)
179
+
180
+ if index % log_every_t == 0 or index == total_steps - 1:
181
+ intermediates['x_inter'].append(img)
182
+ intermediates['pred_x0'].append(pred_x0)
183
+
184
+ return img, intermediates
185
+
186
+ @torch.no_grad()
187
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
188
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
189
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,**kwargs):
190
+ b, *_, device = *x.shape, x.device
191
+ def get_model_output(x, t):
192
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
193
+ e_t = self.model.apply_model(x, t, c)
194
+ else:
195
+ x_in = torch.cat([x] * 2)
196
+ t_in = torch.cat([t] * 2)
197
+ c_in = torch.cat([unconditional_conditioning, c])
198
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
199
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
200
+
201
+ if score_corrector is not None:
202
+ assert self.model.parameterization == "eps"
203
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
204
+
205
+ return e_t
206
+
207
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
208
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
209
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
210
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
211
+
212
+ def get_x_prev_and_pred_x0(e_t, index):
213
+ # select parameters corresponding to the currently considered timestep
214
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
215
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
216
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
217
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
218
+
219
+ # current prediction for x_0
220
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
221
+ if quantize_denoised:
222
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
223
+ # direction pointing to x_t
224
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
225
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
226
+ if noise_dropout > 0.:
227
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
228
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
229
+ return x_prev, pred_x0
230
+ kwargs=kwargs['test_model_kwargs']
231
+ x_new=torch.cat([x,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
232
+ e_t = get_model_output(x_new, t)
233
+ if len(old_eps) == 0:
234
+ # Pseudo Improved Euler (2nd order)
235
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
236
+ x_prev_new=torch.cat([x_prev,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
237
+ e_t_next = get_model_output(x_prev_new, t_next)
238
+ e_t_prime = (e_t + e_t_next) / 2
239
+ elif len(old_eps) == 1:
240
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
241
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
242
+ elif len(old_eps) == 2:
243
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
244
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
245
+ elif len(old_eps) >= 3:
246
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
247
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
248
+
249
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
250
+
251
+ return x_prev, pred_x0, e_t
ldm/modules/attention.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ from inspect import isfunction
15
+ import math
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn, einsum
20
+ from einops import rearrange, repeat
21
+ import glob
22
+
23
+ from ldm.modules.diffusionmodules.util import checkpoint
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return{el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = nn.Sequential(
68
+ nn.Linear(dim, inner_dim),
69
+ nn.GELU()
70
+ ) if not glu else GEGLU(dim, inner_dim)
71
+
72
+ self.net = nn.Sequential(
73
+ project_in,
74
+ nn.Dropout(dropout),
75
+ nn.Linear(inner_dim, dim_out)
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.net(x)
80
+
81
+
82
+ def zero_module(module):
83
+ """
84
+ Zero out the parameters of a module and return it.
85
+ """
86
+ for p in module.parameters():
87
+ p.detach().zero_()
88
+ return module
89
+
90
+
91
+ def Normalize(in_channels):
92
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
93
+
94
+
95
+ class LinearAttention(nn.Module):
96
+ def __init__(self, dim, heads=4, dim_head=32):
97
+ super().__init__()
98
+ self.heads = heads
99
+ hidden_dim = dim_head * heads
100
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
101
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
102
+
103
+ def forward(self, x):
104
+ b, c, h, w = x.shape
105
+ qkv = self.to_qkv(x)
106
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
107
+ k = k.softmax(dim=-1)
108
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
109
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
110
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
111
+ return self.to_out(out)
112
+
113
+
114
+ class SpatialSelfAttention(nn.Module):
115
+ def __init__(self, in_channels):
116
+ super().__init__()
117
+ self.in_channels = in_channels
118
+
119
+ self.norm = Normalize(in_channels)
120
+ self.q = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+ self.k = torch.nn.Conv2d(in_channels,
126
+ in_channels,
127
+ kernel_size=1,
128
+ stride=1,
129
+ padding=0)
130
+ self.v = torch.nn.Conv2d(in_channels,
131
+ in_channels,
132
+ kernel_size=1,
133
+ stride=1,
134
+ padding=0)
135
+ self.proj_out = torch.nn.Conv2d(in_channels,
136
+ in_channels,
137
+ kernel_size=1,
138
+ stride=1,
139
+ padding=0)
140
+
141
+ def forward(self, x):
142
+ h_ = x
143
+ h_ = self.norm(h_)
144
+ q = self.q(h_)
145
+ k = self.k(h_)
146
+ v = self.v(h_)
147
+
148
+ # compute attention
149
+ b,c,h,w = q.shape
150
+ q = rearrange(q, 'b c h w -> b (h w) c')
151
+ k = rearrange(k, 'b c h w -> b c (h w)')
152
+ w_ = torch.einsum('bij,bjk->bik', q, k)
153
+
154
+ w_ = w_ * (int(c)**(-0.5))
155
+ w_ = torch.nn.functional.softmax(w_, dim=2)
156
+
157
+ # attend to values
158
+ v = rearrange(v, 'b c h w -> b c (h w)')
159
+ w_ = rearrange(w_, 'b i j -> b j i')
160
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
161
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
162
+ h_ = self.proj_out(h_)
163
+
164
+ return x+h_
165
+
166
+
167
+ class CrossAttention(nn.Module):
168
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., only_crossref=False):
169
+ super().__init__()
170
+ inner_dim = dim_head * heads
171
+ # forcing attention to only attend on vectors of same size
172
+ # breaking the image2text attention
173
+ context_dim = default(context_dim, query_dim)
174
+
175
+ # print('creating cross attention. Query dim', query_dim, ' context dim', context_dim)
176
+
177
+ self.scale = dim_head ** -0.5
178
+ self.heads = heads
179
+
180
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
181
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
182
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
183
+
184
+ self.to_out = nn.Sequential(
185
+ nn.Linear(inner_dim, query_dim),
186
+ nn.Dropout(dropout)
187
+ )
188
+
189
+ self.only_crossref = only_crossref
190
+ if only_crossref:
191
+ self.merge_attentions = zero_module(nn.Conv2d(self.heads * 2,
192
+ self.heads,
193
+ kernel_size=1,
194
+ stride=1,
195
+ padding=0))
196
+ else:
197
+ self.merge_attentions = zero_module(nn.Conv2d(self.heads * 3,
198
+ self.heads,
199
+ kernel_size=1,
200
+ stride=1,
201
+ padding=0))
202
+
203
+
204
+ self.merge_attentions_missing = zero_module(nn.Conv2d(self.heads * 2,
205
+ self.heads,
206
+ kernel_size=1,
207
+ stride=1,
208
+ padding=0))
209
+
210
+
211
+ def forward(self, x, context=None, mask=None, passed_qkv=None, masks=None, corresp=None, missing_region=None):
212
+ is_self_attention = context is None
213
+
214
+ # if masks is not None:
215
+ # print(is_self_attention, masks.keys())
216
+
217
+ h = self.heads
218
+
219
+ # if passed_qkv is not None:
220
+ # assert context is None
221
+
222
+ # _,_,_,_, x_features = passed_qkv
223
+ # assert x_features is not None
224
+
225
+ # # print('x shape', x.shape, 'x features', x_features.shape)
226
+ # # breakpoint()
227
+ # x = torch.concat([x, x_features], dim=1)
228
+
229
+ q = self.to_q(x)
230
+ context = default(context, x)
231
+ k = self.to_k(context)
232
+ v = self.to_v(context)
233
+
234
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
235
+
236
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
237
+
238
+ if exists(mask):
239
+ assert False
240
+ mask = rearrange(mask, 'b ... -> b (...)')
241
+ max_neg_value = -torch.finfo(sim.dtype).max
242
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
243
+ sim.masked_fill_(~mask, max_neg_value)
244
+
245
+ # attention, what we cannot get enough of
246
+ attn = sim.softmax(dim=-1)
247
+ out = einsum('b i j, b j d -> b i d', attn, v)
248
+ inter_out = rearrange(out, '(b h) n d -> b h n d', h=h)
249
+
250
+ combined_attention = inter_out
251
+ out = rearrange(combined_attention, 'b h n d -> b n (h d)', h=h)
252
+
253
+ final_out = self.to_out(out)
254
+
255
+ if is_self_attention:
256
+ return final_out, q, k, v, inter_out #TODO add attn out
257
+ else:
258
+ return final_out
259
+
260
+
261
+ class BasicTransformerBlock(nn.Module):
262
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
263
+ super().__init__()
264
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
265
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
266
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
267
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
268
+ self.attn3 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)
269
+ self.norm1 = nn.LayerNorm(dim)
270
+ self.norm2 = nn.LayerNorm(dim)
271
+ self.norm3 = nn.LayerNorm(dim)
272
+ self.checkpoint = checkpoint
273
+
274
+ # TODO add attn in
275
+ def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None):
276
+ if passed_qkv is None:
277
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
278
+ else:
279
+ q, k, v, attn, x_features = passed_qkv
280
+ d = int(np.sqrt(q.shape[1]))
281
+ current_mask = masks[d]
282
+ if corresp:
283
+ current_corresp, missing_region = corresp[d]
284
+ current_corresp = current_corresp.float()
285
+ missing_region = missing_region.float()
286
+ else:
287
+ raise ValueError('cannot have empty corresp')
288
+ current_corresp = None
289
+ missing_region = current_mask.float()
290
+ # breakpoint()
291
+ stuff = [q, k, v, attn, x_features, current_mask, current_corresp, missing_region]
292
+ for element in stuff:
293
+ assert element is not None
294
+ return checkpoint(self._forward, (x, context, q, k, v, attn, x_features, current_mask, current_corresp, missing_region), self.parameters(), self.checkpoint)
295
+
296
+ # TODO add attn in
297
+ def _forward(self, x, context=None, q=None, k=None, v=None, attn=None, passed_x=None, masks=None, corresp=None, missing_region=None):
298
+ if q is not None:
299
+ passed_qkv = (q, k, v, attn, passed_x)
300
+ else:
301
+ passed_qkv = None
302
+ x_features = self.norm1(x)
303
+ attended_x, q, k, v, attn = self.attn1(x_features, passed_qkv=passed_qkv, masks=masks, corresp=corresp, missing_region=missing_region)
304
+ x = attended_x + x
305
+ # killing CLIP features
306
+
307
+ if passed_x is not None:
308
+ normed_x = self.norm2(x)
309
+ attn_out = self.attn3(normed_x, context=passed_x)
310
+ x = attn_out + x
311
+ # then use y + x
312
+ # print('y shape', y.shape, ' x shape', x.shape)
313
+
314
+ x = self.ff(self.norm3(x)) + x
315
+ return x, q, k, v, attn, x_features
316
+
317
+
318
+ class SpatialTransformer(nn.Module):
319
+ """
320
+ Transformer block for image-like data.
321
+ First, project the input (aka embedding)
322
+ and reshape to b, t, d.
323
+ Then apply standard transformer action.
324
+ Finally, reshape to image
325
+ """
326
+ def __init__(self, in_channels, n_heads, d_head,
327
+ depth=1, dropout=0., context_dim=None):
328
+ super().__init__()
329
+ self.in_channels = in_channels
330
+ inner_dim = n_heads * d_head
331
+ self.norm = Normalize(in_channels)
332
+
333
+ # print('creating spatial transformer')
334
+ # print('in channels', in_channels, 'inner dim', inner_dim)
335
+
336
+ self.proj_in = nn.Conv2d(in_channels,
337
+ inner_dim,
338
+ kernel_size=1,
339
+ stride=1,
340
+ padding=0)
341
+
342
+ self.transformer_blocks = nn.ModuleList(
343
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
344
+ for d in range(depth)]
345
+ )
346
+
347
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
348
+ in_channels,
349
+ kernel_size=1,
350
+ stride=1,
351
+ padding=0))
352
+
353
+ # TODO add attn in and corresp
354
+ def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None):
355
+ # note: if no context is given, cross-attention defaults to self-attention
356
+ b, c, h, w = x.shape
357
+ # print('spatial transformer x shape given', x.shape)
358
+ # if context is not None:
359
+ # print('also context was provided with shape ', context.shape)
360
+ x_in = x
361
+ x = self.norm(x)
362
+ x = self.proj_in(x)
363
+ x = rearrange(x, 'b c h w -> b (h w) c')
364
+
365
+ qkvs = []
366
+ for block in self.transformer_blocks:
367
+ x, q, k, v, attn, x_features = block(x, context=context, passed_qkv=passed_qkv, masks=masks, corresp=corresp)
368
+ qkv = (q,k,v,attn, x_features)
369
+ qkvs.append(qkv)
370
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
371
+ x = self.proj_out(x)
372
+ return x + x_in, qkvs
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ # pytorch_diffusion + derived encoder decoder
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ import numpy as np
19
+ from einops import rearrange
20
+
21
+ from ldm.util import instantiate_from_config
22
+ from ldm.modules.attention import LinearAttention
23
+
24
+
25
+ def get_timestep_embedding(timesteps, embedding_dim):
26
+ """
27
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
28
+ From Fairseq.
29
+ Build sinusoidal embeddings.
30
+ This matches the implementation in tensor2tensor, but differs slightly
31
+ from the description in Section 3.5 of "Attention Is All You Need".
32
+ """
33
+ assert len(timesteps.shape) == 1
34
+
35
+ half_dim = embedding_dim // 2
36
+ emb = math.log(10000) / (half_dim - 1)
37
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
38
+ emb = emb.to(device=timesteps.device)
39
+ emb = timesteps.float()[:, None] * emb[None, :]
40
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
41
+ if embedding_dim % 2 == 1: # zero pad
42
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
43
+ return emb
44
+
45
+
46
+ def nonlinearity(x):
47
+ # swish
48
+ return x*torch.sigmoid(x)
49
+
50
+
51
+ def Normalize(in_channels, num_groups=32):
52
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ self.conv = torch.nn.Conv2d(in_channels,
61
+ in_channels,
62
+ kernel_size=3,
63
+ stride=1,
64
+ padding=1)
65
+
66
+ def forward(self, x):
67
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
68
+ if self.with_conv:
69
+ x = self.conv(x)
70
+ return x
71
+
72
+
73
+ class Downsample(nn.Module):
74
+ def __init__(self, in_channels, with_conv):
75
+ super().__init__()
76
+ self.with_conv = with_conv
77
+ if self.with_conv:
78
+ # no asymmetric padding in torch conv, must do it ourselves
79
+ self.conv = torch.nn.Conv2d(in_channels,
80
+ in_channels,
81
+ kernel_size=3,
82
+ stride=2,
83
+ padding=0)
84
+
85
+ def forward(self, x):
86
+ if self.with_conv:
87
+ pad = (0,1,0,1)
88
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
89
+ x = self.conv(x)
90
+ else:
91
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
92
+ return x
93
+
94
+
95
+ class ResnetBlock(nn.Module):
96
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
97
+ dropout, temb_channels=512):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+ out_channels = in_channels if out_channels is None else out_channels
101
+ self.out_channels = out_channels
102
+ self.use_conv_shortcut = conv_shortcut
103
+
104
+ self.norm1 = Normalize(in_channels)
105
+ self.conv1 = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ if temb_channels > 0:
111
+ self.temb_proj = torch.nn.Linear(temb_channels,
112
+ out_channels)
113
+ self.norm2 = Normalize(out_channels)
114
+ self.dropout = torch.nn.Dropout(dropout)
115
+ self.conv2 = torch.nn.Conv2d(out_channels,
116
+ out_channels,
117
+ kernel_size=3,
118
+ stride=1,
119
+ padding=1)
120
+ if self.in_channels != self.out_channels:
121
+ if self.use_conv_shortcut:
122
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
123
+ out_channels,
124
+ kernel_size=3,
125
+ stride=1,
126
+ padding=1)
127
+ else:
128
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
129
+ out_channels,
130
+ kernel_size=1,
131
+ stride=1,
132
+ padding=0)
133
+
134
+ def forward(self, x, temb):
135
+ h = x
136
+ h = self.norm1(h)
137
+ h = nonlinearity(h)
138
+ h = self.conv1(h)
139
+
140
+ if temb is not None:
141
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
142
+
143
+ h = self.norm2(h)
144
+ h = nonlinearity(h)
145
+ h = self.dropout(h)
146
+ h = self.conv2(h)
147
+
148
+ if self.in_channels != self.out_channels:
149
+ if self.use_conv_shortcut:
150
+ x = self.conv_shortcut(x)
151
+ else:
152
+ x = self.nin_shortcut(x)
153
+
154
+ return x+h
155
+
156
+
157
+ class LinAttnBlock(LinearAttention):
158
+ """to match AttnBlock usage"""
159
+ def __init__(self, in_channels):
160
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
161
+
162
+
163
+ class AttnBlock(nn.Module):
164
+ def __init__(self, in_channels):
165
+ super().__init__()
166
+ self.in_channels = in_channels
167
+
168
+ self.norm = Normalize(in_channels)
169
+ self.q = torch.nn.Conv2d(in_channels,
170
+ in_channels,
171
+ kernel_size=1,
172
+ stride=1,
173
+ padding=0)
174
+ self.k = torch.nn.Conv2d(in_channels,
175
+ in_channels,
176
+ kernel_size=1,
177
+ stride=1,
178
+ padding=0)
179
+ self.v = torch.nn.Conv2d(in_channels,
180
+ in_channels,
181
+ kernel_size=1,
182
+ stride=1,
183
+ padding=0)
184
+ self.proj_out = torch.nn.Conv2d(in_channels,
185
+ in_channels,
186
+ kernel_size=1,
187
+ stride=1,
188
+ padding=0)
189
+
190
+
191
+ def forward(self, x):
192
+ h_ = x
193
+ h_ = self.norm(h_)
194
+ q = self.q(h_)
195
+ k = self.k(h_)
196
+ v = self.v(h_)
197
+
198
+ # compute attention
199
+ b,c,h,w = q.shape
200
+ q = q.reshape(b,c,h*w)
201
+ q = q.permute(0,2,1) # b,hw,c
202
+ k = k.reshape(b,c,h*w) # b,c,hw
203
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
204
+ w_ = w_ * (int(c)**(-0.5))
205
+ w_ = torch.nn.functional.softmax(w_, dim=2)
206
+
207
+ # attend to values
208
+ v = v.reshape(b,c,h*w)
209
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
210
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
211
+ h_ = h_.reshape(b,c,h,w)
212
+
213
+ h_ = self.proj_out(h_)
214
+
215
+ return x+h_
216
+
217
+
218
+ def make_attn(in_channels, attn_type="vanilla"):
219
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
220
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
221
+ if attn_type == "vanilla":
222
+ return AttnBlock(in_channels)
223
+ elif attn_type == "none":
224
+ return nn.Identity(in_channels)
225
+ else:
226
+ return LinAttnBlock(in_channels)
227
+
228
+
229
+ class Model(nn.Module):
230
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
231
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
232
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
233
+ super().__init__()
234
+ if use_linear_attn: attn_type = "linear"
235
+ self.ch = ch
236
+ self.temb_ch = self.ch*4
237
+ self.num_resolutions = len(ch_mult)
238
+ self.num_res_blocks = num_res_blocks
239
+ self.resolution = resolution
240
+ self.in_channels = in_channels
241
+
242
+ self.use_timestep = use_timestep
243
+ if self.use_timestep:
244
+ # timestep embedding
245
+ self.temb = nn.Module()
246
+ self.temb.dense = nn.ModuleList([
247
+ torch.nn.Linear(self.ch,
248
+ self.temb_ch),
249
+ torch.nn.Linear(self.temb_ch,
250
+ self.temb_ch),
251
+ ])
252
+
253
+ # downsampling
254
+ self.conv_in = torch.nn.Conv2d(in_channels,
255
+ self.ch,
256
+ kernel_size=3,
257
+ stride=1,
258
+ padding=1)
259
+
260
+ curr_res = resolution
261
+ in_ch_mult = (1,)+tuple(ch_mult)
262
+ self.down = nn.ModuleList()
263
+ for i_level in range(self.num_resolutions):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_in = ch*in_ch_mult[i_level]
267
+ block_out = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks):
269
+ block.append(ResnetBlock(in_channels=block_in,
270
+ out_channels=block_out,
271
+ temb_channels=self.temb_ch,
272
+ dropout=dropout))
273
+ block_in = block_out
274
+ if curr_res in attn_resolutions:
275
+ attn.append(make_attn(block_in, attn_type=attn_type))
276
+ down = nn.Module()
277
+ down.block = block
278
+ down.attn = attn
279
+ if i_level != self.num_resolutions-1:
280
+ down.downsample = Downsample(block_in, resamp_with_conv)
281
+ curr_res = curr_res // 2
282
+ self.down.append(down)
283
+
284
+ # middle
285
+ self.mid = nn.Module()
286
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
287
+ out_channels=block_in,
288
+ temb_channels=self.temb_ch,
289
+ dropout=dropout)
290
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
291
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
292
+ out_channels=block_in,
293
+ temb_channels=self.temb_ch,
294
+ dropout=dropout)
295
+
296
+ # upsampling
297
+ self.up = nn.ModuleList()
298
+ for i_level in reversed(range(self.num_resolutions)):
299
+ block = nn.ModuleList()
300
+ attn = nn.ModuleList()
301
+ block_out = ch*ch_mult[i_level]
302
+ skip_in = ch*ch_mult[i_level]
303
+ for i_block in range(self.num_res_blocks+1):
304
+ if i_block == self.num_res_blocks:
305
+ skip_in = ch*in_ch_mult[i_level]
306
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
307
+ out_channels=block_out,
308
+ temb_channels=self.temb_ch,
309
+ dropout=dropout))
310
+ block_in = block_out
311
+ if curr_res in attn_resolutions:
312
+ attn.append(make_attn(block_in, attn_type=attn_type))
313
+ up = nn.Module()
314
+ up.block = block
315
+ up.attn = attn
316
+ if i_level != 0:
317
+ up.upsample = Upsample(block_in, resamp_with_conv)
318
+ curr_res = curr_res * 2
319
+ self.up.insert(0, up) # prepend to get consistent order
320
+
321
+ # end
322
+ self.norm_out = Normalize(block_in)
323
+ self.conv_out = torch.nn.Conv2d(block_in,
324
+ out_ch,
325
+ kernel_size=3,
326
+ stride=1,
327
+ padding=1)
328
+
329
+ def forward(self, x, t=None, context=None):
330
+ #assert x.shape[2] == x.shape[3] == self.resolution
331
+ if context is not None:
332
+ # assume aligned context, cat along channel axis
333
+ x = torch.cat((x, context), dim=1)
334
+ if self.use_timestep:
335
+ # timestep embedding
336
+ assert t is not None
337
+ temb = get_timestep_embedding(t, self.ch)
338
+ temb = self.temb.dense[0](temb)
339
+ temb = nonlinearity(temb)
340
+ temb = self.temb.dense[1](temb)
341
+ else:
342
+ temb = None
343
+
344
+ # downsampling
345
+ hs = [self.conv_in(x)]
346
+ for i_level in range(self.num_resolutions):
347
+ for i_block in range(self.num_res_blocks):
348
+ h = self.down[i_level].block[i_block](hs[-1], temb)
349
+ if len(self.down[i_level].attn) > 0:
350
+ h = self.down[i_level].attn[i_block](h)
351
+ hs.append(h)
352
+ if i_level != self.num_resolutions-1:
353
+ hs.append(self.down[i_level].downsample(hs[-1]))
354
+
355
+ # middle
356
+ h = hs[-1]
357
+ h = self.mid.block_1(h, temb)
358
+ h = self.mid.attn_1(h)
359
+ h = self.mid.block_2(h, temb)
360
+
361
+ # upsampling
362
+ for i_level in reversed(range(self.num_resolutions)):
363
+ for i_block in range(self.num_res_blocks+1):
364
+ h = self.up[i_level].block[i_block](
365
+ torch.cat([h, hs.pop()], dim=1), temb)
366
+ if len(self.up[i_level].attn) > 0:
367
+ h = self.up[i_level].attn[i_block](h)
368
+ if i_level != 0:
369
+ h = self.up[i_level].upsample(h)
370
+
371
+ # end
372
+ h = self.norm_out(h)
373
+ h = nonlinearity(h)
374
+ h = self.conv_out(h)
375
+ return h
376
+
377
+ def get_last_layer(self):
378
+ return self.conv_out.weight
379
+
380
+
381
+ class Encoder(nn.Module):
382
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
383
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
384
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
385
+ **ignore_kwargs):
386
+ super().__init__()
387
+ if use_linear_attn: attn_type = "linear"
388
+ self.ch = ch
389
+ self.temb_ch = 0
390
+ self.num_resolutions = len(ch_mult)
391
+ self.num_res_blocks = num_res_blocks
392
+ self.resolution = resolution
393
+ self.in_channels = in_channels
394
+
395
+ # downsampling
396
+ self.conv_in = torch.nn.Conv2d(in_channels,
397
+ self.ch,
398
+ kernel_size=3,
399
+ stride=1,
400
+ padding=1)
401
+
402
+ curr_res = resolution
403
+ in_ch_mult = (1,)+tuple(ch_mult)
404
+ self.in_ch_mult = in_ch_mult
405
+ self.down = nn.ModuleList()
406
+ for i_level in range(self.num_resolutions):
407
+ block = nn.ModuleList()
408
+ attn = nn.ModuleList()
409
+ block_in = ch*in_ch_mult[i_level]
410
+ block_out = ch*ch_mult[i_level]
411
+ for i_block in range(self.num_res_blocks):
412
+ block.append(ResnetBlock(in_channels=block_in,
413
+ out_channels=block_out,
414
+ temb_channels=self.temb_ch,
415
+ dropout=dropout))
416
+ block_in = block_out
417
+ if curr_res in attn_resolutions:
418
+ attn.append(make_attn(block_in, attn_type=attn_type))
419
+ down = nn.Module()
420
+ down.block = block
421
+ down.attn = attn
422
+ if i_level != self.num_resolutions-1:
423
+ down.downsample = Downsample(block_in, resamp_with_conv)
424
+ curr_res = curr_res // 2
425
+ self.down.append(down)
426
+
427
+ # middle
428
+ self.mid = nn.Module()
429
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
430
+ out_channels=block_in,
431
+ temb_channels=self.temb_ch,
432
+ dropout=dropout)
433
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
434
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
435
+ out_channels=block_in,
436
+ temb_channels=self.temb_ch,
437
+ dropout=dropout)
438
+
439
+ # end
440
+ self.norm_out = Normalize(block_in)
441
+ self.conv_out = torch.nn.Conv2d(block_in,
442
+ 2*z_channels if double_z else z_channels,
443
+ kernel_size=3,
444
+ stride=1,
445
+ padding=1)
446
+
447
+ def forward(self, x):
448
+ # timestep embedding
449
+ temb = None
450
+
451
+ # downsampling
452
+ hs = [self.conv_in(x)]
453
+ for i_level in range(self.num_resolutions):
454
+ for i_block in range(self.num_res_blocks):
455
+ h = self.down[i_level].block[i_block](hs[-1], temb)
456
+ if len(self.down[i_level].attn) > 0:
457
+ h = self.down[i_level].attn[i_block](h)
458
+ hs.append(h)
459
+ if i_level != self.num_resolutions-1:
460
+ hs.append(self.down[i_level].downsample(hs[-1]))
461
+
462
+ # middle
463
+ h = hs[-1]
464
+ h = self.mid.block_1(h, temb)
465
+ h = self.mid.attn_1(h)
466
+ h = self.mid.block_2(h, temb)
467
+
468
+ # end
469
+ h = self.norm_out(h)
470
+ h = nonlinearity(h)
471
+ h = self.conv_out(h)
472
+ return h
473
+
474
+
475
+ class Decoder(nn.Module):
476
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
477
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
478
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
479
+ attn_type="vanilla", **ignorekwargs):
480
+ super().__init__()
481
+ if use_linear_attn: attn_type = "linear"
482
+ self.ch = ch
483
+ self.temb_ch = 0
484
+ self.num_resolutions = len(ch_mult)
485
+ self.num_res_blocks = num_res_blocks
486
+ self.resolution = resolution
487
+ self.in_channels = in_channels
488
+ self.give_pre_end = give_pre_end
489
+ self.tanh_out = tanh_out
490
+
491
+ # compute in_ch_mult, block_in and curr_res at lowest res
492
+ in_ch_mult = (1,)+tuple(ch_mult)
493
+ block_in = ch*ch_mult[self.num_resolutions-1]
494
+ curr_res = resolution // 2**(self.num_resolutions-1)
495
+ self.z_shape = (1,z_channels,curr_res,curr_res)
496
+ print("Working with z of shape {} = {} dimensions.".format(
497
+ self.z_shape, np.prod(self.z_shape)))
498
+
499
+ # z to block_in
500
+ self.conv_in = torch.nn.Conv2d(z_channels,
501
+ block_in,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ # middle
507
+ self.mid = nn.Module()
508
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
509
+ out_channels=block_in,
510
+ temb_channels=self.temb_ch,
511
+ dropout=dropout)
512
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
513
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
514
+ out_channels=block_in,
515
+ temb_channels=self.temb_ch,
516
+ dropout=dropout)
517
+
518
+ # upsampling
519
+ self.up = nn.ModuleList()
520
+ for i_level in reversed(range(self.num_resolutions)):
521
+ block = nn.ModuleList()
522
+ attn = nn.ModuleList()
523
+ block_out = ch*ch_mult[i_level]
524
+ for i_block in range(self.num_res_blocks+1):
525
+ block.append(ResnetBlock(in_channels=block_in,
526
+ out_channels=block_out,
527
+ temb_channels=self.temb_ch,
528
+ dropout=dropout))
529
+ block_in = block_out
530
+ if curr_res in attn_resolutions:
531
+ attn.append(make_attn(block_in, attn_type=attn_type))
532
+ up = nn.Module()
533
+ up.block = block
534
+ up.attn = attn
535
+ if i_level != 0:
536
+ up.upsample = Upsample(block_in, resamp_with_conv)
537
+ curr_res = curr_res * 2
538
+ self.up.insert(0, up) # prepend to get consistent order
539
+
540
+ # end
541
+ self.norm_out = Normalize(block_in)
542
+ self.conv_out = torch.nn.Conv2d(block_in,
543
+ out_ch,
544
+ kernel_size=3,
545
+ stride=1,
546
+ padding=1)
547
+
548
+ def forward(self, z):
549
+ #assert z.shape[1:] == self.z_shape[1:]
550
+ self.last_z_shape = z.shape
551
+
552
+ # timestep embedding
553
+ temb = None
554
+
555
+ # z to block_in
556
+ h = self.conv_in(z)
557
+
558
+ # middle
559
+ h = self.mid.block_1(h, temb)
560
+ h = self.mid.attn_1(h)
561
+ h = self.mid.block_2(h, temb)
562
+
563
+ # upsampling
564
+ for i_level in reversed(range(self.num_resolutions)):
565
+ for i_block in range(self.num_res_blocks+1):
566
+ h = self.up[i_level].block[i_block](h, temb)
567
+ if len(self.up[i_level].attn) > 0:
568
+ h = self.up[i_level].attn[i_block](h)
569
+ if i_level != 0:
570
+ h = self.up[i_level].upsample(h)
571
+
572
+ # end
573
+ if self.give_pre_end:
574
+ return h
575
+
576
+ h = self.norm_out(h)
577
+ h = nonlinearity(h)
578
+ h = self.conv_out(h)
579
+ if self.tanh_out:
580
+ h = torch.tanh(h)
581
+ return h
582
+
583
+
584
+ class SimpleDecoder(nn.Module):
585
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
586
+ super().__init__()
587
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
588
+ ResnetBlock(in_channels=in_channels,
589
+ out_channels=2 * in_channels,
590
+ temb_channels=0, dropout=0.0),
591
+ ResnetBlock(in_channels=2 * in_channels,
592
+ out_channels=4 * in_channels,
593
+ temb_channels=0, dropout=0.0),
594
+ ResnetBlock(in_channels=4 * in_channels,
595
+ out_channels=2 * in_channels,
596
+ temb_channels=0, dropout=0.0),
597
+ nn.Conv2d(2*in_channels, in_channels, 1),
598
+ Upsample(in_channels, with_conv=True)])
599
+ # end
600
+ self.norm_out = Normalize(in_channels)
601
+ self.conv_out = torch.nn.Conv2d(in_channels,
602
+ out_channels,
603
+ kernel_size=3,
604
+ stride=1,
605
+ padding=1)
606
+
607
+ def forward(self, x):
608
+ for i, layer in enumerate(self.model):
609
+ if i in [1,2,3]:
610
+ x = layer(x, None)
611
+ else:
612
+ x = layer(x)
613
+
614
+ h = self.norm_out(x)
615
+ h = nonlinearity(h)
616
+ x = self.conv_out(h)
617
+ return x
618
+
619
+
620
+ class UpsampleDecoder(nn.Module):
621
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
622
+ ch_mult=(2,2), dropout=0.0):
623
+ super().__init__()
624
+ # upsampling
625
+ self.temb_ch = 0
626
+ self.num_resolutions = len(ch_mult)
627
+ self.num_res_blocks = num_res_blocks
628
+ block_in = in_channels
629
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
630
+ self.res_blocks = nn.ModuleList()
631
+ self.upsample_blocks = nn.ModuleList()
632
+ for i_level in range(self.num_resolutions):
633
+ res_block = []
634
+ block_out = ch * ch_mult[i_level]
635
+ for i_block in range(self.num_res_blocks + 1):
636
+ res_block.append(ResnetBlock(in_channels=block_in,
637
+ out_channels=block_out,
638
+ temb_channels=self.temb_ch,
639
+ dropout=dropout))
640
+ block_in = block_out
641
+ self.res_blocks.append(nn.ModuleList(res_block))
642
+ if i_level != self.num_resolutions - 1:
643
+ self.upsample_blocks.append(Upsample(block_in, True))
644
+ curr_res = curr_res * 2
645
+
646
+ # end
647
+ self.norm_out = Normalize(block_in)
648
+ self.conv_out = torch.nn.Conv2d(block_in,
649
+ out_channels,
650
+ kernel_size=3,
651
+ stride=1,
652
+ padding=1)
653
+
654
+ def forward(self, x):
655
+ # upsampling
656
+ h = x
657
+ for k, i_level in enumerate(range(self.num_resolutions)):
658
+ for i_block in range(self.num_res_blocks + 1):
659
+ h = self.res_blocks[i_level][i_block](h, None)
660
+ if i_level != self.num_resolutions - 1:
661
+ h = self.upsample_blocks[k](h)
662
+ h = self.norm_out(h)
663
+ h = nonlinearity(h)
664
+ h = self.conv_out(h)
665
+ return h
666
+
667
+
668
+ class LatentRescaler(nn.Module):
669
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
670
+ super().__init__()
671
+ # residual block, interpolate, residual block
672
+ self.factor = factor
673
+ self.conv_in = nn.Conv2d(in_channels,
674
+ mid_channels,
675
+ kernel_size=3,
676
+ stride=1,
677
+ padding=1)
678
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
679
+ out_channels=mid_channels,
680
+ temb_channels=0,
681
+ dropout=0.0) for _ in range(depth)])
682
+ self.attn = AttnBlock(mid_channels)
683
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
684
+ out_channels=mid_channels,
685
+ temb_channels=0,
686
+ dropout=0.0) for _ in range(depth)])
687
+
688
+ self.conv_out = nn.Conv2d(mid_channels,
689
+ out_channels,
690
+ kernel_size=1,
691
+ )
692
+
693
+ def forward(self, x):
694
+ x = self.conv_in(x)
695
+ for block in self.res_block1:
696
+ x = block(x, None)
697
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
698
+ x = self.attn(x)
699
+ for block in self.res_block2:
700
+ x = block(x, None)
701
+ x = self.conv_out(x)
702
+ return x
703
+
704
+
705
+ class MergedRescaleEncoder(nn.Module):
706
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
707
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
708
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
709
+ super().__init__()
710
+ intermediate_chn = ch * ch_mult[-1]
711
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
712
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
713
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
714
+ out_ch=None)
715
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
716
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
717
+
718
+ def forward(self, x):
719
+ x = self.encoder(x)
720
+ x = self.rescaler(x)
721
+ return x
722
+
723
+
724
+ class MergedRescaleDecoder(nn.Module):
725
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
726
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
727
+ super().__init__()
728
+ tmp_chn = z_channels*ch_mult[-1]
729
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
730
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
731
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
732
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
733
+ out_channels=tmp_chn, depth=rescale_module_depth)
734
+
735
+ def forward(self, x):
736
+ x = self.rescaler(x)
737
+ x = self.decoder(x)
738
+ return x
739
+
740
+
741
+ class Upsampler(nn.Module):
742
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
743
+ super().__init__()
744
+ assert out_size >= in_size
745
+ num_blocks = int(np.log2(out_size//in_size))+1
746
+ factor_up = 1.+ (out_size % in_size)
747
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
748
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
749
+ out_channels=in_channels)
750
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
751
+ attn_resolutions=[], in_channels=None, ch=in_channels,
752
+ ch_mult=[ch_mult for _ in range(num_blocks)])
753
+
754
+ def forward(self, x):
755
+ x = self.rescaler(x)
756
+ x = self.decoder(x)
757
+ return x
758
+
759
+
760
+ class Resize(nn.Module):
761
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
762
+ super().__init__()
763
+ self.with_conv = learned
764
+ self.mode = mode
765
+ if self.with_conv:
766
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
767
+ raise NotImplementedError()
768
+ assert in_channels is not None
769
+ # no asymmetric padding in torch conv, must do it ourselves
770
+ self.conv = torch.nn.Conv2d(in_channels,
771
+ in_channels,
772
+ kernel_size=4,
773
+ stride=2,
774
+ padding=1)
775
+
776
+ def forward(self, x, scale_factor=1.0):
777
+ if scale_factor==1.0:
778
+ return x
779
+ else:
780
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
781
+ return x
782
+
783
+ class FirstStagePostProcessor(nn.Module):
784
+
785
+ def __init__(self, ch_mult:list, in_channels,
786
+ pretrained_model:nn.Module=None,
787
+ reshape=False,
788
+ n_channels=None,
789
+ dropout=0.,
790
+ pretrained_config=None):
791
+ super().__init__()
792
+ if pretrained_config is None:
793
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
794
+ self.pretrained_model = pretrained_model
795
+ else:
796
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
797
+ self.instantiate_pretrained(pretrained_config)
798
+
799
+ self.do_reshape = reshape
800
+
801
+ if n_channels is None:
802
+ n_channels = self.pretrained_model.encoder.ch
803
+
804
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
805
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
806
+ stride=1,padding=1)
807
+
808
+ blocks = []
809
+ downs = []
810
+ ch_in = n_channels
811
+ for m in ch_mult:
812
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
813
+ ch_in = m * n_channels
814
+ downs.append(Downsample(ch_in, with_conv=False))
815
+
816
+ self.model = nn.ModuleList(blocks)
817
+ self.downsampler = nn.ModuleList(downs)
818
+
819
+
820
+ def instantiate_pretrained(self, config):
821
+ model = instantiate_from_config(config)
822
+ self.pretrained_model = model.eval()
823
+ # self.pretrained_model.train = False
824
+ for param in self.pretrained_model.parameters():
825
+ param.requires_grad = False
826
+
827
+
828
+ @torch.no_grad()
829
+ def encode_with_pretrained(self,x):
830
+ c = self.pretrained_model.encode(x)
831
+ if isinstance(c, DiagonalGaussianDistribution):
832
+ c = c.mode()
833
+ return c
834
+
835
+ def forward(self,x):
836
+ z_fs = self.encode_with_pretrained(x)
837
+ z = self.proj_norm(z_fs)
838
+ z = self.proj(z)
839
+ z = nonlinearity(z)
840
+
841
+ for submodel, downmodel in zip(self.model,self.downsampler):
842
+ z = submodel(z,temb=None)
843
+ z = downmodel(z)
844
+
845
+ if self.do_reshape:
846
+ z = rearrange(z,'b c h w -> b (h w) c')
847
+ return z
848
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ from abc import abstractmethod
15
+ from functools import partial
16
+ import math
17
+ from typing import Iterable
18
+ from collections import deque
19
+
20
+ import numpy as np
21
+ import torch as th
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import glob
25
+ import os
26
+
27
+ import torchvision
28
+
29
+ from ldm.modules.diffusionmodules.util import (
30
+ checkpoint,
31
+ conv_nd,
32
+ linear,
33
+ avg_pool_nd,
34
+ zero_module,
35
+ normalization,
36
+ timestep_embedding,
37
+ )
38
+ from ldm.modules.attention import SpatialTransformer
39
+
40
+
41
+ # dummy replace
42
+ def convert_module_to_f16(x):
43
+ pass
44
+
45
+ def convert_module_to_f32(x):
46
+ pass
47
+
48
+
49
+ ## go
50
+ class AttentionPool2d(nn.Module):
51
+ """
52
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ spacial_dim: int,
58
+ embed_dim: int,
59
+ num_heads_channels: int,
60
+ output_dim: int = None,
61
+ ):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
64
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
65
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
66
+ self.num_heads = embed_dim // num_heads_channels
67
+ self.attention = QKVAttention(self.num_heads)
68
+
69
+ def forward(self, x):
70
+ b, c, *_spatial = x.shape
71
+ x = x.reshape(b, c, -1) # NC(HW)
72
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
73
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
74
+ x = self.qkv_proj(x)
75
+ x = self.attention(x)
76
+ x = self.c_proj(x)
77
+ return x[:, :, 0]
78
+
79
+
80
+ class TimestepBlock(nn.Module):
81
+ """
82
+ Any module where forward() takes timestep embeddings as a second argument.
83
+ """
84
+
85
+ @abstractmethod
86
+ def forward(self, x, emb):
87
+ """
88
+ Apply the module to `x` given `emb` timestep embeddings.
89
+ """
90
+
91
+
92
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
93
+ """
94
+ A sequential module that passes timestep embeddings to the children that
95
+ support it as an extra input.
96
+ """
97
+
98
+ def forward(self, x, emb, context=None, passed_kqv=None, kqv_idx=None, masks=None, corresp=None):
99
+ attention_vals = []
100
+ # print('processing a layer')
101
+ # print('idx', kqv_idx)
102
+ for layer in self:
103
+ # print('processing a layer', layer.__class__.__name__)
104
+ if isinstance(layer, TimestepBlock):
105
+ x = layer(x, emb)
106
+ elif isinstance(layer, SpatialTransformer):
107
+ if passed_kqv is not None:
108
+ assert kqv_idx is not None
109
+ passed_item = passed_kqv[kqv_idx]
110
+ # print('pre passed item len', len(passed_item))
111
+ if len(passed_item) == 1:
112
+ passed_item = passed_item[0][0]
113
+ # print('success passed item', len(passed_item))
114
+ else:
115
+ passed_item = None
116
+ x, kqv = layer(x, context, passed_item, masks=masks, corresp=corresp)
117
+ attention_vals.append(kqv)
118
+ else:
119
+ x = layer(x)
120
+ # print('length of attn vals', len(attention_vals))
121
+ return x, attention_vals
122
+
123
+
124
+ class Upsample(nn.Module):
125
+ """
126
+ An upsampling layer with an optional convolution.
127
+ :param channels: channels in the inputs and outputs.
128
+ :param use_conv: a bool determining if a convolution is applied.
129
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
130
+ upsampling occurs in the inner-two dimensions.
131
+ """
132
+
133
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
134
+ super().__init__()
135
+ self.channels = channels
136
+ self.out_channels = out_channels or channels
137
+ self.use_conv = use_conv
138
+ self.dims = dims
139
+ if use_conv:
140
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
141
+
142
+ def forward(self, x):
143
+ assert x.shape[1] == self.channels
144
+ if self.dims == 3:
145
+ x = F.interpolate(
146
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
147
+ )
148
+ else:
149
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
150
+ if self.use_conv:
151
+ x = self.conv(x)
152
+ return x
153
+
154
+ class TransposedUpsample(nn.Module):
155
+ 'Learned 2x upsampling without padding'
156
+ def __init__(self, channels, out_channels=None, ks=5):
157
+ super().__init__()
158
+ self.channels = channels
159
+ self.out_channels = out_channels or channels
160
+
161
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
162
+
163
+ def forward(self,x):
164
+ return self.up(x)
165
+
166
+
167
+ class Downsample(nn.Module):
168
+ """
169
+ A downsampling layer with an optional convolution.
170
+ :param channels: channels in the inputs and outputs.
171
+ :param use_conv: a bool determining if a convolution is applied.
172
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
173
+ downsampling occurs in the inner-two dimensions.
174
+ """
175
+
176
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
177
+ super().__init__()
178
+ self.channels = channels
179
+ self.out_channels = out_channels or channels
180
+ self.use_conv = use_conv
181
+ self.dims = dims
182
+ stride = 2 if dims != 3 else (1, 2, 2)
183
+ if use_conv:
184
+ self.op = conv_nd(
185
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
186
+ )
187
+ else:
188
+ assert self.channels == self.out_channels
189
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
190
+
191
+ def forward(self, x):
192
+ assert x.shape[1] == self.channels
193
+ return self.op(x)
194
+
195
+
196
+ class ResBlock(TimestepBlock):
197
+ """
198
+ A residual block that can optionally change the number of channels.
199
+ :param channels: the number of input channels.
200
+ :param emb_channels: the number of timestep embedding channels.
201
+ :param dropout: the rate of dropout.
202
+ :param out_channels: if specified, the number of out channels.
203
+ :param use_conv: if True and out_channels is specified, use a spatial
204
+ convolution instead of a smaller 1x1 convolution to change the
205
+ channels in the skip connection.
206
+ :param dims: determines if the signal is 1D, 2D, or 3D.
207
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
208
+ :param up: if True, use this block for upsampling.
209
+ :param down: if True, use this block for downsampling.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ channels,
215
+ emb_channels,
216
+ dropout,
217
+ out_channels=None,
218
+ use_conv=False,
219
+ use_scale_shift_norm=False,
220
+ dims=2,
221
+ use_checkpoint=False,
222
+ up=False,
223
+ down=False,
224
+ ):
225
+ super().__init__()
226
+ self.channels = channels
227
+ self.emb_channels = emb_channels
228
+ self.dropout = dropout
229
+ self.out_channels = out_channels or channels
230
+ self.use_conv = use_conv
231
+ self.use_checkpoint = use_checkpoint
232
+ self.use_scale_shift_norm = use_scale_shift_norm
233
+
234
+ self.in_layers = nn.Sequential(
235
+ normalization(channels),
236
+ nn.SiLU(),
237
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
238
+ )
239
+
240
+ self.updown = up or down
241
+
242
+ if up:
243
+ self.h_upd = Upsample(channels, False, dims)
244
+ self.x_upd = Upsample(channels, False, dims)
245
+ elif down:
246
+ self.h_upd = Downsample(channels, False, dims)
247
+ self.x_upd = Downsample(channels, False, dims)
248
+ else:
249
+ self.h_upd = self.x_upd = nn.Identity()
250
+
251
+ self.emb_layers = nn.Sequential(
252
+ nn.SiLU(),
253
+ linear(
254
+ emb_channels,
255
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
256
+ ),
257
+ )
258
+ self.out_layers = nn.Sequential(
259
+ normalization(self.out_channels),
260
+ nn.SiLU(),
261
+ nn.Dropout(p=dropout),
262
+ zero_module(
263
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
264
+ ),
265
+ )
266
+
267
+ if self.out_channels == channels:
268
+ self.skip_connection = nn.Identity()
269
+ elif use_conv:
270
+ self.skip_connection = conv_nd(
271
+ dims, channels, self.out_channels, 3, padding=1
272
+ )
273
+ else:
274
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
275
+
276
+ def forward(self, x, emb):
277
+ """
278
+ Apply the block to a Tensor, conditioned on a timestep embedding.
279
+ :param x: an [N x C x ...] Tensor of features.
280
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
281
+ :return: an [N x C x ...] Tensor of outputs.
282
+ """
283
+ return checkpoint(
284
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
285
+ )
286
+
287
+
288
+ def _forward(self, x, emb):
289
+ if self.updown:
290
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
291
+ h = in_rest(x)
292
+ h = self.h_upd(h)
293
+ x = self.x_upd(x)
294
+ h = in_conv(h)
295
+ else:
296
+ h = self.in_layers(x)
297
+ emb_out = self.emb_layers(emb).type(h.dtype)
298
+ while len(emb_out.shape) < len(h.shape):
299
+ emb_out = emb_out[..., None]
300
+ if self.use_scale_shift_norm:
301
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
302
+ scale, shift = th.chunk(emb_out, 2, dim=1)
303
+ h = out_norm(h) * (1 + scale) + shift
304
+ h = out_rest(h)
305
+ else:
306
+ h = h + emb_out
307
+ h = self.out_layers(h)
308
+ return self.skip_connection(x) + h
309
+
310
+
311
+ class My_ResBlock(TimestepBlock):
312
+ """
313
+ A residual block that can optionally change the number of channels.
314
+ :param channels: the number of input channels.
315
+ :param emb_channels: the number of timestep embedding channels.
316
+ :param dropout: the rate of dropout.
317
+ :param out_channels: if specified, the number of out channels.
318
+ :param use_conv: if True and out_channels is specified, use a spatial
319
+ convolution instead of a smaller 1x1 convolution to change the
320
+ channels in the skip connection.
321
+ :param dims: determines if the signal is 1D, 2D, or 3D.
322
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
323
+ :param up: if True, use this block for upsampling.
324
+ :param down: if True, use this block for downsampling.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ channels,
330
+ emb_channels,
331
+ dropout,
332
+ out_channels=None,
333
+ use_conv=False,
334
+ use_scale_shift_norm=False,
335
+ dims=2,
336
+ use_checkpoint=False,
337
+ up=False,
338
+ down=False,
339
+ ):
340
+ super().__init__()
341
+ self.channels = channels
342
+ self.emb_channels = emb_channels
343
+ self.dropout = dropout
344
+ self.out_channels = out_channels or channels
345
+ self.use_conv = use_conv
346
+ self.use_checkpoint = use_checkpoint
347
+ self.use_scale_shift_norm = use_scale_shift_norm
348
+
349
+ self.in_layers = nn.Sequential(
350
+ normalization(channels),
351
+ nn.SiLU(),
352
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
353
+ )
354
+
355
+ self.updown = up or down
356
+
357
+ if up:
358
+ self.h_upd = Upsample(channels, False, dims)
359
+ self.x_upd = Upsample(channels, False, dims)
360
+ elif down:
361
+ self.h_upd = Downsample(channels, False, dims)
362
+ self.x_upd = Downsample(channels, False, dims)
363
+ else:
364
+ self.h_upd = self.x_upd = nn.Identity()
365
+
366
+ self.emb_layers = nn.Sequential(
367
+ nn.SiLU(),
368
+ linear(
369
+ emb_channels,
370
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
371
+ ),
372
+ )
373
+ self.out_layers = nn.Sequential(
374
+ normalization(self.out_channels),
375
+ nn.SiLU(),
376
+ nn.Dropout(p=dropout),
377
+ zero_module(
378
+ conv_nd(dims, self.out_channels, 4, 3, padding=1)
379
+ ),
380
+ )
381
+
382
+ if self.out_channels == channels:
383
+ self.skip_connection = nn.Identity()
384
+ elif use_conv:
385
+ self.skip_connection = conv_nd(
386
+ dims, channels, self.out_channels, 3, padding=1
387
+ )
388
+ else:
389
+ self.skip_connection = conv_nd(dims, channels, 4, 1)
390
+
391
+ def forward(self, x, emb):
392
+ """
393
+ Apply the block to a Tensor, conditioned on a timestep embedding.
394
+ :param x: an [N x C x ...] Tensor of features.
395
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
396
+ :return: an [N x C x ...] Tensor of outputs.
397
+ """
398
+ return checkpoint(
399
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
400
+ )
401
+
402
+
403
+ def _forward(self, x, emb):
404
+ if self.updown:
405
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
406
+ h = in_rest(x)
407
+ h = self.h_upd(h)
408
+ x = self.x_upd(x)
409
+ h = in_conv(h)
410
+ else:
411
+ h = self.in_layers(x)
412
+ emb_out = self.emb_layers(emb).type(h.dtype)
413
+ while len(emb_out.shape) < len(h.shape):
414
+ emb_out = emb_out[..., None]
415
+ if self.use_scale_shift_norm:
416
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
417
+ scale, shift = th.chunk(emb_out, 2, dim=1)
418
+ h = out_norm(h) * (1 + scale) + shift
419
+ h = out_rest(h)
420
+ else:
421
+ h = h + emb_out
422
+ h = self.out_layers(h)
423
+ return h
424
+
425
+
426
+ class AttentionBlock(nn.Module):
427
+ """
428
+ An attention block that allows spatial positions to attend to each other.
429
+ Originally ported from here, but adapted to the N-d case.
430
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ channels,
436
+ num_heads=1,
437
+ num_head_channels=-1,
438
+ use_checkpoint=False,
439
+ use_new_attention_order=False,
440
+ ):
441
+ super().__init__()
442
+ self.channels = channels
443
+ if num_head_channels == -1:
444
+ self.num_heads = num_heads
445
+ else:
446
+ assert (
447
+ channels % num_head_channels == 0
448
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
449
+ self.num_heads = channels // num_head_channels
450
+ self.use_checkpoint = use_checkpoint
451
+ self.norm = normalization(channels)
452
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
453
+ if use_new_attention_order:
454
+ # split qkv before split heads
455
+ self.attention = QKVAttention(self.num_heads)
456
+ else:
457
+ # split heads before split qkv
458
+ self.attention = QKVAttentionLegacy(self.num_heads)
459
+
460
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
461
+
462
+ def forward(self, x):
463
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
464
+ #return pt_checkpoint(self._forward, x) # pytorch
465
+
466
+ def _forward(self, x):
467
+ b, c, *spatial = x.shape
468
+ x = x.reshape(b, c, -1)
469
+ qkv = self.qkv(self.norm(x))
470
+ h = self.attention(qkv)
471
+ h = self.proj_out(h)
472
+ return (x + h).reshape(b, c, *spatial)
473
+
474
+
475
+ def count_flops_attn(model, _x, y):
476
+ """
477
+ A counter for the `thop` package to count the operations in an
478
+ attention operation.
479
+ Meant to be used like:
480
+ macs, params = thop.profile(
481
+ model,
482
+ inputs=(inputs, timestamps),
483
+ custom_ops={QKVAttention: QKVAttention.count_flops},
484
+ )
485
+ """
486
+ b, c, *spatial = y[0].shape
487
+ num_spatial = int(np.prod(spatial))
488
+ # We perform two matmuls with the same number of ops.
489
+ # The first computes the weight matrix, the second computes
490
+ # the combination of the value vectors.
491
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
492
+ model.total_ops += th.DoubleTensor([matmul_ops])
493
+
494
+
495
+ class QKVAttentionLegacy(nn.Module):
496
+ """
497
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
498
+ """
499
+
500
+ def __init__(self, n_heads):
501
+ super().__init__()
502
+ self.n_heads = n_heads
503
+
504
+ def forward(self, qkv):
505
+ """
506
+ Apply QKV attention.
507
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
508
+ :return: an [N x (H * C) x T] tensor after attention.
509
+ """
510
+ bs, width, length = qkv.shape
511
+ assert width % (3 * self.n_heads) == 0
512
+ ch = width // (3 * self.n_heads)
513
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
514
+ scale = 1 / math.sqrt(math.sqrt(ch))
515
+ weight = th.einsum(
516
+ "bct,bcs->bts", q * scale, k * scale
517
+ ) # More stable with f16 than dividing afterwards
518
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
519
+ a = th.einsum("bts,bcs->bct", weight, v)
520
+ return a.reshape(bs, -1, length)
521
+
522
+ @staticmethod
523
+ def count_flops(model, _x, y):
524
+ return count_flops_attn(model, _x, y)
525
+
526
+
527
+ class QKVAttention(nn.Module):
528
+ """
529
+ A module which performs QKV attention and splits in a different order.
530
+ """
531
+
532
+ def __init__(self, n_heads):
533
+ super().__init__()
534
+ self.n_heads = n_heads
535
+
536
+ def forward(self, qkv):
537
+ """
538
+ Apply QKV attention.
539
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
540
+ :return: an [N x (H * C) x T] tensor after attention.
541
+ """
542
+ bs, width, length = qkv.shape
543
+ assert width % (3 * self.n_heads) == 0
544
+ ch = width // (3 * self.n_heads)
545
+ q, k, v = qkv.chunk(3, dim=1)
546
+ scale = 1 / math.sqrt(math.sqrt(ch))
547
+ weight = th.einsum(
548
+ "bct,bcs->bts",
549
+ (q * scale).view(bs * self.n_heads, ch, length),
550
+ (k * scale).view(bs * self.n_heads, ch, length),
551
+ ) # More stable with f16 than dividing afterwards
552
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
553
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
554
+ return a.reshape(bs, -1, length)
555
+
556
+ @staticmethod
557
+ def count_flops(model, _x, y):
558
+ return count_flops_attn(model, _x, y)
559
+
560
+
561
+ class UNetModel(nn.Module):
562
+ """
563
+ The full UNet model with attention and timestep embedding.
564
+ :param in_channels: channels in the input Tensor.
565
+ :param model_channels: base channel count for the model.
566
+ :param out_channels: channels in the output Tensor.
567
+ :param num_res_blocks: number of residual blocks per downsample.
568
+ :param attention_resolutions: a collection of downsample rates at which
569
+ attention will take place. May be a set, list, or tuple.
570
+ For example, if this contains 4, then at 4x downsampling, attention
571
+ will be used.
572
+ :param dropout: the dropout probability.
573
+ :param channel_mult: channel multiplier for each level of the UNet.
574
+ :param conv_resample: if True, use learned convolutions for upsampling and
575
+ downsampling.
576
+ :param dims: determines if the signal is 1D, 2D, or 3D.
577
+ :param num_classes: if specified (as an int), then this model will be
578
+ class-conditional with `num_classes` classes.
579
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
580
+ :param num_heads: the number of attention heads in each attention layer.
581
+ :param num_heads_channels: if specified, ignore num_heads and instead use
582
+ a fixed channel width per attention head.
583
+ :param num_heads_upsample: works with num_heads to set a different number
584
+ of heads for upsampling. Deprecated.
585
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
586
+ :param resblock_updown: use residual blocks for up/downsampling.
587
+ :param use_new_attention_order: use a different attention pattern for potentially
588
+ increased efficiency.
589
+ """
590
+
591
+ def __init__(
592
+ self,
593
+ image_size,
594
+ in_channels,
595
+ model_channels,
596
+ out_channels,
597
+ num_res_blocks,
598
+ attention_resolutions,
599
+ dropout=0,
600
+ channel_mult=(1, 2, 4, 8),
601
+ conv_resample=True,
602
+ dims=2,
603
+ num_classes=None,
604
+ use_checkpoint=False,
605
+ use_fp16=False,
606
+ num_heads=-1,
607
+ num_head_channels=-1,
608
+ num_heads_upsample=-1,
609
+ use_scale_shift_norm=False,
610
+ resblock_updown=False,
611
+ use_new_attention_order=False,
612
+ use_spatial_transformer=False, # custom transformer support
613
+ transformer_depth=1, # custom transformer support
614
+ context_dim=None, # custom transformer support
615
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
616
+ legacy=True,
617
+ add_conv_in_front_of_unet=False,
618
+ ):
619
+ super().__init__()
620
+ if use_spatial_transformer:
621
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
622
+
623
+ if context_dim is not None:
624
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
625
+ from omegaconf.listconfig import ListConfig
626
+ if type(context_dim) == ListConfig:
627
+ context_dim = list(context_dim)
628
+
629
+ if num_heads_upsample == -1:
630
+ num_heads_upsample = num_heads
631
+
632
+ if num_heads == -1:
633
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
634
+
635
+ if num_head_channels == -1:
636
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
637
+
638
+ self.image_size = image_size
639
+ self.in_channels = in_channels
640
+ self.model_channels = model_channels
641
+ self.out_channels = out_channels
642
+ self.num_res_blocks = num_res_blocks
643
+ self.attention_resolutions = attention_resolutions
644
+ self.dropout = dropout
645
+ self.channel_mult = channel_mult
646
+ self.conv_resample = conv_resample
647
+ self.num_classes = num_classes
648
+ self.use_checkpoint = use_checkpoint
649
+ self.dtype = th.float16 if use_fp16 else th.float32
650
+ self.num_heads = num_heads
651
+ self.num_head_channels = num_head_channels
652
+ self.num_heads_upsample = num_heads_upsample
653
+ self.predict_codebook_ids = n_embed is not None
654
+ self.add_conv_in_front_of_unet=add_conv_in_front_of_unet
655
+
656
+
657
+ # save contexts
658
+ self.save_contexts = False
659
+ self.use_contexts = False
660
+ self.contexts = deque([])
661
+
662
+ time_embed_dim = model_channels * 4
663
+ self.time_embed = nn.Sequential(
664
+ linear(model_channels, time_embed_dim),
665
+ nn.SiLU(),
666
+ linear(time_embed_dim, time_embed_dim),
667
+ )
668
+
669
+ if self.num_classes is not None:
670
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
671
+
672
+
673
+ if self.add_conv_in_front_of_unet:
674
+ self.add_resbolck = nn.ModuleList(
675
+ [
676
+ TimestepEmbedSequential(
677
+ conv_nd(dims, 9, model_channels, 3, padding=1)
678
+ )
679
+ ]
680
+ )
681
+
682
+ add_layers = [
683
+ My_ResBlock(
684
+ model_channels,
685
+ time_embed_dim,
686
+ dropout,
687
+ out_channels=model_channels,
688
+ dims=dims,
689
+ use_checkpoint=use_checkpoint,
690
+ use_scale_shift_norm=use_scale_shift_norm,
691
+ )
692
+ ]
693
+
694
+ self.add_resbolck.append(TimestepEmbedSequential(*add_layers))
695
+
696
+
697
+ self.input_blocks = nn.ModuleList(
698
+ [
699
+ TimestepEmbedSequential(
700
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
701
+ )
702
+ ]
703
+ )
704
+ self._feature_size = model_channels
705
+ input_block_chans = [model_channels]
706
+ ch = model_channels
707
+ ds = 1
708
+ for level, mult in enumerate(channel_mult):
709
+ for _ in range(num_res_blocks):
710
+ layers = [
711
+ ResBlock(
712
+ ch,
713
+ time_embed_dim,
714
+ dropout,
715
+ out_channels=mult * model_channels,
716
+ dims=dims,
717
+ use_checkpoint=use_checkpoint,
718
+ use_scale_shift_norm=use_scale_shift_norm,
719
+ )
720
+ ]
721
+ ch = mult * model_channels
722
+ if ds in attention_resolutions:
723
+ if num_head_channels == -1:
724
+ dim_head = ch // num_heads
725
+ else:
726
+ num_heads = ch // num_head_channels
727
+ dim_head = num_head_channels
728
+ if legacy:
729
+ #num_heads = 1
730
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
731
+ layers.append(
732
+ AttentionBlock(
733
+ ch,
734
+ use_checkpoint=use_checkpoint,
735
+ num_heads=num_heads,
736
+ num_head_channels=dim_head,
737
+ use_new_attention_order=use_new_attention_order,
738
+ ) if not use_spatial_transformer else SpatialTransformer(
739
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
740
+ )
741
+ )
742
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
743
+ self._feature_size += ch
744
+ input_block_chans.append(ch)
745
+ if level != len(channel_mult) - 1:
746
+ out_ch = ch
747
+ self.input_blocks.append(
748
+ TimestepEmbedSequential(
749
+ ResBlock(
750
+ ch,
751
+ time_embed_dim,
752
+ dropout,
753
+ out_channels=out_ch,
754
+ dims=dims,
755
+ use_checkpoint=use_checkpoint,
756
+ use_scale_shift_norm=use_scale_shift_norm,
757
+ down=True,
758
+ )
759
+ if resblock_updown
760
+ else Downsample(
761
+ ch, conv_resample, dims=dims, out_channels=out_ch
762
+ )
763
+ )
764
+ )
765
+ ch = out_ch
766
+ input_block_chans.append(ch)
767
+ ds *= 2
768
+ self._feature_size += ch
769
+
770
+ if num_head_channels == -1:
771
+ dim_head = ch // num_heads
772
+ else:
773
+ num_heads = ch // num_head_channels
774
+ dim_head = num_head_channels
775
+ if legacy:
776
+ #num_heads = 1
777
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
778
+ self.middle_block = TimestepEmbedSequential(
779
+ ResBlock(
780
+ ch,
781
+ time_embed_dim,
782
+ dropout,
783
+ dims=dims,
784
+ use_checkpoint=use_checkpoint,
785
+ use_scale_shift_norm=use_scale_shift_norm,
786
+ ),
787
+ AttentionBlock(
788
+ ch,
789
+ use_checkpoint=use_checkpoint,
790
+ num_heads=num_heads,
791
+ num_head_channels=dim_head,
792
+ use_new_attention_order=use_new_attention_order,
793
+ ) if not use_spatial_transformer else SpatialTransformer(
794
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
795
+ ),
796
+ ResBlock(
797
+ ch,
798
+ time_embed_dim,
799
+ dropout,
800
+ dims=dims,
801
+ use_checkpoint=use_checkpoint,
802
+ use_scale_shift_norm=use_scale_shift_norm,
803
+ ),
804
+ )
805
+ self._feature_size += ch
806
+
807
+ self.output_blocks = nn.ModuleList([])
808
+ for level, mult in list(enumerate(channel_mult))[::-1]:
809
+ for i in range(num_res_blocks + 1):
810
+ ich = input_block_chans.pop()
811
+ layers = [
812
+ ResBlock(
813
+ ch + ich,
814
+ time_embed_dim,
815
+ dropout,
816
+ out_channels=model_channels * mult,
817
+ dims=dims,
818
+ use_checkpoint=use_checkpoint,
819
+ use_scale_shift_norm=use_scale_shift_norm,
820
+ )
821
+ ]
822
+ ch = model_channels * mult
823
+ if ds in attention_resolutions:
824
+ if num_head_channels == -1:
825
+ dim_head = ch // num_heads
826
+ else:
827
+ num_heads = ch // num_head_channels
828
+ dim_head = num_head_channels
829
+ if legacy:
830
+ #num_heads = 1
831
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
832
+ layers.append(
833
+ AttentionBlock(
834
+ ch,
835
+ use_checkpoint=use_checkpoint,
836
+ num_heads=num_heads_upsample,
837
+ num_head_channels=dim_head,
838
+ use_new_attention_order=use_new_attention_order,
839
+ ) if not use_spatial_transformer else SpatialTransformer(
840
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
841
+ )
842
+ )
843
+ if level and i == num_res_blocks:
844
+ out_ch = ch
845
+ layers.append(
846
+ ResBlock(
847
+ ch,
848
+ time_embed_dim,
849
+ dropout,
850
+ out_channels=out_ch,
851
+ dims=dims,
852
+ use_checkpoint=use_checkpoint,
853
+ use_scale_shift_norm=use_scale_shift_norm,
854
+ up=True,
855
+ )
856
+ if resblock_updown
857
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
858
+ )
859
+ ds //= 2
860
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
861
+ self._feature_size += ch
862
+
863
+ self.out = nn.Sequential(
864
+ normalization(ch),
865
+ nn.SiLU(),
866
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
867
+ )
868
+ if self.predict_codebook_ids:
869
+ self.id_predictor = nn.Sequential(
870
+ normalization(ch),
871
+ conv_nd(dims, model_channels, n_embed, 1),
872
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
873
+ )
874
+
875
+ def convert_to_fp16(self):
876
+ """
877
+ Convert the torso of the model to float16.
878
+ """
879
+ self.input_blocks.apply(convert_module_to_f16)
880
+ self.middle_block.apply(convert_module_to_f16)
881
+ self.output_blocks.apply(convert_module_to_f16)
882
+
883
+ def convert_to_fp32(self):
884
+ """
885
+ Convert the torso of the model to float32.
886
+ """
887
+ self.input_blocks.apply(convert_module_to_f32)
888
+ self.middle_block.apply(convert_module_to_f32)
889
+ self.output_blocks.apply(convert_module_to_f32)
890
+
891
+ def forward(self, x, timesteps=None, context=None, y=None, get_contexts=False, passed_contexts=None, corresp=None,**kwargs):
892
+ """
893
+ Apply the model to an input batch.
894
+ :param x: an [N x C x ...] Tensor of inputs.
895
+ :param timesteps: a 1-D batch of timesteps.
896
+ :param context: conditioning plugged in via crossattn
897
+ :param y: an [N] Tensor of labels, if class-conditional.
898
+ :return: an [N x C x ...] Tensor of outputs.
899
+ """
900
+ assert (y is not None) == (
901
+ self.num_classes is not None
902
+ ), "must specify y if and only if the model is class-conditional"
903
+ hs = []
904
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
905
+ emb = self.time_embed(t_emb)
906
+
907
+ ds = [8, 16, 32, 64]
908
+
909
+ # cur_step = len(glob.glob('/dev/shm/dumpster/steps/*'))
910
+ # os.makedirs(f'/dev/shm/dumpster/steps/{cur_step:04d}', exist_ok=False)
911
+
912
+ og_mask = x[:, -1:] # Bx1x64x64
913
+ batch_size = og_mask.shape[0]
914
+ masks = dict()
915
+
916
+ for d in ds:
917
+ resized_mask = torchvision.transforms.functional.resize(og_mask, size=(d, d))
918
+
919
+ mask = resized_mask.reshape(batch_size, -1)
920
+ masks[d] = mask
921
+
922
+ # if self.use_contexts:
923
+ # passed_contexts = self.contexts.popleft()
924
+
925
+ all_kqvs = []
926
+
927
+ if self.num_classes is not None:
928
+ assert y.shape == (x.shape[0],)
929
+ emb = emb + self.label_emb(y)
930
+
931
+ h = x.type(self.dtype)
932
+
933
+ if self.add_conv_in_front_of_unet:
934
+ for module in self.add_resbolck:
935
+ h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
936
+ all_kqvs.append(kqv)
937
+
938
+ for module in self.input_blocks:
939
+ h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
940
+ hs.append(h)
941
+ all_kqvs.append(kqv)
942
+
943
+ h, kqv = self.middle_block(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
944
+ all_kqvs.append(kqv)
945
+ for module in self.output_blocks:
946
+ h = th.cat([h, hs.pop()], dim=1)
947
+ h, kqv = module(h, emb, context, passed_contexts, len(all_kqvs), masks=masks, corresp=corresp)
948
+ all_kqvs.append(kqv)
949
+
950
+ h = h.type(x.dtype)
951
+
952
+ # print(all_kqvs)
953
+ # for i in range(len(all_kqvs)):
954
+ # print('len of contexts at ', i, 'is ', len(all_kqvs[i]))
955
+ # for j in range(len(all_kqvs[i])):
956
+ # print('len of contexts at ', i, j, 'is ', len(all_kqvs[i][j]))
957
+ # for k in range(len(all_kqvs[i][j])):
958
+ # print(all_kqvs[i][j][k])
959
+
960
+
961
+
962
+ if self.predict_codebook_ids:
963
+ out = self.id_predictor(h)
964
+ else:
965
+ out = self.out(h)
966
+
967
+ if self.save_contexts:
968
+ self.contexts.append(all_kqvs)
969
+
970
+ if get_contexts:
971
+ return out, all_kqvs
972
+ else:
973
+ return out
974
+
975
+ def get_contexts(self, x, timesteps=None, context=None, y=None,**kwargs):
976
+ """
977
+ same as forward but saves self attention contexts
978
+ """
979
+ assert (y is not None) == (
980
+ self.num_classes is not None
981
+ ), "must specify y if and only if the model is class-conditional"
982
+ hs = []
983
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
984
+ emb = self.time_embed(t_emb)
985
+
986
+ if self.num_classes is not None:
987
+ assert y.shape == (x.shape[0],)
988
+ emb = emb + self.label_emb(y)
989
+
990
+ h = x.type(self.dtype)
991
+
992
+ if self.add_conv_in_front_of_unet:
993
+ for module in self.add_resbolck:
994
+ h = module(h, emb, context)
995
+
996
+ for module in self.input_blocks:
997
+ h = module(h, emb, context)
998
+ hs.append(h)
999
+ h = self.middle_block(h, emb, context)
1000
+ for module in self.output_blocks:
1001
+ h = th.cat([h, hs.pop()], dim=1)
1002
+ h = module(h, emb, context)
1003
+ h = h.type(x.dtype)
1004
+ if self.predict_codebook_ids:
1005
+ return self.id_predictor(h)
1006
+ else:
1007
+ return self.out(h)
1008
+
1009
+ class EncoderUNetModel(nn.Module):
1010
+ """
1011
+ The half UNet model with attention and timestep embedding.
1012
+ For usage, see UNet.
1013
+ """
1014
+
1015
+ def __init__(
1016
+ self,
1017
+ image_size,
1018
+ in_channels,
1019
+ model_channels,
1020
+ out_channels,
1021
+ num_res_blocks,
1022
+ attention_resolutions,
1023
+ dropout=0,
1024
+ channel_mult=(1, 2, 4, 8),
1025
+ conv_resample=True,
1026
+ dims=2,
1027
+ use_checkpoint=False,
1028
+ use_fp16=False,
1029
+ num_heads=1,
1030
+ num_head_channels=-1,
1031
+ num_heads_upsample=-1,
1032
+ use_scale_shift_norm=False,
1033
+ resblock_updown=False,
1034
+ use_new_attention_order=False,
1035
+ pool="adaptive",
1036
+ *args,
1037
+ **kwargs
1038
+ ):
1039
+ super().__init__()
1040
+
1041
+ if num_heads_upsample == -1:
1042
+ num_heads_upsample = num_heads
1043
+
1044
+ self.in_channels = in_channels
1045
+ self.model_channels = model_channels
1046
+ self.out_channels = out_channels
1047
+ self.num_res_blocks = num_res_blocks
1048
+ self.attention_resolutions = attention_resolutions
1049
+ self.dropout = dropout
1050
+ self.channel_mult = channel_mult
1051
+ self.conv_resample = conv_resample
1052
+ self.use_checkpoint = use_checkpoint
1053
+ self.dtype = th.float16 if use_fp16 else th.float32
1054
+ self.num_heads = num_heads
1055
+ self.num_head_channels = num_head_channels
1056
+ self.num_heads_upsample = num_heads_upsample
1057
+
1058
+ time_embed_dim = model_channels * 4
1059
+ self.time_embed = nn.Sequential(
1060
+ linear(model_channels, time_embed_dim),
1061
+ nn.SiLU(),
1062
+ linear(time_embed_dim, time_embed_dim),
1063
+ )
1064
+
1065
+ self.input_blocks = nn.ModuleList(
1066
+ [
1067
+ TimestepEmbedSequential(
1068
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1069
+ )
1070
+ ]
1071
+ )
1072
+ self._feature_size = model_channels
1073
+ input_block_chans = [model_channels]
1074
+ ch = model_channels
1075
+ ds = 1
1076
+ for level, mult in enumerate(channel_mult):
1077
+ for _ in range(num_res_blocks):
1078
+ layers = [
1079
+ ResBlock(
1080
+ ch,
1081
+ time_embed_dim,
1082
+ dropout,
1083
+ out_channels=mult * model_channels,
1084
+ dims=dims,
1085
+ use_checkpoint=use_checkpoint,
1086
+ use_scale_shift_norm=use_scale_shift_norm,
1087
+ )
1088
+ ]
1089
+ ch = mult * model_channels
1090
+ if ds in attention_resolutions:
1091
+ layers.append(
1092
+ AttentionBlock(
1093
+ ch,
1094
+ use_checkpoint=use_checkpoint,
1095
+ num_heads=num_heads,
1096
+ num_head_channels=num_head_channels,
1097
+ use_new_attention_order=use_new_attention_order,
1098
+ )
1099
+ )
1100
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1101
+ self._feature_size += ch
1102
+ input_block_chans.append(ch)
1103
+ if level != len(channel_mult) - 1:
1104
+ out_ch = ch
1105
+ self.input_blocks.append(
1106
+ TimestepEmbedSequential(
1107
+ ResBlock(
1108
+ ch,
1109
+ time_embed_dim,
1110
+ dropout,
1111
+ out_channels=out_ch,
1112
+ dims=dims,
1113
+ use_checkpoint=use_checkpoint,
1114
+ use_scale_shift_norm=use_scale_shift_norm,
1115
+ down=True,
1116
+ )
1117
+ if resblock_updown
1118
+ else Downsample(
1119
+ ch, conv_resample, dims=dims, out_channels=out_ch
1120
+ )
1121
+ )
1122
+ )
1123
+ ch = out_ch
1124
+ input_block_chans.append(ch)
1125
+ ds *= 2
1126
+ self._feature_size += ch
1127
+
1128
+ self.middle_block = TimestepEmbedSequential(
1129
+ ResBlock(
1130
+ ch,
1131
+ time_embed_dim,
1132
+ dropout,
1133
+ dims=dims,
1134
+ use_checkpoint=use_checkpoint,
1135
+ use_scale_shift_norm=use_scale_shift_norm,
1136
+ ),
1137
+ AttentionBlock(
1138
+ ch,
1139
+ use_checkpoint=use_checkpoint,
1140
+ num_heads=num_heads,
1141
+ num_head_channels=num_head_channels,
1142
+ use_new_attention_order=use_new_attention_order,
1143
+ ),
1144
+ ResBlock(
1145
+ ch,
1146
+ time_embed_dim,
1147
+ dropout,
1148
+ dims=dims,
1149
+ use_checkpoint=use_checkpoint,
1150
+ use_scale_shift_norm=use_scale_shift_norm,
1151
+ ),
1152
+ )
1153
+ self._feature_size += ch
1154
+ self.pool = pool
1155
+ if pool == "adaptive":
1156
+ self.out = nn.Sequential(
1157
+ normalization(ch),
1158
+ nn.SiLU(),
1159
+ nn.AdaptiveAvgPool2d((1, 1)),
1160
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1161
+ nn.Flatten(),
1162
+ )
1163
+ elif pool == "attention":
1164
+ assert num_head_channels != -1
1165
+ self.out = nn.Sequential(
1166
+ normalization(ch),
1167
+ nn.SiLU(),
1168
+ AttentionPool2d(
1169
+ (image_size // ds), ch, num_head_channels, out_channels
1170
+ ),
1171
+ )
1172
+ elif pool == "spatial":
1173
+ self.out = nn.Sequential(
1174
+ nn.Linear(self._feature_size, 2048),
1175
+ nn.ReLU(),
1176
+ nn.Linear(2048, self.out_channels),
1177
+ )
1178
+ elif pool == "spatial_v2":
1179
+ self.out = nn.Sequential(
1180
+ nn.Linear(self._feature_size, 2048),
1181
+ normalization(2048),
1182
+ nn.SiLU(),
1183
+ nn.Linear(2048, self.out_channels),
1184
+ )
1185
+ else:
1186
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1187
+
1188
+ def convert_to_fp16(self):
1189
+ """
1190
+ Convert the torso of the model to float16.
1191
+ """
1192
+ self.input_blocks.apply(convert_module_to_f16)
1193
+ self.middle_block.apply(convert_module_to_f16)
1194
+
1195
+ def convert_to_fp32(self):
1196
+ """
1197
+ Convert the torso of the model to float32.
1198
+ """
1199
+ self.input_blocks.apply(convert_module_to_f32)
1200
+ self.middle_block.apply(convert_module_to_f32)
1201
+
1202
+ def forward(self, x, timesteps):
1203
+ """
1204
+ Apply the model to an input batch.
1205
+ :param x: an [N x C x ...] Tensor of inputs.
1206
+ :param timesteps: a 1-D batch of timesteps.
1207
+ :return: an [N x K] Tensor of outputs.
1208
+ """
1209
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1210
+
1211
+ results = []
1212
+ h = x.type(self.dtype)
1213
+ for module in self.input_blocks:
1214
+ h = module(h, emb)
1215
+ if self.pool.startswith("spatial"):
1216
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1217
+ h = self.middle_block(h, emb)
1218
+ if self.pool.startswith("spatial"):
1219
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1220
+ h = th.cat(results, axis=-1)
1221
+ return self.out(h)
1222
+ else:
1223
+ h = h.type(x.dtype)
1224
+ return self.out(h)
1225
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ # adopted from
15
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
16
+ # and
17
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
18
+ # and
19
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
20
+ #
21
+ # thanks!
22
+
23
+
24
+ import os
25
+ import math
26
+ import torch
27
+ import torch.nn as nn
28
+ import numpy as np
29
+ from einops import repeat
30
+
31
+ from ldm.util import instantiate_from_config
32
+
33
+
34
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
35
+ if schedule == "linear":
36
+ betas = (
37
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
38
+ )
39
+
40
+ elif schedule == "cosine":
41
+ timesteps = (
42
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
43
+ )
44
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
45
+ alphas = torch.cos(alphas).pow(2)
46
+ alphas = alphas / alphas[0]
47
+ betas = 1 - alphas[1:] / alphas[:-1]
48
+ betas = np.clip(betas, a_min=0, a_max=0.999)
49
+
50
+ elif schedule == "sqrt_linear":
51
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
52
+ elif schedule == "sqrt":
53
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
54
+ else:
55
+ raise ValueError(f"schedule '{schedule}' unknown.")
56
+ return betas.numpy()
57
+
58
+
59
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True, steps=None):
60
+ if ddim_discr_method == 'uniform':
61
+ c = num_ddpm_timesteps // num_ddim_timesteps
62
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
63
+ elif ddim_discr_method == 'quad':
64
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
65
+ elif ddim_discr_method == 'manual':
66
+ assert steps is not None
67
+ ddim_timesteps = np.asarray(steps)
68
+ else:
69
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
70
+
71
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
73
+ steps_out = ddim_timesteps + 1
74
+ if verbose:
75
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
76
+ return steps_out
77
+
78
+
79
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80
+ # select alphas for computing the variance schedule
81
+ alphas = alphacums[ddim_timesteps]
82
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
83
+
84
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
85
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
86
+ if verbose:
87
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
88
+ print(f'For the chosen value of eta, which is {eta}, '
89
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
90
+ return sigmas, alphas, alphas_prev
91
+
92
+
93
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
94
+ """
95
+ Create a beta schedule that discretizes the given alpha_t_bar function,
96
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
97
+ :param num_diffusion_timesteps: the number of betas to produce.
98
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
99
+ produces the cumulative product of (1-beta) up to that
100
+ part of the diffusion process.
101
+ :param max_beta: the maximum beta to use; use values lower than 1 to
102
+ prevent singularities.
103
+ """
104
+ betas = []
105
+ for i in range(num_diffusion_timesteps):
106
+ t1 = i / num_diffusion_timesteps
107
+ t2 = (i + 1) / num_diffusion_timesteps
108
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
109
+ return np.array(betas)
110
+
111
+
112
+ def extract_into_tensor(a, t, x_shape):
113
+ b, *_ = t.shape
114
+ out = a.gather(-1, t)
115
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
116
+
117
+
118
+ def checkpoint(func, inputs, params, flag):
119
+ """
120
+ Evaluate a function without caching intermediate activations, allowing for
121
+ reduced memory at the expense of extra compute in the backward pass.
122
+ :param func: the function to evaluate.
123
+ :param inputs: the argument sequence to pass to `func`.
124
+ :param params: a sequence of parameters `func` depends on but does not
125
+ explicitly take as arguments.
126
+ :param flag: if False, disable gradient checkpointing.
127
+ """
128
+ if flag:
129
+ args = tuple(inputs) + tuple(params)
130
+ return CheckpointFunction.apply(func, len(inputs), *args)
131
+ else:
132
+ return func(*inputs)
133
+
134
+
135
+ class CheckpointFunction(torch.autograd.Function):
136
+ @staticmethod
137
+ # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) # added this for map
138
+ def forward(ctx, run_function, length, *args):
139
+ ctx.run_function = run_function
140
+ ctx.input_tensors = list(args[:length])
141
+ ctx.input_params = list(args[length:])
142
+
143
+ with torch.no_grad():
144
+ output_tensors = ctx.run_function(*ctx.input_tensors)
145
+ return output_tensors
146
+
147
+ @staticmethod
148
+ # @torch.cuda.amp.custom_bwd # added this for map
149
+ def backward(ctx, *output_grads):
150
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
151
+ with torch.enable_grad():
152
+ # Fixes a bug where the first op in run_function modifies the
153
+ # Tensor storage in place, which is not allowed for detach()'d
154
+ # Tensors.
155
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
156
+ output_tensors = ctx.run_function(*shallow_copies)
157
+ input_grads = torch.autograd.grad(
158
+ output_tensors,
159
+ ctx.input_tensors + ctx.input_params,
160
+ output_grads,
161
+ allow_unused=True,
162
+ )
163
+ del ctx.input_tensors
164
+ del ctx.input_params
165
+ del output_tensors
166
+ return (None, None) + input_grads
167
+
168
+
169
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
170
+ """
171
+ Create sinusoidal timestep embeddings.
172
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
173
+ These may be fractional.
174
+ :param dim: the dimension of the output.
175
+ :param max_period: controls the minimum frequency of the embeddings.
176
+ :return: an [N x dim] Tensor of positional embeddings.
177
+ """
178
+ if not repeat_only:
179
+ half = dim // 2
180
+ freqs = torch.exp(
181
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
182
+ ).to(device=timesteps.device)
183
+ args = timesteps[:, None].float() * freqs[None]
184
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
185
+ if dim % 2:
186
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
187
+ else:
188
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
189
+ return embedding
190
+
191
+
192
+ def zero_module(module):
193
+ """
194
+ Zero out the parameters of a module and return it.
195
+ """
196
+ for p in module.parameters():
197
+ p.detach().zero_()
198
+ return module
199
+
200
+
201
+ def scale_module(module, scale):
202
+ """
203
+ Scale the parameters of a module and return it.
204
+ """
205
+ for p in module.parameters():
206
+ p.detach().mul_(scale)
207
+ return module
208
+
209
+
210
+ def mean_flat(tensor):
211
+ """
212
+ Take the mean over all non-batch dimensions.
213
+ """
214
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
215
+
216
+
217
+ def normalization(channels):
218
+ """
219
+ Make a standard normalization layer.
220
+ :param channels: number of input channels.
221
+ :return: an nn.Module for normalization.
222
+ """
223
+ return GroupNorm32(32, channels)
224
+
225
+
226
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
227
+ class SiLU(nn.Module):
228
+ def forward(self, x):
229
+ return x * torch.sigmoid(x)
230
+
231
+
232
+ class GroupNorm32(nn.GroupNorm):
233
+ def forward(self, x):
234
+ return super().forward(x.float()).type(x.dtype)
235
+
236
+ def conv_nd(dims, *args, **kwargs):
237
+ """
238
+ Create a 1D, 2D, or 3D convolution module.
239
+ """
240
+ if dims == 1:
241
+ return nn.Conv1d(*args, **kwargs)
242
+ elif dims == 2:
243
+ return nn.Conv2d(*args, **kwargs)
244
+ elif dims == 3:
245
+ return nn.Conv3d(*args, **kwargs)
246
+ raise ValueError(f"unsupported dimensions: {dims}")
247
+
248
+
249
+ def linear(*args, **kwargs):
250
+ """
251
+ Create a linear module.
252
+ """
253
+ return nn.Linear(*args, **kwargs)
254
+
255
+
256
+ def avg_pool_nd(dims, *args, **kwargs):
257
+ """
258
+ Create a 1D, 2D, or 3D average pooling module.
259
+ """
260
+ if dims == 1:
261
+ return nn.AvgPool1d(*args, **kwargs)
262
+ elif dims == 2:
263
+ return nn.AvgPool2d(*args, **kwargs)
264
+ elif dims == 3:
265
+ return nn.AvgPool3d(*args, **kwargs)
266
+ raise ValueError(f"unsupported dimensions: {dims}")
267
+
268
+
269
+ class HybridConditioner(nn.Module):
270
+
271
+ def __init__(self, c_concat_config, c_crossattn_config):
272
+ super().__init__()
273
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
274
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
275
+
276
+ def forward(self, c_concat, c_crossattn):
277
+ c_concat = self.concat_conditioner(c_concat)
278
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
279
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
280
+
281
+
282
+ def noise_like(shape, device, repeat=False):
283
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
284
+ noise = lambda: torch.randn(shape, device=device)
285
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ import numpy as np
16
+
17
+
18
+ class AbstractDistribution:
19
+ def sample(self):
20
+ raise NotImplementedError()
21
+
22
+ def mode(self):
23
+ raise NotImplementedError()
24
+
25
+
26
+ class DiracDistribution(AbstractDistribution):
27
+ def __init__(self, value):
28
+ self.value = value
29
+
30
+ def sample(self):
31
+ return self.value
32
+
33
+ def mode(self):
34
+ return self.value
35
+
36
+
37
+ class DiagonalGaussianDistribution(object):
38
+ def __init__(self, parameters, deterministic=False):
39
+ self.parameters = parameters
40
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
41
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
42
+ self.deterministic = deterministic
43
+ self.std = torch.exp(0.5 * self.logvar)
44
+ self.var = torch.exp(self.logvar)
45
+ if self.deterministic:
46
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
47
+
48
+ def sample(self):
49
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
50
+ return x
51
+
52
+ def kl(self, other=None):
53
+ if self.deterministic:
54
+ return torch.Tensor([0.])
55
+ else:
56
+ if other is None:
57
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
58
+ + self.var - 1.0 - self.logvar,
59
+ dim=[1, 2, 3])
60
+ else:
61
+ return 0.5 * torch.sum(
62
+ torch.pow(self.mean - other.mean, 2) / other.var
63
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
64
+ dim=[1, 2, 3])
65
+
66
+ def nll(self, sample, dims=[1,2,3]):
67
+ if self.deterministic:
68
+ return torch.Tensor([0.])
69
+ logtwopi = np.log(2.0 * np.pi)
70
+ return 0.5 * torch.sum(
71
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
72
+ dim=dims)
73
+
74
+ def mode(self):
75
+ return self.mean
76
+
77
+
78
+ def normal_kl(mean1, logvar1, mean2, logvar2):
79
+ """
80
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
81
+ Compute the KL divergence between two gaussians.
82
+ Shapes are automatically broadcasted, so batches can be compared to
83
+ scalars, among other use cases.
84
+ """
85
+ tensor = None
86
+ for obj in (mean1, logvar1, mean2, logvar2):
87
+ if isinstance(obj, torch.Tensor):
88
+ tensor = obj
89
+ break
90
+ assert tensor is not None, "at least one argument must be a Tensor"
91
+
92
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
93
+ # Tensors, but it does not work for torch.exp().
94
+ logvar1, logvar2 = [
95
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
96
+ for x in (logvar1, logvar2)
97
+ ]
98
+
99
+ return 0.5 * (
100
+ -1.0
101
+ + logvar2
102
+ - logvar1
103
+ + torch.exp(logvar1 - logvar2)
104
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
105
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+
18
+ class LitEma(nn.Module):
19
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
20
+ super().__init__()
21
+ if decay < 0.0 or decay > 1.0:
22
+ raise ValueError('Decay must be between 0 and 1')
23
+
24
+ self.m_name2s_name = {}
25
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
26
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
27
+ else torch.tensor(-1,dtype=torch.int))
28
+
29
+ for name, p in model.named_parameters():
30
+ if p.requires_grad:
31
+ #remove as '.'-character is not allowed in buffers
32
+ s_name = name.replace('.','')
33
+ self.m_name2s_name.update({name:s_name})
34
+ self.register_buffer(s_name,p.clone().detach().data)
35
+
36
+ self.collected_params = []
37
+
38
+ def forward(self,model):
39
+ decay = self.decay
40
+
41
+ if self.num_updates >= 0:
42
+ self.num_updates += 1
43
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
44
+
45
+ one_minus_decay = 1.0 - decay
46
+
47
+ with torch.no_grad():
48
+ m_param = dict(model.named_parameters())
49
+ shadow_params = dict(self.named_buffers())
50
+
51
+ for key in m_param:
52
+ if m_param[key].requires_grad:
53
+ sname = self.m_name2s_name[key]
54
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
55
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def copy_to(self, model):
60
+ m_param = dict(model.named_parameters())
61
+ shadow_params = dict(self.named_buffers())
62
+ for key in m_param:
63
+ if m_param[key].requires_grad:
64
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
65
+ else:
66
+ assert not key in self.m_name2s_name
67
+
68
+ def store(self, parameters):
69
+ """
70
+ Save the current parameters for restoring later.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ temporarily stored.
74
+ """
75
+ self.collected_params = [param.clone() for param in parameters]
76
+
77
+ def restore(self, parameters):
78
+ """
79
+ Restore the parameters stored with the `store` method.
80
+ Useful to validate the model with EMA parameters without affecting the
81
+ original optimization process. Store the parameters before the
82
+ `copy_to` method. After validation (or model saving), use this to
83
+ restore the former parameters.
84
+ Args:
85
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
86
+ updated with the stored parameters.
87
+ """
88
+ for c_param, param in zip(self.collected_params, parameters):
89
+ param.data.copy_(c_param.data)
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from functools import partial
17
+ import clip
18
+ from einops import rearrange, repeat
19
+ from transformers import CLIPTokenizer, CLIPTextModel,CLIPVisionModel,CLIPModel
20
+ import kornia
21
+ from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
22
+ from .xf import LayerNorm, Transformer
23
+ import math
24
+
25
+ class AbstractEncoder(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ def encode(self, *args, **kwargs):
30
+ raise NotImplementedError
31
+
32
+
33
+
34
+ class ClassEmbedder(nn.Module):
35
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
36
+ super().__init__()
37
+ self.key = key
38
+ self.embedding = nn.Embedding(n_classes, embed_dim)
39
+
40
+ def forward(self, batch, key=None):
41
+ if key is None:
42
+ key = self.key
43
+ # this is for use in crossattn
44
+ c = batch[key][:, None]
45
+ c = self.embedding(c)
46
+ return c
47
+
48
+
49
+ class TransformerEmbedder(AbstractEncoder):
50
+ """Some transformer encoder layers"""
51
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
52
+ super().__init__()
53
+ self.device = device
54
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
55
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
56
+
57
+ def forward(self, tokens):
58
+ tokens = tokens.to(self.device) # meh
59
+ z = self.transformer(tokens, return_embeddings=True)
60
+ return z
61
+
62
+ def encode(self, x):
63
+ return self(x)
64
+
65
+
66
+ class BERTTokenizer(AbstractEncoder):
67
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
68
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
69
+ super().__init__()
70
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
71
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
72
+ self.device = device
73
+ self.vq_interface = vq_interface
74
+ self.max_length = max_length
75
+
76
+ def forward(self, text):
77
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
78
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
79
+ tokens = batch_encoding["input_ids"].to(self.device)
80
+ return tokens
81
+
82
+ @torch.no_grad()
83
+ def encode(self, text):
84
+ tokens = self(text)
85
+ if not self.vq_interface:
86
+ return tokens
87
+ return None, None, [None, None, tokens]
88
+
89
+ def decode(self, text):
90
+ return text
91
+
92
+
93
+ class BERTEmbedder(AbstractEncoder):
94
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
95
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
96
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
97
+ super().__init__()
98
+ self.use_tknz_fn = use_tokenizer
99
+ if self.use_tknz_fn:
100
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
101
+ self.device = device
102
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
103
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
104
+ emb_dropout=embedding_dropout)
105
+
106
+ def forward(self, text):
107
+ if self.use_tknz_fn:
108
+ tokens = self.tknz_fn(text)#.to(self.device)
109
+ else:
110
+ tokens = text
111
+ z = self.transformer(tokens, return_embeddings=True)
112
+ return z
113
+
114
+ def encode(self, text):
115
+ # output of length 77
116
+ return self(text)
117
+
118
+ class FrozenCLIPEmbedder(AbstractEncoder):
119
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
120
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
121
+ super().__init__()
122
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
123
+ self.transformer = CLIPTextModel.from_pretrained(version)
124
+ self.device = device
125
+ self.max_length = max_length
126
+ self.freeze()
127
+
128
+ def freeze(self):
129
+ self.transformer = self.transformer.eval()
130
+ for param in self.parameters():
131
+ param.requires_grad = False
132
+
133
+ def forward(self, text):
134
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
135
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
136
+ tokens = batch_encoding["input_ids"].to(self.device)
137
+ outputs = self.transformer(input_ids=tokens)
138
+
139
+ z = outputs.last_hidden_state
140
+ return z
141
+
142
+ def encode(self, text):
143
+ return self(text)
144
+
145
+
146
+ class SpatialRescaler(nn.Module):
147
+ def __init__(self,
148
+ n_stages=1,
149
+ method='bilinear',
150
+ multiplier=0.5,
151
+ in_channels=3,
152
+ out_channels=None,
153
+ bias=False):
154
+ super().__init__()
155
+ self.n_stages = n_stages
156
+ assert self.n_stages >= 0
157
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
158
+ self.multiplier = multiplier
159
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
160
+ self.remap_output = out_channels is not None
161
+ if self.remap_output:
162
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
163
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
164
+
165
+ def forward(self,x):
166
+ for stage in range(self.n_stages):
167
+ x = self.interpolator(x, scale_factor=self.multiplier)
168
+
169
+
170
+ if self.remap_output:
171
+ x = self.channel_mapper(x)
172
+ return x
173
+
174
+ def encode(self, x):
175
+ return self(x)
176
+
177
+ class FrozenCLIPTextEmbedder(nn.Module):
178
+ """
179
+ Uses the CLIP transformer encoder for text.
180
+ """
181
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
182
+ super().__init__()
183
+ self.model, _ = clip.load(version, jit=False, device="cpu")
184
+ self.device = device
185
+ self.max_length = max_length
186
+ self.n_repeat = n_repeat
187
+ self.normalize = normalize
188
+
189
+ def freeze(self):
190
+ self.model = self.model.eval()
191
+ for param in self.parameters():
192
+ param.requires_grad = False
193
+
194
+ def forward(self, text):
195
+ tokens = clip.tokenize(text).to(self.device)
196
+ z = self.model.encode_text(tokens)
197
+ if self.normalize:
198
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
199
+ return z
200
+
201
+ def encode(self, text):
202
+ z = self(text)
203
+ if z.ndim==2:
204
+ z = z[:, None, :]
205
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
206
+ return z
207
+
208
+ class FrozenCLIPImageEmbedder(AbstractEncoder):
209
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
210
+ def __init__(self, version="openai/clip-vit-large-patch14"):
211
+ super().__init__()
212
+ self.transformer = CLIPVisionModel.from_pretrained(version)
213
+ self.final_ln = LayerNorm(1024)
214
+ self.mapper = Transformer(
215
+ 1,
216
+ 1024,
217
+ 5,
218
+ 1,
219
+ )
220
+
221
+ self.freeze()
222
+
223
+ def freeze(self):
224
+ self.transformer = self.transformer.eval()
225
+ for param in self.parameters():
226
+ param.requires_grad = False
227
+ for param in self.mapper.parameters():
228
+ param.requires_grad = True
229
+ for param in self.final_ln.parameters():
230
+ param.requires_grad = True
231
+
232
+ def forward(self, image):
233
+ outputs = self.transformer(pixel_values=image)
234
+ z = outputs.pooler_output
235
+ z = z.unsqueeze(1)
236
+ z = self.mapper(z)
237
+ z = self.final_ln(z)
238
+ return z
239
+
240
+ def encode(self, image):
241
+ return self(image)
242
+
243
+
244
+
245
+ class DINOEmbedder(AbstractEncoder):
246
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
247
+ def __init__(self, dino_version): # small, large, huge, gigantic
248
+ super().__init__()
249
+ assert dino_version in ['small', 'big', 'large', 'huge']
250
+ letter_map = {
251
+ 'small': 's',
252
+ 'big': 'b',
253
+ 'large': 'l',
254
+ 'huge': 'g'
255
+ }
256
+
257
+ self.final_ln = LayerNorm(32) # unused -- remove later
258
+ self.mapper = LayerNorm(32) # unused -- remove later
259
+ # embedding_sizes = {
260
+ # 'small': 384,
261
+ # 'big': 768,
262
+ # 'large': 1024,
263
+ # 'huge': 1536
264
+ # }
265
+
266
+ # embedding_size = embedding_sizes[dino_version]
267
+ letter = letter_map[dino_version]
268
+ # self.transformer = CLIPVisionModel.from_pretrained(version)
269
+ self.dino_model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{letter}14_reg').cuda()
270
+
271
+
272
+ self.freeze()
273
+
274
+ def freeze(self):
275
+ for param in self.parameters():
276
+ param.requires_grad = False
277
+
278
+ def forward(self, image):
279
+ with torch.no_grad():
280
+ outputs = self.dino_model.forward_features(image)
281
+ patch_tokens = outputs['x_norm_patchtokens']
282
+ global_token = outputs['x_norm_clstoken'].unsqueeze(1)
283
+ features = torch.concat([patch_tokens, global_token], dim=1)
284
+ return torch.zeros_like(features)
285
+
286
+ def encode(self, image):
287
+ return self(image)
288
+
289
+
290
+ class FixedVector(AbstractEncoder):
291
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
292
+ def __init__(self): # small, large, huge, gigantic
293
+ super().__init__()
294
+ self.final_ln = LayerNorm(32)
295
+ self.mapper = LayerNorm(32)
296
+ self.fixed_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True).cuda()
297
+ def forward(self, image):
298
+ return self.fixed_vector.repeat(image.shape[0],1,1).to(image.device) * 0.0
299
+
300
+ def encode(self, image):
301
+ return self(image)
302
+
303
+
304
+
305
+
306
+ if __name__ == "__main__":
307
+ from ldm.util import count_params
308
+ model = FrozenCLIPEmbedder()
309
+ count_params(model, verbose=True)
ldm/modules/encoders/xf.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """
15
+ Transformer implementation adapted from CLIP ViT:
16
+ https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
17
+ """
18
+
19
+ import math
20
+
21
+ import torch as th
22
+ import torch.nn as nn
23
+
24
+
25
+ def convert_module_to_f16(l):
26
+ """
27
+ Convert primitive modules to float16.
28
+ """
29
+ if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
30
+ l.weight.data = l.weight.data.half()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.half()
33
+
34
+
35
+ class LayerNorm(nn.LayerNorm):
36
+ """
37
+ Implementation that supports fp16 inputs but fp32 gains/biases.
38
+ """
39
+
40
+ def forward(self, x: th.Tensor):
41
+ return super().forward(x.float()).to(x.dtype)
42
+
43
+
44
+ class MultiheadAttention(nn.Module):
45
+ def __init__(self, n_ctx, width, heads):
46
+ super().__init__()
47
+ self.n_ctx = n_ctx
48
+ self.width = width
49
+ self.heads = heads
50
+ self.c_qkv = nn.Linear(width, width * 3)
51
+ self.c_proj = nn.Linear(width, width)
52
+ self.attention = QKVMultiheadAttention(heads, n_ctx)
53
+
54
+ def forward(self, x):
55
+ x = self.c_qkv(x)
56
+ x = self.attention(x)
57
+ x = self.c_proj(x)
58
+ return x
59
+
60
+
61
+ class MLP(nn.Module):
62
+ def __init__(self, width):
63
+ super().__init__()
64
+ self.width = width
65
+ self.c_fc = nn.Linear(width, width * 4)
66
+ self.c_proj = nn.Linear(width * 4, width)
67
+ self.gelu = nn.GELU()
68
+
69
+ def forward(self, x):
70
+ return self.c_proj(self.gelu(self.c_fc(x)))
71
+
72
+
73
+ class QKVMultiheadAttention(nn.Module):
74
+ def __init__(self, n_heads: int, n_ctx: int):
75
+ super().__init__()
76
+ self.n_heads = n_heads
77
+ self.n_ctx = n_ctx
78
+
79
+ def forward(self, qkv):
80
+ bs, n_ctx, width = qkv.shape
81
+ attn_ch = width // self.n_heads // 3
82
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
83
+ qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
84
+ q, k, v = th.split(qkv, attn_ch, dim=-1)
85
+ weight = th.einsum(
86
+ "bthc,bshc->bhts", q * scale, k * scale
87
+ ) # More stable with f16 than dividing afterwards
88
+ wdtype = weight.dtype
89
+ weight = th.softmax(weight.float(), dim=-1).type(wdtype)
90
+ return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
91
+
92
+
93
+ class ResidualAttentionBlock(nn.Module):
94
+ def __init__(
95
+ self,
96
+ n_ctx: int,
97
+ width: int,
98
+ heads: int,
99
+ ):
100
+ super().__init__()
101
+
102
+ self.attn = MultiheadAttention(
103
+ n_ctx,
104
+ width,
105
+ heads,
106
+ )
107
+ self.ln_1 = LayerNorm(width)
108
+ self.mlp = MLP(width)
109
+ self.ln_2 = LayerNorm(width)
110
+
111
+ def forward(self, x: th.Tensor):
112
+ x = x + self.attn(self.ln_1(x))
113
+ x = x + self.mlp(self.ln_2(x))
114
+ return x
115
+
116
+
117
+ class Transformer(nn.Module):
118
+ def __init__(
119
+ self,
120
+ n_ctx: int,
121
+ width: int,
122
+ layers: int,
123
+ heads: int,
124
+ ):
125
+ super().__init__()
126
+ self.n_ctx = n_ctx
127
+ self.width = width
128
+ self.layers = layers
129
+ self.resblocks = nn.ModuleList(
130
+ [
131
+ ResidualAttentionBlock(
132
+ n_ctx,
133
+ width,
134
+ heads,
135
+ )
136
+ for _ in range(layers)
137
+ ]
138
+ )
139
+
140
+ def forward(self, x: th.Tensor):
141
+ for block in self.resblocks:
142
+ x = block(x)
143
+ return x
ldm/modules/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
ldm/modules/losses/contperceptual.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
18
+
19
+
20
+ class LPIPSWithDiscriminator(nn.Module):
21
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
22
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
23
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
24
+ disc_loss="hinge"):
25
+
26
+ super().__init__()
27
+ assert disc_loss in ["hinge", "vanilla"]
28
+ self.kl_weight = kl_weight
29
+ self.pixel_weight = pixelloss_weight
30
+ self.perceptual_loss = LPIPS().eval()
31
+ self.perceptual_weight = perceptual_weight
32
+ # output log variance
33
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
34
+
35
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
36
+ n_layers=disc_num_layers,
37
+ use_actnorm=use_actnorm
38
+ ).apply(weights_init)
39
+ self.discriminator_iter_start = disc_start
40
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
41
+ self.disc_factor = disc_factor
42
+ self.discriminator_weight = disc_weight
43
+ self.disc_conditional = disc_conditional
44
+
45
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
46
+ if last_layer is not None:
47
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
48
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
49
+ else:
50
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
51
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
52
+
53
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
54
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
55
+ d_weight = d_weight * self.discriminator_weight
56
+ return d_weight
57
+
58
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
59
+ global_step, last_layer=None, cond=None, split="train",
60
+ weights=None):
61
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
62
+ if self.perceptual_weight > 0:
63
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
64
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
65
+
66
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
67
+ weighted_nll_loss = nll_loss
68
+ if weights is not None:
69
+ weighted_nll_loss = weights*nll_loss
70
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
71
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
72
+ kl_loss = posteriors.kl()
73
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
74
+
75
+ # now the GAN part
76
+ if optimizer_idx == 0:
77
+ # generator update
78
+ if cond is None:
79
+ assert not self.disc_conditional
80
+ logits_fake = self.discriminator(reconstructions.contiguous())
81
+ else:
82
+ assert self.disc_conditional
83
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
84
+ g_loss = -torch.mean(logits_fake)
85
+
86
+ if self.disc_factor > 0.0:
87
+ try:
88
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
89
+ except RuntimeError:
90
+ assert not self.training
91
+ d_weight = torch.tensor(0.0)
92
+ else:
93
+ d_weight = torch.tensor(0.0)
94
+
95
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
96
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
97
+
98
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
99
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
100
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
101
+ "{}/d_weight".format(split): d_weight.detach(),
102
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
103
+ "{}/g_loss".format(split): g_loss.detach().mean(),
104
+ }
105
+ return loss, log
106
+
107
+ if optimizer_idx == 1:
108
+ # second pass for discriminator update
109
+ if cond is None:
110
+ logits_real = self.discriminator(inputs.contiguous().detach())
111
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
112
+ else:
113
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
114
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
115
+
116
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
117
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
118
+
119
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
120
+ "{}/logits_real".format(split): logits_real.detach().mean(),
121
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
122
+ }
123
+ return d_loss, log
124
+
ldm/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from einops import repeat
18
+
19
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
20
+ from taming.modules.losses.lpips import LPIPS
21
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
22
+
23
+
24
+ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
25
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
26
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
27
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
28
+ loss_real = (weights * loss_real).sum() / weights.sum()
29
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
30
+ d_loss = 0.5 * (loss_real + loss_fake)
31
+ return d_loss
32
+
33
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
34
+ if global_step < threshold:
35
+ weight = value
36
+ return weight
37
+
38
+
39
+ def measure_perplexity(predicted_indices, n_embed):
40
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
41
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
42
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
43
+ avg_probs = encodings.mean(0)
44
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
45
+ cluster_use = torch.sum(avg_probs > 0)
46
+ return perplexity, cluster_use
47
+
48
+ def l1(x, y):
49
+ return torch.abs(x-y)
50
+
51
+
52
+ def l2(x, y):
53
+ return torch.pow((x-y), 2)
54
+
55
+
56
+ class VQLPIPSWithDiscriminator(nn.Module):
57
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
58
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
59
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
60
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
61
+ pixel_loss="l1"):
62
+ super().__init__()
63
+ assert disc_loss in ["hinge", "vanilla"]
64
+ assert perceptual_loss in ["lpips", "clips", "dists"]
65
+ assert pixel_loss in ["l1", "l2"]
66
+ self.codebook_weight = codebook_weight
67
+ self.pixel_weight = pixelloss_weight
68
+ if perceptual_loss == "lpips":
69
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
70
+ self.perceptual_loss = LPIPS().eval()
71
+ else:
72
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
73
+ self.perceptual_weight = perceptual_weight
74
+
75
+ if pixel_loss == "l1":
76
+ self.pixel_loss = l1
77
+ else:
78
+ self.pixel_loss = l2
79
+
80
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
81
+ n_layers=disc_num_layers,
82
+ use_actnorm=use_actnorm,
83
+ ndf=disc_ndf
84
+ ).apply(weights_init)
85
+ self.discriminator_iter_start = disc_start
86
+ if disc_loss == "hinge":
87
+ self.disc_loss = hinge_d_loss
88
+ elif disc_loss == "vanilla":
89
+ self.disc_loss = vanilla_d_loss
90
+ else:
91
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
92
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
93
+ self.disc_factor = disc_factor
94
+ self.discriminator_weight = disc_weight
95
+ self.disc_conditional = disc_conditional
96
+ self.n_classes = n_classes
97
+
98
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
99
+ if last_layer is not None:
100
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
101
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
102
+ else:
103
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
104
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
105
+
106
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
107
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
108
+ d_weight = d_weight * self.discriminator_weight
109
+ return d_weight
110
+
111
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
112
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
113
+ if not exists(codebook_loss):
114
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
115
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
116
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
117
+ if self.perceptual_weight > 0:
118
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
119
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
120
+ else:
121
+ p_loss = torch.tensor([0.0])
122
+
123
+ nll_loss = rec_loss
124
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
125
+ nll_loss = torch.mean(nll_loss)
126
+
127
+ # now the GAN part
128
+ if optimizer_idx == 0:
129
+ # generator update
130
+ if cond is None:
131
+ assert not self.disc_conditional
132
+ logits_fake = self.discriminator(reconstructions.contiguous())
133
+ else:
134
+ assert self.disc_conditional
135
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
136
+ g_loss = -torch.mean(logits_fake)
137
+
138
+ try:
139
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
140
+ except RuntimeError:
141
+ assert not self.training
142
+ d_weight = torch.tensor(0.0)
143
+
144
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
145
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
146
+
147
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
148
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
149
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
150
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
151
+ "{}/p_loss".format(split): p_loss.detach().mean(),
152
+ "{}/d_weight".format(split): d_weight.detach(),
153
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
154
+ "{}/g_loss".format(split): g_loss.detach().mean(),
155
+ }
156
+ if predicted_indices is not None:
157
+ assert self.n_classes is not None
158
+ with torch.no_grad():
159
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
160
+ log[f"{split}/perplexity"] = perplexity
161
+ log[f"{split}/cluster_usage"] = cluster_usage
162
+ return loss, log
163
+
164
+ if optimizer_idx == 1:
165
+ # second pass for discriminator update
166
+ if cond is None:
167
+ logits_real = self.discriminator(inputs.contiguous().detach())
168
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
169
+ else:
170
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
171
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
172
+
173
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
174
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
175
+
176
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
177
+ "{}/logits_real".format(split): logits_real.detach().mean(),
178
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
179
+ }
180
+ return d_loss, log
ldm/modules/x_transformer.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
15
+ import torch
16
+ from torch import nn, einsum
17
+ import torch.nn.functional as F
18
+ from functools import partial
19
+ from inspect import isfunction
20
+ from collections import namedtuple
21
+ from einops import rearrange, repeat, reduce
22
+
23
+ # constants
24
+
25
+ DEFAULT_DIM_HEAD = 64
26
+
27
+ Intermediates = namedtuple('Intermediates', [
28
+ 'pre_softmax_attn',
29
+ 'post_softmax_attn'
30
+ ])
31
+
32
+ LayerIntermediates = namedtuple('Intermediates', [
33
+ 'hiddens',
34
+ 'attn_intermediates'
35
+ ])
36
+
37
+
38
+ class AbsolutePositionalEmbedding(nn.Module):
39
+ def __init__(self, dim, max_seq_len):
40
+ super().__init__()
41
+ self.emb = nn.Embedding(max_seq_len, dim)
42
+ self.init_()
43
+
44
+ def init_(self):
45
+ nn.init.normal_(self.emb.weight, std=0.02)
46
+
47
+ def forward(self, x):
48
+ n = torch.arange(x.shape[1], device=x.device)
49
+ return self.emb(n)[None, :, :]
50
+
51
+
52
+ class FixedPositionalEmbedding(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
56
+ self.register_buffer('inv_freq', inv_freq)
57
+
58
+ def forward(self, x, seq_dim=1, offset=0):
59
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
60
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
61
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
62
+ return emb[None, :, :]
63
+
64
+
65
+ # helpers
66
+
67
+ def exists(val):
68
+ return val is not None
69
+
70
+
71
+ def default(val, d):
72
+ if exists(val):
73
+ return val
74
+ return d() if isfunction(d) else d
75
+
76
+
77
+ def always(val):
78
+ def inner(*args, **kwargs):
79
+ return val
80
+ return inner
81
+
82
+
83
+ def not_equals(val):
84
+ def inner(x):
85
+ return x != val
86
+ return inner
87
+
88
+
89
+ def equals(val):
90
+ def inner(x):
91
+ return x == val
92
+ return inner
93
+
94
+
95
+ def max_neg_value(tensor):
96
+ return -torch.finfo(tensor.dtype).max
97
+
98
+
99
+ # keyword argument helpers
100
+
101
+ def pick_and_pop(keys, d):
102
+ values = list(map(lambda key: d.pop(key), keys))
103
+ return dict(zip(keys, values))
104
+
105
+
106
+ def group_dict_by_key(cond, d):
107
+ return_val = [dict(), dict()]
108
+ for key in d.keys():
109
+ match = bool(cond(key))
110
+ ind = int(not match)
111
+ return_val[ind][key] = d[key]
112
+ return (*return_val,)
113
+
114
+
115
+ def string_begins_with(prefix, str):
116
+ return str.startswith(prefix)
117
+
118
+
119
+ def group_by_key_prefix(prefix, d):
120
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
121
+
122
+
123
+ def groupby_prefix_and_trim(prefix, d):
124
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
125
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
126
+ return kwargs_without_prefix, kwargs
127
+
128
+
129
+ # classes
130
+ class Scale(nn.Module):
131
+ def __init__(self, value, fn):
132
+ super().__init__()
133
+ self.value = value
134
+ self.fn = fn
135
+
136
+ def forward(self, x, **kwargs):
137
+ x, *rest = self.fn(x, **kwargs)
138
+ return (x * self.value, *rest)
139
+
140
+
141
+ class Rezero(nn.Module):
142
+ def __init__(self, fn):
143
+ super().__init__()
144
+ self.fn = fn
145
+ self.g = nn.Parameter(torch.zeros(1))
146
+
147
+ def forward(self, x, **kwargs):
148
+ x, *rest = self.fn(x, **kwargs)
149
+ return (x * self.g, *rest)
150
+
151
+
152
+ class ScaleNorm(nn.Module):
153
+ def __init__(self, dim, eps=1e-5):
154
+ super().__init__()
155
+ self.scale = dim ** -0.5
156
+ self.eps = eps
157
+ self.g = nn.Parameter(torch.ones(1))
158
+
159
+ def forward(self, x):
160
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
161
+ return x / norm.clamp(min=self.eps) * self.g
162
+
163
+
164
+ class RMSNorm(nn.Module):
165
+ def __init__(self, dim, eps=1e-8):
166
+ super().__init__()
167
+ self.scale = dim ** -0.5
168
+ self.eps = eps
169
+ self.g = nn.Parameter(torch.ones(dim))
170
+
171
+ def forward(self, x):
172
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
173
+ return x / norm.clamp(min=self.eps) * self.g
174
+
175
+
176
+ class Residual(nn.Module):
177
+ def forward(self, x, residual):
178
+ return x + residual
179
+
180
+
181
+ class GRUGating(nn.Module):
182
+ def __init__(self, dim):
183
+ super().__init__()
184
+ self.gru = nn.GRUCell(dim, dim)
185
+
186
+ def forward(self, x, residual):
187
+ gated_output = self.gru(
188
+ rearrange(x, 'b n d -> (b n) d'),
189
+ rearrange(residual, 'b n d -> (b n) d')
190
+ )
191
+
192
+ return gated_output.reshape_as(x)
193
+
194
+
195
+ # feedforward
196
+
197
+ class GEGLU(nn.Module):
198
+ def __init__(self, dim_in, dim_out):
199
+ super().__init__()
200
+ self.proj = nn.Linear(dim_in, dim_out * 2)
201
+
202
+ def forward(self, x):
203
+ x, gate = self.proj(x).chunk(2, dim=-1)
204
+ return x * F.gelu(gate)
205
+
206
+
207
+ class FeedForward(nn.Module):
208
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
209
+ super().__init__()
210
+ inner_dim = int(dim * mult)
211
+ dim_out = default(dim_out, dim)
212
+ project_in = nn.Sequential(
213
+ nn.Linear(dim, inner_dim),
214
+ nn.GELU()
215
+ ) if not glu else GEGLU(dim, inner_dim)
216
+
217
+ self.net = nn.Sequential(
218
+ project_in,
219
+ nn.Dropout(dropout),
220
+ nn.Linear(inner_dim, dim_out)
221
+ )
222
+
223
+ def forward(self, x):
224
+ return self.net(x)
225
+
226
+
227
+ # attention.
228
+ class Attention(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim,
232
+ dim_head=DEFAULT_DIM_HEAD,
233
+ heads=8,
234
+ causal=False,
235
+ mask=None,
236
+ talking_heads=False,
237
+ sparse_topk=None,
238
+ use_entmax15=False,
239
+ num_mem_kv=0,
240
+ dropout=0.,
241
+ on_attn=False
242
+ ):
243
+ super().__init__()
244
+ if use_entmax15:
245
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
246
+ self.scale = dim_head ** -0.5
247
+ self.heads = heads
248
+ self.causal = causal
249
+ self.mask = mask
250
+
251
+ inner_dim = dim_head * heads
252
+
253
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
254
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
255
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
256
+ self.dropout = nn.Dropout(dropout)
257
+
258
+ # talking heads
259
+ self.talking_heads = talking_heads
260
+ if talking_heads:
261
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
262
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
263
+
264
+ # explicit topk sparse attention
265
+ self.sparse_topk = sparse_topk
266
+
267
+ # entmax
268
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
269
+ self.attn_fn = F.softmax
270
+
271
+ # add memory key / values
272
+ self.num_mem_kv = num_mem_kv
273
+ if num_mem_kv > 0:
274
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
275
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
276
+
277
+ # attention on attention
278
+ self.attn_on_attn = on_attn
279
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
280
+
281
+ def forward(
282
+ self,
283
+ x,
284
+ context=None,
285
+ mask=None,
286
+ context_mask=None,
287
+ rel_pos=None,
288
+ sinusoidal_emb=None,
289
+ prev_attn=None,
290
+ mem=None
291
+ ):
292
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
293
+ kv_input = default(context, x)
294
+
295
+ q_input = x
296
+ k_input = kv_input
297
+ v_input = kv_input
298
+
299
+ if exists(mem):
300
+ k_input = torch.cat((mem, k_input), dim=-2)
301
+ v_input = torch.cat((mem, v_input), dim=-2)
302
+
303
+ if exists(sinusoidal_emb):
304
+ # in shortformer, the query would start at a position offset depending on the past cached memory
305
+ offset = k_input.shape[-2] - q_input.shape[-2]
306
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
307
+ k_input = k_input + sinusoidal_emb(k_input)
308
+
309
+ q = self.to_q(q_input)
310
+ k = self.to_k(k_input)
311
+ v = self.to_v(v_input)
312
+
313
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
314
+
315
+ input_mask = None
316
+ if any(map(exists, (mask, context_mask))):
317
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
318
+ k_mask = q_mask if not exists(context) else context_mask
319
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
320
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
321
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
322
+ input_mask = q_mask * k_mask
323
+
324
+ if self.num_mem_kv > 0:
325
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
326
+ k = torch.cat((mem_k, k), dim=-2)
327
+ v = torch.cat((mem_v, v), dim=-2)
328
+ if exists(input_mask):
329
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
330
+
331
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
332
+ mask_value = max_neg_value(dots)
333
+
334
+ if exists(prev_attn):
335
+ dots = dots + prev_attn
336
+
337
+ pre_softmax_attn = dots
338
+
339
+ if talking_heads:
340
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
341
+
342
+ if exists(rel_pos):
343
+ dots = rel_pos(dots)
344
+
345
+ if exists(input_mask):
346
+ dots.masked_fill_(~input_mask, mask_value)
347
+ del input_mask
348
+
349
+ if self.causal:
350
+ i, j = dots.shape[-2:]
351
+ r = torch.arange(i, device=device)
352
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
353
+ mask = F.pad(mask, (j - i, 0), value=False)
354
+ dots.masked_fill_(mask, mask_value)
355
+ del mask
356
+
357
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
358
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
359
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
360
+ mask = dots < vk
361
+ dots.masked_fill_(mask, mask_value)
362
+ del mask
363
+
364
+ attn = self.attn_fn(dots, dim=-1)
365
+ post_softmax_attn = attn
366
+
367
+ attn = self.dropout(attn)
368
+
369
+ if talking_heads:
370
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
371
+
372
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
373
+ out = rearrange(out, 'b h n d -> b n (h d)')
374
+
375
+ intermediates = Intermediates(
376
+ pre_softmax_attn=pre_softmax_attn,
377
+ post_softmax_attn=post_softmax_attn
378
+ )
379
+
380
+ return self.to_out(out), intermediates
381
+
382
+
383
+ class AttentionLayers(nn.Module):
384
+ def __init__(
385
+ self,
386
+ dim,
387
+ depth,
388
+ heads=8,
389
+ causal=False,
390
+ cross_attend=False,
391
+ only_cross=False,
392
+ use_scalenorm=False,
393
+ use_rmsnorm=False,
394
+ use_rezero=False,
395
+ rel_pos_num_buckets=32,
396
+ rel_pos_max_distance=128,
397
+ position_infused_attn=False,
398
+ custom_layers=None,
399
+ sandwich_coef=None,
400
+ par_ratio=None,
401
+ residual_attn=False,
402
+ cross_residual_attn=False,
403
+ macaron=False,
404
+ pre_norm=True,
405
+ gate_residual=False,
406
+ **kwargs
407
+ ):
408
+ super().__init__()
409
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
410
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
411
+
412
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
413
+
414
+ self.dim = dim
415
+ self.depth = depth
416
+ self.layers = nn.ModuleList([])
417
+
418
+ self.has_pos_emb = position_infused_attn
419
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
420
+ self.rotary_pos_emb = always(None)
421
+
422
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
423
+ self.rel_pos = None
424
+
425
+ self.pre_norm = pre_norm
426
+
427
+ self.residual_attn = residual_attn
428
+ self.cross_residual_attn = cross_residual_attn
429
+
430
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
431
+ norm_class = RMSNorm if use_rmsnorm else norm_class
432
+ norm_fn = partial(norm_class, dim)
433
+
434
+ norm_fn = nn.Identity if use_rezero else norm_fn
435
+ branch_fn = Rezero if use_rezero else None
436
+
437
+ if cross_attend and not only_cross:
438
+ default_block = ('a', 'c', 'f')
439
+ elif cross_attend and only_cross:
440
+ default_block = ('c', 'f')
441
+ else:
442
+ default_block = ('a', 'f')
443
+
444
+ if macaron:
445
+ default_block = ('f',) + default_block
446
+
447
+ if exists(custom_layers):
448
+ layer_types = custom_layers
449
+ elif exists(par_ratio):
450
+ par_depth = depth * len(default_block)
451
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
452
+ default_block = tuple(filter(not_equals('f'), default_block))
453
+ par_attn = par_depth // par_ratio
454
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
455
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
456
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
457
+ par_block = default_block + ('f',) * (par_width - len(default_block))
458
+ par_head = par_block * par_attn
459
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
460
+ elif exists(sandwich_coef):
461
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
462
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
463
+ else:
464
+ layer_types = default_block * depth
465
+
466
+ self.layer_types = layer_types
467
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
468
+
469
+ for layer_type in self.layer_types:
470
+ if layer_type == 'a':
471
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
472
+ elif layer_type == 'c':
473
+ layer = Attention(dim, heads=heads, **attn_kwargs)
474
+ elif layer_type == 'f':
475
+ layer = FeedForward(dim, **ff_kwargs)
476
+ layer = layer if not macaron else Scale(0.5, layer)
477
+ else:
478
+ raise Exception(f'invalid layer type {layer_type}')
479
+
480
+ if isinstance(layer, Attention) and exists(branch_fn):
481
+ layer = branch_fn(layer)
482
+
483
+ if gate_residual:
484
+ residual_fn = GRUGating(dim)
485
+ else:
486
+ residual_fn = Residual()
487
+
488
+ self.layers.append(nn.ModuleList([
489
+ norm_fn(),
490
+ layer,
491
+ residual_fn
492
+ ]))
493
+
494
+ def forward(
495
+ self,
496
+ x,
497
+ context=None,
498
+ mask=None,
499
+ context_mask=None,
500
+ mems=None,
501
+ return_hiddens=False
502
+ ):
503
+ hiddens = []
504
+ intermediates = []
505
+ prev_attn = None
506
+ prev_cross_attn = None
507
+
508
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
509
+
510
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
511
+ is_last = ind == (len(self.layers) - 1)
512
+
513
+ if layer_type == 'a':
514
+ hiddens.append(x)
515
+ layer_mem = mems.pop(0)
516
+
517
+ residual = x
518
+
519
+ if self.pre_norm:
520
+ x = norm(x)
521
+
522
+ if layer_type == 'a':
523
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
524
+ prev_attn=prev_attn, mem=layer_mem)
525
+ elif layer_type == 'c':
526
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
527
+ elif layer_type == 'f':
528
+ out = block(x)
529
+
530
+ x = residual_fn(out, residual)
531
+
532
+ if layer_type in ('a', 'c'):
533
+ intermediates.append(inter)
534
+
535
+ if layer_type == 'a' and self.residual_attn:
536
+ prev_attn = inter.pre_softmax_attn
537
+ elif layer_type == 'c' and self.cross_residual_attn:
538
+ prev_cross_attn = inter.pre_softmax_attn
539
+
540
+ if not self.pre_norm and not is_last:
541
+ x = norm(x)
542
+
543
+ if return_hiddens:
544
+ intermediates = LayerIntermediates(
545
+ hiddens=hiddens,
546
+ attn_intermediates=intermediates
547
+ )
548
+
549
+ return x, intermediates
550
+
551
+ return x
552
+
553
+
554
+ class Encoder(AttentionLayers):
555
+ def __init__(self, **kwargs):
556
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
557
+ super().__init__(causal=False, **kwargs)
558
+
559
+
560
+
561
+ class TransformerWrapper(nn.Module):
562
+ def __init__(
563
+ self,
564
+ *,
565
+ num_tokens,
566
+ max_seq_len,
567
+ attn_layers,
568
+ emb_dim=None,
569
+ max_mem_len=0.,
570
+ emb_dropout=0.,
571
+ num_memory_tokens=None,
572
+ tie_embedding=False,
573
+ use_pos_emb=True
574
+ ):
575
+ super().__init__()
576
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
577
+
578
+ dim = attn_layers.dim
579
+ emb_dim = default(emb_dim, dim)
580
+
581
+ self.max_seq_len = max_seq_len
582
+ self.max_mem_len = max_mem_len
583
+ self.num_tokens = num_tokens
584
+
585
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
586
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
587
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
588
+ self.emb_dropout = nn.Dropout(emb_dropout)
589
+
590
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
591
+ self.attn_layers = attn_layers
592
+ self.norm = nn.LayerNorm(dim)
593
+
594
+ self.init_()
595
+
596
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
597
+
598
+ # memory tokens (like [cls]) from Memory Transformers paper
599
+ num_memory_tokens = default(num_memory_tokens, 0)
600
+ self.num_memory_tokens = num_memory_tokens
601
+ if num_memory_tokens > 0:
602
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
603
+
604
+ # let funnel encoder know number of memory tokens, if specified
605
+ if hasattr(attn_layers, 'num_memory_tokens'):
606
+ attn_layers.num_memory_tokens = num_memory_tokens
607
+
608
+ def init_(self):
609
+ nn.init.normal_(self.token_emb.weight, std=0.02)
610
+
611
+ def forward(
612
+ self,
613
+ x,
614
+ return_embeddings=False,
615
+ mask=None,
616
+ return_mems=False,
617
+ return_attn=False,
618
+ mems=None,
619
+ **kwargs
620
+ ):
621
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
622
+ x = self.token_emb(x)
623
+ x += self.pos_emb(x)
624
+ x = self.emb_dropout(x)
625
+
626
+ x = self.project_emb(x)
627
+
628
+ if num_mem > 0:
629
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
630
+ x = torch.cat((mem, x), dim=1)
631
+
632
+ # auto-handle masking after appending memory tokens
633
+ if exists(mask):
634
+ mask = F.pad(mask, (num_mem, 0), value=True)
635
+
636
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
637
+ x = self.norm(x)
638
+
639
+ mem, x = x[:, :num_mem], x[:, num_mem:]
640
+
641
+ out = self.to_logits(x) if not return_embeddings else x
642
+
643
+ if return_mems:
644
+ hiddens = intermediates.hiddens
645
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
646
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
647
+ return out, new_mems
648
+
649
+ if return_attn:
650
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
651
+ return out, attn_maps
652
+
653
+ return out
654
+
ldm/util.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
2
+ # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
3
+ # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
4
+ # CreativeML Open RAIL-M
5
+ #
6
+ # ==========================================================================================
7
+ #
8
+ # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
9
+ # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ # LICENSE.md.
11
+ #
12
+ # ==========================================================================================
13
+
14
+ import importlib
15
+
16
+ import torch
17
+ import numpy as np
18
+ from collections import abc
19
+ from einops import rearrange
20
+ from functools import partial
21
+
22
+ import multiprocessing as mp
23
+ from threading import Thread
24
+ from queue import Queue
25
+
26
+ from inspect import isfunction
27
+ from PIL import Image, ImageDraw, ImageFont
28
+
29
+
30
+ def log_txt_as_img(wh, xc, size=10):
31
+ # wh a tuple of (width, height)
32
+ # xc a list of captions to plot
33
+ b = len(xc)
34
+ txts = list()
35
+ for bi in range(b):
36
+ txt = Image.new("RGB", wh, color="white")
37
+ draw = ImageDraw.Draw(txt)
38
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
39
+ nc = int(40 * (wh[0] / 256))
40
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
41
+
42
+ try:
43
+ draw.text((0, 0), lines, fill="black", font=font)
44
+ except UnicodeEncodeError:
45
+ print("Cant encode string for logging. Skipping.")
46
+
47
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
48
+ txts.append(txt)
49
+ txts = np.stack(txts)
50
+ txts = torch.tensor(txts)
51
+ return txts
52
+
53
+
54
+ def ismap(x):
55
+ if not isinstance(x, torch.Tensor):
56
+ return False
57
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
58
+
59
+
60
+ def isimage(x):
61
+ if not isinstance(x, torch.Tensor):
62
+ return False
63
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
64
+
65
+
66
+ def exists(x):
67
+ return x is not None
68
+
69
+
70
+ def default(val, d):
71
+ if exists(val):
72
+ return val
73
+ return d() if isfunction(d) else d
74
+
75
+
76
+ def mean_flat(tensor):
77
+ """
78
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
79
+ Take the mean over all non-batch dimensions.
80
+ """
81
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
82
+
83
+
84
+ def count_params(model, verbose=False):
85
+ total_params = sum(p.numel() for p in model.parameters())
86
+ if verbose:
87
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
88
+ return total_params
89
+
90
+
91
+ def instantiate_from_config(config):
92
+ if not "target" in config:
93
+ if config == '__is_first_stage__':
94
+ return None
95
+ elif config == "__is_unconditional__":
96
+ return None
97
+ raise KeyError("Expected key `target` to instantiate.")
98
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
99
+
100
+
101
+ def get_obj_from_str(string, reload=False):
102
+ module, cls = string.rsplit(".", 1)
103
+ if reload:
104
+ module_imp = importlib.import_module(module)
105
+ importlib.reload(module_imp)
106
+ return getattr(importlib.import_module(module, package=None), cls)
107
+
108
+
109
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
110
+ # create dummy dataset instance
111
+
112
+ # run prefetching
113
+ if idx_to_fn:
114
+ res = func(data, worker_id=idx)
115
+ else:
116
+ res = func(data)
117
+ Q.put([idx, res])
118
+ Q.put("Done")
119
+
120
+
121
+ def parallel_data_prefetch(
122
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
123
+ ):
124
+ # if target_data_type not in ["ndarray", "list"]:
125
+ # raise ValueError(
126
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
127
+ # )
128
+ if isinstance(data, np.ndarray) and target_data_type == "list":
129
+ raise ValueError("list expected but function got ndarray.")
130
+ elif isinstance(data, abc.Iterable):
131
+ if isinstance(data, dict):
132
+ print(
133
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
134
+ )
135
+ data = list(data.values())
136
+ if target_data_type == "ndarray":
137
+ data = np.asarray(data)
138
+ else:
139
+ data = list(data)
140
+ else:
141
+ raise TypeError(
142
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
143
+ )
144
+
145
+ if cpu_intensive:
146
+ Q = mp.Queue(1000)
147
+ proc = mp.Process
148
+ else:
149
+ Q = Queue(1000)
150
+ proc = Thread
151
+ # spawn processes
152
+ if target_data_type == "ndarray":
153
+ arguments = [
154
+ [func, Q, part, i, use_worker_id]
155
+ for i, part in enumerate(np.array_split(data, n_proc))
156
+ ]
157
+ else:
158
+ step = (
159
+ int(len(data) / n_proc + 1)
160
+ if len(data) % n_proc != 0
161
+ else int(len(data) / n_proc)
162
+ )
163
+ arguments = [
164
+ [func, Q, part, i, use_worker_id]
165
+ for i, part in enumerate(
166
+ [data[i: i + step] for i in range(0, len(data), step)]
167
+ )
168
+ ]
169
+ processes = []
170
+ for i in range(n_proc):
171
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
172
+ processes += [p]
173
+
174
+ # start processes
175
+ print(f"Start prefetching...")
176
+ import time
177
+
178
+ start = time.time()
179
+ gather_res = [[] for _ in range(n_proc)]
180
+ try:
181
+ for p in processes:
182
+ p.start()
183
+
184
+ k = 0
185
+ while k < n_proc:
186
+ # get result
187
+ res = Q.get()
188
+ if res == "Done":
189
+ k += 1
190
+ else:
191
+ gather_res[res[0]] = res[1]
192
+
193
+ except Exception as e:
194
+ print("Exception: ", e)
195
+ for p in processes:
196
+ p.terminate()
197
+
198
+ raise e
199
+ finally:
200
+ for p in processes:
201
+ p.join()
202
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
203
+
204
+ if target_data_type == 'ndarray':
205
+ if not isinstance(gather_res[0], np.ndarray):
206
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
207
+
208
+ # order outputs
209
+ return np.concatenate(gather_res, axis=0)
210
+ elif target_data_type == 'list':
211
+ out = []
212
+ for r in gather_res:
213
+ out.extend(r)
214
+ return out
215
+ else:
216
+ return gather_res
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ torchvision
3
+ torchaudio
4
+
5
+ diffusers
6
+ transformers
7
+ accelerate
8
+ # albumentations==0.4.3
9
+ gradio
10
+ opencv-python-headless==4.9.0.80
11
+ huggingface_hub
12
+ pudb==2019.2
13
+ invisible-watermark
14
+ imageio==2.9.0
15
+ imageio-ffmpeg==0.4.2
16
+ # pytorch-lightning==2.0.0
17
+ omegaconf==2.1.1
18
+ test-tube>=0.7.5
19
+ einops==0.3.0
20
+ torch-fidelity==0.3.0
21
+ torchmetrics==0.7.0
22
+ kornia==0.6
23
+
24
+ # git+https://github.com/CompVis/taming-transformers.git@master
25
+ # git+https://github.com/openai/CLIP.git@main
26
+ taming-transformers-rom1504
27
+ git+https://github.com/openai/CLIP.git@main#egg=clip
28
+
29
+ # install your package in editable mode
30
+ -e .
run_magicfu.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Adobe. All rights reserved.
2
+
3
+ #%%
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from itertools import islice
10
+ from torch import autocast
11
+ import torchvision
12
+ from ldm.util import instantiate_from_config
13
+ from ldm.models.diffusion.ddim import DDIMSampler
14
+ from torchvision.transforms import Resize
15
+ import argparse
16
+ import os
17
+ import pathlib
18
+ import glob
19
+
20
+
21
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22
+
23
+ def fix_img(test_img):
24
+ width, height = test_img.size
25
+ if width != height:
26
+ left = 0
27
+ right = height
28
+ bottom = height
29
+ top = 0
30
+ return test_img.crop((left, top, right, bottom))
31
+ else:
32
+ return test_img
33
+ # util funcs
34
+ def chunk(it, size):
35
+ it = iter(it)
36
+ return iter(lambda: tuple(islice(it, size)), ())
37
+
38
+ def get_tensor_clip(normalize=True, toTensor=True):
39
+ transform_list = []
40
+ if toTensor:
41
+ transform_list += [torchvision.transforms.ToTensor()]
42
+
43
+ if normalize:
44
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
45
+ (0.26862954, 0.26130258, 0.27577711))]
46
+ return torchvision.transforms.Compose(transform_list)
47
+
48
+ def get_tensor_dino(normalize=True, toTensor=True):
49
+ transform_list = [torchvision.transforms.Resize((224,224))]
50
+ if toTensor:
51
+ transform_list += [torchvision.transforms.ToTensor()]
52
+
53
+ if normalize:
54
+ transform_list += [lambda x: 255.0 * x[:3],
55
+ torchvision.transforms.Normalize(
56
+ mean=(123.675, 116.28, 103.53),
57
+ std=(58.395, 57.12, 57.375),
58
+ )]
59
+ return torchvision.transforms.Compose(transform_list)
60
+
61
+ def get_tensor(normalize=True, toTensor=True):
62
+ transform_list = []
63
+ if toTensor:
64
+ transform_list += [torchvision.transforms.ToTensor()]
65
+
66
+ if normalize:
67
+ transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
68
+ (0.5, 0.5, 0.5))]
69
+ transform_list += [
70
+ torchvision.transforms.Resize(512),
71
+ torchvision.transforms.CenterCrop(512)
72
+ ]
73
+ return torchvision.transforms.Compose(transform_list)
74
+
75
+
76
+ def numpy_to_pil(images):
77
+ """
78
+ Convert a numpy image or a batch of images to a PIL image.
79
+ """
80
+ if images.ndim == 3:
81
+ images = images[None, ...]
82
+ images = (images * 255).round().astype("uint8")
83
+ pil_images = [Image.fromarray(image) for image in images]
84
+
85
+ return pil_images
86
+
87
+
88
+
89
+ def load_model_from_config(config, ckpt, verbose=False):
90
+ model = instantiate_from_config(config.model)
91
+ # print('NOTE: NO CHECKPOINT IS LOADED')
92
+
93
+ if ckpt is not None:
94
+ print(f"Loading model from {ckpt}")
95
+ pl_sd = torch.load(ckpt, map_location="cpu")
96
+ if "global_step" in pl_sd:
97
+ print(f"Global Step: {pl_sd['global_step']}")
98
+ # sd = pl_sd["state_dict"]
99
+
100
+ m, u = model.load_state_dict(sd, strict=False)
101
+ if len(m) > 0 and verbose:
102
+ print("missing keys:")
103
+ print(m)
104
+ if len(u) > 0 and verbose:
105
+ print("unexpected keys:")
106
+ print(u)
107
+
108
+ model.cuda()
109
+ model.eval()
110
+ return model
111
+
112
+
113
+ def get_model(config_path, ckpt_path):
114
+ config = OmegaConf.load(f"{config_path}")
115
+ model = load_model_from_config(config, None)
116
+ pl_sd = torch.load(ckpt_path, map_location="cpu")
117
+
118
+ m, u = model.load_state_dict(pl_sd, strict=True)
119
+ if len(m) > 0:
120
+ print("WARNING: missing keys:")
121
+ print(m)
122
+ if len(u) > 0:
123
+ print("unexpected keys:")
124
+ print(u)
125
+
126
+
127
+ model = model.to(device)
128
+ return model
129
+
130
+ def get_grid(size):
131
+ y = np.repeat(np.arange(size)[None, ...], size)
132
+ y = y.reshape(size, size)
133
+ x = y.transpose()
134
+ out = np.stack([y,x], -1)
135
+ return out
136
+
137
+ def un_norm(x):
138
+ return (x+1.0)/2.0
139
+
140
+ class MagicFixup:
141
+ def __init__(self, model_path='/sensei-fs/users/halzayer/collage2photo/Paint-by-Example/official_checkpoint_image_attn_200k.pt'):
142
+ self.model = get_model('configs/collage_mix_train.yaml',model_path)
143
+
144
+
145
+ def edit_image(self, ref_image, coarse_edit, mask_tensor, start_step, steps):
146
+ # essentially sample
147
+ sampler = DDIMSampler(self.model)
148
+
149
+ start_code = None
150
+
151
+ transformed_grid = torch.zeros((2, 64, 64))
152
+
153
+ self.model.model.og_grid = None
154
+ self.model.model.transformed_grid = transformed_grid.unsqueeze(0).to(self.model.device)
155
+
156
+ scale = 1.0
157
+ C, f, H, W= 4, 8, 512, 512
158
+ n_samples = 1
159
+ ddim_steps = steps
160
+ ddim_eta = 1.0
161
+ step = start_step
162
+
163
+ with torch.no_grad():
164
+ with autocast("cuda"):
165
+ with self.model.ema_scope():
166
+ image_tensor = get_tensor(toTensor=False)(coarse_edit)
167
+
168
+ clean_ref_tensor = get_tensor(toTensor=False)(ref_image)
169
+ clean_ref_tensor = clean_ref_tensor.unsqueeze(0)
170
+
171
+ ref_tensor=get_tensor_dino(toTensor=False)(ref_image).unsqueeze(0)
172
+
173
+ b_mask = mask_tensor.cpu() < 0.5
174
+
175
+ # inpainting
176
+ reference = un_norm(image_tensor)
177
+ reference = reference.squeeze()
178
+ ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy()
179
+ ref_cv = (ref_cv * 255).astype(np.uint8)
180
+
181
+ cv_mask = b_mask.int().squeeze().cpu().numpy().astype(np.uint8)
182
+ kernel = np.ones((7,7))
183
+ dilated_mask = cv2.dilate(cv_mask, kernel)
184
+
185
+ dst = cv2.inpaint(ref_cv,dilated_mask,3,cv2.INPAINT_NS)
186
+ # dst = inpaint.inpaint_biharmonic(ref_cv, dilated_mask, channel_axis=-1)
187
+ dst_tensor = torch.tensor(dst).moveaxis(-1, 0) / 255.0
188
+ image_tensor = (dst_tensor * 2.0) - 1.0
189
+ image_tensor = image_tensor.unsqueeze(0)
190
+
191
+ ref_tensor = ref_tensor
192
+
193
+ inpaint_image = image_tensor#*mask_tensor
194
+
195
+ test_model_kwargs={}
196
+ test_model_kwargs['inpaint_mask']=mask_tensor.to(device)
197
+ test_model_kwargs['inpaint_image']=inpaint_image.to(device)
198
+ clean_ref_tensor = clean_ref_tensor.to(device)
199
+ ref_tensor=ref_tensor.to(device)
200
+ uc = None
201
+ if scale != 1.0:
202
+ uc = self.model.learnable_vector
203
+ c = self.model.get_learned_conditioning(ref_tensor.to(torch.float16))
204
+ c = self.model.proj_out(c)
205
+
206
+ z_inpaint = self.model.encode_first_stage(test_model_kwargs['inpaint_image'])
207
+ z_inpaint = self.model.get_first_stage_encoding(z_inpaint).detach()
208
+
209
+
210
+ z_ref = self.model.encode_first_stage(clean_ref_tensor)
211
+ z_ref = self.model.get_first_stage_encoding(z_ref).detach()
212
+
213
+ test_model_kwargs['inpaint_image']=z_inpaint
214
+ test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask'])
215
+
216
+
217
+ shape = [C, H // f, W // f]
218
+
219
+ samples_ddim, _ = sampler.sample(S=ddim_steps,
220
+ conditioning=c,
221
+ z_ref=z_ref,
222
+ batch_size=n_samples,
223
+ shape=shape,
224
+ verbose=False,
225
+ unconditional_guidance_scale=scale,
226
+ unconditional_conditioning=uc,
227
+ eta=ddim_eta,
228
+ x_T=start_code,
229
+ test_model_kwargs=test_model_kwargs,
230
+ x0=z_inpaint,
231
+ x0_step=step,
232
+ ddim_discretize='uniform',
233
+ drop_latent_guidance=1.0
234
+ )
235
+
236
+
237
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
238
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
239
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
240
+
241
+ x_checked_image=x_samples_ddim
242
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
243
+
244
+
245
+ return x_checked_image_torch
246
+ #%%
247
+
248
+
249
+ #%%
250
+ import time
251
+
252
+
253
+
254
+ # %%
255
+ def file_exists(path):
256
+ """ Check if a file exists and is not a directory. """
257
+ if not os.path.isfile(path):
258
+ raise argparse.ArgumentTypeError(f"{path} is not a valid file.")
259
+ return path
260
+
261
+ def parse_arguments():
262
+ """ Parses command-line arguments. """
263
+ parser = argparse.ArgumentParser(description="Process images based on provided paths.")
264
+ parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.")
265
+ parser.add_argument("--reference", type=file_exists, default='examples/fox_drinking_og.png', help="Path to the reference original image.")
266
+ parser.add_argument("--edit", type=file_exists, default='examples/fox_drinking__edit__01.png', help="Path to the image edit. Make sure the alpha channel is set properly")
267
+ parser.add_argument("--output-dir", type=str, default='./outputs', help="Path to the folder where to save the outputs")
268
+ parser.add_argument("--samples", type=int, default=5, help="number of samples to output")
269
+
270
+ return parser.parse_args()
271
+
272
+
273
+ def main():
274
+ # Parse arguments
275
+ args = parse_arguments()
276
+
277
+ # create magic fixup model
278
+ magic_fixup = MagicFixup(model_path=args.checkpoint)
279
+ output_dir = args.output_dir
280
+
281
+ os.makedirs(output_dir, exist_ok=True)
282
+
283
+ # run it here
284
+
285
+ to_tensor = torchvision.transforms.ToTensor()
286
+
287
+
288
+
289
+ ref_path = args.reference
290
+ coarse_edit_path = args.edit
291
+ mask_edit_path = coarse_edit_path
292
+
293
+ edit_file_name = pathlib.Path(coarse_edit_path).stem
294
+ save_pattern = f'{output_dir}/{edit_file_name}__sample__*.png'
295
+ save_counter = len(glob.glob(save_pattern))
296
+
297
+ all_rgbs = []
298
+ for i in range(args.samples):
299
+ with autocast("cuda"):
300
+ ref_image_t = to_tensor(Image.open(ref_path).convert('RGB').resize((512,512))).half().cuda()
301
+ coarse_edit_t = to_tensor(Image.open(coarse_edit_path).resize((512,512))).half().cuda()
302
+ # get mask from coarse
303
+ # mask_t = torch.ones_like(coarse_edit_t[-1][None, None,...])
304
+ coarse_edit_mask_t = to_tensor(Image.open(mask_edit_path).resize((512,512))).half().cuda()
305
+ # get mask from coarse
306
+ mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop
307
+ coarse_edit_t_rgb = coarse_edit_t[:-1]
308
+
309
+ out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50)
310
+ all_rgbs.append(out_rgb.squeeze().cpu().detach().float())
311
+
312
+ save_path = f'{output_dir}/{edit_file_name}__sample__{save_counter:03d}.png'
313
+ torchvision.utils.save_image(all_rgbs[i], save_path)
314
+ save_counter += 1
315
+
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()
setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='Magic-FU',
5
+ version='0.0.1',
6
+ description='',
7
+ packages=find_packages(),
8
+ install_requires=[],
9
+ )
src/clip ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
src/taming-transformers ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3ba01b241669f5ade541ce990f7650a3b8f65318