vinhtong97 commited on
Commit
d382778
·
verified ·
1 Parent(s): 876a0c1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +185 -0
  3. README.md +177 -3
  4. compute_fid.py +158 -0
  5. configs/afhqv2.yml +49 -0
  6. configs/cifar10.yml +49 -0
  7. configs/cifar10_order2.yml +49 -0
  8. configs/ffhq.yml +49 -0
  9. configs/latent-diffusion/celebahq-ldm-vq-4.yaml +86 -0
  10. configs/latent-diffusion/cin-ldm-vq-f8.yaml +98 -0
  11. configs/latent-diffusion/cin256-v2.yaml +68 -0
  12. configs/latent-diffusion/ffhq-ldm-vq-4.yaml +85 -0
  13. configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml +85 -0
  14. configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml +91 -0
  15. configs/latent-diffusion/txt2img-1p4B-eval.yaml +71 -0
  16. configs/latent_diff_LSUN.yml +55 -0
  17. configs/latent_diff_imn.yml +55 -0
  18. configs/stable-diffusion/v1-inference.yaml +70 -0
  19. configs/stable_diff_v1-4.yml +55 -0
  20. configs/stable_diff_v1-5.yml +55 -0
  21. data/coco_captions.txt +0 -0
  22. data/prompts.txt +5 -0
  23. dataset.py +46 -0
  24. dnnlib/__init__.py +8 -0
  25. dnnlib/util.py +491 -0
  26. gen_data.py +188 -0
  27. ldm/__init__.py +0 -0
  28. ldm/data/__init__.py +0 -0
  29. ldm/data/base.py +23 -0
  30. ldm/data/imagenet.py +394 -0
  31. ldm/data/lsun.py +92 -0
  32. ldm/lr_scheduler.py +98 -0
  33. ldm/models/autoencoder.py +442 -0
  34. ldm/models/diffusion/__init__.py +0 -0
  35. ldm/models/diffusion/classifier.py +267 -0
  36. ldm/models/diffusion/ddim.py +241 -0
  37. ldm/models/diffusion/ddpm.py +1445 -0
  38. ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  39. ldm/models/diffusion/dpm_solver/dpm_solver.py +780 -0
  40. ldm/models/diffusion/dpm_solver/sampler.py +95 -0
  41. ldm/models/diffusion/dpm_solver_v3/__init__.py +1 -0
  42. ldm/models/diffusion/dpm_solver_v3/dpm_solver_v3.py +824 -0
  43. ldm/models/diffusion/dpm_solver_v3/sampler.py +95 -0
  44. ldm/models/diffusion/plms.py +236 -0
  45. ldm/models/diffusion/uni_pc/__init__.py +1 -0
  46. ldm/models/diffusion/uni_pc/sampler.py +83 -0
  47. ldm/models/diffusion/uni_pc/uni_pc.py +547 -0
  48. ldm/modules/attention.py +227 -0
  49. ldm/modules/diffusionmodules/__init__.py +0 -0
  50. ldm/modules/diffusionmodules/model.py +835 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/taming-transformers/scripts/reconstruction_usage.ipynb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data file
2
+ train_data*
3
+ test_gen_data
4
+ sampling_data
5
+ val_coco_images
6
+ temp
7
+ tmp
8
+ pretrained
9
+ fid-refs
10
+ captions_val2014.json
11
+ *json
12
+ *gz
13
+ !src/clip/clip/bpe_simple_vocab_16e6.txt.gz
14
+ *zip
15
+ logs/
16
+ *pt
17
+ *pkl
18
+ *png
19
+ *err
20
+ *out
21
+ test_scripts/*
22
+ all_logs
23
+ all_fids
24
+ # Byte-compiled / optimized / DLL files
25
+ __pycache__/
26
+ *.py[cod]
27
+ *$py.class
28
+
29
+ # C extensions
30
+ *.so
31
+
32
+ # Distribution / packaging
33
+ .Python
34
+ build/
35
+ develop-eggs/
36
+ dist/
37
+ downloads/
38
+ eggs/
39
+ .eggs/
40
+ lib/
41
+ lib64/
42
+ parts/
43
+ sdist/
44
+ var/
45
+ wheels/
46
+ share/python-wheels/
47
+ *.egg-info/
48
+ .installed.cfg
49
+ *.egg
50
+ MANIFEST
51
+
52
+ # PyInstaller
53
+ # Usually these files are written by a python script from a template
54
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
55
+ *.manifest
56
+ *.spec
57
+
58
+ # Installer logs
59
+ pip-log.txt
60
+ pip-delete-this-directory.txt
61
+
62
+ # Unit test / coverage reports
63
+ htmlcov/
64
+ .tox/
65
+ .nox/
66
+ .coverage
67
+ .coverage.*
68
+ .cache
69
+ nosetests.xml
70
+ coverage.xml
71
+ *.cover
72
+ *.py,cover
73
+ .hypothesis/
74
+ .pytest_cache/
75
+ cover/
76
+
77
+ # Translations
78
+ *.mo
79
+ *.pot
80
+
81
+ # Django stuff:
82
+ *.log
83
+ local_settings.py
84
+ db.sqlite3
85
+ db.sqlite3-journal
86
+
87
+ # Flask stuff:
88
+ instance/
89
+ .webassets-cache
90
+
91
+ # Scrapy stuff:
92
+ .scrapy
93
+
94
+ # Sphinx documentation
95
+ docs/_build/
96
+
97
+ # PyBuilder
98
+ .pybuilder/
99
+ target/
100
+
101
+ # Jupyter Notebook
102
+ .ipynb_checkpoints
103
+
104
+ # IPython
105
+ profile_default/
106
+ ipython_config.py
107
+
108
+ # pyenv
109
+ # For a library or package, you might want to ignore these files since the code is
110
+ # intended to run in multiple environments; otherwise, check them in:
111
+ # .python-version
112
+
113
+ # pipenv
114
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
115
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
116
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
117
+ # install all needed dependencies.
118
+ #Pipfile.lock
119
+
120
+ # poetry
121
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
122
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
123
+ # commonly ignored for libraries.
124
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
125
+ #poetry.lock
126
+
127
+ # pdm
128
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
129
+ #pdm.lock
130
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
131
+ # in version control.
132
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
133
+ .pdm.toml
134
+ .pdm-python
135
+ .pdm-build/
136
+
137
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
138
+ __pypackages__/
139
+
140
+ # Celery stuff
141
+ celerybeat-schedule
142
+ celerybeat.pid
143
+
144
+ # SageMath parsed files
145
+ *.sage.py
146
+
147
+ # Environments
148
+ .env
149
+ .venv
150
+ env/
151
+ venv/
152
+ ENV/
153
+ env.bak/
154
+ venv.bak/
155
+
156
+ # Spyder project settings
157
+ .spyderproject
158
+ .spyproject
159
+
160
+ # Rope project settings
161
+ .ropeproject
162
+
163
+ # mkdocs documentation
164
+ /site
165
+
166
+ # mypy
167
+ .mypy_cache/
168
+ .dmypy.json
169
+ dmypy.json
170
+
171
+ # Pyre type checker
172
+ .pyre/
173
+
174
+ # pytype static type analyzer
175
+ .pytype/
176
+
177
+ # Cython debug symbols
178
+ cython_debug/
179
+
180
+ # PyCharm
181
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
182
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
183
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
184
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
185
+ #.idea/
README.md CHANGED
@@ -1,3 +1,177 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learning to Discretize Denoising Diffusion ODEs
2
+
3
+ 🏆 ![ICLR2025 Oral](https://img.shields.io/badge/ICLR2025-Oral-blue)
4
+
5
+ ### [Paper on OpenReview](https://openreview.net/forum?id=xDrFWUmCne)
6
+
7
+ Implementation of LD3, a lightweight framework designed to learn the optimal time discretization for sampling from pre-trained Diffusion Probabilistic Models (DPMs). LD3 can be combined with various samplers and consistently improves generation quality without having to retrain resource-intensive neural networks. LD3 offers an efficient approach to sampling from pre-trained diffusion models.
8
+
9
+ ![Alt Text](visualizations/illustration-lddd.png)
10
+
11
+ ## 🔥 Latest News
12
+
13
+ - **March 2025**: We have successfully applied **LD3** to the **Flux-dev** model and observed promising results.
14
+ - We are releasing the trained time steps for the Flux model soon! Stay tuned for updates.
15
+
16
+
17
+ ## Setup Environment
18
+
19
+ We will set up the environment using [Anaconda](https://docs.anaconda.com/anaconda/install/index.html).
20
+
21
+ ```bash
22
+ conda env create -f requirements.yml
23
+ conda activate ld3
24
+ pip install -e ./src/clip/
25
+ pip install -e ./src/taming-transformers/
26
+ pip install omegaconf
27
+ pip install PyYAML
28
+ pip install requests
29
+ pip install scipy
30
+ pip install torchmetrics
31
+ ```
32
+
33
+ ## Download Pretrained Models and FID Reference Sets
34
+
35
+ All necessary data will be automatically downloaded by the script. Note that this process may take some time. If you wish to skip certain downloads, you can comment out the corresponding lines in the script.
36
+
37
+ ```bash
38
+ bash scripts/download_model.sh
39
+ wget https://raw.githubusercontent.com/tylin/coco-caption/master/annotations/captions_val2014.json
40
+ ```
41
+
42
+
43
+ ## 🚀 Generating Training Data for LD3
44
+
45
+ Before training **LD3**, we first need to generate training data using the teacher solver. The script `gen_data.py` handles this process. Below is an example of generating training data with **20 sampling steps** for **CIFAR-10**, using the `uni_pc` solver and `time-edm` discretization.
46
+
47
+ ### 📌 Example: Generating CIFAR-10 Training Data
48
+
49
+ ```bash
50
+ CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
51
+ --all_config configs/cifar10.yml \
52
+ --total_samples 100 \
53
+ --sampling_batch_size 10 \
54
+ --steps 20 \
55
+ --solver_name uni_pc \
56
+ --skip_type edm \
57
+ --save_pt --save_png --data_dir train_data/train_data_cifar10 \
58
+ --low_gpu
59
+ ```
60
+ #### 📌 Key Arguments:
61
+
62
+ - `all_config`: Path to the default configuration file (mandatory). If other arguments are not specified, their values will be taken from this file.
63
+ - `solver_name`: Solver to use. Options include `uni_pc`, `dpm_solver++`, `euler`, and `ipndm`.
64
+ - `skip_type`: Discretization method. Options include `edm`, `time_uniform`, and `time_quadratic`.
65
+ - `low_gpu`: Enables the use of PyTorch's `checkpoint` feature to reduce GPU memory usage.
66
+ - `data_dir`: Root directory for saving the generated data. The script will create a subdirectory within this path using the naming format `${solver_name}_NFE${steps}_${skip_type}`.
67
+
68
+ ### 📌 Example: Generating Stable Diffusion Training Data
69
+
70
+ For Stable Diffusion, you must additionally specify the prompt file and the number of prompts. Below is an example:
71
+
72
+ ```bash
73
+ CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
74
+ --all_config configs/stable_diff_v1-4.yml \
75
+ --total_samples 100 \
76
+ --sampling_batch_size 2 \
77
+ --steps 6 \
78
+ --solver_name uni_pc \
79
+ --skip_type time_uniform \
80
+ --save_pt --save_png --data_dir train_data/train_data_stable_diff_v1-4 \
81
+ --low_gpu \
82
+ --num_prompts 5 --prompt_path captions_val2014.json
83
+ ```
84
+
85
+ ## Training LD3
86
+ After generating the training data, you can train **LD3** using the `main.py` script. Below is an example of training **LD3** on **CIFAR-10** with the following configurations:
87
+ - **Teacher**: 20 sampling steps, `uni_pc` solver, and `time-edm` discretization.
88
+ - **Student**: 10 sampling steps, `dpm_solver++` solver.
89
+
90
+ ```bash
91
+ CUDA_VISIBLE_DEVICES=0 python3 main.py \
92
+ --all_config configs/cifar10.yml \
93
+ --data_dir train_data/train_data_cifar10/uni_pc_NFE20_edm \
94
+ --num_train 50 --num_valid 50 \
95
+ --main_train_batch_size 1 \
96
+ --main_valid_batch_size 10 \
97
+ --solver_name dpm_solver++ \
98
+ --training_rounds_v1 2 \
99
+ --training_rounds_v2 5 \
100
+ --steps 10 \
101
+ --log_path logs/logs_cifar10
102
+ ```
103
+
104
+ **Trained timesteps are available [here](https://docs.google.com/spreadsheets/d/1nUrTDvvtpPHZuRuJcn3zzxGmVKrNX4fFHu8wYIIoGSM/edit?usp=sharing) and are still being updated.**
105
+
106
+
107
+
108
+ #### 📌 Key Arguments:
109
+
110
+ - `data_dir`: The full path to the training data directory (unlike the root directory used during data generation).
111
+ - `log_path`: The root directory for saving logs and models. The script will create a subdirectory within this path using the naming format: `${solver_name}-N${steps}-b${bound}-${loss_type}-lr2${lr2}rv1${rv1}-rv2${rv2}`, for example, `uni_pc-N10-b0.03072-LPIPS-lr20.01rv12-rv25`
112
+
113
+ ## FID Evaluation
114
+
115
+
116
+ ### ⚠️ Different FID Scores
117
+
118
+ It is important to note that FID (Fréchet Inception Distance) scores can vary significantly depending on the processing pipeline used. To ensure transparency and reproducibility, our framework provides a script `compute_fid.py` that supports FID evaluation for both EDM and Latent-Diffusion.
119
+
120
+ ### 📌 How FID Evaluation Works
121
+
122
+ The `compute_fid.py` script is a streamlined version of `gen_data.py` with a few differences:
123
+
124
+ The `--save_dir`, `--save_pt`, and `--save_png` arguments are ignored because the generated data is directly processed for FID calculation without being saved.
125
+
126
+ The data is automatically forwarded to the FID computation module to extract features.
127
+
128
+ Optionally, you can pass your own timesteps via `--custom_ts_1` and `--custom_ts_2`. If `custom_ts_2` is not specified, it will be set the same as `custom_ts_1`
129
+ ### 📌 Example: Computing FID for Stable Diffusion
130
+
131
+ ```bash
132
+ CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
133
+ --all_config configs/stable_diff_v1-5.yml \
134
+ --total_samples 100 \
135
+ --sampling_batch_size 2 \
136
+ --steps 6 \
137
+ --solver_name uni_pc \
138
+ --skip_type time_uniform \
139
+ --low_gpu \
140
+ --num_prompts 5 --prompt_path captions_val2014.json
141
+
142
+ CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
143
+ --all_config configs/stable_diff_v1-5.yml \
144
+ --total_samples 100 \
145
+ --sampling_batch_size 2 \
146
+ --steps 4 \
147
+ --solver_name ipndm \
148
+ --skip_type custom \
149
+ --custom_ts_1 [1.0000e+00,7.6668e-01,4.8113e-01,1.8417e-01,1.0000e-03] \
150
+ --custom_ts_2 [1.0000e+00,7.6706e-01,4.8103e-01,1.8396e-01,1.0000e-03] \
151
+ --low_gpu \
152
+ --num_prompts 5 --prompt_path captions_val2014.json
153
+
154
+ ```
155
+ ## Citation
156
+
157
+ ```
158
+ @inproceedings{tong2024learning,
159
+ title = {Learning to Discretize Denoising Diffusion ODEs},
160
+ author = {Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Van den Broeck, Guy and Niepert, Mathias},
161
+ booktitle = {Proceedings of the 13th International Conference on Learning Representations},
162
+ year = {2025}
163
+ }
164
+ ```
165
+
166
+ ```
167
+ @article{tong2024learning,
168
+ title={Learning to Discretize Denoising Diffusion ODEs},
169
+ author={Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Broeck, Guy Van den and Niepert, Mathias},
170
+ journal={arXiv preprint arXiv:2405.15506},
171
+ year={2024}
172
+ }
173
+ ```
174
+
175
+ ## License
176
+ MIT
177
+
compute_fid.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.autograd import Variable
4
+ from torch.nn import functional as F
5
+
6
+ import numpy as np
7
+
8
+ import pickle
9
+ from tqdm import tqdm
10
+ from utils import (
11
+ parse_arguments,
12
+ check_fid_file,
13
+ prepare_paths,
14
+ adjust_hyper,
15
+ get_solvers,
16
+ set_seed_everything,
17
+ )
18
+ from models import prepare_stuff, prepare_condition_loader
19
+ import math
20
+ import dnnlib
21
+ import pickle
22
+ import scipy
23
+
24
+ from torch.nn.functional import adaptive_avg_pool2d
25
+ from pytorch_fid.inception import InceptionV3
26
+
27
+ from gen_data import Generator, get_data_inverse_scaler
28
+
29
+ def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
30
+ m = np.square(mu - mu_ref).sum()
31
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
32
+ fid = m + np.trace(sigma + sigma_ref - s * 2)
33
+ return float(np.real(fid))
34
+
35
+ def main(args):
36
+
37
+ if not args.use_ema:
38
+ print("Auto update use_ema to True for evaluation")
39
+ args.use_ema = True
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+ print("Start sampling...")
44
+
45
+ # laten-diff evaluation
46
+ FEATURE_DIM = 2048
47
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM]
48
+ fid_model = InceptionV3([block_idx]).to(device)
49
+ fid_model.eval()
50
+
51
+ # edm evalutaion
52
+ DETECTOR_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl"
53
+ with dnnlib.util.open_url(DETECTOR_URL, verbose=True) as f:
54
+ detector_net = pickle.load(f).to(device)
55
+
56
+ with dnnlib.util.open_url(args.ref_path) as f:
57
+ ref = dict(np.load(f))
58
+
59
+ wrapped_model, model, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
60
+ condition_loader = prepare_condition_loader(model_type=args.model,
61
+ model=model,
62
+ scale=args.scale if hasattr(args, "scale") else None,
63
+ condition=args.prompt_path or "uniform",
64
+ sampling_batch_size=args.sampling_batch_size,
65
+ num_prompt=None,
66
+ )
67
+
68
+ adjust_hyper(args, latent_resolution, latent_channel)
69
+ _, _, skip_type = prepare_paths(args)
70
+
71
+
72
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
73
+ solver, steps, solver_extra_params = get_solvers(
74
+ args.solver_name,
75
+ NFEs=args.steps,
76
+ order=args.order,
77
+ noise_schedule=noise_schedule,
78
+ unipc_variant=args.unipc_variant,
79
+ )
80
+
81
+ generator = Generator(
82
+ noise_schedule=noise_schedule,
83
+ solver=solver,
84
+ order=args.order,
85
+ skip_type=skip_type,
86
+ load_from=args.load_from,
87
+ timesteps_1=args.custom_ts_1,
88
+ timesteps_2=args.custom_ts_2,
89
+ steps=steps,
90
+ solver_extra_params=solver_extra_params,
91
+ device=device,
92
+ )
93
+
94
+ print(generator.timesteps, generator.timesteps2)
95
+ inverse_scalar = get_data_inverse_scaler(centered=True)
96
+
97
+ num_batches = math.ceil(args.total_samples / args.sampling_batch_size)
98
+ batch_size = args.sampling_batch_size
99
+ n_total_samples = batch_size * num_batches
100
+
101
+ mu = torch.zeros([FEATURE_DIM], dtype=torch.float64, device=device)
102
+ sigma = torch.zeros([FEATURE_DIM, FEATURE_DIM], dtype=torch.float64, device=device)
103
+ act_arr = np.empty((n_total_samples, FEATURE_DIM))
104
+ start_idx=0
105
+ with torch.no_grad():
106
+ for index in tqdm(range(num_batches)):
107
+ current_batch_size = min(batch_size, args.total_samples - index * batch_size)
108
+ sampling_shape = (current_batch_size, latent_channel, latent_resolution, latent_resolution)
109
+ latents = torch.randn(sampling_shape, device=device)
110
+
111
+ if condition_loader is not None:
112
+ conditioning, conditioned_unconditioning = next(condition_loader)
113
+ else:
114
+ conditioning = None
115
+ conditioned_unconditioning = None
116
+
117
+ img_teacher = generator.sample(wrapped_model, decoding_fn, latents, conditioning, conditioned_unconditioning)
118
+ img_teacher = inverse_scalar(img_teacher)
119
+ samples_edm = 255 * img_teacher
120
+ images = torch.clip(samples_edm, 0, 255).to(torch.uint8)
121
+ features = detector_net(images.to(device), return_features=True).to(
122
+ torch.float64
123
+ )
124
+ mu += features.sum(0)
125
+ sigma += features.T @ features
126
+
127
+ samples_latent_diff = torch.clamp(img_teacher, min=0.0, max=1.0)
128
+
129
+ with torch.no_grad():
130
+ pred = fid_model(samples_latent_diff.float())[0]
131
+
132
+ # If model output is not scalar, apply global spatial average pooling.
133
+ # This happens if you choose a dimensionality not equal 2048.
134
+ if pred.size(2) != 1 or pred.size(3) != 1:
135
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
136
+
137
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
138
+ act_arr[start_idx:start_idx + pred.shape[0]] = pred
139
+ start_idx = start_idx + pred.shape[0]
140
+
141
+ mu /= n_total_samples
142
+ sigma -= mu.ger(mu) * n_total_samples
143
+ sigma /= n_total_samples - 1
144
+ mu = mu.cpu().numpy()
145
+ sigma = sigma.cpu().numpy()
146
+ fid_edm = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
147
+
148
+ mu = np.mean(act_arr, axis=0)
149
+ sigma = np.cov(act_arr, rowvar=False)
150
+ fid_latent_diff = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
151
+
152
+ print("FID EDM: {:.4f}".format(fid_edm))
153
+ print("FID LD: {:.4f}".format(fid_latent_diff))
154
+
155
+ if __name__ == "__main__":
156
+ args = parse_arguments()
157
+ set_seed_everything(args.seed)
158
+ main(args)
configs/afhqv2.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: edm
2
+ # model_parameters:
3
+ ckp_path: pretrained/edm-afhqv2-64x64-uncond-vp.pkl
4
+ solver_name: uni_pc
5
+ unipc_variant: bh1
6
+ steps: 10
7
+ order: 3
8
+ time_mode: lambda
9
+
10
+ # training_parameters:
11
+ seed: 0
12
+ log_path: all_logs/afhqv2_logs/
13
+ data_dir: train_data/train_data_afhqv2
14
+ win_rate: 0.5
15
+ prior_bound: 1.0
16
+ fix_bound: false
17
+ loss_type: LPIPS
18
+ main_train_batch_size: 2
19
+ main_valid_batch_size: 150
20
+ num_train: 50
21
+ num_valid: 50
22
+ training_rounds_v1: 2
23
+ training_rounds_v2: 5
24
+ shift_lr: null
25
+ lr_time_1: 0.005
26
+ lr_time_2: 0.1
27
+ min_lr_time_1: 0.00005
28
+ min_lr_time_2: 0.000001
29
+ momentum_time_1: 0.9
30
+ weight_decay_time_1: 0.0
31
+ shift_lr_decay: 0.5
32
+ lr_time_decay: 0.8
33
+ patient: 5
34
+ lr2_patient: 5
35
+ no_v1: false
36
+ visualize: true
37
+
38
+ # testing_parameters:
39
+ learn: false
40
+ load_from: null
41
+ skip_type: null
42
+ num_multi_steps_fid: null
43
+ fid_folder: all_fids/afhqv2_fids/
44
+ sampling_batch_size: 150
45
+ sampling_seed: 0
46
+ ref_path: fid-refs/afhqv2-64x64.npz
47
+ total_samples: 50000
48
+ save_png: false
49
+ save_pt: true
configs/cifar10.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: edm
2
+ # model_parameters:
3
+ ckp_path: pretrained/edm-cifar10-32x32-uncond-vp.pkl
4
+ solver_name: uni_pc
5
+ unipc_variant: bh1
6
+ steps: 10
7
+ order: 3
8
+ time_mode: lambda
9
+
10
+ # training_parameters:
11
+ seed: 0
12
+ log_path: all_logs/cifar10_logs/
13
+ data_dir: train_data/train_data_cifar
14
+ win_rate: 0.5
15
+ prior_bound: 1.0
16
+ fix_bound: false
17
+ loss_type: LPIPS
18
+ main_train_batch_size: 2
19
+ main_valid_batch_size: 150
20
+ num_train: 50
21
+ num_valid: 50
22
+ training_rounds_v1: 2
23
+ training_rounds_v2: 5
24
+ shift_lr: null
25
+ lr_time_1: 0.005
26
+ lr_time_2: 0.1
27
+ min_lr_time_1: 0.00005
28
+ min_lr_time_2: 0.000001
29
+ momentum_time_1: 0.9
30
+ weight_decay_time_1: 0.0
31
+ shift_lr_decay: 0.5
32
+ lr_time_decay: 0.8
33
+ patient: 5
34
+ lr2_patient: 5
35
+ no_v1: false
36
+ visualize: true
37
+
38
+ # testing_parameters:
39
+ learn: false
40
+ load_from: null
41
+ skip_type: null
42
+ num_multi_steps_fid: null
43
+ fid_folder: all_fids/cifar10_fids/
44
+ sampling_batch_size: 150
45
+ sampling_seed: 0
46
+ ref_path: fid-refs/cifar10-32x32.npz
47
+ total_samples: 50000
48
+ save_png: false
49
+ save_pt: true
configs/cifar10_order2.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: edm
2
+ # model_parameters:
3
+ ckp_path: pretrained/edm-cifar10-32x32-uncond-vp.pkl
4
+ solver_name: uni_pc
5
+ unipc_variant: bh1
6
+ steps: 10
7
+ order: 2
8
+ time_mode: lambda
9
+
10
+ # training_parameters:
11
+ seed: 0
12
+ log_path: all_logs/cifar10_logs/
13
+ data_dir: train_data/train_data_cifar
14
+ win_rate: 0.5
15
+ prior_bound: 1.0
16
+ fix_bound: false
17
+ loss_type: LPIPS
18
+ main_train_batch_size: 2
19
+ main_valid_batch_size: 150
20
+ num_train: 50
21
+ num_valid: 50
22
+ training_rounds_v1: 2
23
+ training_rounds_v2: 5
24
+ shift_lr: null
25
+ lr_time_1: 0.005
26
+ lr_time_2: 0.1
27
+ min_lr_time_1: 0.00005
28
+ min_lr_time_2: 0.000001
29
+ momentum_time_1: 0.9
30
+ weight_decay_time_1: 0.0
31
+ shift_lr_decay: 0.5
32
+ lr_time_decay: 0.8
33
+ patient: 5
34
+ lr2_patient: 5
35
+ no_v1: false
36
+ visualize: true
37
+
38
+ # testing_parameters:
39
+ learn: false
40
+ load_from: null
41
+ skip_type: null
42
+ num_multi_steps_fid: null
43
+ fid_folder: all_fids/cifar10_fids/
44
+ sampling_batch_size: 150
45
+ sampling_seed: 0
46
+ ref_path: fid-refs/cifar10-32x32.npz
47
+ total_samples: 50000
48
+ save_png: false
49
+ save_pt: true
configs/ffhq.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: edm
2
+ # model_parameters:
3
+ ckp_path: pretrained/edm-ffhq-64x64-uncond-vp.pkl
4
+ solver_name: uni_pc
5
+ unipc_variant: bh1
6
+ steps: 10
7
+ order: 3
8
+ time_mode: lambda
9
+
10
+ # training_parameters:
11
+ seed: 0
12
+ log_path: all_logs/ffhq_logs/
13
+ data_dir: train_data/train_data_ffhq
14
+ win_rate: 0.5
15
+ prior_bound: 1.0
16
+ fix_bound: false
17
+ loss_type: LPIPS
18
+ main_train_batch_size: 2
19
+ main_valid_batch_size: 150
20
+ num_train: 50
21
+ num_valid: 50
22
+ training_rounds_v1: 2
23
+ training_rounds_v2: 5
24
+ shift_lr: null
25
+ lr_time_1: 0.005
26
+ lr_time_2: 0.1
27
+ min_lr_time_1: 0.00005
28
+ min_lr_time_2: 0.000001
29
+ momentum_time_1: 0.9
30
+ weight_decay_time_1: 0.0
31
+ shift_lr_decay: 0.5
32
+ lr_time_decay: 0.8
33
+ patient: 5
34
+ lr2_patient: 5
35
+ no_v1: false
36
+ visualize: true
37
+
38
+ # testing_parameters:
39
+ learn: false
40
+ load_from: null
41
+ skip_type: null
42
+ num_multi_steps_fid: null
43
+ fid_folder: all_fids/ffhq_fids/
44
+ sampling_batch_size: 150
45
+ sampling_seed: 0
46
+ ref_path: fid-refs/ffhq-64x64.npz
47
+ total_samples: 50000
48
+ save_png: false
49
+ save_pt: true
configs/latent-diffusion/celebahq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+
15
+ unet_config:
16
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ image_size: 64
19
+ in_channels: 3
20
+ out_channels: 3
21
+ model_channels: 224
22
+ attention_resolutions:
23
+ # note: this isn\t actually the resolution but
24
+ # the downsampling factor, i.e. this corresnponds to
25
+ # attention on spatial resolution 8,16,32, as the
26
+ # spatial reolution of the latents is 64 for f4
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 4
36
+ num_head_channels: 32
37
+ first_stage_config:
38
+ target: ldm.models.autoencoder.VQModelInterface
39
+ params:
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43
+ ddconfig:
44
+ double_z: false
45
+ z_channels: 3
46
+ resolution: 256
47
+ in_channels: 3
48
+ out_ch: 3
49
+ ch: 128
50
+ ch_mult:
51
+ - 1
52
+ - 2
53
+ - 4
54
+ num_res_blocks: 2
55
+ attn_resolutions: []
56
+ dropout: 0.0
57
+ lossconfig:
58
+ target: torch.nn.Identity
59
+ cond_stage_config: __is_unconditional__
60
+ data:
61
+ target: main.DataModuleFromConfig
62
+ params:
63
+ batch_size: 48
64
+ num_workers: 5
65
+ wrap: false
66
+ train:
67
+ target: taming.data.faceshq.CelebAHQTrain
68
+ params:
69
+ size: 256
70
+ validation:
71
+ target: taming.data.faceshq.CelebAHQValidation
72
+ params:
73
+ size: 256
74
+
75
+
76
+ lightning:
77
+ callbacks:
78
+ image_logger:
79
+ target: main.ImageLogger
80
+ params:
81
+ batch_frequency: 5000
82
+ max_images: 8
83
+ increase_log_steps: False
84
+
85
+ trainer:
86
+ benchmark: True
configs/latent-diffusion/cin-ldm-vq-f8.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ unet_config:
18
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ image_size: 32
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 256
24
+ attention_resolutions:
25
+ #note: this isn\t actually the resolution but
26
+ # the downsampling factor, i.e. this corresnponds to
27
+ # attention on spatial resolution 8,16,32, as the
28
+ # spatial reolution of the latents is 32 for f8
29
+ - 4
30
+ - 2
31
+ - 1
32
+ num_res_blocks: 2
33
+ channel_mult:
34
+ - 1
35
+ - 2
36
+ - 4
37
+ num_head_channels: 32
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 512
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 4
45
+ n_embed: 16384
46
+ ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47
+ ddconfig:
48
+ double_z: false
49
+ z_channels: 4
50
+ resolution: 256
51
+ in_channels: 3
52
+ out_ch: 3
53
+ ch: 128
54
+ ch_mult:
55
+ - 1
56
+ - 2
57
+ - 2
58
+ - 4
59
+ num_res_blocks: 2
60
+ attn_resolutions:
61
+ - 32
62
+ dropout: 0.0
63
+ lossconfig:
64
+ target: torch.nn.Identity
65
+ cond_stage_config:
66
+ target: ldm.modules.encoders.modules.ClassEmbedder
67
+ params:
68
+ embed_dim: 512
69
+ key: class_label
70
+ data:
71
+ target: main.DataModuleFromConfig
72
+ params:
73
+ batch_size: 64
74
+ num_workers: 12
75
+ wrap: false
76
+ train:
77
+ target: ldm.data.imagenet.ImageNetTrain
78
+ params:
79
+ config:
80
+ size: 256
81
+ validation:
82
+ target: ldm.data.imagenet.ImageNetValidation
83
+ params:
84
+ config:
85
+ size: 256
86
+
87
+
88
+ lightning:
89
+ callbacks:
90
+ image_logger:
91
+ target: main.ImageLogger
92
+ params:
93
+ batch_frequency: 5000
94
+ max_images: 8
95
+ increase_log_steps: False
96
+
97
+ trainer:
98
+ benchmark: True
configs/latent-diffusion/cin256-v2.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 64
13
+ channels: 3
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss
17
+ use_ema: False
18
+
19
+ unet_config:
20
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ image_size: 64
23
+ in_channels: 3
24
+ out_channels: 3
25
+ model_channels: 192
26
+ attention_resolutions:
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 5
36
+ num_heads: 1
37
+ use_spatial_transformer: true
38
+ transformer_depth: 1
39
+ context_dim: 512
40
+
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 3
45
+ n_embed: 8192
46
+ ddconfig:
47
+ double_z: false
48
+ z_channels: 3
49
+ resolution: 256
50
+ in_channels: 3
51
+ out_ch: 3
52
+ ch: 128
53
+ ch_mult:
54
+ - 1
55
+ - 2
56
+ - 4
57
+ num_res_blocks: 2
58
+ attn_resolutions: []
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config:
64
+ target: ldm.modules.encoders.modules.ClassEmbedder
65
+ params:
66
+ n_classes: 1001
67
+ embed_dim: 512
68
+ key: class_label
configs/latent-diffusion/ffhq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ embed_dim: 3
40
+ n_embed: 8192
41
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 42
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: taming.data.faceshq.FFHQTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: taming.data.faceshq.FFHQValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ ckpt_path: pretrained/first_stage_models/vq-f4/model.ckpt
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 48
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: ldm.data.lsun.LSUNBedroomsTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: ldm.data.lsun.LSUNBedroomsValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0155
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ loss_type: l1
11
+ first_stage_key: "image"
12
+ cond_stage_key: "image"
13
+ image_size: 32
14
+ channels: 4
15
+ cond_stage_trainable: False
16
+ concat_mode: False
17
+ scale_by_std: True
18
+ monitor: 'val/loss_simple_ema'
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [10000]
24
+ cycle_lengths: [10000000000000]
25
+ f_start: [1.e-6]
26
+ f_max: [1.]
27
+ f_min: [ 1.]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 192
36
+ attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39
+ num_heads: 8
40
+ use_scale_shift_norm: True
41
+ resblock_updown: True
42
+
43
+ first_stage_config:
44
+ target: ldm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: "val/rec_loss"
48
+ ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49
+ ddconfig:
50
+ double_z: True
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57
+ num_res_blocks: 2
58
+ attn_resolutions: [ ]
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config: "__is_unconditional__"
64
+
65
+ data:
66
+ target: main.DataModuleFromConfig
67
+ params:
68
+ batch_size: 96
69
+ num_workers: 5
70
+ wrap: False
71
+ train:
72
+ target: ldm.data.lsun.LSUNChurchesTrain
73
+ params:
74
+ size: 256
75
+ validation:
76
+ target: ldm.data.lsun.LSUNChurchesValidation
77
+ params:
78
+ size: 256
79
+
80
+ lightning:
81
+ callbacks:
82
+ image_logger:
83
+ target: main.ImageLogger
84
+ params:
85
+ batch_frequency: 5000
86
+ max_images: 8
87
+ increase_log_steps: False
88
+
89
+
90
+ trainer:
91
+ benchmark: True
configs/latent-diffusion/txt2img-1p4B-eval.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-05
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ image_size: 32
24
+ in_channels: 4
25
+ out_channels: 4
26
+ model_channels: 320
27
+ attention_resolutions:
28
+ - 4
29
+ - 2
30
+ - 1
31
+ num_res_blocks: 2
32
+ channel_mult:
33
+ - 1
34
+ - 2
35
+ - 4
36
+ - 4
37
+ num_heads: 8
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 1280
41
+ use_checkpoint: true
42
+ legacy: False
43
+
44
+ first_stage_config:
45
+ target: ldm.models.autoencoder.AutoencoderKL
46
+ params:
47
+ embed_dim: 4
48
+ monitor: val/rec_loss
49
+ ddconfig:
50
+ double_z: true
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult:
57
+ - 1
58
+ - 2
59
+ - 4
60
+ - 4
61
+ num_res_blocks: 2
62
+ attn_resolutions: []
63
+ dropout: 0.0
64
+ lossconfig:
65
+ target: torch.nn.Identity
66
+
67
+ cond_stage_config:
68
+ target: ldm.modules.encoders.modules.BERTEmbedder
69
+ params:
70
+ n_embed: 1280
71
+ n_layer: 32
configs/latent_diff_LSUN.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: latent_diff
2
+ # model_parameters:
3
+ ckp_path: pretrained/ldm/lsun_beds256/model.ckpt
4
+ config: configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
5
+ solver_name: uni_pc
6
+ unipc_variant: bh2
7
+ steps: 10
8
+ order: 3
9
+ H: 256
10
+ W: 256
11
+ C: 3
12
+ f: 4
13
+ scale: 0.0
14
+ time_mode: time
15
+
16
+ # training_parameters:
17
+ seed: 0
18
+ log_path: all_logs/latent_diff_LSUN_logs/
19
+ data_dir: train_data/train_data_LSUN
20
+ win_rate: 0.5
21
+ prior_bound: 1.0
22
+ fix_bound: false
23
+ loss_type: LPIPS
24
+ main_train_batch_size: 1
25
+ main_valid_batch_size: 50
26
+ num_train: 50
27
+ num_valid: 50
28
+ training_rounds_v1: 2
29
+ training_rounds_v2: 3
30
+ shift_lr: null
31
+ lr_time_1: 0.005
32
+ lr_time_2: 0.001
33
+ min_lr_time_1: 0.00005
34
+ min_lr_time_2: 0.000001
35
+ momentum_time_1: 0.9
36
+ weight_decay_time_1: 0.0
37
+ shift_lr_decay: 0.5
38
+ lr_time_decay: 0.8
39
+ patient: 5
40
+ lr2_patient: 5
41
+ no_v1: false
42
+ visualize: true
43
+
44
+ # testing_parameters:
45
+ learn: false
46
+ load_from: null
47
+ skip_type: null
48
+ num_multi_steps_fid: null
49
+ fid_folder: all_fids/latent_diff_LSUN_fids/
50
+ sampling_batch_size: 50
51
+ sampling_seed: 0
52
+ ref_path: fid-refs/VIRTUAL_lsun_bedroom256.npz
53
+ total_samples: 50000
54
+ save_png: false
55
+ save_pt: true
configs/latent_diff_imn.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: conditioned_latent_diff
2
+ # model_parameters:
3
+ ckp_path: pretrained/ldm/cin256-v2/model.ckpt
4
+ config: configs/latent-diffusion/cin256-v2.yaml
5
+ solver_name: uni_pc
6
+ unipc_variant: bh2
7
+ steps: 10
8
+ order: 3
9
+ H: 256
10
+ W: 256
11
+ C: 3
12
+ f: 4
13
+ scale: 2.0
14
+ time_mode: time
15
+
16
+ # training_parameters:
17
+ seed: 0
18
+ log_path: all_logs/latent_diff_imn_logs/
19
+ data_dir: train_data/train_data_imn
20
+ win_rate: 0.5
21
+ prior_bound: 1.0
22
+ fix_bound: false
23
+ loss_type: LPIPS
24
+ main_train_batch_size: 1
25
+ main_valid_batch_size: 50
26
+ num_train: 50
27
+ num_valid: 50
28
+ training_rounds_v1: 2
29
+ training_rounds_v2: 3
30
+ shift_lr: null
31
+ lr_time_1: 0.005
32
+ lr_time_2: 0.001
33
+ min_lr_time_1: 0.00005
34
+ min_lr_time_2: 0.000001
35
+ momentum_time_1: 0.9
36
+ weight_decay_time_1: 0.0
37
+ shift_lr_decay: 0.5
38
+ lr_time_decay: 0.8
39
+ patient: 5
40
+ lr2_patient: 5
41
+ no_v1: false
42
+ visualize: true
43
+
44
+ # testing_parameters:
45
+ learn: false
46
+ load_from: null
47
+ skip_type: null
48
+ num_multi_steps_fid: null
49
+ fid_folder: all_fids/latent_diff_imn_fids/
50
+ sampling_batch_size: 50
51
+ sampling_seed: 0
52
+ ref_path: fid-refs/VIRTUAL_imagenet256_labeled.npz
53
+ total_samples: 50000
54
+ save_png: false
55
+ save_pt: true
configs/stable-diffusion/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: False
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
configs/stable_diff_v1-4.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: conditioned_latent_diff
2
+ # model_parameters:
3
+ ckp_path: pretrained/ldm/stable-diffusion-v1/sd-v1-4.ckpt
4
+ config: configs/stable-diffusion/v1-inference.yaml
5
+ solver_name: uni_pc
6
+ unipc_variant: bh2
7
+ steps: 10
8
+ order: 2
9
+ H: 512
10
+ W: 512
11
+ C: 4
12
+ f: 8
13
+ scale: 7.5
14
+ time_mode: time
15
+
16
+ # training_parameters:
17
+ seed: 0
18
+ log_path: all_logs/stable_diff_v1-4_logs/
19
+ data_dir: train_data/train_data_stable_diff_v1-4
20
+ win_rate: 0.5
21
+ prior_bound: 1.0
22
+ fix_bound: false
23
+ loss_type: LPIPS
24
+ main_train_batch_size: 1
25
+ main_valid_batch_size: 1
26
+ num_train: 25
27
+ num_valid: 25
28
+ training_rounds_v1: 2
29
+ training_rounds_v2: 3
30
+ shift_lr: null
31
+ lr_time_1: 0.005
32
+ lr_time_2: 0.001
33
+ min_lr_time_1: 0.00005
34
+ min_lr_time_2: 0.000001
35
+ momentum_time_1: 0.9
36
+ weight_decay_time_1: 0.0
37
+ shift_lr_decay: 0.5
38
+ lr_time_decay: 0.8
39
+ patient: 5
40
+ lr2_patient: 5
41
+ no_v1: false
42
+ visualize: true
43
+
44
+ # testing_parameters:
45
+ learn: false
46
+ load_from: null
47
+ skip_type: null
48
+ num_multi_steps_fid: null
49
+ fid_folder: all_fids/stable_diff_v1-4_fids/
50
+ sampling_batch_size: 50
51
+ sampling_seed: 0
52
+ ref_path: null
53
+ total_samples: 50000
54
+ save_png: false
55
+ save_pt: true
configs/stable_diff_v1-5.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: conditioned_latent_diff
2
+ # model_parameters:
3
+ ckp_path: pretrained/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
4
+ config: configs/stable-diffusion/v1-inference.yaml
5
+ solver_name: uni_pc
6
+ unipc_variant: bh2
7
+ steps: 10
8
+ order: 2
9
+ H: 512
10
+ W: 512
11
+ C: 4
12
+ f: 8
13
+ scale: 7.5
14
+ time_mode: time
15
+
16
+ # training_parameters:
17
+ seed: 0
18
+ log_path: all_logs/stable_diff_v1-5_logs/
19
+ data_dir: train_data/train_data_stable_diff_v1-5
20
+ win_rate: 0.5
21
+ prior_bound: 1.0
22
+ fix_bound: false
23
+ loss_type: LPIPS
24
+ main_train_batch_size: 1
25
+ main_valid_batch_size: 1
26
+ num_train: 25
27
+ num_valid: 25
28
+ training_rounds_v1: 2
29
+ training_rounds_v2: 3
30
+ shift_lr: null
31
+ lr_time_1: 0.005
32
+ lr_time_2: 0.001
33
+ min_lr_time_1: 0.00005
34
+ min_lr_time_2: 0.000001
35
+ momentum_time_1: 0.9
36
+ weight_decay_time_1: 0.0
37
+ shift_lr_decay: 0.5
38
+ lr_time_decay: 0.8
39
+ patient: 5
40
+ lr2_patient: 5
41
+ no_v1: false
42
+ visualize: true
43
+
44
+ # testing_parameters:
45
+ learn: false
46
+ load_from: null
47
+ skip_type: null
48
+ num_multi_steps_fid: null
49
+ fid_folder: all_fids/stable_diff_v1-5_fids/
50
+ sampling_batch_size: 50
51
+ sampling_seed: 0
52
+ ref_path: null
53
+ total_samples: 50000
54
+ save_png: false
55
+ save_pt: true
data/coco_captions.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/prompts.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Two individuals learning to ski along with an instructor.
2
+ A man sitting on a chair that is on a deck over the water.
3
+ A dog sitting at a table in front of a plate.
4
+ Four people sit around eating food outside together.
5
+ A cat dips its paws into a cup on a nightstand.
dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ import os
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+
6
+
7
+ def load_data_from_dir(
8
+ data_folder: str, limit: int = 200
9
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[Optional[torch.Tensor]]]:
10
+ latents, targets, conditions, unconditions = [], [], [], []
11
+ pt_files = [f for f in os.listdir(data_folder) if f.endswith('pt')]
12
+ for file_name in sorted(pt_files)[:limit]:
13
+ file_path = os.path.join(data_folder, file_name)
14
+ data = torch.load(file_path)
15
+ latents.append(data["latent"])
16
+ targets.append(data["img"])
17
+ conditions.append(data.get("c", None))
18
+ unconditions.append(data.get("uc", None))
19
+ return latents, targets, conditions, unconditions
20
+
21
+
22
+ class LD3Dataset(Dataset):
23
+ def __init__(
24
+ self,
25
+ ori_latent: List[torch.Tensor],
26
+ latent: List[torch.Tensor],
27
+ target: List[torch.Tensor],
28
+ condition: List[Optional[torch.Tensor]],
29
+ uncondition: List[Optional[torch.Tensor]],
30
+ ):
31
+ self.ori_latent = ori_latent
32
+ self.latent = latent
33
+ self.target = target
34
+ self.condition = condition
35
+ self.uncondition = uncondition
36
+
37
+ def __len__(self) -> int:
38
+ return len(self.ori_latent)
39
+
40
+ def __getitem__(self, idx: int):
41
+ img = self.target[idx]
42
+ latent = self.latent[idx]
43
+ ori_latent = self.ori_latent[idx]
44
+ condition = self.condition[idx]
45
+ uncondition = self.uncondition[idx]
46
+ return img, latent, ori_latent, condition, uncondition
dnnlib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import tempfile
27
+ import urllib
28
+ import urllib.request
29
+ import uuid
30
+
31
+ from distutils.util import strtobool
32
+ from typing import Any, List, Tuple, Union, Optional
33
+
34
+
35
+ # Util classes
36
+ # ------------------------------------------------------------------------------------------
37
+
38
+
39
+ class EasyDict(dict):
40
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41
+
42
+ def __getattr__(self, name: str) -> Any:
43
+ try:
44
+ return self[name]
45
+ except KeyError:
46
+ raise AttributeError(name)
47
+
48
+ def __setattr__(self, name: str, value: Any) -> None:
49
+ self[name] = value
50
+
51
+ def __delattr__(self, name: str) -> None:
52
+ del self[name]
53
+
54
+
55
+ class Logger(object):
56
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57
+
58
+ def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59
+ self.file = None
60
+
61
+ if file_name is not None:
62
+ self.file = open(file_name, file_mode)
63
+
64
+ self.should_flush = should_flush
65
+ self.stdout = sys.stdout
66
+ self.stderr = sys.stderr
67
+
68
+ sys.stdout = self
69
+ sys.stderr = self
70
+
71
+ def __enter__(self) -> "Logger":
72
+ return self
73
+
74
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75
+ self.close()
76
+
77
+ def write(self, text: Union[str, bytes]) -> None:
78
+ """Write text to stdout (and a file) and optionally flush."""
79
+ if isinstance(text, bytes):
80
+ text = text.decode()
81
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82
+ return
83
+
84
+ if self.file is not None:
85
+ self.file.write(text)
86
+
87
+ self.stdout.write(text)
88
+
89
+ if self.should_flush:
90
+ self.flush()
91
+
92
+ def flush(self) -> None:
93
+ """Flush written text to both stdout and a file, if open."""
94
+ if self.file is not None:
95
+ self.file.flush()
96
+
97
+ self.stdout.flush()
98
+
99
+ def close(self) -> None:
100
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
101
+ self.flush()
102
+
103
+ # if using multiple loggers, prevent closing in wrong order
104
+ if sys.stdout is self:
105
+ sys.stdout = self.stdout
106
+ if sys.stderr is self:
107
+ sys.stderr = self.stderr
108
+
109
+ if self.file is not None:
110
+ self.file.close()
111
+ self.file = None
112
+
113
+
114
+ # Cache directories
115
+ # ------------------------------------------------------------------------------------------
116
+
117
+ _dnnlib_cache_dir = None
118
+
119
+ def set_cache_dir(path: str) -> None:
120
+ global _dnnlib_cache_dir
121
+ _dnnlib_cache_dir = path
122
+
123
+ def make_cache_dir_path(*paths: str) -> str:
124
+ if _dnnlib_cache_dir is not None:
125
+ return os.path.join(_dnnlib_cache_dir, *paths)
126
+ if 'DNNLIB_CACHE_DIR' in os.environ:
127
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128
+ if 'HOME' in os.environ:
129
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130
+ if 'USERPROFILE' in os.environ:
131
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133
+
134
+ # Small util functions
135
+ # ------------------------------------------------------------------------------------------
136
+
137
+
138
+ def format_time(seconds: Union[int, float]) -> str:
139
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140
+ s = int(np.rint(seconds))
141
+
142
+ if s < 60:
143
+ return "{0}s".format(s)
144
+ elif s < 60 * 60:
145
+ return "{0}m {1:02}s".format(s // 60, s % 60)
146
+ elif s < 24 * 60 * 60:
147
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148
+ else:
149
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150
+
151
+
152
+ def format_time_brief(seconds: Union[int, float]) -> str:
153
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154
+ s = int(np.rint(seconds))
155
+
156
+ if s < 60:
157
+ return "{0}s".format(s)
158
+ elif s < 60 * 60:
159
+ return "{0}m {1:02}s".format(s // 60, s % 60)
160
+ elif s < 24 * 60 * 60:
161
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162
+ else:
163
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164
+
165
+
166
+ def ask_yes_no(question: str) -> bool:
167
+ """Ask the user the question until the user inputs a valid answer."""
168
+ while True:
169
+ try:
170
+ print("{0} [y/n]".format(question))
171
+ return strtobool(input().lower())
172
+ except ValueError:
173
+ pass
174
+
175
+
176
+ def tuple_product(t: Tuple) -> Any:
177
+ """Calculate the product of the tuple elements."""
178
+ result = 1
179
+
180
+ for v in t:
181
+ result *= v
182
+
183
+ return result
184
+
185
+
186
+ _str_to_ctype = {
187
+ "uint8": ctypes.c_ubyte,
188
+ "uint16": ctypes.c_uint16,
189
+ "uint32": ctypes.c_uint32,
190
+ "uint64": ctypes.c_uint64,
191
+ "int8": ctypes.c_byte,
192
+ "int16": ctypes.c_int16,
193
+ "int32": ctypes.c_int32,
194
+ "int64": ctypes.c_int64,
195
+ "float32": ctypes.c_float,
196
+ "float64": ctypes.c_double
197
+ }
198
+
199
+
200
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202
+ type_str = None
203
+
204
+ if isinstance(type_obj, str):
205
+ type_str = type_obj
206
+ elif hasattr(type_obj, "__name__"):
207
+ type_str = type_obj.__name__
208
+ elif hasattr(type_obj, "name"):
209
+ type_str = type_obj.name
210
+ else:
211
+ raise RuntimeError("Cannot infer type name from input")
212
+
213
+ assert type_str in _str_to_ctype.keys()
214
+
215
+ my_dtype = np.dtype(type_str)
216
+ my_ctype = _str_to_ctype[type_str]
217
+
218
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219
+
220
+ return my_dtype, my_ctype
221
+
222
+
223
+ def is_pickleable(obj: Any) -> bool:
224
+ try:
225
+ with io.BytesIO() as stream:
226
+ pickle.dump(obj, stream)
227
+ return True
228
+ except:
229
+ return False
230
+
231
+
232
+ # Functionality to import modules/objects by name, and call functions by name
233
+ # ------------------------------------------------------------------------------------------
234
+
235
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236
+ """Searches for the underlying module behind the name to some python object.
237
+ Returns the module and the object name (original name with module part removed)."""
238
+
239
+ # allow convenience shorthands, substitute them by full names
240
+ obj_name = re.sub("^np.", "numpy.", obj_name)
241
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242
+
243
+ # list alternatives for (module_name, local_obj_name)
244
+ parts = obj_name.split(".")
245
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246
+
247
+ # try each alternative in turn
248
+ for module_name, local_obj_name in name_pairs:
249
+ try:
250
+ module = importlib.import_module(module_name) # may raise ImportError
251
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
252
+ return module, local_obj_name
253
+ except:
254
+ pass
255
+
256
+ # maybe some of the modules themselves contain errors?
257
+ for module_name, _local_obj_name in name_pairs:
258
+ try:
259
+ importlib.import_module(module_name) # may raise ImportError
260
+ except ImportError:
261
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262
+ raise
263
+
264
+ # maybe the requested attribute is missing?
265
+ for module_name, local_obj_name in name_pairs:
266
+ try:
267
+ module = importlib.import_module(module_name) # may raise ImportError
268
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
269
+ except ImportError:
270
+ pass
271
+
272
+ # we are out of luck, but we have no idea why
273
+ raise ImportError(obj_name)
274
+
275
+
276
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277
+ """Traverses the object name and returns the last (rightmost) python object."""
278
+ if obj_name == '':
279
+ return module
280
+ obj = module
281
+ for part in obj_name.split("."):
282
+ obj = getattr(obj, part)
283
+ return obj
284
+
285
+
286
+ def get_obj_by_name(name: str) -> Any:
287
+ """Finds the python object with the given name."""
288
+ module, obj_name = get_module_from_obj_name(name)
289
+ return get_obj_from_module(module, obj_name)
290
+
291
+
292
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293
+ """Finds the python object with the given name and calls it as a function."""
294
+ assert func_name is not None
295
+ func_obj = get_obj_by_name(func_name)
296
+ assert callable(func_obj)
297
+ return func_obj(*args, **kwargs)
298
+
299
+
300
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301
+ """Finds the python class with the given name and constructs it with the given arguments."""
302
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
303
+
304
+
305
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
306
+ """Get the directory path of the module containing the given object name."""
307
+ module, _ = get_module_from_obj_name(obj_name)
308
+ return os.path.dirname(inspect.getfile(module))
309
+
310
+
311
+ def is_top_level_function(obj: Any) -> bool:
312
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314
+
315
+
316
+ def get_top_level_function_name(obj: Any) -> str:
317
+ """Return the fully-qualified name of a top-level function."""
318
+ assert is_top_level_function(obj)
319
+ module = obj.__module__
320
+ if module == '__main__':
321
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322
+ return module + "." + obj.__name__
323
+
324
+
325
+ # File system helpers
326
+ # ------------------------------------------------------------------------------------------
327
+
328
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329
+ """List all files recursively in a given directory while ignoring given file and directory names.
330
+ Returns list of tuples containing both absolute and relative paths."""
331
+ assert os.path.isdir(dir_path)
332
+ base_name = os.path.basename(os.path.normpath(dir_path))
333
+
334
+ if ignores is None:
335
+ ignores = []
336
+
337
+ result = []
338
+
339
+ for root, dirs, files in os.walk(dir_path, topdown=True):
340
+ for ignore_ in ignores:
341
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342
+
343
+ # dirs need to be edited in-place
344
+ for d in dirs_to_remove:
345
+ dirs.remove(d)
346
+
347
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348
+
349
+ absolute_paths = [os.path.join(root, f) for f in files]
350
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351
+
352
+ if add_base_to_relative:
353
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354
+
355
+ assert len(absolute_paths) == len(relative_paths)
356
+ result += zip(absolute_paths, relative_paths)
357
+
358
+ return result
359
+
360
+
361
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362
+ """Takes in a list of tuples of (src, dst) paths and copies files.
363
+ Will create all necessary directories."""
364
+ for file in files:
365
+ target_dir_name = os.path.dirname(file[1])
366
+
367
+ # will create all intermediate-level directories
368
+ if not os.path.exists(target_dir_name):
369
+ os.makedirs(target_dir_name)
370
+
371
+ shutil.copyfile(file[0], file[1])
372
+
373
+
374
+ # URL helpers
375
+ # ------------------------------------------------------------------------------------------
376
+
377
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378
+ """Determine whether the given object is a valid URL string."""
379
+ if not isinstance(obj, str) or not "://" in obj:
380
+ return False
381
+ if allow_file_urls and obj.startswith('file://'):
382
+ return True
383
+ try:
384
+ res = requests.compat.urlparse(obj)
385
+ if not res.scheme or not res.netloc or not "." in res.netloc:
386
+ return False
387
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388
+ if not res.scheme or not res.netloc or not "." in res.netloc:
389
+ return False
390
+ except:
391
+ return False
392
+ return True
393
+
394
+
395
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396
+ """Download the given URL and return a binary-mode file object to access the data."""
397
+ assert num_attempts >= 1
398
+ assert not (return_filename and (not cache))
399
+
400
+ # Doesn't look like an URL scheme so interpret it as a local filename.
401
+ if not re.match('^[a-z]+://', url):
402
+ return url if return_filename else open(url, "rb")
403
+
404
+ # Handle file URLs. This code handles unusual file:// patterns that
405
+ # arise on Windows:
406
+ #
407
+ # file:///c:/foo.txt
408
+ #
409
+ # which would translate to a local '/c:/foo.txt' filename that's
410
+ # invalid. Drop the forward slash for such pathnames.
411
+ #
412
+ # If you touch this code path, you should test it on both Linux and
413
+ # Windows.
414
+ #
415
+ # Some internet resources suggest using urllib.request.url2pathname() but
416
+ # but that converts forward slashes to backslashes and this causes
417
+ # its own set of problems.
418
+ if url.startswith('file://'):
419
+ filename = urllib.parse.urlparse(url).path
420
+ if re.match(r'^/[a-zA-Z]:', filename):
421
+ filename = filename[1:]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ assert is_url(url)
425
+
426
+ # Lookup from cache.
427
+ if cache_dir is None:
428
+ cache_dir = make_cache_dir_path('downloads')
429
+
430
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431
+ if cache:
432
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433
+ if len(cache_files) == 1:
434
+ filename = cache_files[0]
435
+ return filename if return_filename else open(filename, "rb")
436
+
437
+ # Download.
438
+ url_name = None
439
+ url_data = None
440
+ with requests.Session() as session:
441
+ if verbose:
442
+ print("Downloading %s ..." % url, end="", flush=True)
443
+ for attempts_left in reversed(range(num_attempts)):
444
+ try:
445
+ with session.get(url) as res:
446
+ res.raise_for_status()
447
+ if len(res.content) == 0:
448
+ raise IOError("No data received")
449
+
450
+ if len(res.content) < 8192:
451
+ content_str = res.content.decode("utf-8")
452
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
453
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454
+ if len(links) == 1:
455
+ url = requests.compat.urljoin(url, links[0])
456
+ raise IOError("Google Drive virus checker nag")
457
+ if "Google Drive - Quota exceeded" in content_str:
458
+ raise IOError("Google Drive download quota exceeded -- please try again later")
459
+
460
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461
+ url_name = match[1] if match else url
462
+ url_data = res.content
463
+ if verbose:
464
+ print(" done")
465
+ break
466
+ except KeyboardInterrupt:
467
+ raise
468
+ except:
469
+ if not attempts_left:
470
+ if verbose:
471
+ print(" failed")
472
+ raise
473
+ if verbose:
474
+ print(".", end="", flush=True)
475
+
476
+ # Save to cache.
477
+ if cache:
478
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479
+ safe_name = safe_name[:min(len(safe_name), 128)]
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
gen_data.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tqdm import tqdm
4
+ from utils import (
5
+ get_solvers,
6
+ parse_arguments,
7
+ prepare_paths,
8
+ adjust_hyper,
9
+ )
10
+ from models import prepare_stuff, prepare_condition_loader
11
+ import time
12
+ import numpy as np
13
+ import PIL.Image
14
+
15
+ def get_data_inverse_scaler(centered=True):
16
+ """Inverse data normalizer."""
17
+ if centered:
18
+ # Rescale [-1, 1] to [0, 1]
19
+ return lambda x: (x + 1.0) / 2.0
20
+ else:
21
+ return lambda x: x
22
+
23
+
24
+ class Generator:
25
+ def __init__(
26
+ self,
27
+ noise_schedule,
28
+ solver,
29
+ order,
30
+ skip_type=None,
31
+ load_from=None,
32
+ timesteps_1=None,
33
+ timesteps_2=None,
34
+ steps=35,
35
+ solver_extra_params=None,
36
+ device=None,
37
+ ) -> None:
38
+ self.device = device
39
+ self.noise_schedule = noise_schedule
40
+ self.solver = solver
41
+ self.order = order
42
+ self.skip_type = skip_type
43
+ self.load_from = load_from
44
+ self.timesteps_1 = timesteps_1
45
+ self.timesteps_2 = timesteps_2
46
+ self.steps = steps
47
+ self.solver_extra_params = solver_extra_params
48
+
49
+ self._precompute_timesteps()
50
+
51
+ def _precompute_timesteps(self):
52
+ if self.load_from is None and type(self.timesteps_1) == list and type(self.timesteps_1[0]) == float \
53
+ and type(self.timesteps_2) == list and type(self.timesteps_2[0]) == float:
54
+ self.timesteps = self.noise_schedule.inverse_lambda(-np.log(self.timesteps_1)).to(self.device).float()
55
+ self.timesteps2 = self.noise_schedule.inverse_lambda(-np.log(self.timesteps_2)).to(self.device).float()
56
+ else:
57
+ self.timesteps, self.timesteps2 = self.solver.prepare_timesteps(
58
+ steps=self.steps,
59
+ t_start=self.noise_schedule.T,
60
+ t_end=self.noise_schedule.eps,
61
+ skip_type=self.skip_type,
62
+ device=self.device,
63
+ load_from=self.load_from,
64
+ )
65
+
66
+ def _sample(self, net, decoding_fn, latents, condition=None, unconditional_condition=None):
67
+ x_next_ = self.noise_schedule.prior_transformation(latents)
68
+ x_next_ = self.solver.sample_simple(
69
+ model_fn=net,
70
+ x=x_next_,
71
+ timesteps=self.timesteps,
72
+ timesteps2=self.timesteps2,
73
+ order=self.order,
74
+ NFEs=self.steps,
75
+ condition=condition,
76
+ unconditional_condition=unconditional_condition,
77
+ **self.solver_extra_params,
78
+ )
79
+ x_next_ = decoding_fn(x_next_)
80
+ return x_next_
81
+
82
+ def sample(self, net, decoding_fn, latents, condition=None, unconditional_condition=None, no_grad=True):
83
+ if no_grad:
84
+ with torch.no_grad():
85
+ return self._sample(net, decoding_fn, latents, condition, unconditional_condition)
86
+ else:
87
+ return self._sample(net, decoding_fn, latents, condition, unconditional_condition)
88
+
89
+ def main(args):
90
+
91
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
92
+ wrapped_model, model, decoding_fn, noise_schedule, latent_resolution, latent_channel, img_resolution, img_channel = prepare_stuff(args)
93
+ condition_loader = prepare_condition_loader(model_type=args.model,
94
+ model=model,
95
+ scale=args.scale if hasattr(args, "scale") else None,
96
+ condition=args.prompt_path or "random",
97
+ sampling_batch_size=args.sampling_batch_size,
98
+ num_prompt=args.num_prompts,
99
+ num_samples_per_prompt=args.num_samples_per_prompt,
100
+ )
101
+ adjust_hyper(args, latent_resolution, latent_channel)
102
+ desc, _, skip_type = prepare_paths(args)
103
+ data_dir = os.path.join(args.data_dir, desc)
104
+ os.makedirs(data_dir, exist_ok=True)
105
+
106
+ solver, steps, solver_extra_params = get_solvers(
107
+ args.solver_name,
108
+ NFEs=args.steps,
109
+ order=args.order,
110
+ noise_schedule=noise_schedule,
111
+ unipc_variant=args.unipc_variant,
112
+ )
113
+
114
+ generator = Generator(
115
+ noise_schedule=noise_schedule,
116
+ solver=solver,
117
+ order=args.order,
118
+ skip_type=skip_type,
119
+ load_from=args.load_from,
120
+ timesteps_1=args.custom_ts_1,
121
+ timesteps_2=args.custom_ts_2,
122
+ steps=steps,
123
+ solver_extra_params=solver_extra_params,
124
+ device=device,
125
+ )
126
+
127
+ print(generator.timesteps, generator.timesteps2)
128
+ inverse_scalar = get_data_inverse_scaler(centered=True)
129
+
130
+ start = time.time()
131
+ count = 0
132
+ batch_size = args.sampling_batch_size
133
+ if args.prompt_path is not None:
134
+ args.total_samples = min(args.total_samples, len(condition_loader.prompts))
135
+ num_batches = (args.total_samples + batch_size - 1) // batch_size
136
+
137
+ for i in tqdm(range(num_batches)):
138
+ current_batch_size = min(batch_size, args.total_samples - i * batch_size)
139
+ sampling_shape = (current_batch_size, latent_channel, latent_resolution, latent_resolution)
140
+ latents = torch.randn(sampling_shape, device=device)
141
+
142
+ if condition_loader is not None:
143
+ conditioning, conditioned_unconditioning = next(condition_loader)
144
+ else:
145
+ conditioning = None
146
+ conditioned_unconditioning = None
147
+
148
+ img_teacher = generator.sample(wrapped_model, decoding_fn, latents, conditioning, conditioned_unconditioning)
149
+
150
+ img_teacher = img_teacher.detach().cpu().view(current_batch_size, img_channel, img_resolution, img_resolution)
151
+ latents = latents.detach().cpu()
152
+
153
+ if args.save_pt:
154
+ for i in range(current_batch_size):
155
+ latent = latents[i]
156
+ img = img_teacher[i]
157
+ c = conditioning[i] if conditioning is not None else None
158
+ uc = conditioned_unconditioning[i] if conditioned_unconditioning is not None else None
159
+ data = dict(latent=latent, img=img, c=c, uc=uc)
160
+ torch.save(data, os.path.join(data_dir, f"latent_{(count + i):06d}.pt"))
161
+
162
+ if args.save_png:
163
+ samples_raw = inverse_scalar(img_teacher)
164
+ samples = np.clip(
165
+ samples_raw.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255
166
+ ).astype(np.uint8)
167
+ images_np = samples.reshape((-1, img_resolution, img_resolution, img_channel))
168
+
169
+ for i in range(current_batch_size):
170
+ image_np = images_np[i]
171
+ if args.prompt_path is not None and args.prompt_path.startswith('hpsv2'):
172
+ image_path = os.path.join(data_dir, f"{(count + i):05d}.jpg")
173
+ else:
174
+ image_path = os.path.join(data_dir, f"{(count + i):06d}.png")
175
+ if image_np.shape[2] == 1:
176
+ PIL.Image.fromarray(image_np[:, :, 0], "L").save(image_path)
177
+ else:
178
+ PIL.Image.fromarray(image_np, "RGB").save(image_path)
179
+
180
+ count += batch_size
181
+
182
+ end = time.time()
183
+ print(f"Generation time: {end - start}")
184
+
185
+
186
+ if __name__ == "__main__":
187
+ args = parse_arguments()
188
+ main(args)
ldm/__init__.py ADDED
File without changes
ldm/data/__init__.py ADDED
File without changes
ldm/data/base.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+
4
+
5
+ class Txt2ImgIterableBaseDataset(IterableDataset):
6
+ '''
7
+ Define an interface to make the IterableDatasets for text2img data chainable
8
+ '''
9
+ def __init__(self, num_records=0, valid_ids=None, size=256):
10
+ super().__init__()
11
+ self.num_records = num_records
12
+ self.valid_ids = valid_ids
13
+ self.sample_ids = valid_ids
14
+ self.size = size
15
+
16
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
+
18
+ def __len__(self):
19
+ return self.num_records
20
+
21
+ @abstractmethod
22
+ def __iter__(self):
23
+ pass
ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+ if ckpt_path is not None:
58
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
59
+ self.scheduler_config = scheduler_config
60
+ self.lr_g_factor = lr_g_factor
61
+
62
+ @contextmanager
63
+ def ema_scope(self, context=None):
64
+ if self.use_ema:
65
+ self.model_ema.store(self.parameters())
66
+ self.model_ema.copy_to(self)
67
+ if context is not None:
68
+ print(f"{context}: Switched to EMA weights")
69
+ try:
70
+ yield None
71
+ finally:
72
+ if self.use_ema:
73
+ self.model_ema.restore(self.parameters())
74
+ if context is not None:
75
+ print(f"{context}: Restored training weights")
76
+
77
+ def init_from_ckpt(self, path, ignore_keys=list()):
78
+ sd = torch.load(path, map_location="cpu")["state_dict"]
79
+ keys = list(sd.keys())
80
+ for k in keys:
81
+ for ik in ignore_keys:
82
+ if k.startswith(ik):
83
+ print("Deleting key {} from state_dict.".format(k))
84
+ del sd[k]
85
+ missing, unexpected = self.load_state_dict(sd, strict=False)
86
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
87
+ if len(missing) > 0:
88
+ print(f"Missing Keys: {missing}")
89
+ print(f"Unexpected Keys: {unexpected}")
90
+
91
+ def on_train_batch_end(self, *args, **kwargs):
92
+ if self.use_ema:
93
+ self.model_ema(self)
94
+
95
+ def encode(self, x):
96
+ h = self.encoder(x)
97
+ h = self.quant_conv(h)
98
+ quant, emb_loss, info = self.quantize(h)
99
+ return quant, emb_loss, info
100
+
101
+ def encode_to_prequant(self, x):
102
+ h = self.encoder(x)
103
+ h = self.quant_conv(h)
104
+ return h
105
+
106
+ def decode(self, quant):
107
+ quant = self.post_quant_conv(quant)
108
+ dec = self.decoder(quant)
109
+ return dec
110
+
111
+ def decode_code(self, code_b):
112
+ quant_b = self.quantize.embed_code(code_b)
113
+ dec = self.decode(quant_b)
114
+ return dec
115
+
116
+ def forward(self, input, return_pred_indices=False):
117
+ quant, diff, (_,_,ind) = self.encode(input)
118
+ dec = self.decode(quant)
119
+ if return_pred_indices:
120
+ return dec, diff, ind
121
+ return dec, diff
122
+
123
+ def get_input(self, batch, k):
124
+ x = batch[k]
125
+ if len(x.shape) == 3:
126
+ x = x[..., None]
127
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
128
+ if self.batch_resize_range is not None:
129
+ lower_size = self.batch_resize_range[0]
130
+ upper_size = self.batch_resize_range[1]
131
+ if self.global_step <= 4:
132
+ # do the first few batches with max size to avoid later oom
133
+ new_resize = upper_size
134
+ else:
135
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
136
+ if new_resize != x.shape[2]:
137
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
138
+ x = x.detach()
139
+ return x
140
+
141
+ def training_step(self, batch, batch_idx, optimizer_idx):
142
+ # https://github.com/pytorch/pytorch/issues/37142
143
+ # try not to fool the heuristics
144
+ x = self.get_input(batch, self.image_key)
145
+ xrec, qloss, ind = self(x, return_pred_indices=True)
146
+
147
+ if optimizer_idx == 0:
148
+ # autoencode
149
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
150
+ last_layer=self.get_last_layer(), split="train",
151
+ predicted_indices=ind)
152
+
153
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
154
+ return aeloss
155
+
156
+ if optimizer_idx == 1:
157
+ # discriminator
158
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
159
+ last_layer=self.get_last_layer(), split="train")
160
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
161
+ return discloss
162
+
163
+ def validation_step(self, batch, batch_idx):
164
+ log_dict = self._validation_step(batch, batch_idx)
165
+ with self.ema_scope():
166
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
167
+ return log_dict
168
+
169
+ def _validation_step(self, batch, batch_idx, suffix=""):
170
+ x = self.get_input(batch, self.image_key)
171
+ xrec, qloss, ind = self(x, return_pred_indices=True)
172
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
173
+ self.global_step,
174
+ last_layer=self.get_last_layer(),
175
+ split="val"+suffix,
176
+ predicted_indices=ind
177
+ )
178
+
179
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
180
+ self.global_step,
181
+ last_layer=self.get_last_layer(),
182
+ split="val"+suffix,
183
+ predicted_indices=ind
184
+ )
185
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
186
+ self.log(f"val{suffix}/rec_loss", rec_loss,
187
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
188
+ self.log(f"val{suffix}/aeloss", aeloss,
189
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
190
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
191
+ del log_dict_ae[f"val{suffix}/rec_loss"]
192
+ self.log_dict(log_dict_ae)
193
+ self.log_dict(log_dict_disc)
194
+ return self.log_dict
195
+
196
+ def configure_optimizers(self):
197
+ lr_d = self.learning_rate
198
+ lr_g = self.lr_g_factor*self.learning_rate
199
+ print("lr_d", lr_d)
200
+ print("lr_g", lr_g)
201
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
202
+ list(self.decoder.parameters())+
203
+ list(self.quantize.parameters())+
204
+ list(self.quant_conv.parameters())+
205
+ list(self.post_quant_conv.parameters()),
206
+ lr=lr_g, betas=(0.5, 0.9))
207
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
208
+ lr=lr_d, betas=(0.5, 0.9))
209
+
210
+ if self.scheduler_config is not None:
211
+ scheduler = instantiate_from_config(self.scheduler_config)
212
+
213
+ print("Setting up LambdaLR scheduler...")
214
+ scheduler = [
215
+ {
216
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
217
+ 'interval': 'step',
218
+ 'frequency': 1
219
+ },
220
+ {
221
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
222
+ 'interval': 'step',
223
+ 'frequency': 1
224
+ },
225
+ ]
226
+ return [opt_ae, opt_disc], scheduler
227
+ return [opt_ae, opt_disc], []
228
+
229
+ def get_last_layer(self):
230
+ return self.decoder.conv_out.weight
231
+
232
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
233
+ log = dict()
234
+ x = self.get_input(batch, self.image_key)
235
+ x = x.to(self.device)
236
+ if only_inputs:
237
+ log["inputs"] = x
238
+ return log
239
+ xrec, _ = self(x)
240
+ if x.shape[1] > 3:
241
+ # colorize with random projection
242
+ assert xrec.shape[1] > 3
243
+ x = self.to_rgb(x)
244
+ xrec = self.to_rgb(xrec)
245
+ log["inputs"] = x
246
+ log["reconstructions"] = xrec
247
+ if plot_ema:
248
+ with self.ema_scope():
249
+ xrec_ema, _ = self(x)
250
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
251
+ log["reconstructions_ema"] = xrec_ema
252
+ return log
253
+
254
+ def to_rgb(self, x):
255
+ assert self.image_key == "segmentation"
256
+ if not hasattr(self, "colorize"):
257
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
258
+ x = F.conv2d(x, weight=self.colorize)
259
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
260
+ return x
261
+
262
+
263
+ class VQModelInterface(VQModel):
264
+ def __init__(self, embed_dim, *args, **kwargs):
265
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
266
+ self.embed_dim = embed_dim
267
+
268
+ def encode(self, x):
269
+ h = self.encoder(x)
270
+ h = self.quant_conv(h)
271
+ return h
272
+
273
+ def decode(self, h, force_not_quantize=False):
274
+ # also go through quantization layer
275
+ if not force_not_quantize:
276
+ quant, emb_loss, info = self.quantize(h)
277
+ else:
278
+ quant = h
279
+ quant = self.post_quant_conv(quant)
280
+ dec = self.decoder(quant)
281
+ return dec
282
+
283
+
284
+ class AutoencoderKL(pl.LightningModule):
285
+ def __init__(self,
286
+ ddconfig,
287
+ lossconfig,
288
+ embed_dim,
289
+ ckpt_path=None,
290
+ ignore_keys=[],
291
+ image_key="image",
292
+ colorize_nlabels=None,
293
+ monitor=None,
294
+ ):
295
+ super().__init__()
296
+ self.image_key = image_key
297
+ self.encoder = Encoder(**ddconfig)
298
+ self.decoder = Decoder(**ddconfig)
299
+ self.loss = instantiate_from_config(lossconfig)
300
+ assert ddconfig["double_z"]
301
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
302
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
303
+ self.embed_dim = embed_dim
304
+ if colorize_nlabels is not None:
305
+ assert type(colorize_nlabels)==int
306
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
307
+ if monitor is not None:
308
+ self.monitor = monitor
309
+ if ckpt_path is not None:
310
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
311
+
312
+ def init_from_ckpt(self, path, ignore_keys=list()):
313
+ sd = torch.load(path, map_location="cpu")["state_dict"]
314
+ keys = list(sd.keys())
315
+ for k in keys:
316
+ for ik in ignore_keys:
317
+ if k.startswith(ik):
318
+ print("Deleting key {} from state_dict.".format(k))
319
+ del sd[k]
320
+ self.load_state_dict(sd, strict=False)
321
+ print(f"Restored from {path}")
322
+
323
+ def encode(self, x):
324
+ h = self.encoder(x)
325
+ moments = self.quant_conv(h)
326
+ posterior = DiagonalGaussianDistribution(moments)
327
+ return posterior
328
+
329
+ def decode(self, z):
330
+ z = self.post_quant_conv(z)
331
+ dec = self.decoder(z)
332
+ return dec
333
+
334
+ def forward(self, input, sample_posterior=True):
335
+ posterior = self.encode(input)
336
+ if sample_posterior:
337
+ z = posterior.sample()
338
+ else:
339
+ z = posterior.mode()
340
+ dec = self.decode(z)
341
+ return dec, posterior
342
+
343
+ def get_input(self, batch, k):
344
+ x = batch[k]
345
+ if len(x.shape) == 3:
346
+ x = x[..., None]
347
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
348
+ return x
349
+
350
+ def training_step(self, batch, batch_idx, optimizer_idx):
351
+ inputs = self.get_input(batch, self.image_key)
352
+ reconstructions, posterior = self(inputs)
353
+
354
+ if optimizer_idx == 0:
355
+ # train encoder+decoder+logvar
356
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
357
+ last_layer=self.get_last_layer(), split="train")
358
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
359
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
360
+ return aeloss
361
+
362
+ if optimizer_idx == 1:
363
+ # train the discriminator
364
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
365
+ last_layer=self.get_last_layer(), split="train")
366
+
367
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
368
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
369
+ return discloss
370
+
371
+ def validation_step(self, batch, batch_idx):
372
+ inputs = self.get_input(batch, self.image_key)
373
+ reconstructions, posterior = self(inputs)
374
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
375
+ last_layer=self.get_last_layer(), split="val")
376
+
377
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
378
+ last_layer=self.get_last_layer(), split="val")
379
+
380
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
381
+ self.log_dict(log_dict_ae)
382
+ self.log_dict(log_dict_disc)
383
+ return self.log_dict
384
+
385
+ def configure_optimizers(self):
386
+ lr = self.learning_rate
387
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
388
+ list(self.decoder.parameters())+
389
+ list(self.quant_conv.parameters())+
390
+ list(self.post_quant_conv.parameters()),
391
+ lr=lr, betas=(0.5, 0.9))
392
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
393
+ lr=lr, betas=(0.5, 0.9))
394
+ return [opt_ae, opt_disc], []
395
+
396
+ def get_last_layer(self):
397
+ return self.decoder.conv_out.weight
398
+
399
+ @torch.no_grad()
400
+ def log_images(self, batch, only_inputs=False, **kwargs):
401
+ log = dict()
402
+ x = self.get_input(batch, self.image_key)
403
+ x = x.to(self.device)
404
+ if not only_inputs:
405
+ xrec, posterior = self(x)
406
+ if x.shape[1] > 3:
407
+ # colorize with random projection
408
+ assert xrec.shape[1] > 3
409
+ x = self.to_rgb(x)
410
+ xrec = self.to_rgb(xrec)
411
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
412
+ log["reconstructions"] = xrec
413
+ log["inputs"] = x
414
+ return log
415
+
416
+ def to_rgb(self, x):
417
+ assert self.image_key == "segmentation"
418
+ if not hasattr(self, "colorize"):
419
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
420
+ x = F.conv2d(x, weight=self.colorize)
421
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
422
+ return x
423
+
424
+
425
+ class IdentityFirstStage(torch.nn.Module):
426
+ def __init__(self, *args, vq_interface=False, **kwargs):
427
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
428
+ super().__init__()
429
+
430
+ def encode(self, x, *args, **kwargs):
431
+ return x
432
+
433
+ def decode(self, x, *args, **kwargs):
434
+ return x
435
+
436
+ def quantize(self, x, *args, **kwargs):
437
+ if self.vq_interface:
438
+ return x, None, [None, None, None]
439
+ return x
440
+
441
+ def forward(self, x, *args, **kwargs):
442
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
+ extract_into_tensor
10
+
11
+
12
+ class DDIMSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
48
+ self.register_buffer('ddim_alphas', ddim_alphas)
49
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
+
56
+ @torch.no_grad()
57
+ def sample(self,
58
+ S,
59
+ batch_size,
60
+ shape,
61
+ conditioning=None,
62
+ callback=None,
63
+ normals_sequence=None,
64
+ img_callback=None,
65
+ quantize_x0=False,
66
+ eta=0.,
67
+ mask=None,
68
+ x0=None,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ score_corrector=None,
72
+ corrector_kwargs=None,
73
+ verbose=True,
74
+ x_T=None,
75
+ log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
+ **kwargs
80
+ ):
81
+ if conditioning is not None:
82
+ if isinstance(conditioning, dict):
83
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84
+ if cbs != batch_size:
85
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86
+ else:
87
+ if conditioning.shape[0] != batch_size:
88
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89
+
90
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91
+ # sampling
92
+ C, H, W = shape
93
+ size = (batch_size, C, H, W)
94
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
95
+
96
+ samples, intermediates = self.ddim_sampling(conditioning, size,
97
+ callback=callback,
98
+ img_callback=img_callback,
99
+ quantize_denoised=quantize_x0,
100
+ mask=mask, x0=x0,
101
+ ddim_use_original_steps=False,
102
+ noise_dropout=noise_dropout,
103
+ temperature=temperature,
104
+ score_corrector=score_corrector,
105
+ corrector_kwargs=corrector_kwargs,
106
+ x_T=x_T,
107
+ log_every_t=log_every_t,
108
+ unconditional_guidance_scale=unconditional_guidance_scale,
109
+ unconditional_conditioning=unconditional_conditioning,
110
+ )
111
+ return samples, intermediates
112
+
113
+ @torch.no_grad()
114
+ def ddim_sampling(self, cond, shape,
115
+ x_T=None, ddim_use_original_steps=False,
116
+ callback=None, timesteps=None, quantize_denoised=False,
117
+ mask=None, x0=None, img_callback=None, log_every_t=100,
118
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
120
+ device = self.model.betas.device
121
+ b = shape[0]
122
+ if x_T is None:
123
+ img = torch.randn(shape, device=device)
124
+ else:
125
+ img = x_T
126
+
127
+ if timesteps is None:
128
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
129
+ elif timesteps is not None and not ddim_use_original_steps:
130
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
131
+ timesteps = self.ddim_timesteps[:subset_end]
132
+
133
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
134
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
135
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
136
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
137
+
138
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
139
+
140
+ for i, step in enumerate(iterator):
141
+ index = total_steps - i - 1
142
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
143
+
144
+ if mask is not None:
145
+ assert x0 is not None
146
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
147
+ img = img_orig * mask + (1. - mask) * img
148
+
149
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
150
+ quantize_denoised=quantize_denoised, temperature=temperature,
151
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
152
+ corrector_kwargs=corrector_kwargs,
153
+ unconditional_guidance_scale=unconditional_guidance_scale,
154
+ unconditional_conditioning=unconditional_conditioning)
155
+ img, pred_x0 = outs
156
+ if callback: callback(i)
157
+ if img_callback: img_callback(pred_x0, i)
158
+
159
+ if index % log_every_t == 0 or index == total_steps - 1:
160
+ intermediates['x_inter'].append(img)
161
+ intermediates['pred_x0'].append(pred_x0)
162
+
163
+ return img, intermediates
164
+
165
+ @torch.no_grad()
166
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
167
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
168
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
169
+ b, *_, device = *x.shape, x.device
170
+
171
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
172
+ e_t = self.model.apply_model(x, t, c)
173
+ else:
174
+ x_in = torch.cat([x] * 2)
175
+ t_in = torch.cat([t] * 2)
176
+ c_in = torch.cat([unconditional_conditioning, c])
177
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
178
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
179
+
180
+ if score_corrector is not None:
181
+ assert self.model.parameterization == "eps"
182
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
183
+
184
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
185
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
186
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
187
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
188
+ # select parameters corresponding to the currently considered timestep
189
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
190
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
191
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
192
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
193
+
194
+ # current prediction for x_0
195
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
196
+ if quantize_denoised:
197
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
198
+ # direction pointing to x_t
199
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
200
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
201
+ if noise_dropout > 0.:
202
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
203
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
204
+ return x_prev, pred_x0
205
+
206
+ @torch.no_grad()
207
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
208
+ # fast, but does not allow for exact reconstruction
209
+ # t serves as an index to gather the correct alphas
210
+ if use_original_steps:
211
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
212
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
213
+ else:
214
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
215
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
216
+
217
+ if noise is None:
218
+ noise = torch.randn_like(x0)
219
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
220
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
221
+
222
+ @torch.no_grad()
223
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
224
+ use_original_steps=False):
225
+
226
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
227
+ timesteps = timesteps[:t_start]
228
+
229
+ time_range = np.flip(timesteps)
230
+ total_steps = timesteps.shape[0]
231
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
232
+
233
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
234
+ x_dec = x_latent
235
+ for i, step in enumerate(iterator):
236
+ index = total_steps - i - 1
237
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
238
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239
+ unconditional_guidance_scale=unconditional_guidance_scale,
240
+ unconditional_conditioning=unconditional_conditioning)
241
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+ from pytorch_lightning.utilities.distributed import rank_zero_only
20
+
21
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
+ from ldm.modules.ema import LitEma
23
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
+ from torch import clamp
28
+
29
+ __conditioning_keys__ = {'concat': 'c_concat',
30
+ 'crossattn': 'c_crossattn',
31
+ 'adm': 'y'}
32
+
33
+
34
+ def disabled_train(self, mode=True):
35
+ """Overwrite model.train with this function to make sure train/eval mode
36
+ does not change anymore."""
37
+ return self
38
+
39
+
40
+ def uniform_on_device(r1, r2, shape, device):
41
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
42
+
43
+
44
+ class DDPM(pl.LightningModule):
45
+ # classic DDPM with Gaussian diffusion, in image space
46
+ def __init__(self,
47
+ unet_config,
48
+ timesteps=1000,
49
+ beta_schedule="linear",
50
+ loss_type="l2",
51
+ ckpt_path=None,
52
+ ignore_keys=[],
53
+ load_only_unet=False,
54
+ monitor="val/loss",
55
+ use_ema=True,
56
+ first_stage_key="image",
57
+ image_size=256,
58
+ channels=3,
59
+ log_every_t=100,
60
+ clip_denoised=True,
61
+ linear_start=1e-4,
62
+ linear_end=2e-2,
63
+ cosine_s=8e-3,
64
+ given_betas=None,
65
+ original_elbo_weight=0.,
66
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
67
+ l_simple_weight=1.,
68
+ conditioning_key=None,
69
+ parameterization="eps", # all assuming fixed variance schedules
70
+ scheduler_config=None,
71
+ use_positional_encodings=False,
72
+ learn_logvar=False,
73
+ logvar_init=0.,
74
+ ):
75
+ super().__init__()
76
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
77
+ self.parameterization = parameterization
78
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
79
+ self.cond_stage_model = None
80
+ self.clip_denoised = clip_denoised
81
+ self.log_every_t = log_every_t
82
+ self.first_stage_key = first_stage_key
83
+ self.image_size = image_size # try conv?
84
+ self.channels = channels
85
+ self.use_positional_encodings = use_positional_encodings
86
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
87
+ count_params(self.model, verbose=True)
88
+ self.use_ema = use_ema
89
+ if self.use_ema:
90
+ self.model_ema = LitEma(self.model)
91
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
92
+
93
+ self.use_scheduler = scheduler_config is not None
94
+ if self.use_scheduler:
95
+ self.scheduler_config = scheduler_config
96
+
97
+ self.v_posterior = v_posterior
98
+ self.original_elbo_weight = original_elbo_weight
99
+ self.l_simple_weight = l_simple_weight
100
+
101
+ if monitor is not None:
102
+ self.monitor = monitor
103
+ if ckpt_path is not None:
104
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
105
+
106
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
107
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
108
+
109
+ self.loss_type = loss_type
110
+
111
+ self.learn_logvar = learn_logvar
112
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
113
+ if self.learn_logvar:
114
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
115
+
116
+
117
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
118
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
119
+ if exists(given_betas):
120
+ betas = given_betas
121
+ else:
122
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
123
+ cosine_s=cosine_s)
124
+ alphas = 1. - betas
125
+ alphas_cumprod = np.cumprod(alphas, axis=0)
126
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
127
+
128
+ timesteps, = betas.shape
129
+ self.num_timesteps = int(timesteps)
130
+ self.linear_start = linear_start
131
+ self.linear_end = linear_end
132
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
133
+
134
+ to_torch = partial(torch.tensor, dtype=torch.float32)
135
+
136
+ self.register_buffer('betas', to_torch(betas))
137
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
138
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
139
+
140
+ # calculations for diffusion q(x_t | x_{t-1}) and others
141
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
142
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
143
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
144
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
145
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
146
+
147
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
148
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
149
+ 1. - alphas_cumprod) + self.v_posterior * betas
150
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
151
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
152
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
153
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
154
+ self.register_buffer('posterior_mean_coef1', to_torch(
155
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
156
+ self.register_buffer('posterior_mean_coef2', to_torch(
157
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
158
+
159
+ if self.parameterization == "eps":
160
+ lvlb_weights = self.betas ** 2 / (
161
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
162
+ elif self.parameterization == "x0":
163
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
164
+ else:
165
+ raise NotImplementedError("mu not supported")
166
+ # TODO how to choose this term
167
+ lvlb_weights[0] = lvlb_weights[1]
168
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
169
+ assert not torch.isnan(self.lvlb_weights).all()
170
+
171
+ @contextmanager
172
+ def ema_scope(self, context=None):
173
+ if self.use_ema:
174
+ self.model_ema.store(self.model.parameters())
175
+ self.model_ema.copy_to(self.model)
176
+ if context is not None:
177
+ print(f"{context}: Switched to EMA weights")
178
+ try:
179
+ yield None
180
+ finally:
181
+ if self.use_ema:
182
+ self.model_ema.restore(self.model.parameters())
183
+ if context is not None:
184
+ print(f"{context}: Restored training weights")
185
+
186
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
187
+ sd = torch.load(path, map_location="cpu")
188
+ if "state_dict" in list(sd.keys()):
189
+ sd = sd["state_dict"]
190
+ keys = list(sd.keys())
191
+ for k in keys:
192
+ for ik in ignore_keys:
193
+ if k.startswith(ik):
194
+ print("Deleting key {} from state_dict.".format(k))
195
+ del sd[k]
196
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
197
+ sd, strict=False)
198
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
199
+ if len(missing) > 0:
200
+ print(f"Missing Keys: {missing}")
201
+ if len(unexpected) > 0:
202
+ print(f"Unexpected Keys: {unexpected}")
203
+
204
+ def q_mean_variance(self, x_start, t):
205
+ """
206
+ Get the distribution q(x_t | x_0).
207
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
208
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210
+ """
211
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
212
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
214
+ return mean, variance, log_variance
215
+
216
+ def predict_start_from_noise(self, x_t, t, noise):
217
+ return (
218
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
219
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
220
+ )
221
+
222
+ def q_posterior(self, x_start, x_t, t):
223
+ posterior_mean = (
224
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
225
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
226
+ )
227
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
228
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
229
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
230
+
231
+ def p_mean_variance(self, x, t, clip_denoised: bool):
232
+ model_out = self.model(x, t)
233
+ if self.parameterization == "eps":
234
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
235
+ elif self.parameterization == "x0":
236
+ x_recon = model_out
237
+ if clip_denoised:
238
+ x_recon.clamp_(-1., 1.)
239
+
240
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
241
+ return model_mean, posterior_variance, posterior_log_variance
242
+
243
+
244
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
245
+ b, *_, device = *x.shape, x.device
246
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
247
+ noise = noise_like(x.shape, device, repeat_noise)
248
+ # no noise when t == 0
249
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
250
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
251
+
252
+
253
+ def p_sample_loop(self, shape, return_intermediates=False):
254
+ device = self.betas.device
255
+ b = shape[0]
256
+ img = torch.randn(shape, device=device)
257
+ intermediates = [img]
258
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
259
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
260
+ clip_denoised=self.clip_denoised)
261
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
262
+ intermediates.append(img)
263
+ if return_intermediates:
264
+ return img, intermediates
265
+ return img
266
+
267
+
268
+ def sample(self, batch_size=16, return_intermediates=False):
269
+ image_size = self.image_size
270
+ channels = self.channels
271
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
272
+ return_intermediates=return_intermediates)
273
+
274
+ def q_sample(self, x_start, t, noise=None):
275
+ noise = default(noise, lambda: torch.randn_like(x_start))
276
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
277
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
278
+
279
+ def get_loss(self, pred, target, mean=True):
280
+ if self.loss_type == 'l1':
281
+ loss = (target - pred).abs()
282
+ if mean:
283
+ loss = loss.mean()
284
+ elif self.loss_type == 'l2':
285
+ if mean:
286
+ loss = torch.nn.functional.mse_loss(target, pred)
287
+ else:
288
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
289
+ else:
290
+ raise NotImplementedError("unknown loss type '{loss_type}'")
291
+
292
+ return loss
293
+
294
+ def p_losses(self, x_start, t, noise=None):
295
+ noise = default(noise, lambda: torch.randn_like(x_start))
296
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
297
+ model_out = self.model(x_noisy, t)
298
+
299
+ loss_dict = {}
300
+ if self.parameterization == "eps":
301
+ target = noise
302
+ elif self.parameterization == "x0":
303
+ target = x_start
304
+ else:
305
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
306
+
307
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
308
+
309
+ log_prefix = 'train' if self.training else 'val'
310
+
311
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
312
+ loss_simple = loss.mean() * self.l_simple_weight
313
+
314
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
315
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
316
+
317
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
318
+
319
+ loss_dict.update({f'{log_prefix}/loss': loss})
320
+
321
+ return loss, loss_dict
322
+
323
+ def forward(self, x, *args, **kwargs):
324
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
325
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
326
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
327
+ return self.p_losses(x, t, *args, **kwargs)
328
+
329
+ def get_input(self, batch, k):
330
+ x = batch[k]
331
+ if len(x.shape) == 3:
332
+ x = x[..., None]
333
+ x = rearrange(x, 'b h w c -> b c h w')
334
+ x = x.to(memory_format=torch.contiguous_format).float()
335
+ return x
336
+
337
+ def shared_step(self, batch):
338
+ x = self.get_input(batch, self.first_stage_key)
339
+ loss, loss_dict = self(x)
340
+ return loss, loss_dict
341
+
342
+ def training_step(self, batch, batch_idx):
343
+ loss, loss_dict = self.shared_step(batch)
344
+
345
+ self.log_dict(loss_dict, prog_bar=True,
346
+ logger=True, on_step=True, on_epoch=True)
347
+
348
+ self.log("global_step", self.global_step,
349
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
350
+
351
+ if self.use_scheduler:
352
+ lr = self.optimizers().param_groups[0]['lr']
353
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
354
+
355
+ return loss
356
+
357
+
358
+ def validation_step(self, batch, batch_idx):
359
+ _, loss_dict_no_ema = self.shared_step(batch)
360
+ with self.ema_scope():
361
+ _, loss_dict_ema = self.shared_step(batch)
362
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
363
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
364
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
365
+
366
+ def on_train_batch_end(self, *args, **kwargs):
367
+ if self.use_ema:
368
+ self.model_ema(self.model)
369
+
370
+ def _get_rows_from_list(self, samples):
371
+ n_imgs_per_row = len(samples)
372
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
373
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
374
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
375
+ return denoise_grid
376
+
377
+
378
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
379
+ log = dict()
380
+ x = self.get_input(batch, self.first_stage_key)
381
+ N = min(x.shape[0], N)
382
+ n_row = min(x.shape[0], n_row)
383
+ x = x.to(self.device)[:N]
384
+ log["inputs"] = x
385
+
386
+ # get diffusion row
387
+ diffusion_row = list()
388
+ x_start = x[:n_row]
389
+
390
+ for t in range(self.num_timesteps):
391
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
392
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
393
+ t = t.to(self.device).long()
394
+ noise = torch.randn_like(x_start)
395
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
396
+ diffusion_row.append(x_noisy)
397
+
398
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
399
+
400
+ if sample:
401
+ # get denoise row
402
+ with self.ema_scope("Plotting"):
403
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
404
+
405
+ log["samples"] = samples
406
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
407
+
408
+ if return_keys:
409
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
410
+ return log
411
+ else:
412
+ return {key: log[key] for key in return_keys}
413
+ return log
414
+
415
+ def configure_optimizers(self):
416
+ lr = self.learning_rate
417
+ params = list(self.model.parameters())
418
+ if self.learn_logvar:
419
+ params = params + [self.logvar]
420
+ opt = torch.optim.AdamW(params, lr=lr)
421
+ return opt
422
+
423
+
424
+ class LatentDiffusion(DDPM):
425
+ """main class"""
426
+ def __init__(self,
427
+ first_stage_config,
428
+ cond_stage_config,
429
+ num_timesteps_cond=None,
430
+ cond_stage_key="image",
431
+ cond_stage_trainable=False,
432
+ concat_mode=True,
433
+ cond_stage_forward=None,
434
+ conditioning_key=None,
435
+ scale_factor=1.0,
436
+ scale_by_std=False,
437
+ *args, **kwargs):
438
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
439
+ self.scale_by_std = scale_by_std
440
+ assert self.num_timesteps_cond <= kwargs['timesteps']
441
+ # for backwards compatibility after implementation of DiffusionWrapper
442
+ if conditioning_key is None:
443
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
444
+ if cond_stage_config == '__is_unconditional__':
445
+ conditioning_key = None
446
+ ckpt_path = kwargs.pop("ckpt_path", None)
447
+ ignore_keys = kwargs.pop("ignore_keys", [])
448
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
449
+ self.concat_mode = concat_mode
450
+ self.cond_stage_trainable = cond_stage_trainable
451
+ self.cond_stage_key = cond_stage_key
452
+ try:
453
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
454
+ except:
455
+ self.num_downs = 0
456
+ if not scale_by_std:
457
+ self.scale_factor = scale_factor
458
+ else:
459
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
460
+ self.instantiate_first_stage(first_stage_config)
461
+ self.instantiate_cond_stage(cond_stage_config)
462
+ self.cond_stage_forward = cond_stage_forward
463
+ self.clip_denoised = False
464
+ self.bbox_tokenizer = None
465
+
466
+ self.restarted_from_ckpt = False
467
+ if ckpt_path is not None:
468
+ self.init_from_ckpt(ckpt_path, ignore_keys)
469
+ self.restarted_from_ckpt = True
470
+
471
+ def make_cond_schedule(self, ):
472
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
473
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
474
+ self.cond_ids[:self.num_timesteps_cond] = ids
475
+
476
+ @rank_zero_only
477
+
478
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
479
+ # only for very first batch
480
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
481
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
482
+ # set rescale weight to 1./std of encodings
483
+ print("### USING STD-RESCALING ###")
484
+ x = super().get_input(batch, self.first_stage_key)
485
+ x = x.to(self.device)
486
+ encoder_posterior = self.encode_first_stage(x)
487
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
488
+ del self.scale_factor
489
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
490
+ print(f"setting self.scale_factor to {self.scale_factor}")
491
+ print("### USING STD-RESCALING ###")
492
+
493
+ def register_schedule(self,
494
+ given_betas=None, beta_schedule="linear", timesteps=1000,
495
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
496
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
497
+
498
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
499
+ if self.shorten_cond_schedule:
500
+ self.make_cond_schedule()
501
+
502
+ def instantiate_first_stage(self, config):
503
+ model = instantiate_from_config(config)
504
+ self.first_stage_model = model.eval()
505
+ self.first_stage_model.train = disabled_train
506
+ for param in self.first_stage_model.parameters():
507
+ param.requires_grad = False
508
+
509
+ def instantiate_cond_stage(self, config):
510
+ if not self.cond_stage_trainable:
511
+ if config == "__is_first_stage__":
512
+ print("Using first stage also as cond stage.")
513
+ self.cond_stage_model = self.first_stage_model
514
+ elif config == "__is_unconditional__":
515
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
516
+ self.cond_stage_model = None
517
+ # self.be_unconditional = True
518
+ else:
519
+ model = instantiate_from_config(config)
520
+ self.cond_stage_model = model.eval()
521
+ self.cond_stage_model.train = disabled_train
522
+ for param in self.cond_stage_model.parameters():
523
+ param.requires_grad = False
524
+ else:
525
+ assert config != '__is_first_stage__'
526
+ assert config != '__is_unconditional__'
527
+ model = instantiate_from_config(config)
528
+ self.cond_stage_model = model
529
+
530
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
531
+ denoise_row = []
532
+ for zd in tqdm(samples, desc=desc):
533
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
534
+ force_not_quantize=force_no_decoder_quantization))
535
+ n_imgs_per_row = len(denoise_row)
536
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
537
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
538
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
539
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
540
+ return denoise_grid
541
+
542
+ def get_first_stage_encoding(self, encoder_posterior):
543
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
544
+ z = encoder_posterior.sample()
545
+ elif isinstance(encoder_posterior, torch.Tensor):
546
+ z = encoder_posterior
547
+ else:
548
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
549
+ return self.scale_factor * z
550
+
551
+ def get_learned_conditioning(self, c):
552
+ if self.cond_stage_forward is None:
553
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
554
+ c = self.cond_stage_model.encode(c)
555
+ if isinstance(c, DiagonalGaussianDistribution):
556
+ c = c.mode()
557
+ else:
558
+ c = self.cond_stage_model(c)
559
+ else:
560
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
561
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
562
+ return c
563
+
564
+ def meshgrid(self, h, w):
565
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
566
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
567
+
568
+ arr = torch.cat([y, x], dim=-1)
569
+ return arr
570
+
571
+ def delta_border(self, h, w):
572
+ """
573
+ :param h: height
574
+ :param w: width
575
+ :return: normalized distance to image border,
576
+ wtith min distance = 0 at border and max dist = 0.5 at image center
577
+ """
578
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
579
+ arr = self.meshgrid(h, w) / lower_right_corner
580
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
581
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
582
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
583
+ return edge_dist
584
+
585
+ def get_weighting(self, h, w, Ly, Lx, device):
586
+ weighting = self.delta_border(h, w)
587
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
588
+ self.split_input_params["clip_max_weight"], )
589
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
590
+
591
+ if self.split_input_params["tie_braker"]:
592
+ L_weighting = self.delta_border(Ly, Lx)
593
+ L_weighting = torch.clip(L_weighting,
594
+ self.split_input_params["clip_min_tie_weight"],
595
+ self.split_input_params["clip_max_tie_weight"])
596
+
597
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
598
+ weighting = weighting * L_weighting
599
+ return weighting
600
+
601
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
602
+ """
603
+ :param x: img of size (bs, c, h, w)
604
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
605
+ """
606
+ bs, nc, h, w = x.shape
607
+
608
+ # number of crops in image
609
+ Ly = (h - kernel_size[0]) // stride[0] + 1
610
+ Lx = (w - kernel_size[1]) // stride[1] + 1
611
+
612
+ if uf == 1 and df == 1:
613
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
614
+ unfold = torch.nn.Unfold(**fold_params)
615
+
616
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
617
+
618
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
619
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
620
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
621
+
622
+ elif uf > 1 and df == 1:
623
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
624
+ unfold = torch.nn.Unfold(**fold_params)
625
+
626
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
627
+ dilation=1, padding=0,
628
+ stride=(stride[0] * uf, stride[1] * uf))
629
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
630
+
631
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
632
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
633
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
634
+
635
+ elif df > 1 and uf == 1:
636
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
637
+ unfold = torch.nn.Unfold(**fold_params)
638
+
639
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
640
+ dilation=1, padding=0,
641
+ stride=(stride[0] // df, stride[1] // df))
642
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
643
+
644
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
645
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
646
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
647
+
648
+ else:
649
+ raise NotImplementedError
650
+
651
+ return fold, unfold, normalization, weighting
652
+
653
+
654
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
655
+ cond_key=None, return_original_cond=False, bs=None):
656
+ x = super().get_input(batch, k)
657
+ if bs is not None:
658
+ x = x[:bs]
659
+ x = x.to(self.device)
660
+ encoder_posterior = self.encode_first_stage(x)
661
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
662
+
663
+ if self.model.conditioning_key is not None:
664
+ if cond_key is None:
665
+ cond_key = self.cond_stage_key
666
+ if cond_key != self.first_stage_key:
667
+ if cond_key in ['caption', 'coordinates_bbox']:
668
+ xc = batch[cond_key]
669
+ elif cond_key == 'class_label':
670
+ xc = batch
671
+ else:
672
+ xc = super().get_input(batch, cond_key).to(self.device)
673
+ else:
674
+ xc = x
675
+ if not self.cond_stage_trainable or force_c_encode:
676
+ if isinstance(xc, dict) or isinstance(xc, list):
677
+ # import pudb; pudb.set_trace()
678
+ c = self.get_learned_conditioning(xc)
679
+ else:
680
+ c = self.get_learned_conditioning(xc.to(self.device))
681
+ else:
682
+ c = xc
683
+ if bs is not None:
684
+ c = c[:bs]
685
+
686
+ if self.use_positional_encodings:
687
+ pos_x, pos_y = self.compute_latent_shifts(batch)
688
+ ckey = __conditioning_keys__[self.model.conditioning_key]
689
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
690
+
691
+ else:
692
+ c = None
693
+ xc = None
694
+ if self.use_positional_encodings:
695
+ pos_x, pos_y = self.compute_latent_shifts(batch)
696
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
697
+ out = [z, c]
698
+ if return_first_stage_outputs:
699
+ xrec = self.decode_first_stage(z)
700
+ out.extend([x, xrec])
701
+ if return_original_cond:
702
+ out.append(xc)
703
+ return out
704
+
705
+
706
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
707
+ if predict_cids:
708
+ if z.dim() == 4:
709
+ z = torch.argmax(z.exp(), dim=1).long()
710
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
711
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
712
+
713
+ z = 1. / self.scale_factor * z
714
+
715
+ if hasattr(self, "split_input_params"):
716
+ if self.split_input_params["patch_distributed_vq"]:
717
+ ks = self.split_input_params["ks"] # eg. (128, 128)
718
+ stride = self.split_input_params["stride"] # eg. (64, 64)
719
+ uf = self.split_input_params["vqf"]
720
+ bs, nc, h, w = z.shape
721
+ if ks[0] > h or ks[1] > w:
722
+ ks = (min(ks[0], h), min(ks[1], w))
723
+ print("reducing Kernel")
724
+
725
+ if stride[0] > h or stride[1] > w:
726
+ stride = (min(stride[0], h), min(stride[1], w))
727
+ print("reducing stride")
728
+
729
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
730
+
731
+ z = unfold(z) # (bn, nc * prod(**ks), L)
732
+ # 1. Reshape to img shape
733
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
734
+
735
+ # 2. apply model loop over last dim
736
+ if isinstance(self.first_stage_model, VQModelInterface):
737
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
738
+ force_not_quantize=predict_cids or force_not_quantize)
739
+ for i in range(z.shape[-1])]
740
+ else:
741
+
742
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
743
+ for i in range(z.shape[-1])]
744
+
745
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
746
+ o = o * weighting
747
+ # Reverse 1. reshape to img shape
748
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
749
+ # stitch crops together
750
+ decoded = fold(o)
751
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
752
+ return decoded
753
+ else:
754
+ if isinstance(self.first_stage_model, VQModelInterface):
755
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
756
+ else:
757
+ return self.first_stage_model.decode(z)
758
+
759
+ else:
760
+ if isinstance(self.first_stage_model, VQModelInterface):
761
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
762
+ else:
763
+ return self.first_stage_model.decode(z)
764
+
765
+ # same as above but without decorator
766
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
767
+ if predict_cids:
768
+ if z.dim() == 4:
769
+ z = torch.argmax(z.exp(), dim=1).long()
770
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
771
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
772
+
773
+ z = 1. / self.scale_factor * z
774
+
775
+ if hasattr(self, "split_input_params"):
776
+ if self.split_input_params["patch_distributed_vq"]:
777
+ ks = self.split_input_params["ks"] # eg. (128, 128)
778
+ stride = self.split_input_params["stride"] # eg. (64, 64)
779
+ uf = self.split_input_params["vqf"]
780
+ bs, nc, h, w = z.shape
781
+ if ks[0] > h or ks[1] > w:
782
+ ks = (min(ks[0], h), min(ks[1], w))
783
+ print("reducing Kernel")
784
+
785
+ if stride[0] > h or stride[1] > w:
786
+ stride = (min(stride[0], h), min(stride[1], w))
787
+ print("reducing stride")
788
+
789
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
790
+
791
+ z = unfold(z) # (bn, nc * prod(**ks), L)
792
+ # 1. Reshape to img shape
793
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
794
+
795
+ # 2. apply model loop over last dim
796
+ if isinstance(self.first_stage_model, VQModelInterface):
797
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
798
+ force_not_quantize=predict_cids or force_not_quantize)
799
+ for i in range(z.shape[-1])]
800
+ else:
801
+
802
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
803
+ for i in range(z.shape[-1])]
804
+
805
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
806
+ o = o * weighting
807
+ # Reverse 1. reshape to img shape
808
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
809
+ # stitch crops together
810
+ decoded = fold(o)
811
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
812
+ return decoded
813
+ else:
814
+ if isinstance(self.first_stage_model, VQModelInterface):
815
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
816
+ else:
817
+ return self.first_stage_model.decode(z)
818
+
819
+ else:
820
+ if isinstance(self.first_stage_model, VQModelInterface):
821
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
822
+ else:
823
+ return self.first_stage_model.decode(z)
824
+
825
+
826
+ def encode_first_stage(self, x):
827
+ if hasattr(self, "split_input_params"):
828
+ if self.split_input_params["patch_distributed_vq"]:
829
+ ks = self.split_input_params["ks"] # eg. (128, 128)
830
+ stride = self.split_input_params["stride"] # eg. (64, 64)
831
+ df = self.split_input_params["vqf"]
832
+ self.split_input_params['original_image_size'] = x.shape[-2:]
833
+ bs, nc, h, w = x.shape
834
+ if ks[0] > h or ks[1] > w:
835
+ ks = (min(ks[0], h), min(ks[1], w))
836
+ print("reducing Kernel")
837
+
838
+ if stride[0] > h or stride[1] > w:
839
+ stride = (min(stride[0], h), min(stride[1], w))
840
+ print("reducing stride")
841
+
842
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
843
+ z = unfold(x) # (bn, nc * prod(**ks), L)
844
+ # Reshape to img shape
845
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
846
+
847
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
848
+ for i in range(z.shape[-1])]
849
+
850
+ o = torch.stack(output_list, axis=-1)
851
+ o = o * weighting
852
+
853
+ # Reverse reshape to img shape
854
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
855
+ # stitch crops together
856
+ decoded = fold(o)
857
+ decoded = decoded / normalization
858
+ return decoded
859
+
860
+ else:
861
+ return self.first_stage_model.encode(x)
862
+ else:
863
+ return self.first_stage_model.encode(x)
864
+
865
+ def shared_step(self, batch, **kwargs):
866
+ x, c = self.get_input(batch, self.first_stage_key)
867
+ loss = self(x, c)
868
+ return loss
869
+
870
+ def forward(self, x, c, *args, **kwargs):
871
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
872
+ if self.model.conditioning_key is not None:
873
+ assert c is not None
874
+ if self.cond_stage_trainable:
875
+ c = self.get_learned_conditioning(c)
876
+ if self.shorten_cond_schedule: # TODO: drop this option
877
+ tc = self.cond_ids[t].to(self.device)
878
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
879
+ return self.p_losses(x, c, t, *args, **kwargs)
880
+
881
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
882
+ def rescale_bbox(bbox):
883
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
884
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
885
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
886
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
887
+ return x0, y0, w, h
888
+
889
+ return [rescale_bbox(b) for b in bboxes]
890
+
891
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
892
+
893
+ if isinstance(cond, dict):
894
+ # hybrid case, cond is exptected to be a dict
895
+ pass
896
+ else:
897
+ if not isinstance(cond, list):
898
+ cond = [cond]
899
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
900
+ cond = {key: cond}
901
+
902
+ if hasattr(self, "split_input_params"):
903
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
904
+ assert not return_ids
905
+ ks = self.split_input_params["ks"] # eg. (128, 128)
906
+ stride = self.split_input_params["stride"] # eg. (64, 64)
907
+
908
+ h, w = x_noisy.shape[-2:]
909
+
910
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
911
+
912
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
913
+ # Reshape to img shape
914
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
915
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
916
+
917
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
918
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
919
+ c_key = next(iter(cond.keys())) # get key
920
+ c = next(iter(cond.values())) # get value
921
+ assert (len(c) == 1) # todo extend to list with more than one elem
922
+ c = c[0] # get element
923
+
924
+ c = unfold(c)
925
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
926
+
927
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
928
+
929
+ elif self.cond_stage_key == 'coordinates_bbox':
930
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
931
+
932
+ # assuming padding of unfold is always 0 and its dilation is always 1
933
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
934
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
935
+ # as we are operating on latents, we need the factor from the original image size to the
936
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
937
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
938
+ rescale_latent = 2 ** (num_downs)
939
+
940
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
941
+ # need to rescale the tl patch coordinates to be in between (0,1)
942
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
943
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
944
+ for patch_nr in range(z.shape[-1])]
945
+
946
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
947
+ patch_limits = [(x_tl, y_tl,
948
+ rescale_latent * ks[0] / full_img_w,
949
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
950
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
951
+
952
+ # tokenize crop coordinates for the bounding boxes of the respective patches
953
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
954
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
955
+ print(patch_limits_tknzd[0].shape)
956
+ # cut tknzd crop position from conditioning
957
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
958
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
959
+ print(cut_cond.shape)
960
+
961
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
962
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
963
+ print(adapted_cond.shape)
964
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
965
+ print(adapted_cond.shape)
966
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
967
+ print(adapted_cond.shape)
968
+
969
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
970
+
971
+ else:
972
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
973
+
974
+ # apply model by loop over crops
975
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
976
+ assert not isinstance(output_list[0],
977
+ tuple) # todo cant deal with multiple model outputs check this never happens
978
+
979
+ o = torch.stack(output_list, axis=-1)
980
+ o = o * weighting
981
+ # Reverse reshape to img shape
982
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
983
+ # stitch crops together
984
+ x_recon = fold(o) / normalization
985
+
986
+ else:
987
+ x_recon = self.model(x_noisy, t, **cond)
988
+
989
+ if isinstance(x_recon, tuple) and not return_ids:
990
+ return x_recon[0]
991
+ else:
992
+ return x_recon
993
+
994
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
995
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
996
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
997
+
998
+ def _prior_bpd(self, x_start):
999
+ """
1000
+ Get the prior KL term for the variational lower-bound, measured in
1001
+ bits-per-dim.
1002
+ This term can't be optimized, as it only depends on the encoder.
1003
+ :param x_start: the [N x C x ...] tensor of inputs.
1004
+ :return: a batch of [N] KL values (in bits), one per batch element.
1005
+ """
1006
+ batch_size = x_start.shape[0]
1007
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1008
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1009
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1010
+ return mean_flat(kl_prior) / np.log(2.0)
1011
+
1012
+ def p_losses(self, x_start, cond, t, noise=None):
1013
+ noise = default(noise, lambda: torch.randn_like(x_start))
1014
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1015
+ model_output = self.apply_model(x_noisy, t, cond)
1016
+
1017
+ loss_dict = {}
1018
+ prefix = 'train' if self.training else 'val'
1019
+
1020
+ if self.parameterization == "x0":
1021
+ target = x_start
1022
+ elif self.parameterization == "eps":
1023
+ target = noise
1024
+ else:
1025
+ raise NotImplementedError()
1026
+
1027
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1028
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1029
+
1030
+ logvar_t = self.logvar[t].to(self.device)
1031
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1032
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1033
+ if self.learn_logvar:
1034
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1035
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1036
+
1037
+ loss = self.l_simple_weight * loss.mean()
1038
+
1039
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1040
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1041
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1042
+ loss += (self.original_elbo_weight * loss_vlb)
1043
+ loss_dict.update({f'{prefix}/loss': loss})
1044
+
1045
+ return loss, loss_dict
1046
+
1047
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1048
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1049
+ t_in = t
1050
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1051
+
1052
+ if score_corrector is not None:
1053
+ assert self.parameterization == "eps"
1054
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1055
+
1056
+ if return_codebook_ids:
1057
+ model_out, logits = model_out
1058
+
1059
+ if self.parameterization == "eps":
1060
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1061
+ elif self.parameterization == "x0":
1062
+ x_recon = model_out
1063
+ else:
1064
+ raise NotImplementedError()
1065
+
1066
+ if clip_denoised:
1067
+ x_recon.clamp_(-1., 1.)
1068
+ if quantize_denoised:
1069
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1070
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1071
+ if return_codebook_ids:
1072
+ return model_mean, posterior_variance, posterior_log_variance, logits
1073
+ elif return_x0:
1074
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1075
+ else:
1076
+ return model_mean, posterior_variance, posterior_log_variance
1077
+
1078
+
1079
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1080
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1081
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1082
+ b, *_, device = *x.shape, x.device
1083
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1084
+ return_codebook_ids=return_codebook_ids,
1085
+ quantize_denoised=quantize_denoised,
1086
+ return_x0=return_x0,
1087
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1088
+ if return_codebook_ids:
1089
+ raise DeprecationWarning("Support dropped.")
1090
+ model_mean, _, model_log_variance, logits = outputs
1091
+ elif return_x0:
1092
+ model_mean, _, model_log_variance, x0 = outputs
1093
+ else:
1094
+ model_mean, _, model_log_variance = outputs
1095
+
1096
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1097
+ if noise_dropout > 0.:
1098
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1099
+ # no noise when t == 0
1100
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1101
+
1102
+ if return_codebook_ids:
1103
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1104
+ if return_x0:
1105
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1106
+ else:
1107
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1108
+
1109
+
1110
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1111
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1112
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1113
+ log_every_t=None):
1114
+ if not log_every_t:
1115
+ log_every_t = self.log_every_t
1116
+ timesteps = self.num_timesteps
1117
+ if batch_size is not None:
1118
+ b = batch_size if batch_size is not None else shape[0]
1119
+ shape = [batch_size] + list(shape)
1120
+ else:
1121
+ b = batch_size = shape[0]
1122
+ if x_T is None:
1123
+ img = torch.randn(shape, device=self.device)
1124
+ else:
1125
+ img = x_T
1126
+ intermediates = []
1127
+ if cond is not None:
1128
+ if isinstance(cond, dict):
1129
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1130
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1131
+ else:
1132
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1133
+
1134
+ if start_T is not None:
1135
+ timesteps = min(timesteps, start_T)
1136
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1137
+ total=timesteps) if verbose else reversed(
1138
+ range(0, timesteps))
1139
+ if type(temperature) == float:
1140
+ temperature = [temperature] * timesteps
1141
+
1142
+ for i in iterator:
1143
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1144
+ if self.shorten_cond_schedule:
1145
+ assert self.model.conditioning_key != 'hybrid'
1146
+ tc = self.cond_ids[ts].to(cond.device)
1147
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1148
+
1149
+ img, x0_partial = self.p_sample(img, cond, ts,
1150
+ clip_denoised=self.clip_denoised,
1151
+ quantize_denoised=quantize_denoised, return_x0=True,
1152
+ temperature=temperature[i], noise_dropout=noise_dropout,
1153
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1154
+ if mask is not None:
1155
+ assert x0 is not None
1156
+ img_orig = self.q_sample(x0, ts)
1157
+ img = img_orig * mask + (1. - mask) * img
1158
+
1159
+ if i % log_every_t == 0 or i == timesteps - 1:
1160
+ intermediates.append(x0_partial)
1161
+ if callback: callback(i)
1162
+ if img_callback: img_callback(img, i)
1163
+ return img, intermediates
1164
+
1165
+
1166
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1167
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1168
+ mask=None, x0=None, img_callback=None, start_T=None,
1169
+ log_every_t=None):
1170
+
1171
+ if not log_every_t:
1172
+ log_every_t = self.log_every_t
1173
+ device = self.betas.device
1174
+ b = shape[0]
1175
+ if x_T is None:
1176
+ img = torch.randn(shape, device=device)
1177
+ else:
1178
+ img = x_T
1179
+
1180
+ intermediates = [img]
1181
+ if timesteps is None:
1182
+ timesteps = self.num_timesteps
1183
+
1184
+ if start_T is not None:
1185
+ timesteps = min(timesteps, start_T)
1186
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1187
+ range(0, timesteps))
1188
+
1189
+ if mask is not None:
1190
+ assert x0 is not None
1191
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1192
+
1193
+ for i in iterator:
1194
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1195
+ if self.shorten_cond_schedule:
1196
+ assert self.model.conditioning_key != 'hybrid'
1197
+ tc = self.cond_ids[ts].to(cond.device)
1198
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1199
+
1200
+ img = self.p_sample(img, cond, ts,
1201
+ clip_denoised=self.clip_denoised,
1202
+ quantize_denoised=quantize_denoised)
1203
+ if mask is not None:
1204
+ img_orig = self.q_sample(x0, ts)
1205
+ img = img_orig * mask + (1. - mask) * img
1206
+
1207
+ if i % log_every_t == 0 or i == timesteps - 1:
1208
+ intermediates.append(img)
1209
+ if callback: callback(i)
1210
+ if img_callback: img_callback(img, i)
1211
+
1212
+ if return_intermediates:
1213
+ return img, intermediates
1214
+ return img
1215
+
1216
+
1217
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1218
+ verbose=True, timesteps=None, quantize_denoised=False,
1219
+ mask=None, x0=None, shape=None,**kwargs):
1220
+ if shape is None:
1221
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1222
+ if cond is not None:
1223
+ if isinstance(cond, dict):
1224
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1225
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1226
+ else:
1227
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1228
+ return self.p_sample_loop(cond,
1229
+ shape,
1230
+ return_intermediates=return_intermediates, x_T=x_T,
1231
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1232
+ mask=mask, x0=x0)
1233
+
1234
+
1235
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1236
+
1237
+ if ddim:
1238
+ ddim_sampler = DDIMSampler(self)
1239
+ shape = (self.channels, self.image_size, self.image_size)
1240
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1241
+ shape,cond,verbose=False,**kwargs)
1242
+
1243
+ else:
1244
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1245
+ return_intermediates=True,**kwargs)
1246
+
1247
+ return samples, intermediates
1248
+
1249
+
1250
+
1251
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1252
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1253
+ plot_diffusion_rows=True, **kwargs):
1254
+
1255
+ use_ddim = ddim_steps is not None
1256
+
1257
+ log = dict()
1258
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1259
+ return_first_stage_outputs=True,
1260
+ force_c_encode=True,
1261
+ return_original_cond=True,
1262
+ bs=N)
1263
+ N = min(x.shape[0], N)
1264
+ n_row = min(x.shape[0], n_row)
1265
+ log["inputs"] = x
1266
+ log["reconstruction"] = xrec
1267
+ if self.model.conditioning_key is not None:
1268
+ if hasattr(self.cond_stage_model, "decode"):
1269
+ xc = self.cond_stage_model.decode(c)
1270
+ log["conditioning"] = xc
1271
+ elif self.cond_stage_key in ["caption"]:
1272
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1273
+ log["conditioning"] = xc
1274
+ elif self.cond_stage_key == 'class_label':
1275
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1276
+ log['conditioning'] = xc
1277
+ elif isimage(xc):
1278
+ log["conditioning"] = xc
1279
+ if ismap(xc):
1280
+ log["original_conditioning"] = self.to_rgb(xc)
1281
+
1282
+ if plot_diffusion_rows:
1283
+ # get diffusion row
1284
+ diffusion_row = list()
1285
+ z_start = z[:n_row]
1286
+ for t in range(self.num_timesteps):
1287
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1288
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1289
+ t = t.to(self.device).long()
1290
+ noise = torch.randn_like(z_start)
1291
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1292
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1293
+
1294
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1295
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1296
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1297
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1298
+ log["diffusion_row"] = diffusion_grid
1299
+
1300
+ if sample:
1301
+ # get denoise row
1302
+ with self.ema_scope("Plotting"):
1303
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1304
+ ddim_steps=ddim_steps,eta=ddim_eta)
1305
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1306
+ x_samples = self.decode_first_stage(samples)
1307
+ log["samples"] = x_samples
1308
+ if plot_denoise_rows:
1309
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1310
+ log["denoise_row"] = denoise_grid
1311
+
1312
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1313
+ self.first_stage_model, IdentityFirstStage):
1314
+ # also display when quantizing x0 while sampling
1315
+ with self.ema_scope("Plotting Quantized Denoised"):
1316
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1317
+ ddim_steps=ddim_steps,eta=ddim_eta,
1318
+ quantize_denoised=True)
1319
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1320
+ # quantize_denoised=True)
1321
+ x_samples = self.decode_first_stage(samples.to(self.device))
1322
+ log["samples_x0_quantized"] = x_samples
1323
+
1324
+ if inpaint:
1325
+ # make a simple center square
1326
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1327
+ mask = torch.ones(N, h, w).to(self.device)
1328
+ # zeros will be filled in
1329
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1330
+ mask = mask[:, None, ...]
1331
+ with self.ema_scope("Plotting Inpaint"):
1332
+
1333
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1334
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1335
+ x_samples = self.decode_first_stage(samples.to(self.device))
1336
+ log["samples_inpainting"] = x_samples
1337
+ log["mask"] = mask
1338
+
1339
+ # outpaint
1340
+ with self.ema_scope("Plotting Outpaint"):
1341
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1342
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1343
+ x_samples = self.decode_first_stage(samples.to(self.device))
1344
+ log["samples_outpainting"] = x_samples
1345
+
1346
+ if plot_progressive_rows:
1347
+ with self.ema_scope("Plotting Progressives"):
1348
+ img, progressives = self.progressive_denoising(c,
1349
+ shape=(self.channels, self.image_size, self.image_size),
1350
+ batch_size=N)
1351
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1352
+ log["progressive_row"] = prog_row
1353
+
1354
+ if return_keys:
1355
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1356
+ return log
1357
+ else:
1358
+ return {key: log[key] for key in return_keys}
1359
+ return log
1360
+
1361
+ def configure_optimizers(self):
1362
+ lr = self.learning_rate
1363
+ params = list(self.model.parameters())
1364
+ if self.cond_stage_trainable:
1365
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1366
+ params = params + list(self.cond_stage_model.parameters())
1367
+ if self.learn_logvar:
1368
+ print('Diffusion model optimizing logvar')
1369
+ params.append(self.logvar)
1370
+ opt = torch.optim.AdamW(params, lr=lr)
1371
+ if self.use_scheduler:
1372
+ assert 'target' in self.scheduler_config
1373
+ scheduler = instantiate_from_config(self.scheduler_config)
1374
+
1375
+ print("Setting up LambdaLR scheduler...")
1376
+ scheduler = [
1377
+ {
1378
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1379
+ 'interval': 'step',
1380
+ 'frequency': 1
1381
+ }]
1382
+ return [opt], scheduler
1383
+ return opt
1384
+
1385
+
1386
+ def to_rgb(self, x):
1387
+ x = x.float()
1388
+ if not hasattr(self, "colorize"):
1389
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1390
+ x = nn.functional.conv2d(x, weight=self.colorize)
1391
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1392
+ return x
1393
+
1394
+
1395
+ class DiffusionWrapper(pl.LightningModule):
1396
+ def __init__(self, diff_model_config, conditioning_key):
1397
+ super().__init__()
1398
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1399
+ self.conditioning_key = conditioning_key
1400
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1401
+
1402
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1403
+ if self.conditioning_key is None:
1404
+ out = self.diffusion_model(x, t)
1405
+ elif self.conditioning_key == 'concat':
1406
+ xc = torch.cat([x] + c_concat, dim=1)
1407
+ out = self.diffusion_model(xc, t)
1408
+ elif self.conditioning_key == 'crossattn':
1409
+ cc = torch.cat(c_crossattn, 1)
1410
+ out = self.diffusion_model(x, t, context=cc)
1411
+ elif self.conditioning_key == 'hybrid':
1412
+ xc = torch.cat([x] + c_concat, dim=1)
1413
+ cc = torch.cat(c_crossattn, 1)
1414
+ out = self.diffusion_model(xc, t, context=cc)
1415
+ elif self.conditioning_key == 'adm':
1416
+ cc = c_crossattn[0]
1417
+ out = self.diffusion_model(x, t, y=cc)
1418
+ else:
1419
+ raise NotImplementedError()
1420
+
1421
+ return out
1422
+
1423
+
1424
+ class Layout2ImgDiffusion(LatentDiffusion):
1425
+ # TODO: move all layout-specific hacks to this class
1426
+ def __init__(self, cond_stage_key, *args, **kwargs):
1427
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1428
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1429
+
1430
+ def log_images(self, batch, N=8, *args, **kwargs):
1431
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1432
+
1433
+ key = 'train' if self.training else 'validation'
1434
+ dset = self.trainer.datamodule.datasets[key]
1435
+ mapper = dset.conditional_builders[self.cond_stage_key]
1436
+
1437
+ bbox_imgs = []
1438
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1439
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1440
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1441
+ bbox_imgs.append(bboximg)
1442
+
1443
+ cond_img = torch.stack(bbox_imgs, dim=0)
1444
+ logs['bbox_image'] = cond_img
1445
+ return logs
ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule="discrete",
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.0,
14
+ ):
15
+
16
+ if schedule not in ["discrete", "linear", "cosine"]:
17
+ raise ValueError(
18
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
19
+ schedule
20
+ )
21
+ )
22
+
23
+ self.schedule = schedule
24
+ if schedule == "discrete":
25
+ if betas is not None:
26
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
27
+ else:
28
+ assert alphas_cumprod is not None
29
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
30
+ self.total_N = len(log_alphas)
31
+ self.T = 1.0
32
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
33
+ self.log_alpha_array = log_alphas.reshape(
34
+ (
35
+ 1,
36
+ -1,
37
+ )
38
+ )
39
+ else:
40
+ self.total_N = 1000
41
+ self.beta_0 = continuous_beta_0
42
+ self.beta_1 = continuous_beta_1
43
+ self.cosine_s = 0.008
44
+ self.cosine_beta_max = 999.0
45
+ self.cosine_t_max = (
46
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
47
+ * 2.0
48
+ * (1.0 + self.cosine_s)
49
+ / math.pi
50
+ - self.cosine_s
51
+ )
52
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
53
+ self.schedule = schedule
54
+ if schedule == "cosine":
55
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
56
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
57
+ self.T = 0.9946
58
+ else:
59
+ self.T = 1.0
60
+
61
+ def marginal_log_mean_coeff(self, t):
62
+ """
63
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
64
+ """
65
+ if self.schedule == "discrete":
66
+ return interpolate_fn(
67
+ t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
68
+ ).reshape((-1))
69
+ elif self.schedule == "linear":
70
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
71
+ elif self.schedule == "cosine":
72
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
73
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
74
+ return log_alpha_t
75
+
76
+ def marginal_alpha(self, t):
77
+ """
78
+ Compute alpha_t of a given continuous-time label t in [0, T].
79
+ """
80
+ return torch.exp(self.marginal_log_mean_coeff(t))
81
+
82
+ def marginal_std(self, t):
83
+ """
84
+ Compute sigma_t of a given continuous-time label t in [0, T].
85
+ """
86
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
87
+
88
+ def marginal_lambda(self, t):
89
+ """
90
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
91
+ """
92
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
93
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
94
+ return log_mean_coeff - log_std
95
+
96
+ def inverse_lambda(self, lamb, return_scalar=False):
97
+ """
98
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
99
+ """
100
+ if self.schedule == "linear":
101
+ tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
102
+ Delta = self.beta_0**2 + tmp
103
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
104
+ elif self.schedule == "discrete":
105
+ # check if lamb is a scalar
106
+ if not isinstance(lamb, torch.Tensor):
107
+ lamb = torch.tensor(lamb)
108
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
109
+ t = interpolate_fn(
110
+ log_alpha.reshape((-1, 1)),
111
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
112
+ torch.flip(self.t_array.to(lamb.device), [1]),
113
+ )
114
+ if return_scalar:
115
+ return t.reshape((-1,)).item()
116
+ return t.reshape((-1,))
117
+ else:
118
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
119
+ t_fn = (
120
+ lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
121
+ * 2.0
122
+ * (1.0 + self.cosine_s)
123
+ / math.pi
124
+ - self.cosine_s
125
+ )
126
+ t = t_fn(log_alpha)
127
+ return t
128
+
129
+
130
+ def model_wrapper(
131
+ model,
132
+ noise_schedule,
133
+ model_type="noise",
134
+ model_kwargs={},
135
+ guidance_type="uncond",
136
+ condition=None,
137
+ unconditional_condition=None,
138
+ guidance_scale=1.0,
139
+ classifier_fn=None,
140
+ classifier_kwargs={},
141
+ ):
142
+
143
+ def get_model_input_time(t_continuous):
144
+ """
145
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
146
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
147
+ For continuous-time DPMs, we just use `t_continuous`.
148
+ """
149
+ if noise_schedule.schedule == "discrete":
150
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
151
+ else:
152
+ return t_continuous
153
+
154
+ def noise_pred_fn(x, t_continuous, cond=None):
155
+ if t_continuous.reshape((-1,)).shape[0] == 1:
156
+ t_continuous = t_continuous.expand((x.shape[0]))
157
+ t_input = get_model_input_time(t_continuous)
158
+ if cond is None:
159
+ output = model(x, t_input, None, **model_kwargs)
160
+ else:
161
+ output = model(x, t_input, cond, **model_kwargs)
162
+ if model_type == "noise":
163
+ return output
164
+ elif model_type == "x_start":
165
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
166
+ dims = x.dim()
167
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
168
+ elif model_type == "v":
169
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
170
+ dims = x.dim()
171
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
172
+ elif model_type == "score":
173
+ sigma_t = noise_schedule.marginal_std(t_continuous)
174
+ dims = x.dim()
175
+ return -expand_dims(sigma_t, dims) * output
176
+
177
+ def cond_grad_fn(x, t_input):
178
+ """
179
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
180
+ """
181
+ with torch.enable_grad():
182
+ x_in = x.detach().requires_grad_(True)
183
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
184
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
185
+
186
+ def model_fn(x, t_continuous):
187
+ """
188
+ The noise predicition model function that is used for DPM-Solver.
189
+ """
190
+ if t_continuous.reshape((-1,)).shape[0] == 1:
191
+ t_continuous = t_continuous.expand((x.shape[0]))
192
+ if guidance_type == "uncond":
193
+ return noise_pred_fn(x, t_continuous)
194
+ elif guidance_type == "classifier":
195
+ assert classifier_fn is not None
196
+ t_input = get_model_input_time(t_continuous)
197
+ cond_grad = cond_grad_fn(x, t_input)
198
+ sigma_t = noise_schedule.marginal_std(t_continuous)
199
+ noise = noise_pred_fn(x, t_continuous)
200
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
201
+ elif guidance_type == "classifier-free":
202
+ if guidance_scale == 1.0 or unconditional_condition is None:
203
+ return noise_pred_fn(x, t_continuous, cond=condition)
204
+ else:
205
+ x_in = torch.cat([x] * 2)
206
+ t_in = torch.cat([t_continuous] * 2)
207
+ c_in = torch.cat([unconditional_condition, condition])
208
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
209
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
210
+
211
+ assert model_type in ["noise", "x_start", "v"]
212
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
213
+ return model_fn
214
+
215
+
216
+ class DPM_Solver:
217
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0):
218
+ """Construct a DPM-Solver.
219
+
220
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
221
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
222
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
223
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
224
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
225
+
226
+ Args:
227
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
228
+ ``
229
+ def model_fn(x, t_continuous):
230
+ return noise
231
+ ``
232
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
233
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
234
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
235
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
236
+
237
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
238
+ """
239
+ self.model = model_fn
240
+ self.noise_schedule = noise_schedule
241
+ self.predict_x0 = predict_x0
242
+ self.thresholding = thresholding
243
+ self.max_val = max_val
244
+
245
+ def noise_prediction_fn(self, x, t):
246
+ """
247
+ Return the noise prediction model.
248
+ """
249
+ return self.model(x, t)
250
+
251
+ def data_prediction_fn(self, x, t):
252
+ """
253
+ Return the data prediction model (with thresholding).
254
+ """
255
+ noise = self.noise_prediction_fn(x, t)
256
+ dims = x.dim()
257
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
258
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
259
+ if self.thresholding:
260
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
261
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
262
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
263
+ x0 = torch.clamp(x0, -s, s) / s
264
+ return x0
265
+
266
+ def model_fn(self, x, t):
267
+ """
268
+ Convert the model to the noise prediction model or the data prediction model.
269
+ """
270
+ if self.predict_x0:
271
+ return self.data_prediction_fn(x, t)
272
+ else:
273
+ return self.noise_prediction_fn(x, t)
274
+
275
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
276
+ """Compute the intermediate time steps for sampling.
277
+
278
+ Args:
279
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
280
+ - 'logSNR': uniform logSNR for the time steps.
281
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
282
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
283
+ t_T: A `float`. The starting time of the sampling (default is T).
284
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
285
+ N: A `int`. The total number of the spacing of the time steps.
286
+ device: A torch device.
287
+ Returns:
288
+ A pytorch tensor of the time steps, with the shape (N + 1,).
289
+ """
290
+ if skip_type == "logSNR":
291
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
292
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
293
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
294
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
295
+ elif skip_type == "time_uniform":
296
+ return torch.linspace(t_T, t_0, N + 1).to(device)
297
+ elif skip_type == "time_quadratic":
298
+ t_order = 2
299
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
300
+ return t
301
+ else:
302
+ raise ValueError(
303
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
304
+ )
305
+
306
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
307
+ """
308
+ Get the order of each step for sampling by the singlestep DPM-Solver.
309
+
310
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
311
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
312
+ - If order == 1:
313
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
314
+ - If order == 2:
315
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
316
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
317
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
318
+ - If order == 3:
319
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
320
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
321
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
322
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
323
+
324
+ ============================================
325
+ Args:
326
+ order: A `int`. The max order for the solver (2 or 3).
327
+ steps: A `int`. The total number of function evaluations (NFE).
328
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
329
+ - 'logSNR': uniform logSNR for the time steps.
330
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
331
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
332
+ t_T: A `float`. The starting time of the sampling (default is T).
333
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
334
+ device: A torch device.
335
+ Returns:
336
+ orders: A list of the solver order of each step.
337
+ """
338
+ if order == 3:
339
+ K = steps // 3 + 1
340
+ if steps % 3 == 0:
341
+ orders = [
342
+ 3,
343
+ ] * (
344
+ K - 2
345
+ ) + [2, 1]
346
+ elif steps % 3 == 1:
347
+ orders = [
348
+ 3,
349
+ ] * (
350
+ K - 1
351
+ ) + [1]
352
+ else:
353
+ orders = [
354
+ 3,
355
+ ] * (
356
+ K - 1
357
+ ) + [2]
358
+ elif order == 2:
359
+ if steps % 2 == 0:
360
+ K = steps // 2
361
+ orders = [
362
+ 2,
363
+ ] * K
364
+ else:
365
+ K = steps // 2 + 1
366
+ orders = [
367
+ 2,
368
+ ] * (
369
+ K - 1
370
+ ) + [1]
371
+ elif order == 1:
372
+ K = 1
373
+ orders = [
374
+ 1,
375
+ ] * steps
376
+ else:
377
+ raise ValueError("'order' must be '1' or '2' or '3'.")
378
+ if skip_type == "logSNR":
379
+ # To reproduce the results in DPM-Solver paper
380
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
381
+ else:
382
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
383
+ torch.cumsum(
384
+ torch.tensor(
385
+ [
386
+ 0,
387
+ ]
388
+ + orders
389
+ ),
390
+ dim=0,
391
+ ).to(device)
392
+ ]
393
+ return timesteps_outer, orders
394
+
395
+ def denoise_to_zero_fn(self, x, s):
396
+ """
397
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
398
+ """
399
+ return self.data_prediction_fn(x, s)
400
+
401
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
402
+ """
403
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
404
+
405
+ Args:
406
+ x: A pytorch tensor. The initial value at time `s`.
407
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
408
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
409
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
410
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
411
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
412
+ Returns:
413
+ x_t: A pytorch tensor. The approximated solution at time `t`.
414
+ """
415
+ ns = self.noise_schedule
416
+ dims = x.dim()
417
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
418
+ h = lambda_t - lambda_s
419
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
420
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
421
+ alpha_t = torch.exp(log_alpha_t)
422
+
423
+ if self.predict_x0:
424
+ phi_1 = torch.expm1(-h)
425
+ if model_s is None:
426
+ model_s = self.model_fn(x, s)
427
+ x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s
428
+ if return_intermediate:
429
+ return x_t, {"model_s": model_s}
430
+ else:
431
+ return x_t
432
+ else:
433
+ phi_1 = torch.expm1(h)
434
+ if model_s is None:
435
+ model_s = self.model_fn(x, s)
436
+ x_t = (
437
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
438
+ - expand_dims(sigma_t * phi_1, dims) * model_s
439
+ )
440
+ if return_intermediate:
441
+ return x_t, {"model_s": model_s}
442
+ else:
443
+ return x_t
444
+
445
+
446
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
447
+ """
448
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
449
+
450
+ Args:
451
+ x: A pytorch tensor. The initial value at time `s`.
452
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
453
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
454
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
455
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
456
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
457
+ Returns:
458
+ x_t: A pytorch tensor. The approximated solution at time `t`.
459
+ """
460
+ ns = self.noise_schedule
461
+ dims = x.dim()
462
+ model_prev_1, model_prev_0 = model_prev_list[-2:]
463
+ t_prev_1, t_prev_0 = t_prev_list[-2:]
464
+ lambda_prev_1, lambda_prev_0, lambda_t = (
465
+ ns.marginal_lambda(t_prev_1),
466
+ ns.marginal_lambda(t_prev_0),
467
+ ns.marginal_lambda(t),
468
+ )
469
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
470
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
471
+ alpha_t = torch.exp(log_alpha_t)
472
+
473
+ h_0 = lambda_prev_0 - lambda_prev_1
474
+ h = lambda_t - lambda_prev_0
475
+ r0 = h_0 / h
476
+ D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
477
+ if self.predict_x0:
478
+ if solver_type == "dpm_solver" or solver_type == "dpmsolver":
479
+ x_t = (
480
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
481
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
482
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0
483
+ )
484
+ elif solver_type == "taylor":
485
+ x_t = (
486
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
487
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
488
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0
489
+ )
490
+ else:
491
+ if solver_type == "dpm_solver" or solver_type == "dpmsolver":
492
+ x_t = (
493
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
494
+ - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
495
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0
496
+ )
497
+ elif solver_type == "taylor":
498
+ x_t = (
499
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
500
+ - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
501
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0
502
+ )
503
+ return x_t
504
+
505
+
506
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
507
+ """
508
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
509
+
510
+ Args:
511
+ x: A pytorch tensor. The initial value at time `s`.
512
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
513
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
514
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
515
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
516
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
517
+ Returns:
518
+ x_t: A pytorch tensor. The approximated solution at time `t`.
519
+ """
520
+ ns = self.noise_schedule
521
+ dims = x.dim()
522
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
523
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
524
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
525
+ ns.marginal_lambda(t_prev_2),
526
+ ns.marginal_lambda(t_prev_1),
527
+ ns.marginal_lambda(t_prev_0),
528
+ ns.marginal_lambda(t),
529
+ )
530
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
531
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
532
+ alpha_t = torch.exp(log_alpha_t)
533
+
534
+ h_1 = lambda_prev_1 - lambda_prev_2
535
+ h_0 = lambda_prev_0 - lambda_prev_1
536
+ h = lambda_t - lambda_prev_0
537
+ r0, r1 = h_0 / h, h_1 / h
538
+ D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
539
+ D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)
540
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
541
+ D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)
542
+ if self.predict_x0:
543
+ x_t = (
544
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
545
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
546
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
547
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims) * D2
548
+ )
549
+ else:
550
+ x_t = (
551
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
552
+ - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
553
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
554
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2
555
+ )
556
+ return x_t
557
+
558
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"):
559
+ """
560
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
561
+
562
+ Args:
563
+ x: A pytorch tensor. The initial value at time `s`.
564
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
565
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
566
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
567
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
568
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
569
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
570
+ Returns:
571
+ x_t: A pytorch tensor. The approximated solution at time `t`.
572
+ """
573
+ if order == 1:
574
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
575
+ elif order == 2:
576
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
577
+ elif order == 3:
578
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
579
+ else:
580
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
581
+
582
+ def dpm_solver_adaptive(
583
+ self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpm_solver"
584
+ ):
585
+ """
586
+ The adaptive step size solver based on singlestep DPM-Solver.
587
+
588
+ Args:
589
+ x: A pytorch tensor. The initial value at time `t_T`.
590
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
591
+ t_T: A `float`. The starting time of the sampling (default is T).
592
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
593
+ h_init: A `float`. The initial step size (for logSNR).
594
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
595
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
596
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
597
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
598
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
599
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
600
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
601
+ Returns:
602
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
603
+
604
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
605
+ """
606
+ ns = self.noise_schedule
607
+ s = t_T * torch.ones((x.shape[0],)).to(x)
608
+ lambda_s = ns.marginal_lambda(s)
609
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
610
+ h = h_init * torch.ones_like(s).to(x)
611
+ x_prev = x
612
+ nfe = 0
613
+ if order == 2:
614
+ r1 = 0.5
615
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
616
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
617
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
618
+ )
619
+ elif order == 3:
620
+ r1, r2 = 1.0 / 3.0, 2.0 / 3.0
621
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
622
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
623
+ )
624
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
625
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
626
+ )
627
+ else:
628
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
629
+ while torch.abs((s - t_0)).mean() > t_err:
630
+ t = ns.inverse_lambda(lambda_s + h)
631
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
632
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
633
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
634
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
635
+ E = norm_fn((x_higher - x_lower) / delta).max()
636
+ if torch.all(E <= 1.0):
637
+ x = x_higher
638
+ s = t
639
+ x_prev = x_lower
640
+ lambda_s = ns.marginal_lambda(s)
641
+ h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
642
+ nfe += order
643
+ print("adaptive solver nfe", nfe)
644
+ return x
645
+
646
+ def sample(
647
+ self,
648
+ x,
649
+ steps=20,
650
+ t_start=None,
651
+ t_end=None,
652
+ order=3,
653
+ skip_type="time_uniform",
654
+ method="singlestep",
655
+ lower_order_final=True,
656
+ denoise_to_zero=False,
657
+ solver_type="dpm_solver",
658
+ atol=0.0078,
659
+ rtol=0.05,
660
+ flags=None,
661
+ ):
662
+ device = x.device
663
+ with torch.no_grad():
664
+ if flags.learn:
665
+ load_from = f"{flags.log_path}/NFE-{steps}-256LSUN-dpmsolver++-{order}-decode/best.pt"
666
+ timesteps = torch.load(load_from)['best_t_steps'].to(x.device)
667
+ if flags:
668
+ length = timesteps.shape[0] // 2
669
+ timesteps2 = timesteps[length:]
670
+ timesteps = timesteps[:length]
671
+ else:
672
+ timesteps2 = timesteps
673
+ else:
674
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
675
+ t_T = self.noise_schedule.T if t_start is None else t_start
676
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
677
+ timesteps2 = timesteps
678
+ assert timesteps.shape[0] - 1 == steps
679
+
680
+ def one_step(t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True):
681
+ x_next = self.multistep_dpm_solver_update(x_next, model_prev_list, t_prev_list, t1, step, solver_type="dpmsolver")
682
+ model_x_next = self.model_fn(x_next, t2)
683
+ update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first)
684
+ return x_next
685
+
686
+ def update_lists(t_list, model_list, t_, model_x, order, first=False):
687
+ if first:
688
+ t_list.append(t_)
689
+ model_list.append(model_x)
690
+ return
691
+ for m in range(order - 1):
692
+ t_list[m] = t_list[m + 1]
693
+ model_list[m] = model_list[m + 1]
694
+ t_list[-1] = t_
695
+ model_list[-1] = model_x
696
+
697
+ timesteps1 = timesteps
698
+ step = 0
699
+ vec_t1 = timesteps1[0].expand((x.shape[0]))
700
+ vec_t2 = timesteps2[0].expand((x.shape[0]))
701
+ t_prev_list = [vec_t1]
702
+ model_prev_list = [self.model_fn(x, vec_t2)]
703
+
704
+
705
+ for step in range(1, order):
706
+ vec_t1 = timesteps1[step].expand(x.shape[0])
707
+ vec_t2 = timesteps2[step].expand(x.shape[0])
708
+ x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step, x, order, first=True)
709
+
710
+ for step in range(order, steps + 1):
711
+ step_order = min(order, steps + 1 - step)
712
+ vec_t1 = timesteps1[step].expand(x.shape[0])
713
+ vec_t2 = timesteps2[step].expand(x.shape[0])
714
+ x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step_order, x, order, first=False)
715
+
716
+ return x
717
+
718
+
719
+ #############################################################
720
+ # other utility functions
721
+ #############################################################
722
+
723
+
724
+ def interpolate_fn(x, xp, yp):
725
+ """
726
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
727
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
728
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
729
+
730
+ Args:
731
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
732
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
733
+ yp: PyTorch tensor with shape [C, K].
734
+ Returns:
735
+ The function values f(x), with shape [N, C].
736
+ """
737
+ N, K = x.shape[0], xp.shape[1]
738
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
739
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
740
+ x_idx = torch.argmin(x_indices, dim=2)
741
+ cand_start_idx = x_idx - 1
742
+ start_idx = torch.where(
743
+ torch.eq(x_idx, 0),
744
+ torch.tensor(1, device=x.device),
745
+ torch.where(
746
+ torch.eq(x_idx, K),
747
+ torch.tensor(K - 2, device=x.device),
748
+ cand_start_idx,
749
+ ),
750
+ )
751
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
752
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
753
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
754
+ start_idx2 = torch.where(
755
+ torch.eq(x_idx, 0),
756
+ torch.tensor(0, device=x.device),
757
+ torch.where(
758
+ torch.eq(x_idx, K),
759
+ torch.tensor(K - 2, device=x.device),
760
+ cand_start_idx,
761
+ ),
762
+ )
763
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
764
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
765
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
766
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
767
+ return cand
768
+
769
+
770
+ def expand_dims(v, dims):
771
+ """
772
+ Expand the tensor `v` to the dim `dims`.
773
+
774
+ Args:
775
+ `v`: a PyTorch tensor with shape [N].
776
+ `dim`: a `int`.
777
+ Returns:
778
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
779
+ """
780
+ return v[(...,) + (None,) * (dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6
+
7
+
8
+ class DPMSolverSampler(object):
9
+ def __init__(self, model, **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13
+ self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
14
+
15
+ def register_buffer(self, name, attr):
16
+ if type(attr) == torch.Tensor:
17
+ if attr.device != torch.device("cuda"):
18
+ attr = attr.to(torch.device("cuda"))
19
+ setattr(self, name, attr)
20
+
21
+ @torch.no_grad()
22
+ def sample(
23
+ self,
24
+ S,
25
+ batch_size,
26
+ shape,
27
+ conditioning=None,
28
+ callback=None,
29
+ normals_sequence=None,
30
+ img_callback=None,
31
+ quantize_x0=False,
32
+ eta=0.0,
33
+ mask=None,
34
+ x0=None,
35
+ temperature=1.0,
36
+ noise_dropout=0.0,
37
+ score_corrector=None,
38
+ corrector_kwargs=None,
39
+ verbose=True,
40
+ x_T=None,
41
+ log_every_t=100,
42
+ unconditional_guidance_scale=1.0,
43
+ unconditional_conditioning=None,
44
+ flags=None,
45
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
46
+ **kwargs,
47
+ ):
48
+ if conditioning is not None:
49
+ if isinstance(conditioning, dict):
50
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
51
+ if cbs != batch_size:
52
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
53
+ else:
54
+ if conditioning.shape[0] != batch_size:
55
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
56
+
57
+ # sampling
58
+ C, H, W = shape
59
+ size = (batch_size, C, H, W)
60
+
61
+ # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
62
+
63
+ device = self.model.betas.device
64
+ if x_T is None:
65
+ img = torch.randn(size, device=device)
66
+ else:
67
+ img = x_T
68
+
69
+ ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
70
+
71
+ if conditioning is None:
72
+ model_fn = model_wrapper(
73
+ lambda x, t, c: self.model.apply_model(x, t, c),
74
+ ns,
75
+ model_type="noise",
76
+ guidance_type="uncond",
77
+ )
78
+ ORDER = 3
79
+ else:
80
+ model_fn = model_wrapper(
81
+ lambda x, t, c: self.model.apply_model(x, t, c),
82
+ ns,
83
+ model_type="noise",
84
+ guidance_type="classifier-free",
85
+ condition=conditioning,
86
+ unconditional_condition=unconditional_conditioning,
87
+ guidance_scale=unconditional_guidance_scale,
88
+ )
89
+ ORDER = 2
90
+
91
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
92
+ x = dpm_solver.sample(
93
+ img, steps=S, skip_type=flags.skip_type, method="multistep", order=ORDER, lower_order_final=True, flags=flags
94
+ )
95
+ return x.to(device), None
ldm/models/diffusion/dpm_solver_v3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverv3Sampler
ldm/models/diffusion/dpm_solver_v3/dpm_solver_v3.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ class NoiseScheduleVP:
9
+ def __init__(
10
+ self,
11
+ schedule="discrete",
12
+ betas=None,
13
+ alphas_cumprod=None,
14
+ continuous_beta_0=0.1,
15
+ continuous_beta_1=20.0,
16
+ ):
17
+ """Create a wrapper class for the forward SDE (VP type).
18
+
19
+ ***
20
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
21
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
22
+ ***
23
+
24
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
25
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
26
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
27
+
28
+ log_alpha_t = self.marginal_log_mean_coeff(t)
29
+ sigma_t = self.marginal_std(t)
30
+ lambda_t = self.marginal_lambda(t)
31
+
32
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
33
+
34
+ t = self.inverse_lambda(lambda_t)
35
+
36
+ ===============================================================
37
+
38
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
39
+
40
+ 1. For discrete-time DPMs:
41
+
42
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
43
+ t_i = (i + 1) / N
44
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
45
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
46
+
47
+ Args:
48
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
49
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
50
+
51
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
52
+
53
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
54
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
55
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
56
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
57
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
58
+ and
59
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
60
+
61
+
62
+ 2. For continuous-time DPMs:
63
+
64
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
65
+ schedule are the default settings in DDPM and improved-DDPM:
66
+
67
+ Args:
68
+ beta_min: A `float` number. The smallest beta for the linear schedule.
69
+ beta_max: A `float` number. The largest beta for the linear schedule.
70
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
71
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
72
+ T: A `float` number. The ending time of the forward process.
73
+
74
+ ===============================================================
75
+
76
+ Args:
77
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
78
+ 'linear' or 'cosine' for continuous-time DPMs.
79
+ Returns:
80
+ A wrapper object of the forward SDE (VP type).
81
+
82
+ ===============================================================
83
+
84
+ Example:
85
+
86
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
87
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
88
+
89
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
90
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
91
+
92
+ # For continuous-time DPMs (VPSDE), linear schedule:
93
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
94
+
95
+ """
96
+
97
+ if schedule not in ["discrete", "linear", "cosine"]:
98
+ raise ValueError(
99
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
100
+ schedule
101
+ )
102
+ )
103
+
104
+ self.schedule = schedule
105
+ if schedule == "discrete":
106
+ if betas is not None:
107
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
108
+ else:
109
+ assert alphas_cumprod is not None
110
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
111
+ self.total_N = len(log_alphas)
112
+ self.T = 1.0
113
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
114
+ self.log_alpha_array = log_alphas.reshape(
115
+ (
116
+ 1,
117
+ -1,
118
+ )
119
+ )
120
+ else:
121
+ self.total_N = 1000
122
+ self.beta_0 = continuous_beta_0
123
+ self.beta_1 = continuous_beta_1
124
+ self.cosine_s = 0.008
125
+ self.cosine_beta_max = 999.0
126
+ self.cosine_t_max = (
127
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
128
+ * 2.0
129
+ * (1.0 + self.cosine_s)
130
+ / math.pi
131
+ - self.cosine_s
132
+ )
133
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
134
+ self.schedule = schedule
135
+ if schedule == "cosine":
136
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
137
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
138
+ self.T = 0.9946
139
+ else:
140
+ self.T = 1.0
141
+
142
+ def marginal_log_mean_coeff(self, t):
143
+ """
144
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
145
+ """
146
+ if self.schedule == "discrete":
147
+ return interpolate_fn(
148
+ t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
149
+ ).reshape((-1))
150
+ elif self.schedule == "linear":
151
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
152
+ elif self.schedule == "cosine":
153
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
154
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
155
+ return log_alpha_t
156
+
157
+ def marginal_alpha(self, t):
158
+ """
159
+ Compute alpha_t of a given continuous-time label t in [0, T].
160
+ """
161
+ return torch.exp(self.marginal_log_mean_coeff(t))
162
+
163
+ def marginal_std(self, t):
164
+ """
165
+ Compute sigma_t of a given continuous-time label t in [0, T].
166
+ """
167
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
168
+
169
+ def marginal_lambda(self, t):
170
+ """
171
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
172
+ """
173
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
174
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
175
+ return log_mean_coeff - log_std
176
+
177
+ def inverse_lambda(self, lamb):
178
+ """
179
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
180
+ """
181
+ if self.schedule == "linear":
182
+ tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
183
+ Delta = self.beta_0**2 + tmp
184
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
185
+ elif self.schedule == "discrete":
186
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
187
+ t = interpolate_fn(
188
+ log_alpha.reshape((-1, 1)),
189
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
190
+ torch.flip(self.t_array.to(lamb.device), [1]),
191
+ )
192
+ return t.reshape((-1,))
193
+ else:
194
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
195
+ t_fn = (
196
+ lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
197
+ * 2.0
198
+ * (1.0 + self.cosine_s)
199
+ / math.pi
200
+ - self.cosine_s
201
+ )
202
+ t = t_fn(log_alpha)
203
+ return t
204
+
205
+
206
+ def model_wrapper(
207
+ model,
208
+ noise_schedule,
209
+ model_type="noise",
210
+ model_kwargs={},
211
+ guidance_type="uncond",
212
+ condition=None,
213
+ unconditional_condition=None,
214
+ guidance_scale=1.0,
215
+ classifier_fn=None,
216
+ classifier_kwargs={},
217
+ ):
218
+ """Create a wrapper function for the noise prediction model.
219
+
220
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
221
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
222
+
223
+ We support four types of the diffusion model by setting `model_type`:
224
+
225
+ 1. "noise": noise prediction model. (Trained by predicting noise).
226
+
227
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
228
+
229
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
230
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
231
+
232
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
233
+ arXiv preprint arXiv:2202.00512 (2022).
234
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
235
+ arXiv preprint arXiv:2210.02303 (2022).
236
+
237
+ 4. "score": marginal score function. (Trained by denoising score matching).
238
+ Note that the score function and the noise prediction model follows a simple relationship:
239
+ ```
240
+ noise(x_t, t) = -sigma_t * score(x_t, t)
241
+ ```
242
+
243
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
244
+ 1. "uncond": unconditional sampling by DPMs.
245
+ The input `model` has the following format:
246
+ ``
247
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
248
+ ``
249
+
250
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
251
+ The input `model` has the following format:
252
+ ``
253
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
254
+ ``
255
+
256
+ The input `classifier_fn` has the following format:
257
+ ``
258
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
259
+ ``
260
+
261
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
262
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
263
+
264
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
265
+ The input `model` has the following format:
266
+ ``
267
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
268
+ ``
269
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
270
+
271
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
272
+ arXiv preprint arXiv:2207.12598 (2022).
273
+
274
+
275
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
276
+ or continuous-time labels (i.e. epsilon to T).
277
+
278
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
279
+ ``
280
+ def model_fn(x, t_continuous) -> noise:
281
+ t_input = get_model_input_time(t_continuous)
282
+ return noise_pred(model, x, t_input, **model_kwargs)
283
+ ``
284
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
285
+
286
+ ===============================================================
287
+
288
+ Args:
289
+ model: A diffusion model with the corresponding format described above.
290
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
291
+ model_type: A `str`. The parameterization type of the diffusion model.
292
+ "noise" or "x_start" or "v" or "score".
293
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
294
+ guidance_type: A `str`. The type of the guidance for sampling.
295
+ "uncond" or "classifier" or "classifier-free".
296
+ condition: A pytorch tensor. The condition for the guided sampling.
297
+ Only used for "classifier" or "classifier-free" guidance type.
298
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
299
+ Only used for "classifier-free" guidance type.
300
+ guidance_scale: A `float`. The scale for the guided sampling.
301
+ classifier_fn: A classifier function. Only used for the classifier guidance.
302
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
303
+ Returns:
304
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
305
+ """
306
+
307
+ def get_model_input_time(t_continuous):
308
+ """
309
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
310
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
311
+ For continuous-time DPMs, we just use `t_continuous`.
312
+ """
313
+ if noise_schedule.schedule == "discrete":
314
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
315
+ else:
316
+ return t_continuous
317
+
318
+ def noise_pred_fn(x, t_continuous, cond=None):
319
+ if t_continuous.reshape((-1,)).shape[0] == 1:
320
+ t_continuous = t_continuous.expand((x.shape[0]))
321
+ t_input = get_model_input_time(t_continuous)
322
+ if cond is None:
323
+ output = model(x, t_input, None, **model_kwargs)
324
+ else:
325
+ output = model(x, t_input, cond, **model_kwargs)
326
+ if model_type == "noise":
327
+ return output
328
+ elif model_type == "x_start":
329
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
330
+ dims = x.dim()
331
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
332
+ elif model_type == "v":
333
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
334
+ dims = x.dim()
335
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
336
+ elif model_type == "score":
337
+ sigma_t = noise_schedule.marginal_std(t_continuous)
338
+ dims = x.dim()
339
+ return -expand_dims(sigma_t, dims) * output
340
+
341
+ def cond_grad_fn(x, t_input):
342
+ """
343
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
344
+ """
345
+ with torch.enable_grad():
346
+ x_in = x.detach().requires_grad_(True)
347
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
348
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
349
+
350
+ def model_fn(x, t_continuous):
351
+ """
352
+ The noise predicition model function that is used for DPM-Solver.
353
+ """
354
+ if t_continuous.reshape((-1,)).shape[0] == 1:
355
+ t_continuous = t_continuous.expand((x.shape[0]))
356
+ if guidance_type == "uncond":
357
+ return noise_pred_fn(x, t_continuous)
358
+ elif guidance_type == "classifier":
359
+ assert classifier_fn is not None
360
+ t_input = get_model_input_time(t_continuous)
361
+ cond_grad = cond_grad_fn(x, t_input)
362
+ sigma_t = noise_schedule.marginal_std(t_continuous)
363
+ noise = noise_pred_fn(x, t_continuous)
364
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
365
+ elif guidance_type == "classifier-free":
366
+ if guidance_scale == 1.0 or unconditional_condition is None:
367
+ return noise_pred_fn(x, t_continuous, cond=condition)
368
+ else:
369
+ x_in = torch.cat([x] * 2)
370
+ t_in = torch.cat([t_continuous] * 2)
371
+ c_in = torch.cat([unconditional_condition, condition])
372
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
373
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
374
+
375
+ assert model_type in ["noise", "x_start", "v"]
376
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
377
+ return model_fn
378
+
379
+
380
+ def weighted_cumsumexp_trapezoid(a, x, b, cumsum=True):
381
+ # ∫ b*e^a dx
382
+ # Input: a,x,b: shape (N+1,...)
383
+ # Output: y: shape (N+1,...)
384
+ # y_0 = 0
385
+ # y_n = sum_{i=1}^{n} 0.5*(x_{i}-x_{i-1})*(b_{i}*e^{a_{i}}+b_{i-1}*e^{a_{i-1}}) (n from 1 to N)
386
+
387
+ assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
388
+ if b is not None:
389
+ assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
390
+
391
+ a_max = np.amax(a, axis=0, keepdims=True)
392
+
393
+ if b is not None:
394
+ b = np.asarray(b)
395
+ tmp = b * np.exp(a - a_max)
396
+ else:
397
+ tmp = np.exp(a - a_max)
398
+
399
+ out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
400
+ if not cumsum:
401
+ return np.sum(out, axis=0) * np.exp(a_max)
402
+ out = np.cumsum(out, axis=0)
403
+ out *= np.exp(a_max)
404
+ return np.concatenate([np.zeros_like(out[[0]]), out], axis=0)
405
+
406
+
407
+ def weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=True):
408
+ assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
409
+ if b is not None:
410
+ assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
411
+
412
+ a_max = torch.amax(a, dim=0, keepdims=True)
413
+
414
+ if b is not None:
415
+ tmp = b * torch.exp(a - a_max)
416
+ else:
417
+ tmp = torch.exp(a - a_max)
418
+
419
+ out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
420
+ if not cumsum:
421
+ return torch.sum(out, dim=0) * torch.exp(a_max)
422
+ out = torch.cumsum(out, dim=0)
423
+ out *= torch.exp(a_max)
424
+ return torch.concat([torch.zeros_like(out[[0]]), out], dim=0)
425
+
426
+
427
+ def index_list(lst, index):
428
+ new_lst = []
429
+ for i in index:
430
+ new_lst.append(lst[i])
431
+ return new_lst
432
+
433
+
434
+ class DPM_Solver_v3:
435
+ def __init__(
436
+ self,
437
+ statistics_dir,
438
+ noise_schedule,
439
+ steps=10,
440
+ t_start=None,
441
+ t_end=None,
442
+ skip_type="time_uniform",
443
+ degenerated=False,
444
+ device="cuda",
445
+ ):
446
+ self.device = device
447
+ self.model = None
448
+ self.noise_schedule = noise_schedule
449
+ self.steps = steps
450
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
451
+ t_T = self.noise_schedule.T if t_start is None else t_start
452
+ assert (
453
+ t_0 > 0 and t_T > 0
454
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
455
+
456
+ l = np.load(os.path.join(statistics_dir, "l.npz"))["l"]
457
+ sb = np.load(os.path.join(statistics_dir, "sb.npz"))
458
+ s, b = sb["s"], sb["b"]
459
+ if degenerated:
460
+ l = np.ones_like(l)
461
+ s = np.zeros_like(s)
462
+ b = np.zeros_like(b)
463
+ self.statistics_steps = l.shape[0] - 1
464
+ ts = noise_schedule.marginal_lambda(
465
+ self.get_time_steps("logSNR", t_T, t_0, self.statistics_steps, "cpu")
466
+ ).numpy()[:, None, None, None]
467
+ self.ts = torch.from_numpy(ts).cuda()
468
+ self.lambda_T = self.ts[0].cpu().item()
469
+ self.lambda_0 = self.ts[-1].cpu().item()
470
+ z = np.zeros_like(l)
471
+ o = np.ones_like(l)
472
+ L = weighted_cumsumexp_trapezoid(z, ts, l)
473
+ S = weighted_cumsumexp_trapezoid(z, ts, s)
474
+
475
+ I = weighted_cumsumexp_trapezoid(L + S, ts, o)
476
+ B = weighted_cumsumexp_trapezoid(-S, ts, b)
477
+ C = weighted_cumsumexp_trapezoid(L + S, ts, B)
478
+ self.l = torch.from_numpy(l).cuda()
479
+ self.s = torch.from_numpy(s).cuda()
480
+ self.b = torch.from_numpy(b).cuda()
481
+ self.L = torch.from_numpy(L).cuda()
482
+ self.S = torch.from_numpy(S).cuda()
483
+ self.I = torch.from_numpy(I).cuda()
484
+ self.B = torch.from_numpy(B).cuda()
485
+ self.C = torch.from_numpy(C).cuda()
486
+
487
+ # precompute timesteps
488
+ if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic":
489
+ self.timesteps = self.get_time_steps(skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
490
+ self.indexes = self.convert_to_indexes(self.timesteps)
491
+ self.timesteps = self.convert_to_timesteps(self.indexes, device)
492
+ elif skip_type == "edm":
493
+ self.indexes, self.timesteps = self.get_timesteps_edm(N=steps, device=device)
494
+ self.timesteps = self.convert_to_timesteps(self.indexes, device)
495
+ else:
496
+ raise ValueError(f"Unsupported timestep strategy {skip_type}")
497
+
498
+ print("Indexes", self.indexes)
499
+ print("Time steps", self.timesteps)
500
+ print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
501
+
502
+ # store high-order exponential coefficients (lazy)
503
+ self.exp_coeffs = {}
504
+
505
+ def noise_prediction_fn(self, x, t):
506
+ """
507
+ Return the noise prediction model.
508
+ """
509
+ return self.model(x, t)
510
+
511
+ def convert_to_indexes(self, timesteps):
512
+ logSNR_steps = self.noise_schedule.marginal_lambda(timesteps)
513
+ indexes = list(
514
+ (self.statistics_steps * (logSNR_steps - self.lambda_T) / (self.lambda_0 - self.lambda_T))
515
+ .round()
516
+ .cpu()
517
+ .numpy()
518
+ .astype(np.int64)
519
+ )
520
+ return indexes
521
+
522
+ def convert_to_timesteps(self, indexes, device):
523
+ logSNR_steps = (
524
+ self.lambda_T + (self.lambda_0 - self.lambda_T) * torch.Tensor(indexes).to(device) / self.statistics_steps
525
+ )
526
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
527
+
528
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
529
+ """Compute the intermediate time steps for sampling.
530
+
531
+ Args:
532
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
533
+ - 'logSNR': uniform logSNR for the time steps.
534
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
535
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
536
+ t_T: A `float`. The starting time of the sampling (default is T).
537
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
538
+ N: A `int`. The total number of the spacing of the time steps.
539
+ device: A torch device.
540
+ Returns:
541
+ A pytorch tensor of the time steps, with the shape (N + 1,).
542
+ """
543
+ if skip_type == "logSNR":
544
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
545
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
546
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
547
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
548
+ elif skip_type == "time_uniform":
549
+ return torch.linspace(t_T, t_0, N + 1).to(device)
550
+ elif skip_type == "time_quadratic":
551
+ t_order = 2
552
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
553
+ return t
554
+ else:
555
+ raise ValueError(
556
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
557
+ )
558
+
559
+ def get_timesteps_edm(self, N, device):
560
+ """Constructs the noise schedule of Karras et al. (2022)."""
561
+
562
+ rho = 7.0 # 7.0 is the value used in the paper
563
+
564
+ sigma_min: float = np.exp(-self.lambda_0)
565
+ sigma_max: float = np.exp(-self.lambda_T)
566
+ ramp = np.linspace(0, 1, N + 1)
567
+ min_inv_rho = sigma_min ** (1 / rho)
568
+ max_inv_rho = sigma_max ** (1 / rho)
569
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
570
+ lambdas = torch.Tensor(-np.log(sigmas)).to(device)
571
+ timesteps = self.noise_schedule.inverse_lambda(lambdas)
572
+
573
+ indexes = list(
574
+ (self.statistics_steps * (lambdas - self.lambda_T) / (self.lambda_0 - self.lambda_T))
575
+ .round()
576
+ .cpu()
577
+ .numpy()
578
+ .astype(np.int64)
579
+ )
580
+ return indexes, timesteps
581
+
582
+ def get_g(self, f_t, i_s, i_t):
583
+ return torch.exp(self.S[i_s] - self.S[i_t]) * f_t - torch.exp(self.S[i_s]) * (self.B[i_t] - self.B[i_s])
584
+
585
+ def compute_exponential_coefficients_high_order(self, i_s, i_t, order=2):
586
+ key = (i_s, i_t, order)
587
+ if key in self.exp_coeffs.keys():
588
+ coeffs = self.exp_coeffs[key]
589
+ else:
590
+ n = order - 1
591
+ a = self.L[i_s : i_t + 1] + self.S[i_s : i_t + 1] - self.L[i_s] - self.S[i_s]
592
+ x = self.ts[i_s : i_t + 1]
593
+ b = (self.ts[i_s : i_t + 1] - self.ts[i_s]) ** n / math.factorial(n)
594
+ coeffs = weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=False)
595
+ self.exp_coeffs[key] = coeffs
596
+ return coeffs
597
+
598
+ def compute_high_order_derivatives(self, n, lambda_0n, g_0n, pseudo=False):
599
+ # return g^(1), ..., g^(n)
600
+ if pseudo:
601
+ D = [[] for _ in range(n + 1)]
602
+ D[0] = g_0n
603
+ for i in range(1, n + 1):
604
+ for j in range(n - i + 1):
605
+ D[i].append((D[i - 1][j] - D[i - 1][j + 1]) / (lambda_0n[j] - lambda_0n[i + j]))
606
+
607
+ return [D[i][0] * math.factorial(i) for i in range(1, n + 1)]
608
+ else:
609
+ R = []
610
+ for i in range(1, n + 1):
611
+ R.append(torch.pow(lambda_0n[1:] - lambda_0n[0], i))
612
+ R = torch.stack(R).t()
613
+ B = (torch.stack(g_0n[1:]) - g_0n[0]).reshape(n, -1)
614
+ shape = g_0n[0].shape
615
+ solution = torch.linalg.inv(R) @ B
616
+ solution = solution.reshape([n] + list(shape))
617
+ return [solution[i - 1] * math.factorial(i) for i in range(1, n + 1)]
618
+
619
+ def multistep_predictor_update(self, x_lst, eps_lst, time_lst, index_lst, t, i_t, order=1, pseudo=False):
620
+ # x_lst: [..., x_s]
621
+ # eps_lst: [..., eps_s]
622
+ # time_lst: [..., time_s]
623
+ ns = self.noise_schedule
624
+ n = order - 1
625
+ indexes = [-i - 1 for i in range(n + 1)]
626
+ x_0n = index_list(x_lst, indexes)
627
+ eps_0n = index_list(eps_lst, indexes)
628
+ time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
629
+ index_0n = index_list(index_lst, indexes)
630
+ lambda_0n = ns.marginal_lambda(time_0n)
631
+ alpha_0n = ns.marginal_alpha(time_0n)
632
+ sigma_0n = ns.marginal_std(time_0n)
633
+
634
+ alpha_s, alpha_t = alpha_0n[0], ns.marginal_alpha(t)
635
+ i_s = index_0n[0]
636
+ x_s = x_0n[0]
637
+ g_0n = []
638
+ for i in range(n + 1):
639
+ f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
640
+ g_i = self.get_g(f_i, index_0n[0], index_0n[i])
641
+ g_0n.append(g_i)
642
+ g_0 = g_0n[0]
643
+ x_t = (
644
+ alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
645
+ - alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
646
+ - alpha_t
647
+ * torch.exp(-self.L[i_t])
648
+ * (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
649
+ )
650
+ if order > 1:
651
+ g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
652
+ for i in range(order - 1):
653
+ x_t = (
654
+ x_t
655
+ - alpha_t
656
+ * torch.exp(self.L[i_s] - self.L[i_t])
657
+ * self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
658
+ * g_d[i]
659
+ )
660
+ return x_t
661
+
662
+ def multistep_corrector_update(self, x_lst, eps_lst, time_lst, index_lst, order=1, pseudo=False):
663
+ # x_lst: [..., x_s, x_t]
664
+ # eps_lst: [..., eps_s, eps_t]
665
+ # lambda_lst: [..., lambda_s, lambda_t]
666
+ ns = self.noise_schedule
667
+ n = order - 1
668
+ indexes = [-i - 1 for i in range(n + 1)]
669
+ indexes[0] = -2
670
+ indexes[1] = -1
671
+ x_0n = index_list(x_lst, indexes)
672
+ eps_0n = index_list(eps_lst, indexes)
673
+ time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
674
+ index_0n = index_list(index_lst, indexes)
675
+ lambda_0n = ns.marginal_lambda(time_0n)
676
+ alpha_0n = ns.marginal_alpha(time_0n)
677
+ sigma_0n = ns.marginal_std(time_0n)
678
+
679
+ alpha_s, alpha_t = alpha_0n[0], alpha_0n[1]
680
+ i_s, i_t = index_0n[0], index_0n[1]
681
+ x_s = x_0n[0]
682
+ g_0n = []
683
+ for i in range(n + 1):
684
+ f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
685
+ g_i = self.get_g(f_i, index_0n[0], index_0n[i])
686
+ g_0n.append(g_i)
687
+ g_0 = g_0n[0]
688
+ x_t_new = (
689
+ alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
690
+ - alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
691
+ - alpha_t
692
+ * torch.exp(-self.L[i_t])
693
+ * (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
694
+ )
695
+ if order > 1:
696
+ g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
697
+ for i in range(order - 1):
698
+ x_t_new = (
699
+ x_t_new
700
+ - alpha_t
701
+ * torch.exp(self.L[i_s] - self.L[i_t])
702
+ * self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
703
+ * g_d[i]
704
+ )
705
+ return x_t_new
706
+
707
+ def sample(
708
+ self,
709
+ x,
710
+ model_fn,
711
+ order,
712
+ p_pseudo,
713
+ use_corrector,
714
+ c_pseudo,
715
+ lower_order_final,
716
+ half=False,
717
+ return_intermediate=False,
718
+ ):
719
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
720
+ steps = self.steps
721
+ cached_x = []
722
+ cached_model_output = []
723
+ cached_time = []
724
+ cached_index = []
725
+ indexes, timesteps = self.indexes, self.timesteps
726
+ step_p_order = 0
727
+
728
+ for step in range(1, steps + 1):
729
+ cached_x.append(x)
730
+ cached_model_output.append(self.noise_prediction_fn(x, timesteps[step - 1]))
731
+ cached_time.append(timesteps[step - 1])
732
+ cached_index.append(indexes[step - 1])
733
+ if use_corrector and (timesteps[step - 1] > 0.5 or not half):
734
+ step_c_order = step_p_order + c_pseudo
735
+ if step_c_order > 1:
736
+ x_new = self.multistep_corrector_update(
737
+ cached_x, cached_model_output, cached_time, cached_index, order=step_c_order, pseudo=c_pseudo
738
+ )
739
+ sigma_t = self.noise_schedule.marginal_std(cached_time[-1])
740
+ l_t = self.l[cached_index[-1]]
741
+ N_old = sigma_t * cached_model_output[-1] - l_t * cached_x[-1]
742
+ cached_x[-1] = x_new
743
+ cached_model_output[-1] = (N_old + l_t * cached_x[-1]) / sigma_t
744
+ if step < order:
745
+ step_p_order = step
746
+ else:
747
+ step_p_order = order
748
+ if lower_order_final:
749
+ step_p_order = min(step_p_order, steps + 1 - step)
750
+ t = timesteps[step]
751
+ i_t = indexes[step]
752
+
753
+ x = self.multistep_predictor_update(
754
+ cached_x, cached_model_output, cached_time, cached_index, t, i_t, order=step_p_order, pseudo=p_pseudo
755
+ )
756
+
757
+ if return_intermediate:
758
+ return x, cached_x
759
+ else:
760
+ return x
761
+
762
+
763
+ #############################################################
764
+ # other utility functions
765
+ #############################################################
766
+
767
+
768
+ def interpolate_fn(x, xp, yp):
769
+ """
770
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
771
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
772
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
773
+
774
+ Args:
775
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
776
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
777
+ yp: PyTorch tensor with shape [C, K].
778
+ Returns:
779
+ The function values f(x), with shape [N, C].
780
+ """
781
+ N, K = x.shape[0], xp.shape[1]
782
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
783
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
784
+ x_idx = torch.argmin(x_indices, dim=2)
785
+ cand_start_idx = x_idx - 1
786
+ start_idx = torch.where(
787
+ torch.eq(x_idx, 0),
788
+ torch.tensor(1, device=x.device),
789
+ torch.where(
790
+ torch.eq(x_idx, K),
791
+ torch.tensor(K - 2, device=x.device),
792
+ cand_start_idx,
793
+ ),
794
+ )
795
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
796
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
797
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
798
+ start_idx2 = torch.where(
799
+ torch.eq(x_idx, 0),
800
+ torch.tensor(0, device=x.device),
801
+ torch.where(
802
+ torch.eq(x_idx, K),
803
+ torch.tensor(K - 2, device=x.device),
804
+ cand_start_idx,
805
+ ),
806
+ )
807
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
808
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
809
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
810
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
811
+ return cand
812
+
813
+
814
+ def expand_dims(v, dims):
815
+ """
816
+ Expand the tensor `v` to the dim `dims`.
817
+
818
+ Args:
819
+ `v`: a PyTorch tensor with shape [N].
820
+ `dim`: a `int`.
821
+ Returns:
822
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
823
+ """
824
+ return v[(...,) + (None,) * (dims - 1)]
ldm/models/diffusion/dpm_solver_v3/sampler.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
6
+
7
+
8
+ class DPMSolverv3Sampler:
9
+ def __init__(self, ckp_path, stats_dir, model, steps, guidance_scale, **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13
+ self.alphas_cumprod = to_torch(model.alphas_cumprod)
14
+ self.device = self.model.betas.device
15
+ self.guidance_scale = guidance_scale
16
+
17
+ self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
18
+
19
+ assert stats_dir is not None, f"No statistics file found in {stats_base}."
20
+ print("Use statistics", stats_dir)
21
+ self.dpm_solver_v3 = DPM_Solver_v3(
22
+ statistics_dir=stats_dir,
23
+ noise_schedule=self.ns,
24
+ steps=steps,
25
+ t_start=None,
26
+ t_end=None,
27
+ skip_type="time_uniform",
28
+ degenerated=False,
29
+ device=self.device,
30
+ )
31
+ self.steps = steps
32
+
33
+ @torch.no_grad()
34
+ def sample(
35
+ self,
36
+ batch_size,
37
+ shape,
38
+ conditioning=None,
39
+ x_T=None,
40
+ unconditional_conditioning=None,
41
+ use_corrector=False,
42
+ half=False,
43
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44
+ **kwargs,
45
+ ):
46
+ if conditioning is not None:
47
+ if isinstance(conditioning, dict):
48
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49
+ if cbs != batch_size:
50
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51
+ else:
52
+ if conditioning.shape[0] != batch_size:
53
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54
+
55
+ # sampling
56
+ C, H, W = shape
57
+ size = (batch_size, C, H, W)
58
+
59
+ if x_T is None:
60
+ img = torch.randn(size, device=self.device)
61
+ else:
62
+ img = x_T
63
+
64
+ if conditioning is None:
65
+ model_fn = model_wrapper(
66
+ lambda x, t, c: self.model.apply_model(x, t, c),
67
+ self.ns,
68
+ model_type="noise",
69
+ guidance_type="uncond",
70
+ )
71
+ ORDER = 3
72
+ else:
73
+ model_fn = model_wrapper(
74
+ lambda x, t, c: self.model.apply_model(x, t, c),
75
+ self.ns,
76
+ model_type="noise",
77
+ guidance_type="classifier-free",
78
+ condition=conditioning,
79
+ unconditional_condition=unconditional_conditioning,
80
+ guidance_scale=self.guidance_scale,
81
+ )
82
+ ORDER = 2
83
+
84
+ x = self.dpm_solver_v3.sample(
85
+ img,
86
+ model_fn,
87
+ order=ORDER,
88
+ p_pseudo=False,
89
+ c_pseudo=True,
90
+ lower_order_final=True,
91
+ use_corrector=use_corrector,
92
+ half=half,
93
+ )
94
+
95
+ return x.to(self.device), None
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def plms_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
138
+
139
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140
+ old_eps = []
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning,
158
+ old_eps=old_eps, t_next=ts_next)
159
+ img, pred_x0, e_t = outs
160
+ old_eps.append(e_t)
161
+ if len(old_eps) >= 4:
162
+ old_eps.pop(0)
163
+ if callback: callback(i)
164
+ if img_callback: img_callback(pred_x0, i)
165
+
166
+ if index % log_every_t == 0 or index == total_steps - 1:
167
+ intermediates['x_inter'].append(img)
168
+ intermediates['pred_x0'].append(pred_x0)
169
+
170
+ return img, intermediates
171
+
172
+ @torch.no_grad()
173
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176
+ b, *_, device = *x.shape, x.device
177
+
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+
218
+ e_t = get_model_output(x, t)
219
+ if len(old_eps) == 0:
220
+ # Pseudo Improved Euler (2nd order)
221
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222
+ e_t_next = get_model_output(x_prev, t_next)
223
+ e_t_prime = (e_t + e_t_next) / 2
224
+ elif len(old_eps) == 1:
225
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
227
+ elif len(old_eps) == 2:
228
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230
+ elif len(old_eps) >= 3:
231
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233
+
234
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235
+
236
+ return x_prev, pred_x0, e_t
ldm/models/diffusion/uni_pc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import UniPCSampler
ldm/models/diffusion/uni_pc/sampler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
6
+
7
+
8
+ class UniPCSampler(object):
9
+ def __init__(self, model, **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13
+ self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
14
+
15
+ def register_buffer(self, name, attr):
16
+ if type(attr) == torch.Tensor:
17
+ if attr.device != torch.device("cuda"):
18
+ attr = attr.to(torch.device("cuda"))
19
+ setattr(self, name, attr)
20
+
21
+ @torch.no_grad()
22
+ def sample(
23
+ self,
24
+ S,
25
+ batch_size,
26
+ shape,
27
+ conditioning=None,
28
+ x_T=None,
29
+ unconditional_guidance_scale=1.0,
30
+ unconditional_conditioning=None,
31
+ flags=None,
32
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
33
+ **kwargs,
34
+ ):
35
+ if conditioning is not None:
36
+ if isinstance(conditioning, dict):
37
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
38
+ if cbs != batch_size:
39
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
40
+ else:
41
+ if conditioning.shape[0] != batch_size:
42
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
43
+
44
+ # sampling
45
+ C, H, W = shape
46
+ size = (batch_size, C, H, W)
47
+
48
+ device = self.model.betas.device
49
+ if x_T is None:
50
+ img = torch.randn(size, device=device)
51
+ else:
52
+ img = x_T
53
+
54
+ ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
55
+
56
+ if conditioning is None:
57
+ model_fn = model_wrapper(
58
+ lambda x, t, c: self.model.apply_model(x, t, c),
59
+ ns,
60
+ model_type="noise",
61
+ guidance_type="uncond",
62
+ )
63
+ ORDER = 3
64
+ else:
65
+ model_fn = model_wrapper(
66
+ lambda x, t, c: self.model.apply_model(x, t, c),
67
+ ns,
68
+ model_type="noise",
69
+ guidance_type="classifier-free",
70
+ condition=conditioning,
71
+ unconditional_condition=unconditional_conditioning,
72
+ guidance_scale=unconditional_guidance_scale, # 7.5
73
+ )
74
+ ORDER = 2
75
+
76
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant="bh2")
77
+ x = uni_pc.sample(
78
+ img, steps=S, skip_type=flags.skip_type, method="multistep", order=ORDER, lower_order_final=True, flags=flags
79
+ )
80
+
81
+ return x.to(device), None
82
+
83
+
ldm/models/diffusion/uni_pc/uni_pc.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule="discrete",
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.0,
14
+ ):
15
+
16
+ if schedule not in ["discrete", "linear", "cosine"]:
17
+ raise ValueError(
18
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
19
+ schedule
20
+ )
21
+ )
22
+
23
+ self.schedule = schedule
24
+ if schedule == "discrete":
25
+ if betas is not None:
26
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
27
+ else:
28
+ assert alphas_cumprod is not None
29
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
30
+ self.total_N = len(log_alphas)
31
+ self.T = 1.0
32
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
33
+ self.log_alpha_array = log_alphas.reshape(
34
+ (
35
+ 1,
36
+ -1,
37
+ )
38
+ )
39
+ else:
40
+ self.total_N = 1000
41
+ self.beta_0 = continuous_beta_0
42
+ self.beta_1 = continuous_beta_1
43
+ self.cosine_s = 0.008
44
+ self.cosine_beta_max = 999.0
45
+ self.cosine_t_max = (
46
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
47
+ * 2.0
48
+ * (1.0 + self.cosine_s)
49
+ / math.pi
50
+ - self.cosine_s
51
+ )
52
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
53
+ self.schedule = schedule
54
+ if schedule == "cosine":
55
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
56
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
57
+ self.T = 0.9946
58
+ else:
59
+ self.T = 1.0
60
+
61
+ def marginal_log_mean_coeff(self, t):
62
+ """
63
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
64
+ """
65
+ if self.schedule == "discrete":
66
+ return interpolate_fn(
67
+ t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
68
+ ).reshape((-1))
69
+ elif self.schedule == "linear":
70
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
71
+ elif self.schedule == "cosine":
72
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
73
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
74
+ return log_alpha_t
75
+
76
+ def marginal_alpha(self, t):
77
+ """
78
+ Compute alpha_t of a given continuous-time label t in [0, T].
79
+ """
80
+ return torch.exp(self.marginal_log_mean_coeff(t))
81
+
82
+ def marginal_std(self, t):
83
+ """
84
+ Compute sigma_t of a given continuous-time label t in [0, T].
85
+ """
86
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
87
+
88
+ def marginal_lambda(self, t):
89
+ """
90
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
91
+ """
92
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
93
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
94
+ return log_mean_coeff - log_std
95
+
96
+ def inverse_lambda(self, lamb):
97
+ """
98
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
99
+ """
100
+ if self.schedule == "linear":
101
+ tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
102
+ Delta = self.beta_0**2 + tmp
103
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
104
+ elif self.schedule == "discrete":
105
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
106
+ t = interpolate_fn(
107
+ log_alpha.reshape((-1, 1)),
108
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
109
+ torch.flip(self.t_array.to(lamb.device), [1]),
110
+ )
111
+ return t.reshape((-1,))
112
+ else:
113
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
114
+ t_fn = (
115
+ lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
116
+ * 2.0
117
+ * (1.0 + self.cosine_s)
118
+ / math.pi
119
+ - self.cosine_s
120
+ )
121
+ t = t_fn(log_alpha)
122
+ return t
123
+
124
+
125
+ def model_wrapper(
126
+ model,
127
+ noise_schedule,
128
+ model_type="noise",
129
+ model_kwargs={},
130
+ guidance_type="uncond",
131
+ condition=None,
132
+ unconditional_condition=None,
133
+ guidance_scale=1.0,
134
+ classifier_fn=None,
135
+ classifier_kwargs={},
136
+ ):
137
+
138
+ def get_model_input_time(t_continuous):
139
+ """
140
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
141
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
142
+ For continuous-time DPMs, we just use `t_continuous`.
143
+ """
144
+ if noise_schedule.schedule == "discrete":
145
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
146
+ else:
147
+ return t_continuous
148
+
149
+ def noise_pred_fn(x, t_continuous, cond=None):
150
+ if t_continuous.reshape((-1,)).shape[0] == 1:
151
+ t_continuous = t_continuous.expand((x.shape[0]))
152
+ t_input = get_model_input_time(t_continuous)
153
+ if cond is None:
154
+ output = model(x, t_input, None, **model_kwargs)
155
+ else:
156
+ output = model(x, t_input, cond, **model_kwargs)
157
+ if model_type == "noise":
158
+ return output
159
+ elif model_type == "x_start":
160
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
161
+ dims = x.dim()
162
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
163
+ elif model_type == "v":
164
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
165
+ dims = x.dim()
166
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
167
+ elif model_type == "score":
168
+ sigma_t = noise_schedule.marginal_std(t_continuous)
169
+ dims = x.dim()
170
+ return -expand_dims(sigma_t, dims) * output
171
+
172
+ def cond_grad_fn(x, t_input):
173
+ """
174
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
175
+ """
176
+ with torch.enable_grad():
177
+ x_in = x.detach().requires_grad_(True)
178
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
179
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
180
+
181
+ def model_fn(x, t_continuous):
182
+ """
183
+ The noise predicition model function that is used for DPM-Solver.
184
+ """
185
+ if t_continuous.reshape((-1,)).shape[0] == 1:
186
+ t_continuous = t_continuous.expand((x.shape[0]))
187
+ if guidance_type == "uncond":
188
+ return noise_pred_fn(x, t_continuous)
189
+ elif guidance_type == "classifier":
190
+ assert classifier_fn is not None
191
+ t_input = get_model_input_time(t_continuous)
192
+ cond_grad = cond_grad_fn(x, t_input)
193
+ sigma_t = noise_schedule.marginal_std(t_continuous)
194
+ noise = noise_pred_fn(x, t_continuous)
195
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
196
+ elif guidance_type == "classifier-free":
197
+ if guidance_scale == 1.0 or unconditional_condition is None:
198
+ return noise_pred_fn(x, t_continuous, cond=condition)
199
+ else:
200
+ x_in = torch.cat([x] * 2)
201
+ t_in = torch.cat([t_continuous] * 2)
202
+ c_in = torch.cat([unconditional_condition, condition])
203
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
204
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
205
+
206
+ assert model_type in ["noise", "x_start", "v"]
207
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
208
+ return model_fn
209
+
210
+
211
+ class UniPC:
212
+ def __init__(self, model_fn, noise_schedule, predict_x0=True, thresholding=False, max_val=1.0, variant="bh1"):
213
+ """Construct a UniPC.
214
+
215
+ We support both data_prediction and noise_prediction.
216
+ """
217
+ self.model = model_fn
218
+ self.noise_schedule = noise_schedule
219
+ self.variant = variant
220
+ self.predict_x0 = predict_x0
221
+ self.thresholding = thresholding
222
+ self.max_val = max_val
223
+
224
+ def dynamic_thresholding_fn(self, x0, t=None):
225
+ """
226
+ The dynamic thresholding method.
227
+ """
228
+ dims = x0.dim()
229
+ p = self.dynamic_thresholding_ratio
230
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
231
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
232
+ x0 = torch.clamp(x0, -s, s) / s
233
+ return x0
234
+
235
+ def noise_prediction_fn(self, x, t):
236
+ """
237
+ Return the noise prediction model.
238
+ """
239
+ return self.model(x, t)
240
+
241
+ def data_prediction_fn(self, x, t):
242
+ """
243
+ Return the data prediction model (with thresholding).
244
+ """
245
+ noise = self.noise_prediction_fn(x, t)
246
+ dims = x.dim()
247
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
248
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
249
+ if self.thresholding:
250
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
251
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
252
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
253
+ x0 = torch.clamp(x0, -s, s) / s
254
+ return x0
255
+
256
+ def model_fn(self, x, t):
257
+ """
258
+ Convert the model to the noise prediction model or the data prediction model.
259
+ """
260
+ if self.predict_x0:
261
+ return self.data_prediction_fn(x, t)
262
+ else:
263
+ return self.noise_prediction_fn(x, t)
264
+
265
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
266
+ """Compute the intermediate time steps for sampling."""
267
+ if skip_type == "logSNR":
268
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
269
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
270
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
271
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
272
+ elif skip_type == "time_uniform":
273
+ return torch.linspace(t_T, t_0, N + 1).to(device)
274
+ elif skip_type == "time_quadratic":
275
+ t_order = 2
276
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
277
+ return t
278
+ else:
279
+ raise ValueError(
280
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
281
+ )
282
+
283
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, t2, order, **kwargs):
284
+ if len(t.shape) == 0:
285
+ t = t.view(-1)
286
+ if "bh" in self.variant:
287
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, t2, order, **kwargs)
288
+
289
+
290
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, t2, order, x_t=None, use_corrector=True):
291
+ # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
292
+ ns = self.noise_schedule
293
+ assert order <= len(model_prev_list)
294
+ dims = x.dim()
295
+
296
+ # first compute rks
297
+ t_prev_0 = t_prev_list[-1]
298
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
299
+ lambda_t = ns.marginal_lambda(t)
300
+ model_prev_0 = model_prev_list[-1]
301
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
302
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
303
+ alpha_t = torch.exp(log_alpha_t)
304
+
305
+ h = lambda_t - lambda_prev_0
306
+
307
+ rks = []
308
+ D1s = []
309
+ for i in range(1, order):
310
+ t_prev_i = t_prev_list[-(i + 1)]
311
+ model_prev_i = model_prev_list[-(i + 1)]
312
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
313
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
314
+ rks.append(rk)
315
+ D1s.append((model_prev_i - model_prev_0) / rk)
316
+
317
+ rks.append(1.0)
318
+ rks = torch.tensor(rks, device=x.device)
319
+
320
+ R = []
321
+ b = []
322
+
323
+ hh = -h[0] if self.predict_x0 else h[0]
324
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
325
+ h_phi_k = h_phi_1 / hh - 1
326
+
327
+ factorial_i = 1
328
+
329
+ if self.variant == "bh1":
330
+ B_h = hh
331
+ elif self.variant == "bh2":
332
+ B_h = torch.expm1(hh)
333
+ else:
334
+ raise NotImplementedError()
335
+
336
+ for i in range(1, order + 1):
337
+ R.append(torch.pow(rks, i - 1))
338
+ b.append(h_phi_k * factorial_i / B_h)
339
+ factorial_i *= i + 1
340
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
341
+
342
+ R = torch.stack(R)
343
+ b = torch.tensor(b, device=x.device)
344
+
345
+ # now predictor
346
+ use_predictor = len(D1s) > 0 and x_t is None
347
+ if len(D1s) > 0:
348
+ D1s = torch.stack(D1s, dim=1) # (B, K)
349
+ if x_t is None:
350
+ # for order 2, we use a simplified version
351
+ if order == 2:
352
+ rhos_p = torch.tensor([0.5], device=b.device)
353
+ else:
354
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
355
+ else:
356
+ D1s = None
357
+
358
+ if use_corrector:
359
+ # print('using corrector')
360
+ # for order 1, we use a simplified version
361
+ if order == 1:
362
+ rhos_c = torch.tensor([0.5], device=b.device)
363
+ else:
364
+ rhos_c = torch.linalg.solve(R, b)
365
+
366
+ model_t = None
367
+ if self.predict_x0:
368
+ x_t_ = expand_dims(sigma_t / sigma_prev_0, dims) * x - expand_dims(alpha_t * h_phi_1, dims) * model_prev_0
369
+
370
+ if x_t is None:
371
+ if use_predictor:
372
+ pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
373
+ else:
374
+ pred_res = 0
375
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
376
+
377
+ if use_corrector:
378
+ model_t = self.model_fn(x_t, t2)
379
+ if D1s is not None:
380
+ corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
381
+ else:
382
+ corr_res = 0
383
+ D1_t = model_t - model_prev_0
384
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
385
+ else:
386
+ x_t_ = (
387
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
388
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
389
+ )
390
+ if x_t is None:
391
+ if use_predictor:
392
+ pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
393
+ else:
394
+ pred_res = 0
395
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
396
+
397
+ if use_corrector:
398
+ model_t = self.model_fn(x_t, t2)
399
+ if D1s is not None:
400
+ corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
401
+ else:
402
+ corr_res = 0
403
+ D1_t = model_t - model_prev_0
404
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
405
+ return x_t, model_t
406
+
407
+
408
+ def sample(
409
+ self,
410
+ x,
411
+ steps=20,
412
+ t_start=None,
413
+ t_end=None,
414
+ order=3,
415
+ skip_type="time_uniform",
416
+ method="singlestep",
417
+ lower_order_final=True,
418
+ denoise_to_zero=False,
419
+ solver_type="dpm_solver",
420
+ atol=0.0078,
421
+ rtol=0.05,
422
+ corrector=False,
423
+ flags=None,
424
+ ):
425
+
426
+ device = x.device
427
+ assert steps >= order
428
+ with torch.no_grad():
429
+ if flags.learn:
430
+ load_from = f"{flags.log_path}/NFE-{steps}-256LSUN-uni_pc-{order}-decode/best.pt"
431
+ timesteps = torch.load(load_from)['best_t_steps'].to(x.device)
432
+ if flags.vs:
433
+ length = timesteps.shape[0] // 2
434
+ timesteps2 = timesteps[length:]
435
+ timesteps = timesteps[:length]
436
+ else:
437
+ timesteps2 = timesteps
438
+ else:
439
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
440
+ t_T = self.noise_schedule.T if t_start is None else t_start
441
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
442
+ timesteps2 = timesteps
443
+ assert timesteps.shape[0] - 1 == steps
444
+
445
+ def one_step(t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True, use_corrector=True):
446
+ x_next, model_x_next = self.multistep_uni_pc_update(x_next, model_prev_list, t_prev_list, t1, t2, step, use_corrector=use_corrector)
447
+ if model_x_next is None:
448
+ model_x_next = self.model_fn(x_next, t2)
449
+ update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first)
450
+ return x_next
451
+
452
+ def update_lists(t_list, model_list, t_, model_x, order, first=False):
453
+ if first:
454
+ t_list.append(t_)
455
+ model_list.append(model_x)
456
+ return
457
+ for m in range(order - 1):
458
+ t_list[m] = t_list[m + 1]
459
+ model_list[m] = model_list[m + 1]
460
+ t_list[-1] = t_
461
+ model_list[-1] = model_x
462
+
463
+ timesteps1 = timesteps
464
+ step = 0
465
+ vec_t1 = timesteps1[0].expand((x.shape[0])) # bs
466
+ vec_t2 = timesteps2[0].expand((x.shape[0])) # bs
467
+ t_prev_list = [vec_t1]
468
+ model_prev_list = [self.model_fn(x, vec_t2)]
469
+
470
+ for step in range(1, order):
471
+ vec_t1 = timesteps1[step].expand((x.shape[0]))
472
+ vec_t2 = timesteps2[step].expand((x.shape[0]))
473
+ x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step, x, order, first=True)
474
+
475
+ for step in range(order, steps + 1):
476
+ step_order = min(order, steps + 1 - step)
477
+ vec_t1 = timesteps1[step].expand((x.shape[0]))
478
+ vec_t2 = timesteps2[step].expand((x.shape[0]))
479
+ use_corrector = True
480
+ if step == steps:
481
+ use_corrector = False
482
+ x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step_order, x, order, first=False, use_corrector=use_corrector)
483
+ return x
484
+
485
+
486
+ #############################################################
487
+ # other utility functions
488
+ #############################################################
489
+
490
+
491
+ def interpolate_fn(x, xp, yp):
492
+ """
493
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
494
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
495
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
496
+
497
+ Args:
498
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
499
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
500
+ yp: PyTorch tensor with shape [C, K].
501
+ Returns:
502
+ The function values f(x), with shape [N, C].
503
+ """
504
+ N, K = x.shape[0], xp.shape[1]
505
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
506
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
507
+ x_idx = torch.argmin(x_indices, dim=2)
508
+ cand_start_idx = x_idx - 1
509
+ start_idx = torch.where(
510
+ torch.eq(x_idx, 0),
511
+ torch.tensor(1, device=x.device),
512
+ torch.where(
513
+ torch.eq(x_idx, K),
514
+ torch.tensor(K - 2, device=x.device),
515
+ cand_start_idx,
516
+ ),
517
+ )
518
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
519
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
520
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
521
+ start_idx2 = torch.where(
522
+ torch.eq(x_idx, 0),
523
+ torch.tensor(0, device=x.device),
524
+ torch.where(
525
+ torch.eq(x_idx, K),
526
+ torch.tensor(K - 2, device=x.device),
527
+ cand_start_idx,
528
+ ),
529
+ )
530
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
531
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
532
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
533
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
534
+ return cand
535
+
536
+
537
+ def expand_dims(v, dims):
538
+ """
539
+ Expand the tensor `v` to the dim `dims`.
540
+
541
+ Args:
542
+ `v`: a PyTorch tensor with shape [N].
543
+ `dim`: a `int`.
544
+ Returns:
545
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
546
+ """
547
+ return v[(...,) + (None,) * (dims - 1)]
ldm/modules/attention.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
53
+
54
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
55
+
56
+ def forward(self, x):
57
+ return self.net(x)
58
+
59
+
60
+ def zero_module(module):
61
+ """
62
+ Zero out the parameters of a module and return it.
63
+ """
64
+ for p in module.parameters():
65
+ p.detach().zero_()
66
+ return module
67
+
68
+
69
+ def Normalize(in_channels):
70
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
71
+
72
+
73
+ class LinearAttention(nn.Module):
74
+ def __init__(self, dim, heads=4, dim_head=32):
75
+ super().__init__()
76
+ self.heads = heads
77
+ hidden_dim = dim_head * heads
78
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
79
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
80
+
81
+ def forward(self, x):
82
+ b, c, h, w = x.shape
83
+ qkv = self.to_qkv(x)
84
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
85
+ k = k.softmax(dim=-1)
86
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
87
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
88
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
89
+ return self.to_out(out)
90
+
91
+
92
+ class SpatialSelfAttention(nn.Module):
93
+ def __init__(self, in_channels):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+
97
+ self.norm = Normalize(in_channels)
98
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
99
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
100
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
101
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
102
+
103
+ def forward(self, x):
104
+ h_ = x
105
+ h_ = self.norm(h_)
106
+ q = self.q(h_)
107
+ k = self.k(h_)
108
+ v = self.v(h_)
109
+
110
+ # compute attention
111
+ b, c, h, w = q.shape
112
+ q = rearrange(q, "b c h w -> b (h w) c")
113
+ k = rearrange(k, "b c h w -> b c (h w)")
114
+ w_ = torch.einsum("bij,bjk->bik", q, k)
115
+
116
+ w_ = w_ * (int(c) ** (-0.5))
117
+ w_ = torch.nn.functional.softmax(w_, dim=2)
118
+
119
+ # attend to values
120
+ v = rearrange(v, "b c h w -> b c (h w)")
121
+ w_ = rearrange(w_, "b i j -> b j i")
122
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
123
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
124
+ h_ = self.proj_out(h_)
125
+
126
+ return x + h_
127
+
128
+
129
+ class CrossAttention(nn.Module):
130
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
131
+ super().__init__()
132
+ inner_dim = dim_head * heads
133
+ context_dim = default(context_dim, query_dim)
134
+
135
+ self.scale = dim_head ** -0.5
136
+ self.heads = heads
137
+
138
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
139
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
140
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
141
+
142
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
143
+
144
+ def forward(self, x, context=None, mask=None):
145
+ h = self.heads
146
+
147
+ q = self.to_q(x)
148
+ context = default(context, x)
149
+ k = self.to_k(context)
150
+ v = self.to_v(context)
151
+
152
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
153
+
154
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
155
+
156
+ if exists(mask):
157
+ mask = rearrange(mask, "b ... -> b (...)")
158
+ max_neg_value = -torch.finfo(sim.dtype).max
159
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
160
+ sim.masked_fill_(~mask, max_neg_value)
161
+
162
+ # attention, what we cannot get enough of
163
+ attn = sim.softmax(dim=-1)
164
+
165
+ out = einsum("b i j, b j d -> b i d", attn, v)
166
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
167
+ return self.to_out(out)
168
+
169
+
170
+ class BasicTransformerBlock(nn.Module):
171
+ def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=False):
172
+ super().__init__()
173
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
174
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
175
+ self.attn2 = CrossAttention(
176
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
177
+ ) # is self-attn if context is none
178
+ self.norm1 = nn.LayerNorm(dim)
179
+ self.norm2 = nn.LayerNorm(dim)
180
+ self.norm3 = nn.LayerNorm(dim)
181
+ self.checkpoint = checkpoint
182
+
183
+ def forward(self, x, context=None):
184
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
185
+
186
+ def _forward(self, x, context=None):
187
+ x = self.attn1(self.norm1(x)) + x
188
+ x = self.attn2(self.norm2(x), context=context) + x
189
+ x = self.ff(self.norm3(x)) + x
190
+ return x
191
+
192
+
193
+ class SpatialTransformer(nn.Module):
194
+ """
195
+ Transformer block for image-like data.
196
+ First, project the input (aka embedding)
197
+ and reshape to b, t, d.
198
+ Then apply standard transformer action.
199
+ Finally, reshape to image
200
+ """
201
+
202
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
203
+ super().__init__()
204
+ self.in_channels = in_channels
205
+ inner_dim = n_heads * d_head
206
+ self.norm = Normalize(in_channels)
207
+
208
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
209
+
210
+ self.transformer_blocks = nn.ModuleList(
211
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth)]
212
+ )
213
+
214
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
215
+
216
+ def forward(self, x, context=None):
217
+ # note: if no context is given, cross-attention defaults to self-attention
218
+ b, c, h, w = x.shape
219
+ x_in = x
220
+ x = self.norm(x)
221
+ x = self.proj_in(x)
222
+ x = rearrange(x, "b c h w -> b (h w) c")
223
+ for block in self.transformer_blocks:
224
+ x = block(x, context=context)
225
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
226
+ x = self.proj_out(x)
227
+ return x + x_in
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from ldm.util import instantiate_from_config
9
+ from ldm.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+