Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +185 -0
- README.md +177 -3
- compute_fid.py +158 -0
- configs/afhqv2.yml +49 -0
- configs/cifar10.yml +49 -0
- configs/cifar10_order2.yml +49 -0
- configs/ffhq.yml +49 -0
- configs/latent-diffusion/celebahq-ldm-vq-4.yaml +86 -0
- configs/latent-diffusion/cin-ldm-vq-f8.yaml +98 -0
- configs/latent-diffusion/cin256-v2.yaml +68 -0
- configs/latent-diffusion/ffhq-ldm-vq-4.yaml +85 -0
- configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml +85 -0
- configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml +91 -0
- configs/latent-diffusion/txt2img-1p4B-eval.yaml +71 -0
- configs/latent_diff_LSUN.yml +55 -0
- configs/latent_diff_imn.yml +55 -0
- configs/stable-diffusion/v1-inference.yaml +70 -0
- configs/stable_diff_v1-4.yml +55 -0
- configs/stable_diff_v1-5.yml +55 -0
- data/coco_captions.txt +0 -0
- data/prompts.txt +5 -0
- dataset.py +46 -0
- dnnlib/__init__.py +8 -0
- dnnlib/util.py +491 -0
- gen_data.py +188 -0
- ldm/__init__.py +0 -0
- ldm/data/__init__.py +0 -0
- ldm/data/base.py +23 -0
- ldm/data/imagenet.py +394 -0
- ldm/data/lsun.py +92 -0
- ldm/lr_scheduler.py +98 -0
- ldm/models/autoencoder.py +442 -0
- ldm/models/diffusion/__init__.py +0 -0
- ldm/models/diffusion/classifier.py +267 -0
- ldm/models/diffusion/ddim.py +241 -0
- ldm/models/diffusion/ddpm.py +1445 -0
- ldm/models/diffusion/dpm_solver/__init__.py +1 -0
- ldm/models/diffusion/dpm_solver/dpm_solver.py +780 -0
- ldm/models/diffusion/dpm_solver/sampler.py +95 -0
- ldm/models/diffusion/dpm_solver_v3/__init__.py +1 -0
- ldm/models/diffusion/dpm_solver_v3/dpm_solver_v3.py +824 -0
- ldm/models/diffusion/dpm_solver_v3/sampler.py +95 -0
- ldm/models/diffusion/plms.py +236 -0
- ldm/models/diffusion/uni_pc/__init__.py +1 -0
- ldm/models/diffusion/uni_pc/sampler.py +83 -0
- ldm/models/diffusion/uni_pc/uni_pc.py +547 -0
- ldm/modules/attention.py +227 -0
- ldm/modules/diffusionmodules/__init__.py +0 -0
- ldm/modules/diffusionmodules/model.py +835 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
src/taming-transformers/scripts/reconstruction_usage.ipynb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data file
|
| 2 |
+
train_data*
|
| 3 |
+
test_gen_data
|
| 4 |
+
sampling_data
|
| 5 |
+
val_coco_images
|
| 6 |
+
temp
|
| 7 |
+
tmp
|
| 8 |
+
pretrained
|
| 9 |
+
fid-refs
|
| 10 |
+
captions_val2014.json
|
| 11 |
+
*json
|
| 12 |
+
*gz
|
| 13 |
+
!src/clip/clip/bpe_simple_vocab_16e6.txt.gz
|
| 14 |
+
*zip
|
| 15 |
+
logs/
|
| 16 |
+
*pt
|
| 17 |
+
*pkl
|
| 18 |
+
*png
|
| 19 |
+
*err
|
| 20 |
+
*out
|
| 21 |
+
test_scripts/*
|
| 22 |
+
all_logs
|
| 23 |
+
all_fids
|
| 24 |
+
# Byte-compiled / optimized / DLL files
|
| 25 |
+
__pycache__/
|
| 26 |
+
*.py[cod]
|
| 27 |
+
*$py.class
|
| 28 |
+
|
| 29 |
+
# C extensions
|
| 30 |
+
*.so
|
| 31 |
+
|
| 32 |
+
# Distribution / packaging
|
| 33 |
+
.Python
|
| 34 |
+
build/
|
| 35 |
+
develop-eggs/
|
| 36 |
+
dist/
|
| 37 |
+
downloads/
|
| 38 |
+
eggs/
|
| 39 |
+
.eggs/
|
| 40 |
+
lib/
|
| 41 |
+
lib64/
|
| 42 |
+
parts/
|
| 43 |
+
sdist/
|
| 44 |
+
var/
|
| 45 |
+
wheels/
|
| 46 |
+
share/python-wheels/
|
| 47 |
+
*.egg-info/
|
| 48 |
+
.installed.cfg
|
| 49 |
+
*.egg
|
| 50 |
+
MANIFEST
|
| 51 |
+
|
| 52 |
+
# PyInstaller
|
| 53 |
+
# Usually these files are written by a python script from a template
|
| 54 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 55 |
+
*.manifest
|
| 56 |
+
*.spec
|
| 57 |
+
|
| 58 |
+
# Installer logs
|
| 59 |
+
pip-log.txt
|
| 60 |
+
pip-delete-this-directory.txt
|
| 61 |
+
|
| 62 |
+
# Unit test / coverage reports
|
| 63 |
+
htmlcov/
|
| 64 |
+
.tox/
|
| 65 |
+
.nox/
|
| 66 |
+
.coverage
|
| 67 |
+
.coverage.*
|
| 68 |
+
.cache
|
| 69 |
+
nosetests.xml
|
| 70 |
+
coverage.xml
|
| 71 |
+
*.cover
|
| 72 |
+
*.py,cover
|
| 73 |
+
.hypothesis/
|
| 74 |
+
.pytest_cache/
|
| 75 |
+
cover/
|
| 76 |
+
|
| 77 |
+
# Translations
|
| 78 |
+
*.mo
|
| 79 |
+
*.pot
|
| 80 |
+
|
| 81 |
+
# Django stuff:
|
| 82 |
+
*.log
|
| 83 |
+
local_settings.py
|
| 84 |
+
db.sqlite3
|
| 85 |
+
db.sqlite3-journal
|
| 86 |
+
|
| 87 |
+
# Flask stuff:
|
| 88 |
+
instance/
|
| 89 |
+
.webassets-cache
|
| 90 |
+
|
| 91 |
+
# Scrapy stuff:
|
| 92 |
+
.scrapy
|
| 93 |
+
|
| 94 |
+
# Sphinx documentation
|
| 95 |
+
docs/_build/
|
| 96 |
+
|
| 97 |
+
# PyBuilder
|
| 98 |
+
.pybuilder/
|
| 99 |
+
target/
|
| 100 |
+
|
| 101 |
+
# Jupyter Notebook
|
| 102 |
+
.ipynb_checkpoints
|
| 103 |
+
|
| 104 |
+
# IPython
|
| 105 |
+
profile_default/
|
| 106 |
+
ipython_config.py
|
| 107 |
+
|
| 108 |
+
# pyenv
|
| 109 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 110 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 111 |
+
# .python-version
|
| 112 |
+
|
| 113 |
+
# pipenv
|
| 114 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 115 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 116 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 117 |
+
# install all needed dependencies.
|
| 118 |
+
#Pipfile.lock
|
| 119 |
+
|
| 120 |
+
# poetry
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 122 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 123 |
+
# commonly ignored for libraries.
|
| 124 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 125 |
+
#poetry.lock
|
| 126 |
+
|
| 127 |
+
# pdm
|
| 128 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 129 |
+
#pdm.lock
|
| 130 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 131 |
+
# in version control.
|
| 132 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 133 |
+
.pdm.toml
|
| 134 |
+
.pdm-python
|
| 135 |
+
.pdm-build/
|
| 136 |
+
|
| 137 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 138 |
+
__pypackages__/
|
| 139 |
+
|
| 140 |
+
# Celery stuff
|
| 141 |
+
celerybeat-schedule
|
| 142 |
+
celerybeat.pid
|
| 143 |
+
|
| 144 |
+
# SageMath parsed files
|
| 145 |
+
*.sage.py
|
| 146 |
+
|
| 147 |
+
# Environments
|
| 148 |
+
.env
|
| 149 |
+
.venv
|
| 150 |
+
env/
|
| 151 |
+
venv/
|
| 152 |
+
ENV/
|
| 153 |
+
env.bak/
|
| 154 |
+
venv.bak/
|
| 155 |
+
|
| 156 |
+
# Spyder project settings
|
| 157 |
+
.spyderproject
|
| 158 |
+
.spyproject
|
| 159 |
+
|
| 160 |
+
# Rope project settings
|
| 161 |
+
.ropeproject
|
| 162 |
+
|
| 163 |
+
# mkdocs documentation
|
| 164 |
+
/site
|
| 165 |
+
|
| 166 |
+
# mypy
|
| 167 |
+
.mypy_cache/
|
| 168 |
+
.dmypy.json
|
| 169 |
+
dmypy.json
|
| 170 |
+
|
| 171 |
+
# Pyre type checker
|
| 172 |
+
.pyre/
|
| 173 |
+
|
| 174 |
+
# pytype static type analyzer
|
| 175 |
+
.pytype/
|
| 176 |
+
|
| 177 |
+
# Cython debug symbols
|
| 178 |
+
cython_debug/
|
| 179 |
+
|
| 180 |
+
# PyCharm
|
| 181 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 182 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 183 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 184 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 185 |
+
#.idea/
|
README.md
CHANGED
|
@@ -1,3 +1,177 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Learning to Discretize Denoising Diffusion ODEs
|
| 2 |
+
|
| 3 |
+
🏆 
|
| 4 |
+
|
| 5 |
+
### [Paper on OpenReview](https://openreview.net/forum?id=xDrFWUmCne)
|
| 6 |
+
|
| 7 |
+
Implementation of LD3, a lightweight framework designed to learn the optimal time discretization for sampling from pre-trained Diffusion Probabilistic Models (DPMs). LD3 can be combined with various samplers and consistently improves generation quality without having to retrain resource-intensive neural networks. LD3 offers an efficient approach to sampling from pre-trained diffusion models.
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
## 🔥 Latest News
|
| 12 |
+
|
| 13 |
+
- **March 2025**: We have successfully applied **LD3** to the **Flux-dev** model and observed promising results.
|
| 14 |
+
- We are releasing the trained time steps for the Flux model soon! Stay tuned for updates.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Setup Environment
|
| 18 |
+
|
| 19 |
+
We will set up the environment using [Anaconda](https://docs.anaconda.com/anaconda/install/index.html).
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
conda env create -f requirements.yml
|
| 23 |
+
conda activate ld3
|
| 24 |
+
pip install -e ./src/clip/
|
| 25 |
+
pip install -e ./src/taming-transformers/
|
| 26 |
+
pip install omegaconf
|
| 27 |
+
pip install PyYAML
|
| 28 |
+
pip install requests
|
| 29 |
+
pip install scipy
|
| 30 |
+
pip install torchmetrics
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Download Pretrained Models and FID Reference Sets
|
| 34 |
+
|
| 35 |
+
All necessary data will be automatically downloaded by the script. Note that this process may take some time. If you wish to skip certain downloads, you can comment out the corresponding lines in the script.
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
bash scripts/download_model.sh
|
| 39 |
+
wget https://raw.githubusercontent.com/tylin/coco-caption/master/annotations/captions_val2014.json
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## 🚀 Generating Training Data for LD3
|
| 44 |
+
|
| 45 |
+
Before training **LD3**, we first need to generate training data using the teacher solver. The script `gen_data.py` handles this process. Below is an example of generating training data with **20 sampling steps** for **CIFAR-10**, using the `uni_pc` solver and `time-edm` discretization.
|
| 46 |
+
|
| 47 |
+
### 📌 Example: Generating CIFAR-10 Training Data
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
|
| 51 |
+
--all_config configs/cifar10.yml \
|
| 52 |
+
--total_samples 100 \
|
| 53 |
+
--sampling_batch_size 10 \
|
| 54 |
+
--steps 20 \
|
| 55 |
+
--solver_name uni_pc \
|
| 56 |
+
--skip_type edm \
|
| 57 |
+
--save_pt --save_png --data_dir train_data/train_data_cifar10 \
|
| 58 |
+
--low_gpu
|
| 59 |
+
```
|
| 60 |
+
#### 📌 Key Arguments:
|
| 61 |
+
|
| 62 |
+
- `all_config`: Path to the default configuration file (mandatory). If other arguments are not specified, their values will be taken from this file.
|
| 63 |
+
- `solver_name`: Solver to use. Options include `uni_pc`, `dpm_solver++`, `euler`, and `ipndm`.
|
| 64 |
+
- `skip_type`: Discretization method. Options include `edm`, `time_uniform`, and `time_quadratic`.
|
| 65 |
+
- `low_gpu`: Enables the use of PyTorch's `checkpoint` feature to reduce GPU memory usage.
|
| 66 |
+
- `data_dir`: Root directory for saving the generated data. The script will create a subdirectory within this path using the naming format `${solver_name}_NFE${steps}_${skip_type}`.
|
| 67 |
+
|
| 68 |
+
### 📌 Example: Generating Stable Diffusion Training Data
|
| 69 |
+
|
| 70 |
+
For Stable Diffusion, you must additionally specify the prompt file and the number of prompts. Below is an example:
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
|
| 74 |
+
--all_config configs/stable_diff_v1-4.yml \
|
| 75 |
+
--total_samples 100 \
|
| 76 |
+
--sampling_batch_size 2 \
|
| 77 |
+
--steps 6 \
|
| 78 |
+
--solver_name uni_pc \
|
| 79 |
+
--skip_type time_uniform \
|
| 80 |
+
--save_pt --save_png --data_dir train_data/train_data_stable_diff_v1-4 \
|
| 81 |
+
--low_gpu \
|
| 82 |
+
--num_prompts 5 --prompt_path captions_val2014.json
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Training LD3
|
| 86 |
+
After generating the training data, you can train **LD3** using the `main.py` script. Below is an example of training **LD3** on **CIFAR-10** with the following configurations:
|
| 87 |
+
- **Teacher**: 20 sampling steps, `uni_pc` solver, and `time-edm` discretization.
|
| 88 |
+
- **Student**: 10 sampling steps, `dpm_solver++` solver.
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
| 92 |
+
--all_config configs/cifar10.yml \
|
| 93 |
+
--data_dir train_data/train_data_cifar10/uni_pc_NFE20_edm \
|
| 94 |
+
--num_train 50 --num_valid 50 \
|
| 95 |
+
--main_train_batch_size 1 \
|
| 96 |
+
--main_valid_batch_size 10 \
|
| 97 |
+
--solver_name dpm_solver++ \
|
| 98 |
+
--training_rounds_v1 2 \
|
| 99 |
+
--training_rounds_v2 5 \
|
| 100 |
+
--steps 10 \
|
| 101 |
+
--log_path logs/logs_cifar10
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
**Trained timesteps are available [here](https://docs.google.com/spreadsheets/d/1nUrTDvvtpPHZuRuJcn3zzxGmVKrNX4fFHu8wYIIoGSM/edit?usp=sharing) and are still being updated.**
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
#### 📌 Key Arguments:
|
| 109 |
+
|
| 110 |
+
- `data_dir`: The full path to the training data directory (unlike the root directory used during data generation).
|
| 111 |
+
- `log_path`: The root directory for saving logs and models. The script will create a subdirectory within this path using the naming format: `${solver_name}-N${steps}-b${bound}-${loss_type}-lr2${lr2}rv1${rv1}-rv2${rv2}`, for example, `uni_pc-N10-b0.03072-LPIPS-lr20.01rv12-rv25`
|
| 112 |
+
|
| 113 |
+
## FID Evaluation
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
### ⚠️ Different FID Scores
|
| 117 |
+
|
| 118 |
+
It is important to note that FID (Fréchet Inception Distance) scores can vary significantly depending on the processing pipeline used. To ensure transparency and reproducibility, our framework provides a script `compute_fid.py` that supports FID evaluation for both EDM and Latent-Diffusion.
|
| 119 |
+
|
| 120 |
+
### 📌 How FID Evaluation Works
|
| 121 |
+
|
| 122 |
+
The `compute_fid.py` script is a streamlined version of `gen_data.py` with a few differences:
|
| 123 |
+
|
| 124 |
+
The `--save_dir`, `--save_pt`, and `--save_png` arguments are ignored because the generated data is directly processed for FID calculation without being saved.
|
| 125 |
+
|
| 126 |
+
The data is automatically forwarded to the FID computation module to extract features.
|
| 127 |
+
|
| 128 |
+
Optionally, you can pass your own timesteps via `--custom_ts_1` and `--custom_ts_2`. If `custom_ts_2` is not specified, it will be set the same as `custom_ts_1`
|
| 129 |
+
### 📌 Example: Computing FID for Stable Diffusion
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
|
| 133 |
+
--all_config configs/stable_diff_v1-5.yml \
|
| 134 |
+
--total_samples 100 \
|
| 135 |
+
--sampling_batch_size 2 \
|
| 136 |
+
--steps 6 \
|
| 137 |
+
--solver_name uni_pc \
|
| 138 |
+
--skip_type time_uniform \
|
| 139 |
+
--low_gpu \
|
| 140 |
+
--num_prompts 5 --prompt_path captions_val2014.json
|
| 141 |
+
|
| 142 |
+
CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
|
| 143 |
+
--all_config configs/stable_diff_v1-5.yml \
|
| 144 |
+
--total_samples 100 \
|
| 145 |
+
--sampling_batch_size 2 \
|
| 146 |
+
--steps 4 \
|
| 147 |
+
--solver_name ipndm \
|
| 148 |
+
--skip_type custom \
|
| 149 |
+
--custom_ts_1 [1.0000e+00,7.6668e-01,4.8113e-01,1.8417e-01,1.0000e-03] \
|
| 150 |
+
--custom_ts_2 [1.0000e+00,7.6706e-01,4.8103e-01,1.8396e-01,1.0000e-03] \
|
| 151 |
+
--low_gpu \
|
| 152 |
+
--num_prompts 5 --prompt_path captions_val2014.json
|
| 153 |
+
|
| 154 |
+
```
|
| 155 |
+
## Citation
|
| 156 |
+
|
| 157 |
+
```
|
| 158 |
+
@inproceedings{tong2024learning,
|
| 159 |
+
title = {Learning to Discretize Denoising Diffusion ODEs},
|
| 160 |
+
author = {Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Van den Broeck, Guy and Niepert, Mathias},
|
| 161 |
+
booktitle = {Proceedings of the 13th International Conference on Learning Representations},
|
| 162 |
+
year = {2025}
|
| 163 |
+
}
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
```
|
| 167 |
+
@article{tong2024learning,
|
| 168 |
+
title={Learning to Discretize Denoising Diffusion ODEs},
|
| 169 |
+
author={Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Broeck, Guy Van den and Niepert, Mathias},
|
| 170 |
+
journal={arXiv preprint arXiv:2405.15506},
|
| 171 |
+
year={2024}
|
| 172 |
+
}
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## License
|
| 176 |
+
MIT
|
| 177 |
+
|
compute_fid.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import pickle
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from utils import (
|
| 11 |
+
parse_arguments,
|
| 12 |
+
check_fid_file,
|
| 13 |
+
prepare_paths,
|
| 14 |
+
adjust_hyper,
|
| 15 |
+
get_solvers,
|
| 16 |
+
set_seed_everything,
|
| 17 |
+
)
|
| 18 |
+
from models import prepare_stuff, prepare_condition_loader
|
| 19 |
+
import math
|
| 20 |
+
import dnnlib
|
| 21 |
+
import pickle
|
| 22 |
+
import scipy
|
| 23 |
+
|
| 24 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
| 25 |
+
from pytorch_fid.inception import InceptionV3
|
| 26 |
+
|
| 27 |
+
from gen_data import Generator, get_data_inverse_scaler
|
| 28 |
+
|
| 29 |
+
def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
|
| 30 |
+
m = np.square(mu - mu_ref).sum()
|
| 31 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
|
| 32 |
+
fid = m + np.trace(sigma + sigma_ref - s * 2)
|
| 33 |
+
return float(np.real(fid))
|
| 34 |
+
|
| 35 |
+
def main(args):
|
| 36 |
+
|
| 37 |
+
if not args.use_ema:
|
| 38 |
+
print("Auto update use_ema to True for evaluation")
|
| 39 |
+
args.use_ema = True
|
| 40 |
+
|
| 41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
|
| 43 |
+
print("Start sampling...")
|
| 44 |
+
|
| 45 |
+
# laten-diff evaluation
|
| 46 |
+
FEATURE_DIM = 2048
|
| 47 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM]
|
| 48 |
+
fid_model = InceptionV3([block_idx]).to(device)
|
| 49 |
+
fid_model.eval()
|
| 50 |
+
|
| 51 |
+
# edm evalutaion
|
| 52 |
+
DETECTOR_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl"
|
| 53 |
+
with dnnlib.util.open_url(DETECTOR_URL, verbose=True) as f:
|
| 54 |
+
detector_net = pickle.load(f).to(device)
|
| 55 |
+
|
| 56 |
+
with dnnlib.util.open_url(args.ref_path) as f:
|
| 57 |
+
ref = dict(np.load(f))
|
| 58 |
+
|
| 59 |
+
wrapped_model, model, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
|
| 60 |
+
condition_loader = prepare_condition_loader(model_type=args.model,
|
| 61 |
+
model=model,
|
| 62 |
+
scale=args.scale if hasattr(args, "scale") else None,
|
| 63 |
+
condition=args.prompt_path or "uniform",
|
| 64 |
+
sampling_batch_size=args.sampling_batch_size,
|
| 65 |
+
num_prompt=None,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
adjust_hyper(args, latent_resolution, latent_channel)
|
| 69 |
+
_, _, skip_type = prepare_paths(args)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 73 |
+
solver, steps, solver_extra_params = get_solvers(
|
| 74 |
+
args.solver_name,
|
| 75 |
+
NFEs=args.steps,
|
| 76 |
+
order=args.order,
|
| 77 |
+
noise_schedule=noise_schedule,
|
| 78 |
+
unipc_variant=args.unipc_variant,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
generator = Generator(
|
| 82 |
+
noise_schedule=noise_schedule,
|
| 83 |
+
solver=solver,
|
| 84 |
+
order=args.order,
|
| 85 |
+
skip_type=skip_type,
|
| 86 |
+
load_from=args.load_from,
|
| 87 |
+
timesteps_1=args.custom_ts_1,
|
| 88 |
+
timesteps_2=args.custom_ts_2,
|
| 89 |
+
steps=steps,
|
| 90 |
+
solver_extra_params=solver_extra_params,
|
| 91 |
+
device=device,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
print(generator.timesteps, generator.timesteps2)
|
| 95 |
+
inverse_scalar = get_data_inverse_scaler(centered=True)
|
| 96 |
+
|
| 97 |
+
num_batches = math.ceil(args.total_samples / args.sampling_batch_size)
|
| 98 |
+
batch_size = args.sampling_batch_size
|
| 99 |
+
n_total_samples = batch_size * num_batches
|
| 100 |
+
|
| 101 |
+
mu = torch.zeros([FEATURE_DIM], dtype=torch.float64, device=device)
|
| 102 |
+
sigma = torch.zeros([FEATURE_DIM, FEATURE_DIM], dtype=torch.float64, device=device)
|
| 103 |
+
act_arr = np.empty((n_total_samples, FEATURE_DIM))
|
| 104 |
+
start_idx=0
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
for index in tqdm(range(num_batches)):
|
| 107 |
+
current_batch_size = min(batch_size, args.total_samples - index * batch_size)
|
| 108 |
+
sampling_shape = (current_batch_size, latent_channel, latent_resolution, latent_resolution)
|
| 109 |
+
latents = torch.randn(sampling_shape, device=device)
|
| 110 |
+
|
| 111 |
+
if condition_loader is not None:
|
| 112 |
+
conditioning, conditioned_unconditioning = next(condition_loader)
|
| 113 |
+
else:
|
| 114 |
+
conditioning = None
|
| 115 |
+
conditioned_unconditioning = None
|
| 116 |
+
|
| 117 |
+
img_teacher = generator.sample(wrapped_model, decoding_fn, latents, conditioning, conditioned_unconditioning)
|
| 118 |
+
img_teacher = inverse_scalar(img_teacher)
|
| 119 |
+
samples_edm = 255 * img_teacher
|
| 120 |
+
images = torch.clip(samples_edm, 0, 255).to(torch.uint8)
|
| 121 |
+
features = detector_net(images.to(device), return_features=True).to(
|
| 122 |
+
torch.float64
|
| 123 |
+
)
|
| 124 |
+
mu += features.sum(0)
|
| 125 |
+
sigma += features.T @ features
|
| 126 |
+
|
| 127 |
+
samples_latent_diff = torch.clamp(img_teacher, min=0.0, max=1.0)
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
pred = fid_model(samples_latent_diff.float())[0]
|
| 131 |
+
|
| 132 |
+
# If model output is not scalar, apply global spatial average pooling.
|
| 133 |
+
# This happens if you choose a dimensionality not equal 2048.
|
| 134 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
| 135 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
| 136 |
+
|
| 137 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
| 138 |
+
act_arr[start_idx:start_idx + pred.shape[0]] = pred
|
| 139 |
+
start_idx = start_idx + pred.shape[0]
|
| 140 |
+
|
| 141 |
+
mu /= n_total_samples
|
| 142 |
+
sigma -= mu.ger(mu) * n_total_samples
|
| 143 |
+
sigma /= n_total_samples - 1
|
| 144 |
+
mu = mu.cpu().numpy()
|
| 145 |
+
sigma = sigma.cpu().numpy()
|
| 146 |
+
fid_edm = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
|
| 147 |
+
|
| 148 |
+
mu = np.mean(act_arr, axis=0)
|
| 149 |
+
sigma = np.cov(act_arr, rowvar=False)
|
| 150 |
+
fid_latent_diff = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
|
| 151 |
+
|
| 152 |
+
print("FID EDM: {:.4f}".format(fid_edm))
|
| 153 |
+
print("FID LD: {:.4f}".format(fid_latent_diff))
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
args = parse_arguments()
|
| 157 |
+
set_seed_everything(args.seed)
|
| 158 |
+
main(args)
|
configs/afhqv2.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: edm
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/edm-afhqv2-64x64-uncond-vp.pkl
|
| 4 |
+
solver_name: uni_pc
|
| 5 |
+
unipc_variant: bh1
|
| 6 |
+
steps: 10
|
| 7 |
+
order: 3
|
| 8 |
+
time_mode: lambda
|
| 9 |
+
|
| 10 |
+
# training_parameters:
|
| 11 |
+
seed: 0
|
| 12 |
+
log_path: all_logs/afhqv2_logs/
|
| 13 |
+
data_dir: train_data/train_data_afhqv2
|
| 14 |
+
win_rate: 0.5
|
| 15 |
+
prior_bound: 1.0
|
| 16 |
+
fix_bound: false
|
| 17 |
+
loss_type: LPIPS
|
| 18 |
+
main_train_batch_size: 2
|
| 19 |
+
main_valid_batch_size: 150
|
| 20 |
+
num_train: 50
|
| 21 |
+
num_valid: 50
|
| 22 |
+
training_rounds_v1: 2
|
| 23 |
+
training_rounds_v2: 5
|
| 24 |
+
shift_lr: null
|
| 25 |
+
lr_time_1: 0.005
|
| 26 |
+
lr_time_2: 0.1
|
| 27 |
+
min_lr_time_1: 0.00005
|
| 28 |
+
min_lr_time_2: 0.000001
|
| 29 |
+
momentum_time_1: 0.9
|
| 30 |
+
weight_decay_time_1: 0.0
|
| 31 |
+
shift_lr_decay: 0.5
|
| 32 |
+
lr_time_decay: 0.8
|
| 33 |
+
patient: 5
|
| 34 |
+
lr2_patient: 5
|
| 35 |
+
no_v1: false
|
| 36 |
+
visualize: true
|
| 37 |
+
|
| 38 |
+
# testing_parameters:
|
| 39 |
+
learn: false
|
| 40 |
+
load_from: null
|
| 41 |
+
skip_type: null
|
| 42 |
+
num_multi_steps_fid: null
|
| 43 |
+
fid_folder: all_fids/afhqv2_fids/
|
| 44 |
+
sampling_batch_size: 150
|
| 45 |
+
sampling_seed: 0
|
| 46 |
+
ref_path: fid-refs/afhqv2-64x64.npz
|
| 47 |
+
total_samples: 50000
|
| 48 |
+
save_png: false
|
| 49 |
+
save_pt: true
|
configs/cifar10.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: edm
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/edm-cifar10-32x32-uncond-vp.pkl
|
| 4 |
+
solver_name: uni_pc
|
| 5 |
+
unipc_variant: bh1
|
| 6 |
+
steps: 10
|
| 7 |
+
order: 3
|
| 8 |
+
time_mode: lambda
|
| 9 |
+
|
| 10 |
+
# training_parameters:
|
| 11 |
+
seed: 0
|
| 12 |
+
log_path: all_logs/cifar10_logs/
|
| 13 |
+
data_dir: train_data/train_data_cifar
|
| 14 |
+
win_rate: 0.5
|
| 15 |
+
prior_bound: 1.0
|
| 16 |
+
fix_bound: false
|
| 17 |
+
loss_type: LPIPS
|
| 18 |
+
main_train_batch_size: 2
|
| 19 |
+
main_valid_batch_size: 150
|
| 20 |
+
num_train: 50
|
| 21 |
+
num_valid: 50
|
| 22 |
+
training_rounds_v1: 2
|
| 23 |
+
training_rounds_v2: 5
|
| 24 |
+
shift_lr: null
|
| 25 |
+
lr_time_1: 0.005
|
| 26 |
+
lr_time_2: 0.1
|
| 27 |
+
min_lr_time_1: 0.00005
|
| 28 |
+
min_lr_time_2: 0.000001
|
| 29 |
+
momentum_time_1: 0.9
|
| 30 |
+
weight_decay_time_1: 0.0
|
| 31 |
+
shift_lr_decay: 0.5
|
| 32 |
+
lr_time_decay: 0.8
|
| 33 |
+
patient: 5
|
| 34 |
+
lr2_patient: 5
|
| 35 |
+
no_v1: false
|
| 36 |
+
visualize: true
|
| 37 |
+
|
| 38 |
+
# testing_parameters:
|
| 39 |
+
learn: false
|
| 40 |
+
load_from: null
|
| 41 |
+
skip_type: null
|
| 42 |
+
num_multi_steps_fid: null
|
| 43 |
+
fid_folder: all_fids/cifar10_fids/
|
| 44 |
+
sampling_batch_size: 150
|
| 45 |
+
sampling_seed: 0
|
| 46 |
+
ref_path: fid-refs/cifar10-32x32.npz
|
| 47 |
+
total_samples: 50000
|
| 48 |
+
save_png: false
|
| 49 |
+
save_pt: true
|
configs/cifar10_order2.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: edm
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/edm-cifar10-32x32-uncond-vp.pkl
|
| 4 |
+
solver_name: uni_pc
|
| 5 |
+
unipc_variant: bh1
|
| 6 |
+
steps: 10
|
| 7 |
+
order: 2
|
| 8 |
+
time_mode: lambda
|
| 9 |
+
|
| 10 |
+
# training_parameters:
|
| 11 |
+
seed: 0
|
| 12 |
+
log_path: all_logs/cifar10_logs/
|
| 13 |
+
data_dir: train_data/train_data_cifar
|
| 14 |
+
win_rate: 0.5
|
| 15 |
+
prior_bound: 1.0
|
| 16 |
+
fix_bound: false
|
| 17 |
+
loss_type: LPIPS
|
| 18 |
+
main_train_batch_size: 2
|
| 19 |
+
main_valid_batch_size: 150
|
| 20 |
+
num_train: 50
|
| 21 |
+
num_valid: 50
|
| 22 |
+
training_rounds_v1: 2
|
| 23 |
+
training_rounds_v2: 5
|
| 24 |
+
shift_lr: null
|
| 25 |
+
lr_time_1: 0.005
|
| 26 |
+
lr_time_2: 0.1
|
| 27 |
+
min_lr_time_1: 0.00005
|
| 28 |
+
min_lr_time_2: 0.000001
|
| 29 |
+
momentum_time_1: 0.9
|
| 30 |
+
weight_decay_time_1: 0.0
|
| 31 |
+
shift_lr_decay: 0.5
|
| 32 |
+
lr_time_decay: 0.8
|
| 33 |
+
patient: 5
|
| 34 |
+
lr2_patient: 5
|
| 35 |
+
no_v1: false
|
| 36 |
+
visualize: true
|
| 37 |
+
|
| 38 |
+
# testing_parameters:
|
| 39 |
+
learn: false
|
| 40 |
+
load_from: null
|
| 41 |
+
skip_type: null
|
| 42 |
+
num_multi_steps_fid: null
|
| 43 |
+
fid_folder: all_fids/cifar10_fids/
|
| 44 |
+
sampling_batch_size: 150
|
| 45 |
+
sampling_seed: 0
|
| 46 |
+
ref_path: fid-refs/cifar10-32x32.npz
|
| 47 |
+
total_samples: 50000
|
| 48 |
+
save_png: false
|
| 49 |
+
save_pt: true
|
configs/ffhq.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: edm
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/edm-ffhq-64x64-uncond-vp.pkl
|
| 4 |
+
solver_name: uni_pc
|
| 5 |
+
unipc_variant: bh1
|
| 6 |
+
steps: 10
|
| 7 |
+
order: 3
|
| 8 |
+
time_mode: lambda
|
| 9 |
+
|
| 10 |
+
# training_parameters:
|
| 11 |
+
seed: 0
|
| 12 |
+
log_path: all_logs/ffhq_logs/
|
| 13 |
+
data_dir: train_data/train_data_ffhq
|
| 14 |
+
win_rate: 0.5
|
| 15 |
+
prior_bound: 1.0
|
| 16 |
+
fix_bound: false
|
| 17 |
+
loss_type: LPIPS
|
| 18 |
+
main_train_batch_size: 2
|
| 19 |
+
main_valid_batch_size: 150
|
| 20 |
+
num_train: 50
|
| 21 |
+
num_valid: 50
|
| 22 |
+
training_rounds_v1: 2
|
| 23 |
+
training_rounds_v2: 5
|
| 24 |
+
shift_lr: null
|
| 25 |
+
lr_time_1: 0.005
|
| 26 |
+
lr_time_2: 0.1
|
| 27 |
+
min_lr_time_1: 0.00005
|
| 28 |
+
min_lr_time_2: 0.000001
|
| 29 |
+
momentum_time_1: 0.9
|
| 30 |
+
weight_decay_time_1: 0.0
|
| 31 |
+
shift_lr_decay: 0.5
|
| 32 |
+
lr_time_decay: 0.8
|
| 33 |
+
patient: 5
|
| 34 |
+
lr2_patient: 5
|
| 35 |
+
no_v1: false
|
| 36 |
+
visualize: true
|
| 37 |
+
|
| 38 |
+
# testing_parameters:
|
| 39 |
+
learn: false
|
| 40 |
+
load_from: null
|
| 41 |
+
skip_type: null
|
| 42 |
+
num_multi_steps_fid: null
|
| 43 |
+
fid_folder: all_fids/ffhq_fids/
|
| 44 |
+
sampling_batch_size: 150
|
| 45 |
+
sampling_seed: 0
|
| 46 |
+
ref_path: fid-refs/ffhq-64x64.npz
|
| 47 |
+
total_samples: 50000
|
| 48 |
+
save_png: false
|
| 49 |
+
save_pt: true
|
configs/latent-diffusion/celebahq-ldm-vq-4.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 2.0e-06
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.0015
|
| 6 |
+
linear_end: 0.0195
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
image_size: 64
|
| 12 |
+
channels: 3
|
| 13 |
+
monitor: val/loss_simple_ema
|
| 14 |
+
|
| 15 |
+
unet_config:
|
| 16 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 17 |
+
params:
|
| 18 |
+
image_size: 64
|
| 19 |
+
in_channels: 3
|
| 20 |
+
out_channels: 3
|
| 21 |
+
model_channels: 224
|
| 22 |
+
attention_resolutions:
|
| 23 |
+
# note: this isn\t actually the resolution but
|
| 24 |
+
# the downsampling factor, i.e. this corresnponds to
|
| 25 |
+
# attention on spatial resolution 8,16,32, as the
|
| 26 |
+
# spatial reolution of the latents is 64 for f4
|
| 27 |
+
- 8
|
| 28 |
+
- 4
|
| 29 |
+
- 2
|
| 30 |
+
num_res_blocks: 2
|
| 31 |
+
channel_mult:
|
| 32 |
+
- 1
|
| 33 |
+
- 2
|
| 34 |
+
- 3
|
| 35 |
+
- 4
|
| 36 |
+
num_head_channels: 32
|
| 37 |
+
first_stage_config:
|
| 38 |
+
target: ldm.models.autoencoder.VQModelInterface
|
| 39 |
+
params:
|
| 40 |
+
embed_dim: 3
|
| 41 |
+
n_embed: 8192
|
| 42 |
+
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
|
| 43 |
+
ddconfig:
|
| 44 |
+
double_z: false
|
| 45 |
+
z_channels: 3
|
| 46 |
+
resolution: 256
|
| 47 |
+
in_channels: 3
|
| 48 |
+
out_ch: 3
|
| 49 |
+
ch: 128
|
| 50 |
+
ch_mult:
|
| 51 |
+
- 1
|
| 52 |
+
- 2
|
| 53 |
+
- 4
|
| 54 |
+
num_res_blocks: 2
|
| 55 |
+
attn_resolutions: []
|
| 56 |
+
dropout: 0.0
|
| 57 |
+
lossconfig:
|
| 58 |
+
target: torch.nn.Identity
|
| 59 |
+
cond_stage_config: __is_unconditional__
|
| 60 |
+
data:
|
| 61 |
+
target: main.DataModuleFromConfig
|
| 62 |
+
params:
|
| 63 |
+
batch_size: 48
|
| 64 |
+
num_workers: 5
|
| 65 |
+
wrap: false
|
| 66 |
+
train:
|
| 67 |
+
target: taming.data.faceshq.CelebAHQTrain
|
| 68 |
+
params:
|
| 69 |
+
size: 256
|
| 70 |
+
validation:
|
| 71 |
+
target: taming.data.faceshq.CelebAHQValidation
|
| 72 |
+
params:
|
| 73 |
+
size: 256
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
lightning:
|
| 77 |
+
callbacks:
|
| 78 |
+
image_logger:
|
| 79 |
+
target: main.ImageLogger
|
| 80 |
+
params:
|
| 81 |
+
batch_frequency: 5000
|
| 82 |
+
max_images: 8
|
| 83 |
+
increase_log_steps: False
|
| 84 |
+
|
| 85 |
+
trainer:
|
| 86 |
+
benchmark: True
|
configs/latent-diffusion/cin-ldm-vq-f8.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-06
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.0015
|
| 6 |
+
linear_end: 0.0195
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
cond_stage_key: class_label
|
| 12 |
+
image_size: 32
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: true
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
unet_config:
|
| 18 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 19 |
+
params:
|
| 20 |
+
image_size: 32
|
| 21 |
+
in_channels: 4
|
| 22 |
+
out_channels: 4
|
| 23 |
+
model_channels: 256
|
| 24 |
+
attention_resolutions:
|
| 25 |
+
#note: this isn\t actually the resolution but
|
| 26 |
+
# the downsampling factor, i.e. this corresnponds to
|
| 27 |
+
# attention on spatial resolution 8,16,32, as the
|
| 28 |
+
# spatial reolution of the latents is 32 for f8
|
| 29 |
+
- 4
|
| 30 |
+
- 2
|
| 31 |
+
- 1
|
| 32 |
+
num_res_blocks: 2
|
| 33 |
+
channel_mult:
|
| 34 |
+
- 1
|
| 35 |
+
- 2
|
| 36 |
+
- 4
|
| 37 |
+
num_head_channels: 32
|
| 38 |
+
use_spatial_transformer: true
|
| 39 |
+
transformer_depth: 1
|
| 40 |
+
context_dim: 512
|
| 41 |
+
first_stage_config:
|
| 42 |
+
target: ldm.models.autoencoder.VQModelInterface
|
| 43 |
+
params:
|
| 44 |
+
embed_dim: 4
|
| 45 |
+
n_embed: 16384
|
| 46 |
+
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
|
| 47 |
+
ddconfig:
|
| 48 |
+
double_z: false
|
| 49 |
+
z_channels: 4
|
| 50 |
+
resolution: 256
|
| 51 |
+
in_channels: 3
|
| 52 |
+
out_ch: 3
|
| 53 |
+
ch: 128
|
| 54 |
+
ch_mult:
|
| 55 |
+
- 1
|
| 56 |
+
- 2
|
| 57 |
+
- 2
|
| 58 |
+
- 4
|
| 59 |
+
num_res_blocks: 2
|
| 60 |
+
attn_resolutions:
|
| 61 |
+
- 32
|
| 62 |
+
dropout: 0.0
|
| 63 |
+
lossconfig:
|
| 64 |
+
target: torch.nn.Identity
|
| 65 |
+
cond_stage_config:
|
| 66 |
+
target: ldm.modules.encoders.modules.ClassEmbedder
|
| 67 |
+
params:
|
| 68 |
+
embed_dim: 512
|
| 69 |
+
key: class_label
|
| 70 |
+
data:
|
| 71 |
+
target: main.DataModuleFromConfig
|
| 72 |
+
params:
|
| 73 |
+
batch_size: 64
|
| 74 |
+
num_workers: 12
|
| 75 |
+
wrap: false
|
| 76 |
+
train:
|
| 77 |
+
target: ldm.data.imagenet.ImageNetTrain
|
| 78 |
+
params:
|
| 79 |
+
config:
|
| 80 |
+
size: 256
|
| 81 |
+
validation:
|
| 82 |
+
target: ldm.data.imagenet.ImageNetValidation
|
| 83 |
+
params:
|
| 84 |
+
config:
|
| 85 |
+
size: 256
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
lightning:
|
| 89 |
+
callbacks:
|
| 90 |
+
image_logger:
|
| 91 |
+
target: main.ImageLogger
|
| 92 |
+
params:
|
| 93 |
+
batch_frequency: 5000
|
| 94 |
+
max_images: 8
|
| 95 |
+
increase_log_steps: False
|
| 96 |
+
|
| 97 |
+
trainer:
|
| 98 |
+
benchmark: True
|
configs/latent-diffusion/cin256-v2.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 0.0001
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.0015
|
| 6 |
+
linear_end: 0.0195
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
cond_stage_key: class_label
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 3
|
| 14 |
+
cond_stage_trainable: true
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss
|
| 17 |
+
use_ema: False
|
| 18 |
+
|
| 19 |
+
unet_config:
|
| 20 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 21 |
+
params:
|
| 22 |
+
image_size: 64
|
| 23 |
+
in_channels: 3
|
| 24 |
+
out_channels: 3
|
| 25 |
+
model_channels: 192
|
| 26 |
+
attention_resolutions:
|
| 27 |
+
- 8
|
| 28 |
+
- 4
|
| 29 |
+
- 2
|
| 30 |
+
num_res_blocks: 2
|
| 31 |
+
channel_mult:
|
| 32 |
+
- 1
|
| 33 |
+
- 2
|
| 34 |
+
- 3
|
| 35 |
+
- 5
|
| 36 |
+
num_heads: 1
|
| 37 |
+
use_spatial_transformer: true
|
| 38 |
+
transformer_depth: 1
|
| 39 |
+
context_dim: 512
|
| 40 |
+
|
| 41 |
+
first_stage_config:
|
| 42 |
+
target: ldm.models.autoencoder.VQModelInterface
|
| 43 |
+
params:
|
| 44 |
+
embed_dim: 3
|
| 45 |
+
n_embed: 8192
|
| 46 |
+
ddconfig:
|
| 47 |
+
double_z: false
|
| 48 |
+
z_channels: 3
|
| 49 |
+
resolution: 256
|
| 50 |
+
in_channels: 3
|
| 51 |
+
out_ch: 3
|
| 52 |
+
ch: 128
|
| 53 |
+
ch_mult:
|
| 54 |
+
- 1
|
| 55 |
+
- 2
|
| 56 |
+
- 4
|
| 57 |
+
num_res_blocks: 2
|
| 58 |
+
attn_resolutions: []
|
| 59 |
+
dropout: 0.0
|
| 60 |
+
lossconfig:
|
| 61 |
+
target: torch.nn.Identity
|
| 62 |
+
|
| 63 |
+
cond_stage_config:
|
| 64 |
+
target: ldm.modules.encoders.modules.ClassEmbedder
|
| 65 |
+
params:
|
| 66 |
+
n_classes: 1001
|
| 67 |
+
embed_dim: 512
|
| 68 |
+
key: class_label
|
configs/latent-diffusion/ffhq-ldm-vq-4.yaml
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 2.0e-06
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.0015
|
| 6 |
+
linear_end: 0.0195
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
image_size: 64
|
| 12 |
+
channels: 3
|
| 13 |
+
monitor: val/loss_simple_ema
|
| 14 |
+
unet_config:
|
| 15 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 16 |
+
params:
|
| 17 |
+
image_size: 64
|
| 18 |
+
in_channels: 3
|
| 19 |
+
out_channels: 3
|
| 20 |
+
model_channels: 224
|
| 21 |
+
attention_resolutions:
|
| 22 |
+
# note: this isn\t actually the resolution but
|
| 23 |
+
# the downsampling factor, i.e. this corresnponds to
|
| 24 |
+
# attention on spatial resolution 8,16,32, as the
|
| 25 |
+
# spatial reolution of the latents is 64 for f4
|
| 26 |
+
- 8
|
| 27 |
+
- 4
|
| 28 |
+
- 2
|
| 29 |
+
num_res_blocks: 2
|
| 30 |
+
channel_mult:
|
| 31 |
+
- 1
|
| 32 |
+
- 2
|
| 33 |
+
- 3
|
| 34 |
+
- 4
|
| 35 |
+
num_head_channels: 32
|
| 36 |
+
first_stage_config:
|
| 37 |
+
target: ldm.models.autoencoder.VQModelInterface
|
| 38 |
+
params:
|
| 39 |
+
embed_dim: 3
|
| 40 |
+
n_embed: 8192
|
| 41 |
+
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
| 42 |
+
ddconfig:
|
| 43 |
+
double_z: false
|
| 44 |
+
z_channels: 3
|
| 45 |
+
resolution: 256
|
| 46 |
+
in_channels: 3
|
| 47 |
+
out_ch: 3
|
| 48 |
+
ch: 128
|
| 49 |
+
ch_mult:
|
| 50 |
+
- 1
|
| 51 |
+
- 2
|
| 52 |
+
- 4
|
| 53 |
+
num_res_blocks: 2
|
| 54 |
+
attn_resolutions: []
|
| 55 |
+
dropout: 0.0
|
| 56 |
+
lossconfig:
|
| 57 |
+
target: torch.nn.Identity
|
| 58 |
+
cond_stage_config: __is_unconditional__
|
| 59 |
+
data:
|
| 60 |
+
target: main.DataModuleFromConfig
|
| 61 |
+
params:
|
| 62 |
+
batch_size: 42
|
| 63 |
+
num_workers: 5
|
| 64 |
+
wrap: false
|
| 65 |
+
train:
|
| 66 |
+
target: taming.data.faceshq.FFHQTrain
|
| 67 |
+
params:
|
| 68 |
+
size: 256
|
| 69 |
+
validation:
|
| 70 |
+
target: taming.data.faceshq.FFHQValidation
|
| 71 |
+
params:
|
| 72 |
+
size: 256
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
lightning:
|
| 76 |
+
callbacks:
|
| 77 |
+
image_logger:
|
| 78 |
+
target: main.ImageLogger
|
| 79 |
+
params:
|
| 80 |
+
batch_frequency: 5000
|
| 81 |
+
max_images: 8
|
| 82 |
+
increase_log_steps: False
|
| 83 |
+
|
| 84 |
+
trainer:
|
| 85 |
+
benchmark: True
|
configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 2.0e-06
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.0015
|
| 6 |
+
linear_end: 0.0195
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
image_size: 64
|
| 12 |
+
channels: 3
|
| 13 |
+
monitor: val/loss_simple_ema
|
| 14 |
+
unet_config:
|
| 15 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 16 |
+
params:
|
| 17 |
+
image_size: 64
|
| 18 |
+
in_channels: 3
|
| 19 |
+
out_channels: 3
|
| 20 |
+
model_channels: 224
|
| 21 |
+
attention_resolutions:
|
| 22 |
+
# note: this isn\t actually the resolution but
|
| 23 |
+
# the downsampling factor, i.e. this corresnponds to
|
| 24 |
+
# attention on spatial resolution 8,16,32, as the
|
| 25 |
+
# spatial reolution of the latents is 64 for f4
|
| 26 |
+
- 8
|
| 27 |
+
- 4
|
| 28 |
+
- 2
|
| 29 |
+
num_res_blocks: 2
|
| 30 |
+
channel_mult:
|
| 31 |
+
- 1
|
| 32 |
+
- 2
|
| 33 |
+
- 3
|
| 34 |
+
- 4
|
| 35 |
+
num_head_channels: 32
|
| 36 |
+
first_stage_config:
|
| 37 |
+
target: ldm.models.autoencoder.VQModelInterface
|
| 38 |
+
params:
|
| 39 |
+
ckpt_path: pretrained/first_stage_models/vq-f4/model.ckpt
|
| 40 |
+
embed_dim: 3
|
| 41 |
+
n_embed: 8192
|
| 42 |
+
ddconfig:
|
| 43 |
+
double_z: false
|
| 44 |
+
z_channels: 3
|
| 45 |
+
resolution: 256
|
| 46 |
+
in_channels: 3
|
| 47 |
+
out_ch: 3
|
| 48 |
+
ch: 128
|
| 49 |
+
ch_mult:
|
| 50 |
+
- 1
|
| 51 |
+
- 2
|
| 52 |
+
- 4
|
| 53 |
+
num_res_blocks: 2
|
| 54 |
+
attn_resolutions: []
|
| 55 |
+
dropout: 0.0
|
| 56 |
+
lossconfig:
|
| 57 |
+
target: torch.nn.Identity
|
| 58 |
+
cond_stage_config: __is_unconditional__
|
| 59 |
+
data:
|
| 60 |
+
target: main.DataModuleFromConfig
|
| 61 |
+
params:
|
| 62 |
+
batch_size: 48
|
| 63 |
+
num_workers: 5
|
| 64 |
+
wrap: false
|
| 65 |
+
train:
|
| 66 |
+
target: ldm.data.lsun.LSUNBedroomsTrain
|
| 67 |
+
params:
|
| 68 |
+
size: 256
|
| 69 |
+
validation:
|
| 70 |
+
target: ldm.data.lsun.LSUNBedroomsValidation
|
| 71 |
+
params:
|
| 72 |
+
size: 256
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
lightning:
|
| 76 |
+
callbacks:
|
| 77 |
+
image_logger:
|
| 78 |
+
target: main.ImageLogger
|
| 79 |
+
params:
|
| 80 |
+
batch_frequency: 5000
|
| 81 |
+
max_images: 8
|
| 82 |
+
increase_log_steps: False
|
| 83 |
+
|
| 84 |
+
trainer:
|
| 85 |
+
benchmark: True
|
configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.0015
|
| 6 |
+
linear_end: 0.0155
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
loss_type: l1
|
| 11 |
+
first_stage_key: "image"
|
| 12 |
+
cond_stage_key: "image"
|
| 13 |
+
image_size: 32
|
| 14 |
+
channels: 4
|
| 15 |
+
cond_stage_trainable: False
|
| 16 |
+
concat_mode: False
|
| 17 |
+
scale_by_std: True
|
| 18 |
+
monitor: 'val/loss_simple_ema'
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [10000]
|
| 24 |
+
cycle_lengths: [10000000000000]
|
| 25 |
+
f_start: [1.e-6]
|
| 26 |
+
f_max: [1.]
|
| 27 |
+
f_min: [ 1.]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 192
|
| 36 |
+
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_scale_shift_norm: True
|
| 41 |
+
resblock_updown: True
|
| 42 |
+
|
| 43 |
+
first_stage_config:
|
| 44 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 45 |
+
params:
|
| 46 |
+
embed_dim: 4
|
| 47 |
+
monitor: "val/rec_loss"
|
| 48 |
+
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
| 49 |
+
ddconfig:
|
| 50 |
+
double_z: True
|
| 51 |
+
z_channels: 4
|
| 52 |
+
resolution: 256
|
| 53 |
+
in_channels: 3
|
| 54 |
+
out_ch: 3
|
| 55 |
+
ch: 128
|
| 56 |
+
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
| 57 |
+
num_res_blocks: 2
|
| 58 |
+
attn_resolutions: [ ]
|
| 59 |
+
dropout: 0.0
|
| 60 |
+
lossconfig:
|
| 61 |
+
target: torch.nn.Identity
|
| 62 |
+
|
| 63 |
+
cond_stage_config: "__is_unconditional__"
|
| 64 |
+
|
| 65 |
+
data:
|
| 66 |
+
target: main.DataModuleFromConfig
|
| 67 |
+
params:
|
| 68 |
+
batch_size: 96
|
| 69 |
+
num_workers: 5
|
| 70 |
+
wrap: False
|
| 71 |
+
train:
|
| 72 |
+
target: ldm.data.lsun.LSUNChurchesTrain
|
| 73 |
+
params:
|
| 74 |
+
size: 256
|
| 75 |
+
validation:
|
| 76 |
+
target: ldm.data.lsun.LSUNChurchesValidation
|
| 77 |
+
params:
|
| 78 |
+
size: 256
|
| 79 |
+
|
| 80 |
+
lightning:
|
| 81 |
+
callbacks:
|
| 82 |
+
image_logger:
|
| 83 |
+
target: main.ImageLogger
|
| 84 |
+
params:
|
| 85 |
+
batch_frequency: 5000
|
| 86 |
+
max_images: 8
|
| 87 |
+
increase_log_steps: False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
trainer:
|
| 91 |
+
benchmark: True
|
configs/latent-diffusion/txt2img-1p4B-eval.yaml
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 5.0e-05
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.012
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: image
|
| 11 |
+
cond_stage_key: caption
|
| 12 |
+
image_size: 32
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: true
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
unet_config:
|
| 21 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 22 |
+
params:
|
| 23 |
+
image_size: 32
|
| 24 |
+
in_channels: 4
|
| 25 |
+
out_channels: 4
|
| 26 |
+
model_channels: 320
|
| 27 |
+
attention_resolutions:
|
| 28 |
+
- 4
|
| 29 |
+
- 2
|
| 30 |
+
- 1
|
| 31 |
+
num_res_blocks: 2
|
| 32 |
+
channel_mult:
|
| 33 |
+
- 1
|
| 34 |
+
- 2
|
| 35 |
+
- 4
|
| 36 |
+
- 4
|
| 37 |
+
num_heads: 8
|
| 38 |
+
use_spatial_transformer: true
|
| 39 |
+
transformer_depth: 1
|
| 40 |
+
context_dim: 1280
|
| 41 |
+
use_checkpoint: true
|
| 42 |
+
legacy: False
|
| 43 |
+
|
| 44 |
+
first_stage_config:
|
| 45 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 46 |
+
params:
|
| 47 |
+
embed_dim: 4
|
| 48 |
+
monitor: val/rec_loss
|
| 49 |
+
ddconfig:
|
| 50 |
+
double_z: true
|
| 51 |
+
z_channels: 4
|
| 52 |
+
resolution: 256
|
| 53 |
+
in_channels: 3
|
| 54 |
+
out_ch: 3
|
| 55 |
+
ch: 128
|
| 56 |
+
ch_mult:
|
| 57 |
+
- 1
|
| 58 |
+
- 2
|
| 59 |
+
- 4
|
| 60 |
+
- 4
|
| 61 |
+
num_res_blocks: 2
|
| 62 |
+
attn_resolutions: []
|
| 63 |
+
dropout: 0.0
|
| 64 |
+
lossconfig:
|
| 65 |
+
target: torch.nn.Identity
|
| 66 |
+
|
| 67 |
+
cond_stage_config:
|
| 68 |
+
target: ldm.modules.encoders.modules.BERTEmbedder
|
| 69 |
+
params:
|
| 70 |
+
n_embed: 1280
|
| 71 |
+
n_layer: 32
|
configs/latent_diff_LSUN.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: latent_diff
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/ldm/lsun_beds256/model.ckpt
|
| 4 |
+
config: configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
|
| 5 |
+
solver_name: uni_pc
|
| 6 |
+
unipc_variant: bh2
|
| 7 |
+
steps: 10
|
| 8 |
+
order: 3
|
| 9 |
+
H: 256
|
| 10 |
+
W: 256
|
| 11 |
+
C: 3
|
| 12 |
+
f: 4
|
| 13 |
+
scale: 0.0
|
| 14 |
+
time_mode: time
|
| 15 |
+
|
| 16 |
+
# training_parameters:
|
| 17 |
+
seed: 0
|
| 18 |
+
log_path: all_logs/latent_diff_LSUN_logs/
|
| 19 |
+
data_dir: train_data/train_data_LSUN
|
| 20 |
+
win_rate: 0.5
|
| 21 |
+
prior_bound: 1.0
|
| 22 |
+
fix_bound: false
|
| 23 |
+
loss_type: LPIPS
|
| 24 |
+
main_train_batch_size: 1
|
| 25 |
+
main_valid_batch_size: 50
|
| 26 |
+
num_train: 50
|
| 27 |
+
num_valid: 50
|
| 28 |
+
training_rounds_v1: 2
|
| 29 |
+
training_rounds_v2: 3
|
| 30 |
+
shift_lr: null
|
| 31 |
+
lr_time_1: 0.005
|
| 32 |
+
lr_time_2: 0.001
|
| 33 |
+
min_lr_time_1: 0.00005
|
| 34 |
+
min_lr_time_2: 0.000001
|
| 35 |
+
momentum_time_1: 0.9
|
| 36 |
+
weight_decay_time_1: 0.0
|
| 37 |
+
shift_lr_decay: 0.5
|
| 38 |
+
lr_time_decay: 0.8
|
| 39 |
+
patient: 5
|
| 40 |
+
lr2_patient: 5
|
| 41 |
+
no_v1: false
|
| 42 |
+
visualize: true
|
| 43 |
+
|
| 44 |
+
# testing_parameters:
|
| 45 |
+
learn: false
|
| 46 |
+
load_from: null
|
| 47 |
+
skip_type: null
|
| 48 |
+
num_multi_steps_fid: null
|
| 49 |
+
fid_folder: all_fids/latent_diff_LSUN_fids/
|
| 50 |
+
sampling_batch_size: 50
|
| 51 |
+
sampling_seed: 0
|
| 52 |
+
ref_path: fid-refs/VIRTUAL_lsun_bedroom256.npz
|
| 53 |
+
total_samples: 50000
|
| 54 |
+
save_png: false
|
| 55 |
+
save_pt: true
|
configs/latent_diff_imn.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: conditioned_latent_diff
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/ldm/cin256-v2/model.ckpt
|
| 4 |
+
config: configs/latent-diffusion/cin256-v2.yaml
|
| 5 |
+
solver_name: uni_pc
|
| 6 |
+
unipc_variant: bh2
|
| 7 |
+
steps: 10
|
| 8 |
+
order: 3
|
| 9 |
+
H: 256
|
| 10 |
+
W: 256
|
| 11 |
+
C: 3
|
| 12 |
+
f: 4
|
| 13 |
+
scale: 2.0
|
| 14 |
+
time_mode: time
|
| 15 |
+
|
| 16 |
+
# training_parameters:
|
| 17 |
+
seed: 0
|
| 18 |
+
log_path: all_logs/latent_diff_imn_logs/
|
| 19 |
+
data_dir: train_data/train_data_imn
|
| 20 |
+
win_rate: 0.5
|
| 21 |
+
prior_bound: 1.0
|
| 22 |
+
fix_bound: false
|
| 23 |
+
loss_type: LPIPS
|
| 24 |
+
main_train_batch_size: 1
|
| 25 |
+
main_valid_batch_size: 50
|
| 26 |
+
num_train: 50
|
| 27 |
+
num_valid: 50
|
| 28 |
+
training_rounds_v1: 2
|
| 29 |
+
training_rounds_v2: 3
|
| 30 |
+
shift_lr: null
|
| 31 |
+
lr_time_1: 0.005
|
| 32 |
+
lr_time_2: 0.001
|
| 33 |
+
min_lr_time_1: 0.00005
|
| 34 |
+
min_lr_time_2: 0.000001
|
| 35 |
+
momentum_time_1: 0.9
|
| 36 |
+
weight_decay_time_1: 0.0
|
| 37 |
+
shift_lr_decay: 0.5
|
| 38 |
+
lr_time_decay: 0.8
|
| 39 |
+
patient: 5
|
| 40 |
+
lr2_patient: 5
|
| 41 |
+
no_v1: false
|
| 42 |
+
visualize: true
|
| 43 |
+
|
| 44 |
+
# testing_parameters:
|
| 45 |
+
learn: false
|
| 46 |
+
load_from: null
|
| 47 |
+
skip_type: null
|
| 48 |
+
num_multi_steps_fid: null
|
| 49 |
+
fid_folder: all_fids/latent_diff_imn_fids/
|
| 50 |
+
sampling_batch_size: 50
|
| 51 |
+
sampling_seed: 0
|
| 52 |
+
ref_path: fid-refs/VIRTUAL_imagenet256_labeled.npz
|
| 53 |
+
total_samples: 50000
|
| 54 |
+
save_png: false
|
| 55 |
+
save_pt: true
|
configs/stable-diffusion/v1-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 10000 ]
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: False
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
configs/stable_diff_v1-4.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: conditioned_latent_diff
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/ldm/stable-diffusion-v1/sd-v1-4.ckpt
|
| 4 |
+
config: configs/stable-diffusion/v1-inference.yaml
|
| 5 |
+
solver_name: uni_pc
|
| 6 |
+
unipc_variant: bh2
|
| 7 |
+
steps: 10
|
| 8 |
+
order: 2
|
| 9 |
+
H: 512
|
| 10 |
+
W: 512
|
| 11 |
+
C: 4
|
| 12 |
+
f: 8
|
| 13 |
+
scale: 7.5
|
| 14 |
+
time_mode: time
|
| 15 |
+
|
| 16 |
+
# training_parameters:
|
| 17 |
+
seed: 0
|
| 18 |
+
log_path: all_logs/stable_diff_v1-4_logs/
|
| 19 |
+
data_dir: train_data/train_data_stable_diff_v1-4
|
| 20 |
+
win_rate: 0.5
|
| 21 |
+
prior_bound: 1.0
|
| 22 |
+
fix_bound: false
|
| 23 |
+
loss_type: LPIPS
|
| 24 |
+
main_train_batch_size: 1
|
| 25 |
+
main_valid_batch_size: 1
|
| 26 |
+
num_train: 25
|
| 27 |
+
num_valid: 25
|
| 28 |
+
training_rounds_v1: 2
|
| 29 |
+
training_rounds_v2: 3
|
| 30 |
+
shift_lr: null
|
| 31 |
+
lr_time_1: 0.005
|
| 32 |
+
lr_time_2: 0.001
|
| 33 |
+
min_lr_time_1: 0.00005
|
| 34 |
+
min_lr_time_2: 0.000001
|
| 35 |
+
momentum_time_1: 0.9
|
| 36 |
+
weight_decay_time_1: 0.0
|
| 37 |
+
shift_lr_decay: 0.5
|
| 38 |
+
lr_time_decay: 0.8
|
| 39 |
+
patient: 5
|
| 40 |
+
lr2_patient: 5
|
| 41 |
+
no_v1: false
|
| 42 |
+
visualize: true
|
| 43 |
+
|
| 44 |
+
# testing_parameters:
|
| 45 |
+
learn: false
|
| 46 |
+
load_from: null
|
| 47 |
+
skip_type: null
|
| 48 |
+
num_multi_steps_fid: null
|
| 49 |
+
fid_folder: all_fids/stable_diff_v1-4_fids/
|
| 50 |
+
sampling_batch_size: 50
|
| 51 |
+
sampling_seed: 0
|
| 52 |
+
ref_path: null
|
| 53 |
+
total_samples: 50000
|
| 54 |
+
save_png: false
|
| 55 |
+
save_pt: true
|
configs/stable_diff_v1-5.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: conditioned_latent_diff
|
| 2 |
+
# model_parameters:
|
| 3 |
+
ckp_path: pretrained/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
| 4 |
+
config: configs/stable-diffusion/v1-inference.yaml
|
| 5 |
+
solver_name: uni_pc
|
| 6 |
+
unipc_variant: bh2
|
| 7 |
+
steps: 10
|
| 8 |
+
order: 2
|
| 9 |
+
H: 512
|
| 10 |
+
W: 512
|
| 11 |
+
C: 4
|
| 12 |
+
f: 8
|
| 13 |
+
scale: 7.5
|
| 14 |
+
time_mode: time
|
| 15 |
+
|
| 16 |
+
# training_parameters:
|
| 17 |
+
seed: 0
|
| 18 |
+
log_path: all_logs/stable_diff_v1-5_logs/
|
| 19 |
+
data_dir: train_data/train_data_stable_diff_v1-5
|
| 20 |
+
win_rate: 0.5
|
| 21 |
+
prior_bound: 1.0
|
| 22 |
+
fix_bound: false
|
| 23 |
+
loss_type: LPIPS
|
| 24 |
+
main_train_batch_size: 1
|
| 25 |
+
main_valid_batch_size: 1
|
| 26 |
+
num_train: 25
|
| 27 |
+
num_valid: 25
|
| 28 |
+
training_rounds_v1: 2
|
| 29 |
+
training_rounds_v2: 3
|
| 30 |
+
shift_lr: null
|
| 31 |
+
lr_time_1: 0.005
|
| 32 |
+
lr_time_2: 0.001
|
| 33 |
+
min_lr_time_1: 0.00005
|
| 34 |
+
min_lr_time_2: 0.000001
|
| 35 |
+
momentum_time_1: 0.9
|
| 36 |
+
weight_decay_time_1: 0.0
|
| 37 |
+
shift_lr_decay: 0.5
|
| 38 |
+
lr_time_decay: 0.8
|
| 39 |
+
patient: 5
|
| 40 |
+
lr2_patient: 5
|
| 41 |
+
no_v1: false
|
| 42 |
+
visualize: true
|
| 43 |
+
|
| 44 |
+
# testing_parameters:
|
| 45 |
+
learn: false
|
| 46 |
+
load_from: null
|
| 47 |
+
skip_type: null
|
| 48 |
+
num_multi_steps_fid: null
|
| 49 |
+
fid_folder: all_fids/stable_diff_v1-5_fids/
|
| 50 |
+
sampling_batch_size: 50
|
| 51 |
+
sampling_seed: 0
|
| 52 |
+
ref_path: null
|
| 53 |
+
total_samples: 50000
|
| 54 |
+
save_png: false
|
| 55 |
+
save_pt: true
|
data/coco_captions.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/prompts.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Two individuals learning to ski along with an instructor.
|
| 2 |
+
A man sitting on a chair that is on a deck over the water.
|
| 3 |
+
A dog sitting at a table in front of a plate.
|
| 4 |
+
Four people sit around eating food outside together.
|
| 5 |
+
A cat dips its paws into a cup on a nightstand.
|
dataset.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_data_from_dir(
|
| 8 |
+
data_folder: str, limit: int = 200
|
| 9 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[Optional[torch.Tensor]]]:
|
| 10 |
+
latents, targets, conditions, unconditions = [], [], [], []
|
| 11 |
+
pt_files = [f for f in os.listdir(data_folder) if f.endswith('pt')]
|
| 12 |
+
for file_name in sorted(pt_files)[:limit]:
|
| 13 |
+
file_path = os.path.join(data_folder, file_name)
|
| 14 |
+
data = torch.load(file_path)
|
| 15 |
+
latents.append(data["latent"])
|
| 16 |
+
targets.append(data["img"])
|
| 17 |
+
conditions.append(data.get("c", None))
|
| 18 |
+
unconditions.append(data.get("uc", None))
|
| 19 |
+
return latents, targets, conditions, unconditions
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LD3Dataset(Dataset):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
ori_latent: List[torch.Tensor],
|
| 26 |
+
latent: List[torch.Tensor],
|
| 27 |
+
target: List[torch.Tensor],
|
| 28 |
+
condition: List[Optional[torch.Tensor]],
|
| 29 |
+
uncondition: List[Optional[torch.Tensor]],
|
| 30 |
+
):
|
| 31 |
+
self.ori_latent = ori_latent
|
| 32 |
+
self.latent = latent
|
| 33 |
+
self.target = target
|
| 34 |
+
self.condition = condition
|
| 35 |
+
self.uncondition = uncondition
|
| 36 |
+
|
| 37 |
+
def __len__(self) -> int:
|
| 38 |
+
return len(self.ori_latent)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx: int):
|
| 41 |
+
img = self.target[idx]
|
| 42 |
+
latent = self.latent[idx]
|
| 43 |
+
ori_latent = self.ori_latent[idx]
|
| 44 |
+
condition = self.condition[idx]
|
| 45 |
+
uncondition = self.uncondition[idx]
|
| 46 |
+
return img, latent, ori_latent, condition, uncondition
|
dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/util.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Miscellaneous utility classes and functions."""
|
| 9 |
+
|
| 10 |
+
import ctypes
|
| 11 |
+
import fnmatch
|
| 12 |
+
import importlib
|
| 13 |
+
import inspect
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import sys
|
| 18 |
+
import types
|
| 19 |
+
import io
|
| 20 |
+
import pickle
|
| 21 |
+
import re
|
| 22 |
+
import requests
|
| 23 |
+
import html
|
| 24 |
+
import hashlib
|
| 25 |
+
import glob
|
| 26 |
+
import tempfile
|
| 27 |
+
import urllib
|
| 28 |
+
import urllib.request
|
| 29 |
+
import uuid
|
| 30 |
+
|
| 31 |
+
from distutils.util import strtobool
|
| 32 |
+
from typing import Any, List, Tuple, Union, Optional
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Util classes
|
| 36 |
+
# ------------------------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class EasyDict(dict):
|
| 40 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 41 |
+
|
| 42 |
+
def __getattr__(self, name: str) -> Any:
|
| 43 |
+
try:
|
| 44 |
+
return self[name]
|
| 45 |
+
except KeyError:
|
| 46 |
+
raise AttributeError(name)
|
| 47 |
+
|
| 48 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 49 |
+
self[name] = value
|
| 50 |
+
|
| 51 |
+
def __delattr__(self, name: str) -> None:
|
| 52 |
+
del self[name]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Logger(object):
|
| 56 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
|
| 59 |
+
self.file = None
|
| 60 |
+
|
| 61 |
+
if file_name is not None:
|
| 62 |
+
self.file = open(file_name, file_mode)
|
| 63 |
+
|
| 64 |
+
self.should_flush = should_flush
|
| 65 |
+
self.stdout = sys.stdout
|
| 66 |
+
self.stderr = sys.stderr
|
| 67 |
+
|
| 68 |
+
sys.stdout = self
|
| 69 |
+
sys.stderr = self
|
| 70 |
+
|
| 71 |
+
def __enter__(self) -> "Logger":
|
| 72 |
+
return self
|
| 73 |
+
|
| 74 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 75 |
+
self.close()
|
| 76 |
+
|
| 77 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 78 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 79 |
+
if isinstance(text, bytes):
|
| 80 |
+
text = text.decode()
|
| 81 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
if self.file is not None:
|
| 85 |
+
self.file.write(text)
|
| 86 |
+
|
| 87 |
+
self.stdout.write(text)
|
| 88 |
+
|
| 89 |
+
if self.should_flush:
|
| 90 |
+
self.flush()
|
| 91 |
+
|
| 92 |
+
def flush(self) -> None:
|
| 93 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 94 |
+
if self.file is not None:
|
| 95 |
+
self.file.flush()
|
| 96 |
+
|
| 97 |
+
self.stdout.flush()
|
| 98 |
+
|
| 99 |
+
def close(self) -> None:
|
| 100 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 101 |
+
self.flush()
|
| 102 |
+
|
| 103 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 104 |
+
if sys.stdout is self:
|
| 105 |
+
sys.stdout = self.stdout
|
| 106 |
+
if sys.stderr is self:
|
| 107 |
+
sys.stderr = self.stderr
|
| 108 |
+
|
| 109 |
+
if self.file is not None:
|
| 110 |
+
self.file.close()
|
| 111 |
+
self.file = None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Cache directories
|
| 115 |
+
# ------------------------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
_dnnlib_cache_dir = None
|
| 118 |
+
|
| 119 |
+
def set_cache_dir(path: str) -> None:
|
| 120 |
+
global _dnnlib_cache_dir
|
| 121 |
+
_dnnlib_cache_dir = path
|
| 122 |
+
|
| 123 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 124 |
+
if _dnnlib_cache_dir is not None:
|
| 125 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 126 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 127 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 128 |
+
if 'HOME' in os.environ:
|
| 129 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 130 |
+
if 'USERPROFILE' in os.environ:
|
| 131 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 132 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 133 |
+
|
| 134 |
+
# Small util functions
|
| 135 |
+
# ------------------------------------------------------------------------------------------
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 139 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 140 |
+
s = int(np.rint(seconds))
|
| 141 |
+
|
| 142 |
+
if s < 60:
|
| 143 |
+
return "{0}s".format(s)
|
| 144 |
+
elif s < 60 * 60:
|
| 145 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 146 |
+
elif s < 24 * 60 * 60:
|
| 147 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 148 |
+
else:
|
| 149 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
| 153 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 154 |
+
s = int(np.rint(seconds))
|
| 155 |
+
|
| 156 |
+
if s < 60:
|
| 157 |
+
return "{0}s".format(s)
|
| 158 |
+
elif s < 60 * 60:
|
| 159 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 160 |
+
elif s < 24 * 60 * 60:
|
| 161 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
| 162 |
+
else:
|
| 163 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def ask_yes_no(question: str) -> bool:
|
| 167 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 168 |
+
while True:
|
| 169 |
+
try:
|
| 170 |
+
print("{0} [y/n]".format(question))
|
| 171 |
+
return strtobool(input().lower())
|
| 172 |
+
except ValueError:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def tuple_product(t: Tuple) -> Any:
|
| 177 |
+
"""Calculate the product of the tuple elements."""
|
| 178 |
+
result = 1
|
| 179 |
+
|
| 180 |
+
for v in t:
|
| 181 |
+
result *= v
|
| 182 |
+
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
_str_to_ctype = {
|
| 187 |
+
"uint8": ctypes.c_ubyte,
|
| 188 |
+
"uint16": ctypes.c_uint16,
|
| 189 |
+
"uint32": ctypes.c_uint32,
|
| 190 |
+
"uint64": ctypes.c_uint64,
|
| 191 |
+
"int8": ctypes.c_byte,
|
| 192 |
+
"int16": ctypes.c_int16,
|
| 193 |
+
"int32": ctypes.c_int32,
|
| 194 |
+
"int64": ctypes.c_int64,
|
| 195 |
+
"float32": ctypes.c_float,
|
| 196 |
+
"float64": ctypes.c_double
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 201 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 202 |
+
type_str = None
|
| 203 |
+
|
| 204 |
+
if isinstance(type_obj, str):
|
| 205 |
+
type_str = type_obj
|
| 206 |
+
elif hasattr(type_obj, "__name__"):
|
| 207 |
+
type_str = type_obj.__name__
|
| 208 |
+
elif hasattr(type_obj, "name"):
|
| 209 |
+
type_str = type_obj.name
|
| 210 |
+
else:
|
| 211 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 212 |
+
|
| 213 |
+
assert type_str in _str_to_ctype.keys()
|
| 214 |
+
|
| 215 |
+
my_dtype = np.dtype(type_str)
|
| 216 |
+
my_ctype = _str_to_ctype[type_str]
|
| 217 |
+
|
| 218 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 219 |
+
|
| 220 |
+
return my_dtype, my_ctype
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def is_pickleable(obj: Any) -> bool:
|
| 224 |
+
try:
|
| 225 |
+
with io.BytesIO() as stream:
|
| 226 |
+
pickle.dump(obj, stream)
|
| 227 |
+
return True
|
| 228 |
+
except:
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 233 |
+
# ------------------------------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 236 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 237 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 238 |
+
|
| 239 |
+
# allow convenience shorthands, substitute them by full names
|
| 240 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 241 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 242 |
+
|
| 243 |
+
# list alternatives for (module_name, local_obj_name)
|
| 244 |
+
parts = obj_name.split(".")
|
| 245 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 246 |
+
|
| 247 |
+
# try each alternative in turn
|
| 248 |
+
for module_name, local_obj_name in name_pairs:
|
| 249 |
+
try:
|
| 250 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 251 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 252 |
+
return module, local_obj_name
|
| 253 |
+
except:
|
| 254 |
+
pass
|
| 255 |
+
|
| 256 |
+
# maybe some of the modules themselves contain errors?
|
| 257 |
+
for module_name, _local_obj_name in name_pairs:
|
| 258 |
+
try:
|
| 259 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 260 |
+
except ImportError:
|
| 261 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 262 |
+
raise
|
| 263 |
+
|
| 264 |
+
# maybe the requested attribute is missing?
|
| 265 |
+
for module_name, local_obj_name in name_pairs:
|
| 266 |
+
try:
|
| 267 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 268 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 269 |
+
except ImportError:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# we are out of luck, but we have no idea why
|
| 273 |
+
raise ImportError(obj_name)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 277 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 278 |
+
if obj_name == '':
|
| 279 |
+
return module
|
| 280 |
+
obj = module
|
| 281 |
+
for part in obj_name.split("."):
|
| 282 |
+
obj = getattr(obj, part)
|
| 283 |
+
return obj
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def get_obj_by_name(name: str) -> Any:
|
| 287 |
+
"""Finds the python object with the given name."""
|
| 288 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 289 |
+
return get_obj_from_module(module, obj_name)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 293 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 294 |
+
assert func_name is not None
|
| 295 |
+
func_obj = get_obj_by_name(func_name)
|
| 296 |
+
assert callable(func_obj)
|
| 297 |
+
return func_obj(*args, **kwargs)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
| 301 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 302 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 306 |
+
"""Get the directory path of the module containing the given object name."""
|
| 307 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 308 |
+
return os.path.dirname(inspect.getfile(module))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 312 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 313 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 317 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 318 |
+
assert is_top_level_function(obj)
|
| 319 |
+
module = obj.__module__
|
| 320 |
+
if module == '__main__':
|
| 321 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
| 322 |
+
return module + "." + obj.__name__
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# File system helpers
|
| 326 |
+
# ------------------------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 329 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 330 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 331 |
+
assert os.path.isdir(dir_path)
|
| 332 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 333 |
+
|
| 334 |
+
if ignores is None:
|
| 335 |
+
ignores = []
|
| 336 |
+
|
| 337 |
+
result = []
|
| 338 |
+
|
| 339 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 340 |
+
for ignore_ in ignores:
|
| 341 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 342 |
+
|
| 343 |
+
# dirs need to be edited in-place
|
| 344 |
+
for d in dirs_to_remove:
|
| 345 |
+
dirs.remove(d)
|
| 346 |
+
|
| 347 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 348 |
+
|
| 349 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 350 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 351 |
+
|
| 352 |
+
if add_base_to_relative:
|
| 353 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 354 |
+
|
| 355 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 356 |
+
result += zip(absolute_paths, relative_paths)
|
| 357 |
+
|
| 358 |
+
return result
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 362 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 363 |
+
Will create all necessary directories."""
|
| 364 |
+
for file in files:
|
| 365 |
+
target_dir_name = os.path.dirname(file[1])
|
| 366 |
+
|
| 367 |
+
# will create all intermediate-level directories
|
| 368 |
+
if not os.path.exists(target_dir_name):
|
| 369 |
+
os.makedirs(target_dir_name)
|
| 370 |
+
|
| 371 |
+
shutil.copyfile(file[0], file[1])
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# URL helpers
|
| 375 |
+
# ------------------------------------------------------------------------------------------
|
| 376 |
+
|
| 377 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 378 |
+
"""Determine whether the given object is a valid URL string."""
|
| 379 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 380 |
+
return False
|
| 381 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 382 |
+
return True
|
| 383 |
+
try:
|
| 384 |
+
res = requests.compat.urlparse(obj)
|
| 385 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 386 |
+
return False
|
| 387 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 388 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 389 |
+
return False
|
| 390 |
+
except:
|
| 391 |
+
return False
|
| 392 |
+
return True
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
| 396 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 397 |
+
assert num_attempts >= 1
|
| 398 |
+
assert not (return_filename and (not cache))
|
| 399 |
+
|
| 400 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 401 |
+
if not re.match('^[a-z]+://', url):
|
| 402 |
+
return url if return_filename else open(url, "rb")
|
| 403 |
+
|
| 404 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 405 |
+
# arise on Windows:
|
| 406 |
+
#
|
| 407 |
+
# file:///c:/foo.txt
|
| 408 |
+
#
|
| 409 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 410 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 411 |
+
#
|
| 412 |
+
# If you touch this code path, you should test it on both Linux and
|
| 413 |
+
# Windows.
|
| 414 |
+
#
|
| 415 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 416 |
+
# but that converts forward slashes to backslashes and this causes
|
| 417 |
+
# its own set of problems.
|
| 418 |
+
if url.startswith('file://'):
|
| 419 |
+
filename = urllib.parse.urlparse(url).path
|
| 420 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 421 |
+
filename = filename[1:]
|
| 422 |
+
return filename if return_filename else open(filename, "rb")
|
| 423 |
+
|
| 424 |
+
assert is_url(url)
|
| 425 |
+
|
| 426 |
+
# Lookup from cache.
|
| 427 |
+
if cache_dir is None:
|
| 428 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 429 |
+
|
| 430 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 431 |
+
if cache:
|
| 432 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 433 |
+
if len(cache_files) == 1:
|
| 434 |
+
filename = cache_files[0]
|
| 435 |
+
return filename if return_filename else open(filename, "rb")
|
| 436 |
+
|
| 437 |
+
# Download.
|
| 438 |
+
url_name = None
|
| 439 |
+
url_data = None
|
| 440 |
+
with requests.Session() as session:
|
| 441 |
+
if verbose:
|
| 442 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 443 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 444 |
+
try:
|
| 445 |
+
with session.get(url) as res:
|
| 446 |
+
res.raise_for_status()
|
| 447 |
+
if len(res.content) == 0:
|
| 448 |
+
raise IOError("No data received")
|
| 449 |
+
|
| 450 |
+
if len(res.content) < 8192:
|
| 451 |
+
content_str = res.content.decode("utf-8")
|
| 452 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 453 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 454 |
+
if len(links) == 1:
|
| 455 |
+
url = requests.compat.urljoin(url, links[0])
|
| 456 |
+
raise IOError("Google Drive virus checker nag")
|
| 457 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 458 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 459 |
+
|
| 460 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 461 |
+
url_name = match[1] if match else url
|
| 462 |
+
url_data = res.content
|
| 463 |
+
if verbose:
|
| 464 |
+
print(" done")
|
| 465 |
+
break
|
| 466 |
+
except KeyboardInterrupt:
|
| 467 |
+
raise
|
| 468 |
+
except:
|
| 469 |
+
if not attempts_left:
|
| 470 |
+
if verbose:
|
| 471 |
+
print(" failed")
|
| 472 |
+
raise
|
| 473 |
+
if verbose:
|
| 474 |
+
print(".", end="", flush=True)
|
| 475 |
+
|
| 476 |
+
# Save to cache.
|
| 477 |
+
if cache:
|
| 478 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 479 |
+
safe_name = safe_name[:min(len(safe_name), 128)]
|
| 480 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 481 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 482 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 483 |
+
with open(temp_file, "wb") as f:
|
| 484 |
+
f.write(url_data)
|
| 485 |
+
os.replace(temp_file, cache_file) # atomic
|
| 486 |
+
if return_filename:
|
| 487 |
+
return cache_file
|
| 488 |
+
|
| 489 |
+
# Return data as file object.
|
| 490 |
+
assert not return_filename
|
| 491 |
+
return io.BytesIO(url_data)
|
gen_data.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from utils import (
|
| 5 |
+
get_solvers,
|
| 6 |
+
parse_arguments,
|
| 7 |
+
prepare_paths,
|
| 8 |
+
adjust_hyper,
|
| 9 |
+
)
|
| 10 |
+
from models import prepare_stuff, prepare_condition_loader
|
| 11 |
+
import time
|
| 12 |
+
import numpy as np
|
| 13 |
+
import PIL.Image
|
| 14 |
+
|
| 15 |
+
def get_data_inverse_scaler(centered=True):
|
| 16 |
+
"""Inverse data normalizer."""
|
| 17 |
+
if centered:
|
| 18 |
+
# Rescale [-1, 1] to [0, 1]
|
| 19 |
+
return lambda x: (x + 1.0) / 2.0
|
| 20 |
+
else:
|
| 21 |
+
return lambda x: x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Generator:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
noise_schedule,
|
| 28 |
+
solver,
|
| 29 |
+
order,
|
| 30 |
+
skip_type=None,
|
| 31 |
+
load_from=None,
|
| 32 |
+
timesteps_1=None,
|
| 33 |
+
timesteps_2=None,
|
| 34 |
+
steps=35,
|
| 35 |
+
solver_extra_params=None,
|
| 36 |
+
device=None,
|
| 37 |
+
) -> None:
|
| 38 |
+
self.device = device
|
| 39 |
+
self.noise_schedule = noise_schedule
|
| 40 |
+
self.solver = solver
|
| 41 |
+
self.order = order
|
| 42 |
+
self.skip_type = skip_type
|
| 43 |
+
self.load_from = load_from
|
| 44 |
+
self.timesteps_1 = timesteps_1
|
| 45 |
+
self.timesteps_2 = timesteps_2
|
| 46 |
+
self.steps = steps
|
| 47 |
+
self.solver_extra_params = solver_extra_params
|
| 48 |
+
|
| 49 |
+
self._precompute_timesteps()
|
| 50 |
+
|
| 51 |
+
def _precompute_timesteps(self):
|
| 52 |
+
if self.load_from is None and type(self.timesteps_1) == list and type(self.timesteps_1[0]) == float \
|
| 53 |
+
and type(self.timesteps_2) == list and type(self.timesteps_2[0]) == float:
|
| 54 |
+
self.timesteps = self.noise_schedule.inverse_lambda(-np.log(self.timesteps_1)).to(self.device).float()
|
| 55 |
+
self.timesteps2 = self.noise_schedule.inverse_lambda(-np.log(self.timesteps_2)).to(self.device).float()
|
| 56 |
+
else:
|
| 57 |
+
self.timesteps, self.timesteps2 = self.solver.prepare_timesteps(
|
| 58 |
+
steps=self.steps,
|
| 59 |
+
t_start=self.noise_schedule.T,
|
| 60 |
+
t_end=self.noise_schedule.eps,
|
| 61 |
+
skip_type=self.skip_type,
|
| 62 |
+
device=self.device,
|
| 63 |
+
load_from=self.load_from,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def _sample(self, net, decoding_fn, latents, condition=None, unconditional_condition=None):
|
| 67 |
+
x_next_ = self.noise_schedule.prior_transformation(latents)
|
| 68 |
+
x_next_ = self.solver.sample_simple(
|
| 69 |
+
model_fn=net,
|
| 70 |
+
x=x_next_,
|
| 71 |
+
timesteps=self.timesteps,
|
| 72 |
+
timesteps2=self.timesteps2,
|
| 73 |
+
order=self.order,
|
| 74 |
+
NFEs=self.steps,
|
| 75 |
+
condition=condition,
|
| 76 |
+
unconditional_condition=unconditional_condition,
|
| 77 |
+
**self.solver_extra_params,
|
| 78 |
+
)
|
| 79 |
+
x_next_ = decoding_fn(x_next_)
|
| 80 |
+
return x_next_
|
| 81 |
+
|
| 82 |
+
def sample(self, net, decoding_fn, latents, condition=None, unconditional_condition=None, no_grad=True):
|
| 83 |
+
if no_grad:
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
return self._sample(net, decoding_fn, latents, condition, unconditional_condition)
|
| 86 |
+
else:
|
| 87 |
+
return self._sample(net, decoding_fn, latents, condition, unconditional_condition)
|
| 88 |
+
|
| 89 |
+
def main(args):
|
| 90 |
+
|
| 91 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 92 |
+
wrapped_model, model, decoding_fn, noise_schedule, latent_resolution, latent_channel, img_resolution, img_channel = prepare_stuff(args)
|
| 93 |
+
condition_loader = prepare_condition_loader(model_type=args.model,
|
| 94 |
+
model=model,
|
| 95 |
+
scale=args.scale if hasattr(args, "scale") else None,
|
| 96 |
+
condition=args.prompt_path or "random",
|
| 97 |
+
sampling_batch_size=args.sampling_batch_size,
|
| 98 |
+
num_prompt=args.num_prompts,
|
| 99 |
+
num_samples_per_prompt=args.num_samples_per_prompt,
|
| 100 |
+
)
|
| 101 |
+
adjust_hyper(args, latent_resolution, latent_channel)
|
| 102 |
+
desc, _, skip_type = prepare_paths(args)
|
| 103 |
+
data_dir = os.path.join(args.data_dir, desc)
|
| 104 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 105 |
+
|
| 106 |
+
solver, steps, solver_extra_params = get_solvers(
|
| 107 |
+
args.solver_name,
|
| 108 |
+
NFEs=args.steps,
|
| 109 |
+
order=args.order,
|
| 110 |
+
noise_schedule=noise_schedule,
|
| 111 |
+
unipc_variant=args.unipc_variant,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
generator = Generator(
|
| 115 |
+
noise_schedule=noise_schedule,
|
| 116 |
+
solver=solver,
|
| 117 |
+
order=args.order,
|
| 118 |
+
skip_type=skip_type,
|
| 119 |
+
load_from=args.load_from,
|
| 120 |
+
timesteps_1=args.custom_ts_1,
|
| 121 |
+
timesteps_2=args.custom_ts_2,
|
| 122 |
+
steps=steps,
|
| 123 |
+
solver_extra_params=solver_extra_params,
|
| 124 |
+
device=device,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
print(generator.timesteps, generator.timesteps2)
|
| 128 |
+
inverse_scalar = get_data_inverse_scaler(centered=True)
|
| 129 |
+
|
| 130 |
+
start = time.time()
|
| 131 |
+
count = 0
|
| 132 |
+
batch_size = args.sampling_batch_size
|
| 133 |
+
if args.prompt_path is not None:
|
| 134 |
+
args.total_samples = min(args.total_samples, len(condition_loader.prompts))
|
| 135 |
+
num_batches = (args.total_samples + batch_size - 1) // batch_size
|
| 136 |
+
|
| 137 |
+
for i in tqdm(range(num_batches)):
|
| 138 |
+
current_batch_size = min(batch_size, args.total_samples - i * batch_size)
|
| 139 |
+
sampling_shape = (current_batch_size, latent_channel, latent_resolution, latent_resolution)
|
| 140 |
+
latents = torch.randn(sampling_shape, device=device)
|
| 141 |
+
|
| 142 |
+
if condition_loader is not None:
|
| 143 |
+
conditioning, conditioned_unconditioning = next(condition_loader)
|
| 144 |
+
else:
|
| 145 |
+
conditioning = None
|
| 146 |
+
conditioned_unconditioning = None
|
| 147 |
+
|
| 148 |
+
img_teacher = generator.sample(wrapped_model, decoding_fn, latents, conditioning, conditioned_unconditioning)
|
| 149 |
+
|
| 150 |
+
img_teacher = img_teacher.detach().cpu().view(current_batch_size, img_channel, img_resolution, img_resolution)
|
| 151 |
+
latents = latents.detach().cpu()
|
| 152 |
+
|
| 153 |
+
if args.save_pt:
|
| 154 |
+
for i in range(current_batch_size):
|
| 155 |
+
latent = latents[i]
|
| 156 |
+
img = img_teacher[i]
|
| 157 |
+
c = conditioning[i] if conditioning is not None else None
|
| 158 |
+
uc = conditioned_unconditioning[i] if conditioned_unconditioning is not None else None
|
| 159 |
+
data = dict(latent=latent, img=img, c=c, uc=uc)
|
| 160 |
+
torch.save(data, os.path.join(data_dir, f"latent_{(count + i):06d}.pt"))
|
| 161 |
+
|
| 162 |
+
if args.save_png:
|
| 163 |
+
samples_raw = inverse_scalar(img_teacher)
|
| 164 |
+
samples = np.clip(
|
| 165 |
+
samples_raw.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255
|
| 166 |
+
).astype(np.uint8)
|
| 167 |
+
images_np = samples.reshape((-1, img_resolution, img_resolution, img_channel))
|
| 168 |
+
|
| 169 |
+
for i in range(current_batch_size):
|
| 170 |
+
image_np = images_np[i]
|
| 171 |
+
if args.prompt_path is not None and args.prompt_path.startswith('hpsv2'):
|
| 172 |
+
image_path = os.path.join(data_dir, f"{(count + i):05d}.jpg")
|
| 173 |
+
else:
|
| 174 |
+
image_path = os.path.join(data_dir, f"{(count + i):06d}.png")
|
| 175 |
+
if image_np.shape[2] == 1:
|
| 176 |
+
PIL.Image.fromarray(image_np[:, :, 0], "L").save(image_path)
|
| 177 |
+
else:
|
| 178 |
+
PIL.Image.fromarray(image_np, "RGB").save(image_path)
|
| 179 |
+
|
| 180 |
+
count += batch_size
|
| 181 |
+
|
| 182 |
+
end = time.time()
|
| 183 |
+
print(f"Generation time: {end - start}")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
args = parse_arguments()
|
| 188 |
+
main(args)
|
ldm/__init__.py
ADDED
|
File without changes
|
ldm/data/__init__.py
ADDED
|
File without changes
|
ldm/data/base.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Txt2ImgIterableBaseDataset(IterableDataset):
|
| 6 |
+
'''
|
| 7 |
+
Define an interface to make the IterableDatasets for text2img data chainable
|
| 8 |
+
'''
|
| 9 |
+
def __init__(self, num_records=0, valid_ids=None, size=256):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.num_records = num_records
|
| 12 |
+
self.valid_ids = valid_ids
|
| 13 |
+
self.sample_ids = valid_ids
|
| 14 |
+
self.size = size
|
| 15 |
+
|
| 16 |
+
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
| 17 |
+
|
| 18 |
+
def __len__(self):
|
| 19 |
+
return self.num_records
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def __iter__(self):
|
| 23 |
+
pass
|
ldm/data/imagenet.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, yaml, pickle, shutil, tarfile, glob
|
| 2 |
+
import cv2
|
| 3 |
+
import albumentations
|
| 4 |
+
import PIL
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchvision.transforms.functional as TF
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from functools import partial
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from torch.utils.data import Dataset, Subset
|
| 12 |
+
|
| 13 |
+
import taming.data.utils as tdu
|
| 14 |
+
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
| 15 |
+
from taming.data.imagenet import ImagePaths
|
| 16 |
+
|
| 17 |
+
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
| 21 |
+
with open(path_to_yaml) as f:
|
| 22 |
+
di2s = yaml.load(f)
|
| 23 |
+
return dict((v,k) for k,v in di2s.items())
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ImageNetBase(Dataset):
|
| 27 |
+
def __init__(self, config=None):
|
| 28 |
+
self.config = config or OmegaConf.create()
|
| 29 |
+
if not type(self.config)==dict:
|
| 30 |
+
self.config = OmegaConf.to_container(self.config)
|
| 31 |
+
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
| 32 |
+
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
| 33 |
+
self._prepare()
|
| 34 |
+
self._prepare_synset_to_human()
|
| 35 |
+
self._prepare_idx_to_synset()
|
| 36 |
+
self._prepare_human_to_integer_label()
|
| 37 |
+
self._load()
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.data)
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, i):
|
| 43 |
+
return self.data[i]
|
| 44 |
+
|
| 45 |
+
def _prepare(self):
|
| 46 |
+
raise NotImplementedError()
|
| 47 |
+
|
| 48 |
+
def _filter_relpaths(self, relpaths):
|
| 49 |
+
ignore = set([
|
| 50 |
+
"n06596364_9591.JPEG",
|
| 51 |
+
])
|
| 52 |
+
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
| 53 |
+
if "sub_indices" in self.config:
|
| 54 |
+
indices = str_to_indices(self.config["sub_indices"])
|
| 55 |
+
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
| 56 |
+
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
| 57 |
+
files = []
|
| 58 |
+
for rpath in relpaths:
|
| 59 |
+
syn = rpath.split("/")[0]
|
| 60 |
+
if syn in synsets:
|
| 61 |
+
files.append(rpath)
|
| 62 |
+
return files
|
| 63 |
+
else:
|
| 64 |
+
return relpaths
|
| 65 |
+
|
| 66 |
+
def _prepare_synset_to_human(self):
|
| 67 |
+
SIZE = 2655750
|
| 68 |
+
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
| 69 |
+
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
| 70 |
+
if (not os.path.exists(self.human_dict) or
|
| 71 |
+
not os.path.getsize(self.human_dict)==SIZE):
|
| 72 |
+
download(URL, self.human_dict)
|
| 73 |
+
|
| 74 |
+
def _prepare_idx_to_synset(self):
|
| 75 |
+
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
| 76 |
+
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
| 77 |
+
if (not os.path.exists(self.idx2syn)):
|
| 78 |
+
download(URL, self.idx2syn)
|
| 79 |
+
|
| 80 |
+
def _prepare_human_to_integer_label(self):
|
| 81 |
+
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
| 82 |
+
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
| 83 |
+
if (not os.path.exists(self.human2integer)):
|
| 84 |
+
download(URL, self.human2integer)
|
| 85 |
+
with open(self.human2integer, "r") as f:
|
| 86 |
+
lines = f.read().splitlines()
|
| 87 |
+
assert len(lines) == 1000
|
| 88 |
+
self.human2integer_dict = dict()
|
| 89 |
+
for line in lines:
|
| 90 |
+
value, key = line.split(":")
|
| 91 |
+
self.human2integer_dict[key] = int(value)
|
| 92 |
+
|
| 93 |
+
def _load(self):
|
| 94 |
+
with open(self.txt_filelist, "r") as f:
|
| 95 |
+
self.relpaths = f.read().splitlines()
|
| 96 |
+
l1 = len(self.relpaths)
|
| 97 |
+
self.relpaths = self._filter_relpaths(self.relpaths)
|
| 98 |
+
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
| 99 |
+
|
| 100 |
+
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
| 101 |
+
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
| 102 |
+
|
| 103 |
+
unique_synsets = np.unique(self.synsets)
|
| 104 |
+
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
| 105 |
+
if not self.keep_orig_class_label:
|
| 106 |
+
self.class_labels = [class_dict[s] for s in self.synsets]
|
| 107 |
+
else:
|
| 108 |
+
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
| 109 |
+
|
| 110 |
+
with open(self.human_dict, "r") as f:
|
| 111 |
+
human_dict = f.read().splitlines()
|
| 112 |
+
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
| 113 |
+
|
| 114 |
+
self.human_labels = [human_dict[s] for s in self.synsets]
|
| 115 |
+
|
| 116 |
+
labels = {
|
| 117 |
+
"relpath": np.array(self.relpaths),
|
| 118 |
+
"synsets": np.array(self.synsets),
|
| 119 |
+
"class_label": np.array(self.class_labels),
|
| 120 |
+
"human_label": np.array(self.human_labels),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
if self.process_images:
|
| 124 |
+
self.size = retrieve(self.config, "size", default=256)
|
| 125 |
+
self.data = ImagePaths(self.abspaths,
|
| 126 |
+
labels=labels,
|
| 127 |
+
size=self.size,
|
| 128 |
+
random_crop=self.random_crop,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
self.data = self.abspaths
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ImageNetTrain(ImageNetBase):
|
| 135 |
+
NAME = "ILSVRC2012_train"
|
| 136 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
| 137 |
+
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
| 138 |
+
FILES = [
|
| 139 |
+
"ILSVRC2012_img_train.tar",
|
| 140 |
+
]
|
| 141 |
+
SIZES = [
|
| 142 |
+
147897477120,
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
| 146 |
+
self.process_images = process_images
|
| 147 |
+
self.data_root = data_root
|
| 148 |
+
super().__init__(**kwargs)
|
| 149 |
+
|
| 150 |
+
def _prepare(self):
|
| 151 |
+
if self.data_root:
|
| 152 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
| 153 |
+
else:
|
| 154 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
| 155 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
| 156 |
+
|
| 157 |
+
self.datadir = os.path.join(self.root, "data")
|
| 158 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
| 159 |
+
self.expected_length = 1281167
|
| 160 |
+
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
| 161 |
+
default=True)
|
| 162 |
+
if not tdu.is_prepared(self.root):
|
| 163 |
+
# prep
|
| 164 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
| 165 |
+
|
| 166 |
+
datadir = self.datadir
|
| 167 |
+
if not os.path.exists(datadir):
|
| 168 |
+
path = os.path.join(self.root, self.FILES[0])
|
| 169 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
| 170 |
+
import academictorrents as at
|
| 171 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
| 172 |
+
assert atpath == path
|
| 173 |
+
|
| 174 |
+
print("Extracting {} to {}".format(path, datadir))
|
| 175 |
+
os.makedirs(datadir, exist_ok=True)
|
| 176 |
+
with tarfile.open(path, "r:") as tar:
|
| 177 |
+
tar.extractall(path=datadir)
|
| 178 |
+
|
| 179 |
+
print("Extracting sub-tars.")
|
| 180 |
+
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
| 181 |
+
for subpath in tqdm(subpaths):
|
| 182 |
+
subdir = subpath[:-len(".tar")]
|
| 183 |
+
os.makedirs(subdir, exist_ok=True)
|
| 184 |
+
with tarfile.open(subpath, "r:") as tar:
|
| 185 |
+
tar.extractall(path=subdir)
|
| 186 |
+
|
| 187 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
| 188 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
| 189 |
+
filelist = sorted(filelist)
|
| 190 |
+
filelist = "\n".join(filelist)+"\n"
|
| 191 |
+
with open(self.txt_filelist, "w") as f:
|
| 192 |
+
f.write(filelist)
|
| 193 |
+
|
| 194 |
+
tdu.mark_prepared(self.root)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class ImageNetValidation(ImageNetBase):
|
| 198 |
+
NAME = "ILSVRC2012_validation"
|
| 199 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
| 200 |
+
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
| 201 |
+
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
| 202 |
+
FILES = [
|
| 203 |
+
"ILSVRC2012_img_val.tar",
|
| 204 |
+
"validation_synset.txt",
|
| 205 |
+
]
|
| 206 |
+
SIZES = [
|
| 207 |
+
6744924160,
|
| 208 |
+
1950000,
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
| 212 |
+
self.data_root = data_root
|
| 213 |
+
self.process_images = process_images
|
| 214 |
+
super().__init__(**kwargs)
|
| 215 |
+
|
| 216 |
+
def _prepare(self):
|
| 217 |
+
if self.data_root:
|
| 218 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
| 219 |
+
else:
|
| 220 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
| 221 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
| 222 |
+
self.datadir = os.path.join(self.root, "data")
|
| 223 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
| 224 |
+
self.expected_length = 50000
|
| 225 |
+
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
| 226 |
+
default=False)
|
| 227 |
+
if not tdu.is_prepared(self.root):
|
| 228 |
+
# prep
|
| 229 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
| 230 |
+
|
| 231 |
+
datadir = self.datadir
|
| 232 |
+
if not os.path.exists(datadir):
|
| 233 |
+
path = os.path.join(self.root, self.FILES[0])
|
| 234 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
| 235 |
+
import academictorrents as at
|
| 236 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
| 237 |
+
assert atpath == path
|
| 238 |
+
|
| 239 |
+
print("Extracting {} to {}".format(path, datadir))
|
| 240 |
+
os.makedirs(datadir, exist_ok=True)
|
| 241 |
+
with tarfile.open(path, "r:") as tar:
|
| 242 |
+
tar.extractall(path=datadir)
|
| 243 |
+
|
| 244 |
+
vspath = os.path.join(self.root, self.FILES[1])
|
| 245 |
+
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
| 246 |
+
download(self.VS_URL, vspath)
|
| 247 |
+
|
| 248 |
+
with open(vspath, "r") as f:
|
| 249 |
+
synset_dict = f.read().splitlines()
|
| 250 |
+
synset_dict = dict(line.split() for line in synset_dict)
|
| 251 |
+
|
| 252 |
+
print("Reorganizing into synset folders")
|
| 253 |
+
synsets = np.unique(list(synset_dict.values()))
|
| 254 |
+
for s in synsets:
|
| 255 |
+
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
| 256 |
+
for k, v in synset_dict.items():
|
| 257 |
+
src = os.path.join(datadir, k)
|
| 258 |
+
dst = os.path.join(datadir, v)
|
| 259 |
+
shutil.move(src, dst)
|
| 260 |
+
|
| 261 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
| 262 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
| 263 |
+
filelist = sorted(filelist)
|
| 264 |
+
filelist = "\n".join(filelist)+"\n"
|
| 265 |
+
with open(self.txt_filelist, "w") as f:
|
| 266 |
+
f.write(filelist)
|
| 267 |
+
|
| 268 |
+
tdu.mark_prepared(self.root)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class ImageNetSR(Dataset):
|
| 273 |
+
def __init__(self, size=None,
|
| 274 |
+
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
| 275 |
+
random_crop=True):
|
| 276 |
+
"""
|
| 277 |
+
Imagenet Superresolution Dataloader
|
| 278 |
+
Performs following ops in order:
|
| 279 |
+
1. crops a crop of size s from image either as random or center crop
|
| 280 |
+
2. resizes crop to size with cv2.area_interpolation
|
| 281 |
+
3. degrades resized crop with degradation_fn
|
| 282 |
+
|
| 283 |
+
:param size: resizing to size after cropping
|
| 284 |
+
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
| 285 |
+
:param downscale_f: Low Resolution Downsample factor
|
| 286 |
+
:param min_crop_f: determines crop size s,
|
| 287 |
+
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
| 288 |
+
:param max_crop_f: ""
|
| 289 |
+
:param data_root:
|
| 290 |
+
:param random_crop:
|
| 291 |
+
"""
|
| 292 |
+
self.base = self.get_base()
|
| 293 |
+
assert size
|
| 294 |
+
assert (size / downscale_f).is_integer()
|
| 295 |
+
self.size = size
|
| 296 |
+
self.LR_size = int(size / downscale_f)
|
| 297 |
+
self.min_crop_f = min_crop_f
|
| 298 |
+
self.max_crop_f = max_crop_f
|
| 299 |
+
assert(max_crop_f <= 1.)
|
| 300 |
+
self.center_crop = not random_crop
|
| 301 |
+
|
| 302 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
| 303 |
+
|
| 304 |
+
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
| 305 |
+
|
| 306 |
+
if degradation == "bsrgan":
|
| 307 |
+
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
| 308 |
+
|
| 309 |
+
elif degradation == "bsrgan_light":
|
| 310 |
+
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
| 311 |
+
|
| 312 |
+
else:
|
| 313 |
+
interpolation_fn = {
|
| 314 |
+
"cv_nearest": cv2.INTER_NEAREST,
|
| 315 |
+
"cv_bilinear": cv2.INTER_LINEAR,
|
| 316 |
+
"cv_bicubic": cv2.INTER_CUBIC,
|
| 317 |
+
"cv_area": cv2.INTER_AREA,
|
| 318 |
+
"cv_lanczos": cv2.INTER_LANCZOS4,
|
| 319 |
+
"pil_nearest": PIL.Image.NEAREST,
|
| 320 |
+
"pil_bilinear": PIL.Image.BILINEAR,
|
| 321 |
+
"pil_bicubic": PIL.Image.BICUBIC,
|
| 322 |
+
"pil_box": PIL.Image.BOX,
|
| 323 |
+
"pil_hamming": PIL.Image.HAMMING,
|
| 324 |
+
"pil_lanczos": PIL.Image.LANCZOS,
|
| 325 |
+
}[degradation]
|
| 326 |
+
|
| 327 |
+
self.pil_interpolation = degradation.startswith("pil_")
|
| 328 |
+
|
| 329 |
+
if self.pil_interpolation:
|
| 330 |
+
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
| 334 |
+
interpolation=interpolation_fn)
|
| 335 |
+
|
| 336 |
+
def __len__(self):
|
| 337 |
+
return len(self.base)
|
| 338 |
+
|
| 339 |
+
def __getitem__(self, i):
|
| 340 |
+
example = self.base[i]
|
| 341 |
+
image = Image.open(example["file_path_"])
|
| 342 |
+
|
| 343 |
+
if not image.mode == "RGB":
|
| 344 |
+
image = image.convert("RGB")
|
| 345 |
+
|
| 346 |
+
image = np.array(image).astype(np.uint8)
|
| 347 |
+
|
| 348 |
+
min_side_len = min(image.shape[:2])
|
| 349 |
+
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
| 350 |
+
crop_side_len = int(crop_side_len)
|
| 351 |
+
|
| 352 |
+
if self.center_crop:
|
| 353 |
+
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
| 354 |
+
|
| 355 |
+
else:
|
| 356 |
+
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
| 357 |
+
|
| 358 |
+
image = self.cropper(image=image)["image"]
|
| 359 |
+
image = self.image_rescaler(image=image)["image"]
|
| 360 |
+
|
| 361 |
+
if self.pil_interpolation:
|
| 362 |
+
image_pil = PIL.Image.fromarray(image)
|
| 363 |
+
LR_image = self.degradation_process(image_pil)
|
| 364 |
+
LR_image = np.array(LR_image).astype(np.uint8)
|
| 365 |
+
|
| 366 |
+
else:
|
| 367 |
+
LR_image = self.degradation_process(image=image)["image"]
|
| 368 |
+
|
| 369 |
+
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
| 370 |
+
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
| 371 |
+
|
| 372 |
+
return example
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class ImageNetSRTrain(ImageNetSR):
|
| 376 |
+
def __init__(self, **kwargs):
|
| 377 |
+
super().__init__(**kwargs)
|
| 378 |
+
|
| 379 |
+
def get_base(self):
|
| 380 |
+
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
| 381 |
+
indices = pickle.load(f)
|
| 382 |
+
dset = ImageNetTrain(process_images=False,)
|
| 383 |
+
return Subset(dset, indices)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class ImageNetSRValidation(ImageNetSR):
|
| 387 |
+
def __init__(self, **kwargs):
|
| 388 |
+
super().__init__(**kwargs)
|
| 389 |
+
|
| 390 |
+
def get_base(self):
|
| 391 |
+
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
| 392 |
+
indices = pickle.load(f)
|
| 393 |
+
dset = ImageNetValidation(process_images=False,)
|
| 394 |
+
return Subset(dset, indices)
|
ldm/data/lsun.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import PIL
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LSUNBase(Dataset):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
txt_file,
|
| 12 |
+
data_root,
|
| 13 |
+
size=None,
|
| 14 |
+
interpolation="bicubic",
|
| 15 |
+
flip_p=0.5
|
| 16 |
+
):
|
| 17 |
+
self.data_paths = txt_file
|
| 18 |
+
self.data_root = data_root
|
| 19 |
+
with open(self.data_paths, "r") as f:
|
| 20 |
+
self.image_paths = f.read().splitlines()
|
| 21 |
+
self._length = len(self.image_paths)
|
| 22 |
+
self.labels = {
|
| 23 |
+
"relative_file_path_": [l for l in self.image_paths],
|
| 24 |
+
"file_path_": [os.path.join(self.data_root, l)
|
| 25 |
+
for l in self.image_paths],
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
self.size = size
|
| 29 |
+
self.interpolation = {"linear": PIL.Image.LINEAR,
|
| 30 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 31 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 32 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 33 |
+
}[interpolation]
|
| 34 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return self._length
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, i):
|
| 40 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
| 41 |
+
image = Image.open(example["file_path_"])
|
| 42 |
+
if not image.mode == "RGB":
|
| 43 |
+
image = image.convert("RGB")
|
| 44 |
+
|
| 45 |
+
# default to score-sde preprocessing
|
| 46 |
+
img = np.array(image).astype(np.uint8)
|
| 47 |
+
crop = min(img.shape[0], img.shape[1])
|
| 48 |
+
h, w, = img.shape[0], img.shape[1]
|
| 49 |
+
img = img[(h - crop) // 2:(h + crop) // 2,
|
| 50 |
+
(w - crop) // 2:(w + crop) // 2]
|
| 51 |
+
|
| 52 |
+
image = Image.fromarray(img)
|
| 53 |
+
if self.size is not None:
|
| 54 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
| 55 |
+
|
| 56 |
+
image = self.flip(image)
|
| 57 |
+
image = np.array(image).astype(np.uint8)
|
| 58 |
+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
| 59 |
+
return example
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class LSUNChurchesTrain(LSUNBase):
|
| 63 |
+
def __init__(self, **kwargs):
|
| 64 |
+
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LSUNChurchesValidation(LSUNBase):
|
| 68 |
+
def __init__(self, flip_p=0., **kwargs):
|
| 69 |
+
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
| 70 |
+
flip_p=flip_p, **kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LSUNBedroomsTrain(LSUNBase):
|
| 74 |
+
def __init__(self, **kwargs):
|
| 75 |
+
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LSUNBedroomsValidation(LSUNBase):
|
| 79 |
+
def __init__(self, flip_p=0.0, **kwargs):
|
| 80 |
+
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
| 81 |
+
flip_p=flip_p, **kwargs)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class LSUNCatsTrain(LSUNBase):
|
| 85 |
+
def __init__(self, **kwargs):
|
| 86 |
+
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class LSUNCatsValidation(LSUNBase):
|
| 90 |
+
def __init__(self, flip_p=0., **kwargs):
|
| 91 |
+
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
| 92 |
+
flip_p=flip_p, **kwargs)
|
ldm/lr_scheduler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LambdaWarmUpCosineScheduler:
|
| 5 |
+
"""
|
| 6 |
+
note: use with a base_lr of 1.0
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
| 9 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 10 |
+
self.lr_start = lr_start
|
| 11 |
+
self.lr_min = lr_min
|
| 12 |
+
self.lr_max = lr_max
|
| 13 |
+
self.lr_max_decay_steps = max_decay_steps
|
| 14 |
+
self.last_lr = 0.
|
| 15 |
+
self.verbosity_interval = verbosity_interval
|
| 16 |
+
|
| 17 |
+
def schedule(self, n, **kwargs):
|
| 18 |
+
if self.verbosity_interval > 0:
|
| 19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
| 20 |
+
if n < self.lr_warm_up_steps:
|
| 21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
| 22 |
+
self.last_lr = lr
|
| 23 |
+
return lr
|
| 24 |
+
else:
|
| 25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
| 26 |
+
t = min(t, 1.0)
|
| 27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
| 28 |
+
1 + np.cos(t * np.pi))
|
| 29 |
+
self.last_lr = lr
|
| 30 |
+
return lr
|
| 31 |
+
|
| 32 |
+
def __call__(self, n, **kwargs):
|
| 33 |
+
return self.schedule(n,**kwargs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LambdaWarmUpCosineScheduler2:
|
| 37 |
+
"""
|
| 38 |
+
supports repeated iterations, configurable via lists
|
| 39 |
+
note: use with a base_lr of 1.0.
|
| 40 |
+
"""
|
| 41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
| 42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
| 43 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 44 |
+
self.f_start = f_start
|
| 45 |
+
self.f_min = f_min
|
| 46 |
+
self.f_max = f_max
|
| 47 |
+
self.cycle_lengths = cycle_lengths
|
| 48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
| 49 |
+
self.last_f = 0.
|
| 50 |
+
self.verbosity_interval = verbosity_interval
|
| 51 |
+
|
| 52 |
+
def find_in_interval(self, n):
|
| 53 |
+
interval = 0
|
| 54 |
+
for cl in self.cum_cycles[1:]:
|
| 55 |
+
if n <= cl:
|
| 56 |
+
return interval
|
| 57 |
+
interval += 1
|
| 58 |
+
|
| 59 |
+
def schedule(self, n, **kwargs):
|
| 60 |
+
cycle = self.find_in_interval(n)
|
| 61 |
+
n = n - self.cum_cycles[cycle]
|
| 62 |
+
if self.verbosity_interval > 0:
|
| 63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 64 |
+
f"current cycle {cycle}")
|
| 65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 67 |
+
self.last_f = f
|
| 68 |
+
return f
|
| 69 |
+
else:
|
| 70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
| 71 |
+
t = min(t, 1.0)
|
| 72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
| 73 |
+
1 + np.cos(t * np.pi))
|
| 74 |
+
self.last_f = f
|
| 75 |
+
return f
|
| 76 |
+
|
| 77 |
+
def __call__(self, n, **kwargs):
|
| 78 |
+
return self.schedule(n, **kwargs)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
| 82 |
+
|
| 83 |
+
def schedule(self, n, **kwargs):
|
| 84 |
+
cycle = self.find_in_interval(n)
|
| 85 |
+
n = n - self.cum_cycles[cycle]
|
| 86 |
+
if self.verbosity_interval > 0:
|
| 87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 88 |
+
f"current cycle {cycle}")
|
| 89 |
+
|
| 90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 92 |
+
self.last_f = f
|
| 93 |
+
return f
|
| 94 |
+
else:
|
| 95 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
| 96 |
+
self.last_f = f
|
| 97 |
+
return f
|
| 98 |
+
|
ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
|
| 6 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
| 7 |
+
|
| 8 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
| 9 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 10 |
+
|
| 11 |
+
from ldm.util import instantiate_from_config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class VQModel(pl.LightningModule):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
ddconfig,
|
| 17 |
+
lossconfig,
|
| 18 |
+
n_embed,
|
| 19 |
+
embed_dim,
|
| 20 |
+
ckpt_path=None,
|
| 21 |
+
ignore_keys=[],
|
| 22 |
+
image_key="image",
|
| 23 |
+
colorize_nlabels=None,
|
| 24 |
+
monitor=None,
|
| 25 |
+
batch_resize_range=None,
|
| 26 |
+
scheduler_config=None,
|
| 27 |
+
lr_g_factor=1.0,
|
| 28 |
+
remap=None,
|
| 29 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
| 30 |
+
use_ema=False
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.embed_dim = embed_dim
|
| 34 |
+
self.n_embed = n_embed
|
| 35 |
+
self.image_key = image_key
|
| 36 |
+
self.encoder = Encoder(**ddconfig)
|
| 37 |
+
self.decoder = Decoder(**ddconfig)
|
| 38 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 39 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
| 40 |
+
remap=remap,
|
| 41 |
+
sane_index_shape=sane_index_shape)
|
| 42 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
| 43 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 44 |
+
if colorize_nlabels is not None:
|
| 45 |
+
assert type(colorize_nlabels)==int
|
| 46 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 47 |
+
if monitor is not None:
|
| 48 |
+
self.monitor = monitor
|
| 49 |
+
self.batch_resize_range = batch_resize_range
|
| 50 |
+
if self.batch_resize_range is not None:
|
| 51 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
| 52 |
+
|
| 53 |
+
self.use_ema = use_ema
|
| 54 |
+
if self.use_ema:
|
| 55 |
+
self.model_ema = LitEma(self)
|
| 56 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 57 |
+
if ckpt_path is not None:
|
| 58 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 59 |
+
self.scheduler_config = scheduler_config
|
| 60 |
+
self.lr_g_factor = lr_g_factor
|
| 61 |
+
|
| 62 |
+
@contextmanager
|
| 63 |
+
def ema_scope(self, context=None):
|
| 64 |
+
if self.use_ema:
|
| 65 |
+
self.model_ema.store(self.parameters())
|
| 66 |
+
self.model_ema.copy_to(self)
|
| 67 |
+
if context is not None:
|
| 68 |
+
print(f"{context}: Switched to EMA weights")
|
| 69 |
+
try:
|
| 70 |
+
yield None
|
| 71 |
+
finally:
|
| 72 |
+
if self.use_ema:
|
| 73 |
+
self.model_ema.restore(self.parameters())
|
| 74 |
+
if context is not None:
|
| 75 |
+
print(f"{context}: Restored training weights")
|
| 76 |
+
|
| 77 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 78 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 79 |
+
keys = list(sd.keys())
|
| 80 |
+
for k in keys:
|
| 81 |
+
for ik in ignore_keys:
|
| 82 |
+
if k.startswith(ik):
|
| 83 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 84 |
+
del sd[k]
|
| 85 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
| 86 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 87 |
+
if len(missing) > 0:
|
| 88 |
+
print(f"Missing Keys: {missing}")
|
| 89 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 90 |
+
|
| 91 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 92 |
+
if self.use_ema:
|
| 93 |
+
self.model_ema(self)
|
| 94 |
+
|
| 95 |
+
def encode(self, x):
|
| 96 |
+
h = self.encoder(x)
|
| 97 |
+
h = self.quant_conv(h)
|
| 98 |
+
quant, emb_loss, info = self.quantize(h)
|
| 99 |
+
return quant, emb_loss, info
|
| 100 |
+
|
| 101 |
+
def encode_to_prequant(self, x):
|
| 102 |
+
h = self.encoder(x)
|
| 103 |
+
h = self.quant_conv(h)
|
| 104 |
+
return h
|
| 105 |
+
|
| 106 |
+
def decode(self, quant):
|
| 107 |
+
quant = self.post_quant_conv(quant)
|
| 108 |
+
dec = self.decoder(quant)
|
| 109 |
+
return dec
|
| 110 |
+
|
| 111 |
+
def decode_code(self, code_b):
|
| 112 |
+
quant_b = self.quantize.embed_code(code_b)
|
| 113 |
+
dec = self.decode(quant_b)
|
| 114 |
+
return dec
|
| 115 |
+
|
| 116 |
+
def forward(self, input, return_pred_indices=False):
|
| 117 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
| 118 |
+
dec = self.decode(quant)
|
| 119 |
+
if return_pred_indices:
|
| 120 |
+
return dec, diff, ind
|
| 121 |
+
return dec, diff
|
| 122 |
+
|
| 123 |
+
def get_input(self, batch, k):
|
| 124 |
+
x = batch[k]
|
| 125 |
+
if len(x.shape) == 3:
|
| 126 |
+
x = x[..., None]
|
| 127 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 128 |
+
if self.batch_resize_range is not None:
|
| 129 |
+
lower_size = self.batch_resize_range[0]
|
| 130 |
+
upper_size = self.batch_resize_range[1]
|
| 131 |
+
if self.global_step <= 4:
|
| 132 |
+
# do the first few batches with max size to avoid later oom
|
| 133 |
+
new_resize = upper_size
|
| 134 |
+
else:
|
| 135 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
| 136 |
+
if new_resize != x.shape[2]:
|
| 137 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
| 138 |
+
x = x.detach()
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 142 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
| 143 |
+
# try not to fool the heuristics
|
| 144 |
+
x = self.get_input(batch, self.image_key)
|
| 145 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 146 |
+
|
| 147 |
+
if optimizer_idx == 0:
|
| 148 |
+
# autoencode
|
| 149 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 150 |
+
last_layer=self.get_last_layer(), split="train",
|
| 151 |
+
predicted_indices=ind)
|
| 152 |
+
|
| 153 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 154 |
+
return aeloss
|
| 155 |
+
|
| 156 |
+
if optimizer_idx == 1:
|
| 157 |
+
# discriminator
|
| 158 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 159 |
+
last_layer=self.get_last_layer(), split="train")
|
| 160 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 161 |
+
return discloss
|
| 162 |
+
|
| 163 |
+
def validation_step(self, batch, batch_idx):
|
| 164 |
+
log_dict = self._validation_step(batch, batch_idx)
|
| 165 |
+
with self.ema_scope():
|
| 166 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
| 167 |
+
return log_dict
|
| 168 |
+
|
| 169 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
| 170 |
+
x = self.get_input(batch, self.image_key)
|
| 171 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 172 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
| 173 |
+
self.global_step,
|
| 174 |
+
last_layer=self.get_last_layer(),
|
| 175 |
+
split="val"+suffix,
|
| 176 |
+
predicted_indices=ind
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
| 180 |
+
self.global_step,
|
| 181 |
+
last_layer=self.get_last_layer(),
|
| 182 |
+
split="val"+suffix,
|
| 183 |
+
predicted_indices=ind
|
| 184 |
+
)
|
| 185 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
| 186 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
| 187 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 188 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
| 189 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 190 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
| 191 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
| 192 |
+
self.log_dict(log_dict_ae)
|
| 193 |
+
self.log_dict(log_dict_disc)
|
| 194 |
+
return self.log_dict
|
| 195 |
+
|
| 196 |
+
def configure_optimizers(self):
|
| 197 |
+
lr_d = self.learning_rate
|
| 198 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
| 199 |
+
print("lr_d", lr_d)
|
| 200 |
+
print("lr_g", lr_g)
|
| 201 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 202 |
+
list(self.decoder.parameters())+
|
| 203 |
+
list(self.quantize.parameters())+
|
| 204 |
+
list(self.quant_conv.parameters())+
|
| 205 |
+
list(self.post_quant_conv.parameters()),
|
| 206 |
+
lr=lr_g, betas=(0.5, 0.9))
|
| 207 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 208 |
+
lr=lr_d, betas=(0.5, 0.9))
|
| 209 |
+
|
| 210 |
+
if self.scheduler_config is not None:
|
| 211 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 212 |
+
|
| 213 |
+
print("Setting up LambdaLR scheduler...")
|
| 214 |
+
scheduler = [
|
| 215 |
+
{
|
| 216 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
| 217 |
+
'interval': 'step',
|
| 218 |
+
'frequency': 1
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
| 222 |
+
'interval': 'step',
|
| 223 |
+
'frequency': 1
|
| 224 |
+
},
|
| 225 |
+
]
|
| 226 |
+
return [opt_ae, opt_disc], scheduler
|
| 227 |
+
return [opt_ae, opt_disc], []
|
| 228 |
+
|
| 229 |
+
def get_last_layer(self):
|
| 230 |
+
return self.decoder.conv_out.weight
|
| 231 |
+
|
| 232 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
| 233 |
+
log = dict()
|
| 234 |
+
x = self.get_input(batch, self.image_key)
|
| 235 |
+
x = x.to(self.device)
|
| 236 |
+
if only_inputs:
|
| 237 |
+
log["inputs"] = x
|
| 238 |
+
return log
|
| 239 |
+
xrec, _ = self(x)
|
| 240 |
+
if x.shape[1] > 3:
|
| 241 |
+
# colorize with random projection
|
| 242 |
+
assert xrec.shape[1] > 3
|
| 243 |
+
x = self.to_rgb(x)
|
| 244 |
+
xrec = self.to_rgb(xrec)
|
| 245 |
+
log["inputs"] = x
|
| 246 |
+
log["reconstructions"] = xrec
|
| 247 |
+
if plot_ema:
|
| 248 |
+
with self.ema_scope():
|
| 249 |
+
xrec_ema, _ = self(x)
|
| 250 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
| 251 |
+
log["reconstructions_ema"] = xrec_ema
|
| 252 |
+
return log
|
| 253 |
+
|
| 254 |
+
def to_rgb(self, x):
|
| 255 |
+
assert self.image_key == "segmentation"
|
| 256 |
+
if not hasattr(self, "colorize"):
|
| 257 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 258 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 259 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class VQModelInterface(VQModel):
|
| 264 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
| 265 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
| 266 |
+
self.embed_dim = embed_dim
|
| 267 |
+
|
| 268 |
+
def encode(self, x):
|
| 269 |
+
h = self.encoder(x)
|
| 270 |
+
h = self.quant_conv(h)
|
| 271 |
+
return h
|
| 272 |
+
|
| 273 |
+
def decode(self, h, force_not_quantize=False):
|
| 274 |
+
# also go through quantization layer
|
| 275 |
+
if not force_not_quantize:
|
| 276 |
+
quant, emb_loss, info = self.quantize(h)
|
| 277 |
+
else:
|
| 278 |
+
quant = h
|
| 279 |
+
quant = self.post_quant_conv(quant)
|
| 280 |
+
dec = self.decoder(quant)
|
| 281 |
+
return dec
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class AutoencoderKL(pl.LightningModule):
|
| 285 |
+
def __init__(self,
|
| 286 |
+
ddconfig,
|
| 287 |
+
lossconfig,
|
| 288 |
+
embed_dim,
|
| 289 |
+
ckpt_path=None,
|
| 290 |
+
ignore_keys=[],
|
| 291 |
+
image_key="image",
|
| 292 |
+
colorize_nlabels=None,
|
| 293 |
+
monitor=None,
|
| 294 |
+
):
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.image_key = image_key
|
| 297 |
+
self.encoder = Encoder(**ddconfig)
|
| 298 |
+
self.decoder = Decoder(**ddconfig)
|
| 299 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 300 |
+
assert ddconfig["double_z"]
|
| 301 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
| 302 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 303 |
+
self.embed_dim = embed_dim
|
| 304 |
+
if colorize_nlabels is not None:
|
| 305 |
+
assert type(colorize_nlabels)==int
|
| 306 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 307 |
+
if monitor is not None:
|
| 308 |
+
self.monitor = monitor
|
| 309 |
+
if ckpt_path is not None:
|
| 310 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 311 |
+
|
| 312 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 313 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 314 |
+
keys = list(sd.keys())
|
| 315 |
+
for k in keys:
|
| 316 |
+
for ik in ignore_keys:
|
| 317 |
+
if k.startswith(ik):
|
| 318 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 319 |
+
del sd[k]
|
| 320 |
+
self.load_state_dict(sd, strict=False)
|
| 321 |
+
print(f"Restored from {path}")
|
| 322 |
+
|
| 323 |
+
def encode(self, x):
|
| 324 |
+
h = self.encoder(x)
|
| 325 |
+
moments = self.quant_conv(h)
|
| 326 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 327 |
+
return posterior
|
| 328 |
+
|
| 329 |
+
def decode(self, z):
|
| 330 |
+
z = self.post_quant_conv(z)
|
| 331 |
+
dec = self.decoder(z)
|
| 332 |
+
return dec
|
| 333 |
+
|
| 334 |
+
def forward(self, input, sample_posterior=True):
|
| 335 |
+
posterior = self.encode(input)
|
| 336 |
+
if sample_posterior:
|
| 337 |
+
z = posterior.sample()
|
| 338 |
+
else:
|
| 339 |
+
z = posterior.mode()
|
| 340 |
+
dec = self.decode(z)
|
| 341 |
+
return dec, posterior
|
| 342 |
+
|
| 343 |
+
def get_input(self, batch, k):
|
| 344 |
+
x = batch[k]
|
| 345 |
+
if len(x.shape) == 3:
|
| 346 |
+
x = x[..., None]
|
| 347 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 348 |
+
return x
|
| 349 |
+
|
| 350 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 351 |
+
inputs = self.get_input(batch, self.image_key)
|
| 352 |
+
reconstructions, posterior = self(inputs)
|
| 353 |
+
|
| 354 |
+
if optimizer_idx == 0:
|
| 355 |
+
# train encoder+decoder+logvar
|
| 356 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 357 |
+
last_layer=self.get_last_layer(), split="train")
|
| 358 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 359 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 360 |
+
return aeloss
|
| 361 |
+
|
| 362 |
+
if optimizer_idx == 1:
|
| 363 |
+
# train the discriminator
|
| 364 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 365 |
+
last_layer=self.get_last_layer(), split="train")
|
| 366 |
+
|
| 367 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 368 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 369 |
+
return discloss
|
| 370 |
+
|
| 371 |
+
def validation_step(self, batch, batch_idx):
|
| 372 |
+
inputs = self.get_input(batch, self.image_key)
|
| 373 |
+
reconstructions, posterior = self(inputs)
|
| 374 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
| 375 |
+
last_layer=self.get_last_layer(), split="val")
|
| 376 |
+
|
| 377 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
| 378 |
+
last_layer=self.get_last_layer(), split="val")
|
| 379 |
+
|
| 380 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
| 381 |
+
self.log_dict(log_dict_ae)
|
| 382 |
+
self.log_dict(log_dict_disc)
|
| 383 |
+
return self.log_dict
|
| 384 |
+
|
| 385 |
+
def configure_optimizers(self):
|
| 386 |
+
lr = self.learning_rate
|
| 387 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 388 |
+
list(self.decoder.parameters())+
|
| 389 |
+
list(self.quant_conv.parameters())+
|
| 390 |
+
list(self.post_quant_conv.parameters()),
|
| 391 |
+
lr=lr, betas=(0.5, 0.9))
|
| 392 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 393 |
+
lr=lr, betas=(0.5, 0.9))
|
| 394 |
+
return [opt_ae, opt_disc], []
|
| 395 |
+
|
| 396 |
+
def get_last_layer(self):
|
| 397 |
+
return self.decoder.conv_out.weight
|
| 398 |
+
|
| 399 |
+
@torch.no_grad()
|
| 400 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
| 401 |
+
log = dict()
|
| 402 |
+
x = self.get_input(batch, self.image_key)
|
| 403 |
+
x = x.to(self.device)
|
| 404 |
+
if not only_inputs:
|
| 405 |
+
xrec, posterior = self(x)
|
| 406 |
+
if x.shape[1] > 3:
|
| 407 |
+
# colorize with random projection
|
| 408 |
+
assert xrec.shape[1] > 3
|
| 409 |
+
x = self.to_rgb(x)
|
| 410 |
+
xrec = self.to_rgb(xrec)
|
| 411 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
| 412 |
+
log["reconstructions"] = xrec
|
| 413 |
+
log["inputs"] = x
|
| 414 |
+
return log
|
| 415 |
+
|
| 416 |
+
def to_rgb(self, x):
|
| 417 |
+
assert self.image_key == "segmentation"
|
| 418 |
+
if not hasattr(self, "colorize"):
|
| 419 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 420 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 421 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 422 |
+
return x
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class IdentityFirstStage(torch.nn.Module):
|
| 426 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 427 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
| 428 |
+
super().__init__()
|
| 429 |
+
|
| 430 |
+
def encode(self, x, *args, **kwargs):
|
| 431 |
+
return x
|
| 432 |
+
|
| 433 |
+
def decode(self, x, *args, **kwargs):
|
| 434 |
+
return x
|
| 435 |
+
|
| 436 |
+
def quantize(self, x, *args, **kwargs):
|
| 437 |
+
if self.vq_interface:
|
| 438 |
+
return x, None, [None, None, None]
|
| 439 |
+
return x
|
| 440 |
+
|
| 441 |
+
def forward(self, x, *args, **kwargs):
|
| 442 |
+
return x
|
ldm/models/diffusion/__init__.py
ADDED
|
File without changes
|
ldm/models/diffusion/classifier.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from glob import glob
|
| 11 |
+
from natsort import natsorted
|
| 12 |
+
|
| 13 |
+
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
| 14 |
+
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
| 15 |
+
|
| 16 |
+
__models__ = {
|
| 17 |
+
'class_label': EncoderUNetModel,
|
| 18 |
+
'segmentation': UNetModel
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def disabled_train(self, mode=True):
|
| 23 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 24 |
+
does not change anymore."""
|
| 25 |
+
return self
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NoisyLatentImageClassifier(pl.LightningModule):
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
diffusion_path,
|
| 32 |
+
num_classes,
|
| 33 |
+
ckpt_path=None,
|
| 34 |
+
pool='attention',
|
| 35 |
+
label_key=None,
|
| 36 |
+
diffusion_ckpt_path=None,
|
| 37 |
+
scheduler_config=None,
|
| 38 |
+
weight_decay=1.e-2,
|
| 39 |
+
log_steps=10,
|
| 40 |
+
monitor='val/loss',
|
| 41 |
+
*args,
|
| 42 |
+
**kwargs):
|
| 43 |
+
super().__init__(*args, **kwargs)
|
| 44 |
+
self.num_classes = num_classes
|
| 45 |
+
# get latest config of diffusion model
|
| 46 |
+
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
| 47 |
+
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
| 48 |
+
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
| 49 |
+
self.load_diffusion()
|
| 50 |
+
|
| 51 |
+
self.monitor = monitor
|
| 52 |
+
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
| 53 |
+
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
| 54 |
+
self.log_steps = log_steps
|
| 55 |
+
|
| 56 |
+
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
| 57 |
+
else self.diffusion_model.cond_stage_key
|
| 58 |
+
|
| 59 |
+
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
| 60 |
+
|
| 61 |
+
if self.label_key not in __models__:
|
| 62 |
+
raise NotImplementedError()
|
| 63 |
+
|
| 64 |
+
self.load_classifier(ckpt_path, pool)
|
| 65 |
+
|
| 66 |
+
self.scheduler_config = scheduler_config
|
| 67 |
+
self.use_scheduler = self.scheduler_config is not None
|
| 68 |
+
self.weight_decay = weight_decay
|
| 69 |
+
|
| 70 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
| 71 |
+
sd = torch.load(path, map_location="cpu")
|
| 72 |
+
if "state_dict" in list(sd.keys()):
|
| 73 |
+
sd = sd["state_dict"]
|
| 74 |
+
keys = list(sd.keys())
|
| 75 |
+
for k in keys:
|
| 76 |
+
for ik in ignore_keys:
|
| 77 |
+
if k.startswith(ik):
|
| 78 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 79 |
+
del sd[k]
|
| 80 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
| 81 |
+
sd, strict=False)
|
| 82 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 83 |
+
if len(missing) > 0:
|
| 84 |
+
print(f"Missing Keys: {missing}")
|
| 85 |
+
if len(unexpected) > 0:
|
| 86 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 87 |
+
|
| 88 |
+
def load_diffusion(self):
|
| 89 |
+
model = instantiate_from_config(self.diffusion_config)
|
| 90 |
+
self.diffusion_model = model.eval()
|
| 91 |
+
self.diffusion_model.train = disabled_train
|
| 92 |
+
for param in self.diffusion_model.parameters():
|
| 93 |
+
param.requires_grad = False
|
| 94 |
+
|
| 95 |
+
def load_classifier(self, ckpt_path, pool):
|
| 96 |
+
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
| 97 |
+
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
| 98 |
+
model_config.out_channels = self.num_classes
|
| 99 |
+
if self.label_key == 'class_label':
|
| 100 |
+
model_config.pool = pool
|
| 101 |
+
|
| 102 |
+
self.model = __models__[self.label_key](**model_config)
|
| 103 |
+
if ckpt_path is not None:
|
| 104 |
+
print('#####################################################################')
|
| 105 |
+
print(f'load from ckpt "{ckpt_path}"')
|
| 106 |
+
print('#####################################################################')
|
| 107 |
+
self.init_from_ckpt(ckpt_path)
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def get_x_noisy(self, x, t, noise=None):
|
| 111 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
| 112 |
+
continuous_sqrt_alpha_cumprod = None
|
| 113 |
+
if self.diffusion_model.use_continuous_noise:
|
| 114 |
+
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
| 115 |
+
# todo: make sure t+1 is correct here
|
| 116 |
+
|
| 117 |
+
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
| 118 |
+
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
| 119 |
+
|
| 120 |
+
def forward(self, x_noisy, t, *args, **kwargs):
|
| 121 |
+
return self.model(x_noisy, t)
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def get_input(self, batch, k):
|
| 125 |
+
x = batch[k]
|
| 126 |
+
if len(x.shape) == 3:
|
| 127 |
+
x = x[..., None]
|
| 128 |
+
x = rearrange(x, 'b h w c -> b c h w')
|
| 129 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def get_conditioning(self, batch, k=None):
|
| 134 |
+
if k is None:
|
| 135 |
+
k = self.label_key
|
| 136 |
+
assert k is not None, 'Needs to provide label key'
|
| 137 |
+
|
| 138 |
+
targets = batch[k].to(self.device)
|
| 139 |
+
|
| 140 |
+
if self.label_key == 'segmentation':
|
| 141 |
+
targets = rearrange(targets, 'b h w c -> b c h w')
|
| 142 |
+
for down in range(self.numd):
|
| 143 |
+
h, w = targets.shape[-2:]
|
| 144 |
+
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
| 145 |
+
|
| 146 |
+
# targets = rearrange(targets,'b c h w -> b h w c')
|
| 147 |
+
|
| 148 |
+
return targets
|
| 149 |
+
|
| 150 |
+
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
| 151 |
+
_, top_ks = torch.topk(logits, k, dim=1)
|
| 152 |
+
if reduction == "mean":
|
| 153 |
+
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
| 154 |
+
elif reduction == "none":
|
| 155 |
+
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
| 156 |
+
|
| 157 |
+
def on_train_epoch_start(self):
|
| 158 |
+
# save some memory
|
| 159 |
+
self.diffusion_model.model.to('cpu')
|
| 160 |
+
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
def write_logs(self, loss, logits, targets):
|
| 163 |
+
log_prefix = 'train' if self.training else 'val'
|
| 164 |
+
log = {}
|
| 165 |
+
log[f"{log_prefix}/loss"] = loss.mean()
|
| 166 |
+
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
| 167 |
+
logits, targets, k=1, reduction="mean"
|
| 168 |
+
)
|
| 169 |
+
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
| 170 |
+
logits, targets, k=5, reduction="mean"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
| 174 |
+
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
| 175 |
+
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
| 176 |
+
lr = self.optimizers().param_groups[0]['lr']
|
| 177 |
+
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
| 178 |
+
|
| 179 |
+
def shared_step(self, batch, t=None):
|
| 180 |
+
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
| 181 |
+
targets = self.get_conditioning(batch)
|
| 182 |
+
if targets.dim() == 4:
|
| 183 |
+
targets = targets.argmax(dim=1)
|
| 184 |
+
if t is None:
|
| 185 |
+
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 186 |
+
else:
|
| 187 |
+
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
| 188 |
+
x_noisy = self.get_x_noisy(x, t)
|
| 189 |
+
logits = self(x_noisy, t)
|
| 190 |
+
|
| 191 |
+
loss = F.cross_entropy(logits, targets, reduction='none')
|
| 192 |
+
|
| 193 |
+
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
| 194 |
+
|
| 195 |
+
loss = loss.mean()
|
| 196 |
+
return loss, logits, x_noisy, targets
|
| 197 |
+
|
| 198 |
+
def training_step(self, batch, batch_idx):
|
| 199 |
+
loss, *_ = self.shared_step(batch)
|
| 200 |
+
return loss
|
| 201 |
+
|
| 202 |
+
def reset_noise_accs(self):
|
| 203 |
+
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
| 204 |
+
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
| 205 |
+
|
| 206 |
+
def on_validation_start(self):
|
| 207 |
+
self.reset_noise_accs()
|
| 208 |
+
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def validation_step(self, batch, batch_idx):
|
| 211 |
+
loss, *_ = self.shared_step(batch)
|
| 212 |
+
|
| 213 |
+
for t in self.noisy_acc:
|
| 214 |
+
_, logits, _, targets = self.shared_step(batch, t)
|
| 215 |
+
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
| 216 |
+
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
| 217 |
+
|
| 218 |
+
return loss
|
| 219 |
+
|
| 220 |
+
def configure_optimizers(self):
|
| 221 |
+
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 222 |
+
|
| 223 |
+
if self.use_scheduler:
|
| 224 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 225 |
+
|
| 226 |
+
print("Setting up LambdaLR scheduler...")
|
| 227 |
+
scheduler = [
|
| 228 |
+
{
|
| 229 |
+
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
| 230 |
+
'interval': 'step',
|
| 231 |
+
'frequency': 1
|
| 232 |
+
}]
|
| 233 |
+
return [optimizer], scheduler
|
| 234 |
+
|
| 235 |
+
return optimizer
|
| 236 |
+
|
| 237 |
+
@torch.no_grad()
|
| 238 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
| 239 |
+
log = dict()
|
| 240 |
+
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
| 241 |
+
log['inputs'] = x
|
| 242 |
+
|
| 243 |
+
y = self.get_conditioning(batch)
|
| 244 |
+
|
| 245 |
+
if self.label_key == 'class_label':
|
| 246 |
+
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
| 247 |
+
log['labels'] = y
|
| 248 |
+
|
| 249 |
+
if ismap(y):
|
| 250 |
+
log['labels'] = self.diffusion_model.to_rgb(y)
|
| 251 |
+
|
| 252 |
+
for step in range(self.log_steps):
|
| 253 |
+
current_time = step * self.log_time_interval
|
| 254 |
+
|
| 255 |
+
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
| 256 |
+
|
| 257 |
+
log[f'inputs@t{current_time}'] = x_noisy
|
| 258 |
+
|
| 259 |
+
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
| 260 |
+
pred = rearrange(pred, 'b h w c -> b c h w')
|
| 261 |
+
|
| 262 |
+
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
| 263 |
+
|
| 264 |
+
for key in log:
|
| 265 |
+
log[key] = log[key][:N]
|
| 266 |
+
|
| 267 |
+
return log
|
ldm/models/diffusion/ddim.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
| 9 |
+
extract_into_tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DDIMSampler(object):
|
| 13 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.model = model
|
| 16 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 17 |
+
self.schedule = schedule
|
| 18 |
+
|
| 19 |
+
def register_buffer(self, name, attr):
|
| 20 |
+
if type(attr) == torch.Tensor:
|
| 21 |
+
if attr.device != torch.device("cuda"):
|
| 22 |
+
attr = attr.to(torch.device("cuda"))
|
| 23 |
+
setattr(self, name, attr)
|
| 24 |
+
|
| 25 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
| 26 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
| 27 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
| 28 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 29 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 30 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 31 |
+
|
| 32 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
| 33 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 34 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
| 35 |
+
|
| 36 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 37 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
| 38 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
| 39 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
| 40 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
| 41 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
| 42 |
+
|
| 43 |
+
# ddim sampling parameters
|
| 44 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
| 45 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 46 |
+
eta=ddim_eta,verbose=verbose)
|
| 47 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
| 48 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
| 49 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
| 50 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
| 51 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 52 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
| 53 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
| 54 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def sample(self,
|
| 58 |
+
S,
|
| 59 |
+
batch_size,
|
| 60 |
+
shape,
|
| 61 |
+
conditioning=None,
|
| 62 |
+
callback=None,
|
| 63 |
+
normals_sequence=None,
|
| 64 |
+
img_callback=None,
|
| 65 |
+
quantize_x0=False,
|
| 66 |
+
eta=0.,
|
| 67 |
+
mask=None,
|
| 68 |
+
x0=None,
|
| 69 |
+
temperature=1.,
|
| 70 |
+
noise_dropout=0.,
|
| 71 |
+
score_corrector=None,
|
| 72 |
+
corrector_kwargs=None,
|
| 73 |
+
verbose=True,
|
| 74 |
+
x_T=None,
|
| 75 |
+
log_every_t=100,
|
| 76 |
+
unconditional_guidance_scale=1.,
|
| 77 |
+
unconditional_conditioning=None,
|
| 78 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 79 |
+
**kwargs
|
| 80 |
+
):
|
| 81 |
+
if conditioning is not None:
|
| 82 |
+
if isinstance(conditioning, dict):
|
| 83 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 84 |
+
if cbs != batch_size:
|
| 85 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 86 |
+
else:
|
| 87 |
+
if conditioning.shape[0] != batch_size:
|
| 88 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 89 |
+
|
| 90 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 91 |
+
# sampling
|
| 92 |
+
C, H, W = shape
|
| 93 |
+
size = (batch_size, C, H, W)
|
| 94 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
| 95 |
+
|
| 96 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
| 97 |
+
callback=callback,
|
| 98 |
+
img_callback=img_callback,
|
| 99 |
+
quantize_denoised=quantize_x0,
|
| 100 |
+
mask=mask, x0=x0,
|
| 101 |
+
ddim_use_original_steps=False,
|
| 102 |
+
noise_dropout=noise_dropout,
|
| 103 |
+
temperature=temperature,
|
| 104 |
+
score_corrector=score_corrector,
|
| 105 |
+
corrector_kwargs=corrector_kwargs,
|
| 106 |
+
x_T=x_T,
|
| 107 |
+
log_every_t=log_every_t,
|
| 108 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 109 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 110 |
+
)
|
| 111 |
+
return samples, intermediates
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def ddim_sampling(self, cond, shape,
|
| 115 |
+
x_T=None, ddim_use_original_steps=False,
|
| 116 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
| 117 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
| 118 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 119 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
| 120 |
+
device = self.model.betas.device
|
| 121 |
+
b = shape[0]
|
| 122 |
+
if x_T is None:
|
| 123 |
+
img = torch.randn(shape, device=device)
|
| 124 |
+
else:
|
| 125 |
+
img = x_T
|
| 126 |
+
|
| 127 |
+
if timesteps is None:
|
| 128 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 129 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 130 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 131 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 132 |
+
|
| 133 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 134 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 135 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 136 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 137 |
+
|
| 138 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 139 |
+
|
| 140 |
+
for i, step in enumerate(iterator):
|
| 141 |
+
index = total_steps - i - 1
|
| 142 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 143 |
+
|
| 144 |
+
if mask is not None:
|
| 145 |
+
assert x0 is not None
|
| 146 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
| 147 |
+
img = img_orig * mask + (1. - mask) * img
|
| 148 |
+
|
| 149 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 150 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 151 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 152 |
+
corrector_kwargs=corrector_kwargs,
|
| 153 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 154 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 155 |
+
img, pred_x0 = outs
|
| 156 |
+
if callback: callback(i)
|
| 157 |
+
if img_callback: img_callback(pred_x0, i)
|
| 158 |
+
|
| 159 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 160 |
+
intermediates['x_inter'].append(img)
|
| 161 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 162 |
+
|
| 163 |
+
return img, intermediates
|
| 164 |
+
|
| 165 |
+
@torch.no_grad()
|
| 166 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
| 167 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 168 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
| 169 |
+
b, *_, device = *x.shape, x.device
|
| 170 |
+
|
| 171 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 172 |
+
e_t = self.model.apply_model(x, t, c)
|
| 173 |
+
else:
|
| 174 |
+
x_in = torch.cat([x] * 2)
|
| 175 |
+
t_in = torch.cat([t] * 2)
|
| 176 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
| 177 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
| 178 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 179 |
+
|
| 180 |
+
if score_corrector is not None:
|
| 181 |
+
assert self.model.parameterization == "eps"
|
| 182 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
| 183 |
+
|
| 184 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 185 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 186 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 187 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 188 |
+
# select parameters corresponding to the currently considered timestep
|
| 189 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 190 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 191 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 192 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 193 |
+
|
| 194 |
+
# current prediction for x_0
|
| 195 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 196 |
+
if quantize_denoised:
|
| 197 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 198 |
+
# direction pointing to x_t
|
| 199 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 200 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 201 |
+
if noise_dropout > 0.:
|
| 202 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 203 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 204 |
+
return x_prev, pred_x0
|
| 205 |
+
|
| 206 |
+
@torch.no_grad()
|
| 207 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
| 208 |
+
# fast, but does not allow for exact reconstruction
|
| 209 |
+
# t serves as an index to gather the correct alphas
|
| 210 |
+
if use_original_steps:
|
| 211 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
| 212 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
| 213 |
+
else:
|
| 214 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
| 215 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
| 216 |
+
|
| 217 |
+
if noise is None:
|
| 218 |
+
noise = torch.randn_like(x0)
|
| 219 |
+
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
| 220 |
+
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
| 221 |
+
|
| 222 |
+
@torch.no_grad()
|
| 223 |
+
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
| 224 |
+
use_original_steps=False):
|
| 225 |
+
|
| 226 |
+
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
| 227 |
+
timesteps = timesteps[:t_start]
|
| 228 |
+
|
| 229 |
+
time_range = np.flip(timesteps)
|
| 230 |
+
total_steps = timesteps.shape[0]
|
| 231 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 232 |
+
|
| 233 |
+
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
| 234 |
+
x_dec = x_latent
|
| 235 |
+
for i, step in enumerate(iterator):
|
| 236 |
+
index = total_steps - i - 1
|
| 237 |
+
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
| 238 |
+
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
| 239 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 240 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 241 |
+
return x_dec
|
ldm/models/diffusion/ddpm.py
ADDED
|
@@ -0,0 +1,1445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
wild mixture of
|
| 3 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 4 |
+
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
| 5 |
+
https://github.com/CompVis/taming-transformers
|
| 6 |
+
-- merci
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
from contextlib import contextmanager
|
| 16 |
+
from functools import partial
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from torchvision.utils import make_grid
|
| 19 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
| 20 |
+
|
| 21 |
+
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
| 22 |
+
from ldm.modules.ema import LitEma
|
| 23 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
| 24 |
+
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
| 25 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
| 26 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 27 |
+
from torch import clamp
|
| 28 |
+
|
| 29 |
+
__conditioning_keys__ = {'concat': 'c_concat',
|
| 30 |
+
'crossattn': 'c_crossattn',
|
| 31 |
+
'adm': 'y'}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def disabled_train(self, mode=True):
|
| 35 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 36 |
+
does not change anymore."""
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def uniform_on_device(r1, r2, shape, device):
|
| 41 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DDPM(pl.LightningModule):
|
| 45 |
+
# classic DDPM with Gaussian diffusion, in image space
|
| 46 |
+
def __init__(self,
|
| 47 |
+
unet_config,
|
| 48 |
+
timesteps=1000,
|
| 49 |
+
beta_schedule="linear",
|
| 50 |
+
loss_type="l2",
|
| 51 |
+
ckpt_path=None,
|
| 52 |
+
ignore_keys=[],
|
| 53 |
+
load_only_unet=False,
|
| 54 |
+
monitor="val/loss",
|
| 55 |
+
use_ema=True,
|
| 56 |
+
first_stage_key="image",
|
| 57 |
+
image_size=256,
|
| 58 |
+
channels=3,
|
| 59 |
+
log_every_t=100,
|
| 60 |
+
clip_denoised=True,
|
| 61 |
+
linear_start=1e-4,
|
| 62 |
+
linear_end=2e-2,
|
| 63 |
+
cosine_s=8e-3,
|
| 64 |
+
given_betas=None,
|
| 65 |
+
original_elbo_weight=0.,
|
| 66 |
+
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
| 67 |
+
l_simple_weight=1.,
|
| 68 |
+
conditioning_key=None,
|
| 69 |
+
parameterization="eps", # all assuming fixed variance schedules
|
| 70 |
+
scheduler_config=None,
|
| 71 |
+
use_positional_encodings=False,
|
| 72 |
+
learn_logvar=False,
|
| 73 |
+
logvar_init=0.,
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
| 77 |
+
self.parameterization = parameterization
|
| 78 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
| 79 |
+
self.cond_stage_model = None
|
| 80 |
+
self.clip_denoised = clip_denoised
|
| 81 |
+
self.log_every_t = log_every_t
|
| 82 |
+
self.first_stage_key = first_stage_key
|
| 83 |
+
self.image_size = image_size # try conv?
|
| 84 |
+
self.channels = channels
|
| 85 |
+
self.use_positional_encodings = use_positional_encodings
|
| 86 |
+
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
| 87 |
+
count_params(self.model, verbose=True)
|
| 88 |
+
self.use_ema = use_ema
|
| 89 |
+
if self.use_ema:
|
| 90 |
+
self.model_ema = LitEma(self.model)
|
| 91 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 92 |
+
|
| 93 |
+
self.use_scheduler = scheduler_config is not None
|
| 94 |
+
if self.use_scheduler:
|
| 95 |
+
self.scheduler_config = scheduler_config
|
| 96 |
+
|
| 97 |
+
self.v_posterior = v_posterior
|
| 98 |
+
self.original_elbo_weight = original_elbo_weight
|
| 99 |
+
self.l_simple_weight = l_simple_weight
|
| 100 |
+
|
| 101 |
+
if monitor is not None:
|
| 102 |
+
self.monitor = monitor
|
| 103 |
+
if ckpt_path is not None:
|
| 104 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
| 105 |
+
|
| 106 |
+
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
| 107 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
| 108 |
+
|
| 109 |
+
self.loss_type = loss_type
|
| 110 |
+
|
| 111 |
+
self.learn_logvar = learn_logvar
|
| 112 |
+
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
| 113 |
+
if self.learn_logvar:
|
| 114 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
| 118 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 119 |
+
if exists(given_betas):
|
| 120 |
+
betas = given_betas
|
| 121 |
+
else:
|
| 122 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
| 123 |
+
cosine_s=cosine_s)
|
| 124 |
+
alphas = 1. - betas
|
| 125 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 126 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 127 |
+
|
| 128 |
+
timesteps, = betas.shape
|
| 129 |
+
self.num_timesteps = int(timesteps)
|
| 130 |
+
self.linear_start = linear_start
|
| 131 |
+
self.linear_end = linear_end
|
| 132 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
| 133 |
+
|
| 134 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 135 |
+
|
| 136 |
+
self.register_buffer('betas', to_torch(betas))
|
| 137 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 138 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 139 |
+
|
| 140 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 141 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 142 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 143 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
| 144 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
| 145 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
| 146 |
+
|
| 147 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 148 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
| 149 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
| 150 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 151 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
| 152 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 153 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
| 154 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
| 155 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
| 156 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
| 157 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
| 158 |
+
|
| 159 |
+
if self.parameterization == "eps":
|
| 160 |
+
lvlb_weights = self.betas ** 2 / (
|
| 161 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
| 162 |
+
elif self.parameterization == "x0":
|
| 163 |
+
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
| 164 |
+
else:
|
| 165 |
+
raise NotImplementedError("mu not supported")
|
| 166 |
+
# TODO how to choose this term
|
| 167 |
+
lvlb_weights[0] = lvlb_weights[1]
|
| 168 |
+
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
| 169 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
| 170 |
+
|
| 171 |
+
@contextmanager
|
| 172 |
+
def ema_scope(self, context=None):
|
| 173 |
+
if self.use_ema:
|
| 174 |
+
self.model_ema.store(self.model.parameters())
|
| 175 |
+
self.model_ema.copy_to(self.model)
|
| 176 |
+
if context is not None:
|
| 177 |
+
print(f"{context}: Switched to EMA weights")
|
| 178 |
+
try:
|
| 179 |
+
yield None
|
| 180 |
+
finally:
|
| 181 |
+
if self.use_ema:
|
| 182 |
+
self.model_ema.restore(self.model.parameters())
|
| 183 |
+
if context is not None:
|
| 184 |
+
print(f"{context}: Restored training weights")
|
| 185 |
+
|
| 186 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
| 187 |
+
sd = torch.load(path, map_location="cpu")
|
| 188 |
+
if "state_dict" in list(sd.keys()):
|
| 189 |
+
sd = sd["state_dict"]
|
| 190 |
+
keys = list(sd.keys())
|
| 191 |
+
for k in keys:
|
| 192 |
+
for ik in ignore_keys:
|
| 193 |
+
if k.startswith(ik):
|
| 194 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 195 |
+
del sd[k]
|
| 196 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
| 197 |
+
sd, strict=False)
|
| 198 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 199 |
+
if len(missing) > 0:
|
| 200 |
+
print(f"Missing Keys: {missing}")
|
| 201 |
+
if len(unexpected) > 0:
|
| 202 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 203 |
+
|
| 204 |
+
def q_mean_variance(self, x_start, t):
|
| 205 |
+
"""
|
| 206 |
+
Get the distribution q(x_t | x_0).
|
| 207 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
| 208 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 209 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
| 210 |
+
"""
|
| 211 |
+
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
|
| 212 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 213 |
+
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 214 |
+
return mean, variance, log_variance
|
| 215 |
+
|
| 216 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 217 |
+
return (
|
| 218 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 219 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def q_posterior(self, x_start, x_t, t):
|
| 223 |
+
posterior_mean = (
|
| 224 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 225 |
+
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 226 |
+
)
|
| 227 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 228 |
+
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 229 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 230 |
+
|
| 231 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
| 232 |
+
model_out = self.model(x, t)
|
| 233 |
+
if self.parameterization == "eps":
|
| 234 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
| 235 |
+
elif self.parameterization == "x0":
|
| 236 |
+
x_recon = model_out
|
| 237 |
+
if clip_denoised:
|
| 238 |
+
x_recon.clamp_(-1., 1.)
|
| 239 |
+
|
| 240 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
| 241 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
| 245 |
+
b, *_, device = *x.shape, x.device
|
| 246 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
| 247 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
| 248 |
+
# no noise when t == 0
|
| 249 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 250 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
| 254 |
+
device = self.betas.device
|
| 255 |
+
b = shape[0]
|
| 256 |
+
img = torch.randn(shape, device=device)
|
| 257 |
+
intermediates = [img]
|
| 258 |
+
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
| 259 |
+
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
| 260 |
+
clip_denoised=self.clip_denoised)
|
| 261 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
| 262 |
+
intermediates.append(img)
|
| 263 |
+
if return_intermediates:
|
| 264 |
+
return img, intermediates
|
| 265 |
+
return img
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
| 269 |
+
image_size = self.image_size
|
| 270 |
+
channels = self.channels
|
| 271 |
+
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
| 272 |
+
return_intermediates=return_intermediates)
|
| 273 |
+
|
| 274 |
+
def q_sample(self, x_start, t, noise=None):
|
| 275 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 276 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 277 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
| 278 |
+
|
| 279 |
+
def get_loss(self, pred, target, mean=True):
|
| 280 |
+
if self.loss_type == 'l1':
|
| 281 |
+
loss = (target - pred).abs()
|
| 282 |
+
if mean:
|
| 283 |
+
loss = loss.mean()
|
| 284 |
+
elif self.loss_type == 'l2':
|
| 285 |
+
if mean:
|
| 286 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
| 287 |
+
else:
|
| 288 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
| 289 |
+
else:
|
| 290 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
| 291 |
+
|
| 292 |
+
return loss
|
| 293 |
+
|
| 294 |
+
def p_losses(self, x_start, t, noise=None):
|
| 295 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 296 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 297 |
+
model_out = self.model(x_noisy, t)
|
| 298 |
+
|
| 299 |
+
loss_dict = {}
|
| 300 |
+
if self.parameterization == "eps":
|
| 301 |
+
target = noise
|
| 302 |
+
elif self.parameterization == "x0":
|
| 303 |
+
target = x_start
|
| 304 |
+
else:
|
| 305 |
+
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
| 306 |
+
|
| 307 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
| 308 |
+
|
| 309 |
+
log_prefix = 'train' if self.training else 'val'
|
| 310 |
+
|
| 311 |
+
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
|
| 312 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
| 313 |
+
|
| 314 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
| 315 |
+
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
|
| 316 |
+
|
| 317 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
| 318 |
+
|
| 319 |
+
loss_dict.update({f'{log_prefix}/loss': loss})
|
| 320 |
+
|
| 321 |
+
return loss, loss_dict
|
| 322 |
+
|
| 323 |
+
def forward(self, x, *args, **kwargs):
|
| 324 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
| 325 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
| 326 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 327 |
+
return self.p_losses(x, t, *args, **kwargs)
|
| 328 |
+
|
| 329 |
+
def get_input(self, batch, k):
|
| 330 |
+
x = batch[k]
|
| 331 |
+
if len(x.shape) == 3:
|
| 332 |
+
x = x[..., None]
|
| 333 |
+
x = rearrange(x, 'b h w c -> b c h w')
|
| 334 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
def shared_step(self, batch):
|
| 338 |
+
x = self.get_input(batch, self.first_stage_key)
|
| 339 |
+
loss, loss_dict = self(x)
|
| 340 |
+
return loss, loss_dict
|
| 341 |
+
|
| 342 |
+
def training_step(self, batch, batch_idx):
|
| 343 |
+
loss, loss_dict = self.shared_step(batch)
|
| 344 |
+
|
| 345 |
+
self.log_dict(loss_dict, prog_bar=True,
|
| 346 |
+
logger=True, on_step=True, on_epoch=True)
|
| 347 |
+
|
| 348 |
+
self.log("global_step", self.global_step,
|
| 349 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
| 350 |
+
|
| 351 |
+
if self.use_scheduler:
|
| 352 |
+
lr = self.optimizers().param_groups[0]['lr']
|
| 353 |
+
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
| 354 |
+
|
| 355 |
+
return loss
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def validation_step(self, batch, batch_idx):
|
| 359 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
| 360 |
+
with self.ema_scope():
|
| 361 |
+
_, loss_dict_ema = self.shared_step(batch)
|
| 362 |
+
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
| 363 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
| 364 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
| 365 |
+
|
| 366 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 367 |
+
if self.use_ema:
|
| 368 |
+
self.model_ema(self.model)
|
| 369 |
+
|
| 370 |
+
def _get_rows_from_list(self, samples):
|
| 371 |
+
n_imgs_per_row = len(samples)
|
| 372 |
+
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
|
| 373 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
| 374 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
| 375 |
+
return denoise_grid
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
| 379 |
+
log = dict()
|
| 380 |
+
x = self.get_input(batch, self.first_stage_key)
|
| 381 |
+
N = min(x.shape[0], N)
|
| 382 |
+
n_row = min(x.shape[0], n_row)
|
| 383 |
+
x = x.to(self.device)[:N]
|
| 384 |
+
log["inputs"] = x
|
| 385 |
+
|
| 386 |
+
# get diffusion row
|
| 387 |
+
diffusion_row = list()
|
| 388 |
+
x_start = x[:n_row]
|
| 389 |
+
|
| 390 |
+
for t in range(self.num_timesteps):
|
| 391 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 392 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 393 |
+
t = t.to(self.device).long()
|
| 394 |
+
noise = torch.randn_like(x_start)
|
| 395 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 396 |
+
diffusion_row.append(x_noisy)
|
| 397 |
+
|
| 398 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
| 399 |
+
|
| 400 |
+
if sample:
|
| 401 |
+
# get denoise row
|
| 402 |
+
with self.ema_scope("Plotting"):
|
| 403 |
+
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
|
| 404 |
+
|
| 405 |
+
log["samples"] = samples
|
| 406 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
| 407 |
+
|
| 408 |
+
if return_keys:
|
| 409 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
| 410 |
+
return log
|
| 411 |
+
else:
|
| 412 |
+
return {key: log[key] for key in return_keys}
|
| 413 |
+
return log
|
| 414 |
+
|
| 415 |
+
def configure_optimizers(self):
|
| 416 |
+
lr = self.learning_rate
|
| 417 |
+
params = list(self.model.parameters())
|
| 418 |
+
if self.learn_logvar:
|
| 419 |
+
params = params + [self.logvar]
|
| 420 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 421 |
+
return opt
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class LatentDiffusion(DDPM):
|
| 425 |
+
"""main class"""
|
| 426 |
+
def __init__(self,
|
| 427 |
+
first_stage_config,
|
| 428 |
+
cond_stage_config,
|
| 429 |
+
num_timesteps_cond=None,
|
| 430 |
+
cond_stage_key="image",
|
| 431 |
+
cond_stage_trainable=False,
|
| 432 |
+
concat_mode=True,
|
| 433 |
+
cond_stage_forward=None,
|
| 434 |
+
conditioning_key=None,
|
| 435 |
+
scale_factor=1.0,
|
| 436 |
+
scale_by_std=False,
|
| 437 |
+
*args, **kwargs):
|
| 438 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
| 439 |
+
self.scale_by_std = scale_by_std
|
| 440 |
+
assert self.num_timesteps_cond <= kwargs['timesteps']
|
| 441 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
| 442 |
+
if conditioning_key is None:
|
| 443 |
+
conditioning_key = 'concat' if concat_mode else 'crossattn'
|
| 444 |
+
if cond_stage_config == '__is_unconditional__':
|
| 445 |
+
conditioning_key = None
|
| 446 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
| 447 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
| 448 |
+
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
| 449 |
+
self.concat_mode = concat_mode
|
| 450 |
+
self.cond_stage_trainable = cond_stage_trainable
|
| 451 |
+
self.cond_stage_key = cond_stage_key
|
| 452 |
+
try:
|
| 453 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
| 454 |
+
except:
|
| 455 |
+
self.num_downs = 0
|
| 456 |
+
if not scale_by_std:
|
| 457 |
+
self.scale_factor = scale_factor
|
| 458 |
+
else:
|
| 459 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
| 460 |
+
self.instantiate_first_stage(first_stage_config)
|
| 461 |
+
self.instantiate_cond_stage(cond_stage_config)
|
| 462 |
+
self.cond_stage_forward = cond_stage_forward
|
| 463 |
+
self.clip_denoised = False
|
| 464 |
+
self.bbox_tokenizer = None
|
| 465 |
+
|
| 466 |
+
self.restarted_from_ckpt = False
|
| 467 |
+
if ckpt_path is not None:
|
| 468 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
| 469 |
+
self.restarted_from_ckpt = True
|
| 470 |
+
|
| 471 |
+
def make_cond_schedule(self, ):
|
| 472 |
+
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
| 473 |
+
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
| 474 |
+
self.cond_ids[:self.num_timesteps_cond] = ids
|
| 475 |
+
|
| 476 |
+
@rank_zero_only
|
| 477 |
+
|
| 478 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 479 |
+
# only for very first batch
|
| 480 |
+
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
| 481 |
+
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
| 482 |
+
# set rescale weight to 1./std of encodings
|
| 483 |
+
print("### USING STD-RESCALING ###")
|
| 484 |
+
x = super().get_input(batch, self.first_stage_key)
|
| 485 |
+
x = x.to(self.device)
|
| 486 |
+
encoder_posterior = self.encode_first_stage(x)
|
| 487 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 488 |
+
del self.scale_factor
|
| 489 |
+
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
| 490 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
| 491 |
+
print("### USING STD-RESCALING ###")
|
| 492 |
+
|
| 493 |
+
def register_schedule(self,
|
| 494 |
+
given_betas=None, beta_schedule="linear", timesteps=1000,
|
| 495 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 496 |
+
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
| 497 |
+
|
| 498 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
| 499 |
+
if self.shorten_cond_schedule:
|
| 500 |
+
self.make_cond_schedule()
|
| 501 |
+
|
| 502 |
+
def instantiate_first_stage(self, config):
|
| 503 |
+
model = instantiate_from_config(config)
|
| 504 |
+
self.first_stage_model = model.eval()
|
| 505 |
+
self.first_stage_model.train = disabled_train
|
| 506 |
+
for param in self.first_stage_model.parameters():
|
| 507 |
+
param.requires_grad = False
|
| 508 |
+
|
| 509 |
+
def instantiate_cond_stage(self, config):
|
| 510 |
+
if not self.cond_stage_trainable:
|
| 511 |
+
if config == "__is_first_stage__":
|
| 512 |
+
print("Using first stage also as cond stage.")
|
| 513 |
+
self.cond_stage_model = self.first_stage_model
|
| 514 |
+
elif config == "__is_unconditional__":
|
| 515 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
| 516 |
+
self.cond_stage_model = None
|
| 517 |
+
# self.be_unconditional = True
|
| 518 |
+
else:
|
| 519 |
+
model = instantiate_from_config(config)
|
| 520 |
+
self.cond_stage_model = model.eval()
|
| 521 |
+
self.cond_stage_model.train = disabled_train
|
| 522 |
+
for param in self.cond_stage_model.parameters():
|
| 523 |
+
param.requires_grad = False
|
| 524 |
+
else:
|
| 525 |
+
assert config != '__is_first_stage__'
|
| 526 |
+
assert config != '__is_unconditional__'
|
| 527 |
+
model = instantiate_from_config(config)
|
| 528 |
+
self.cond_stage_model = model
|
| 529 |
+
|
| 530 |
+
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
| 531 |
+
denoise_row = []
|
| 532 |
+
for zd in tqdm(samples, desc=desc):
|
| 533 |
+
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
| 534 |
+
force_not_quantize=force_no_decoder_quantization))
|
| 535 |
+
n_imgs_per_row = len(denoise_row)
|
| 536 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
| 537 |
+
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
| 538 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
| 539 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
| 540 |
+
return denoise_grid
|
| 541 |
+
|
| 542 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
| 543 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
| 544 |
+
z = encoder_posterior.sample()
|
| 545 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
| 546 |
+
z = encoder_posterior
|
| 547 |
+
else:
|
| 548 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
| 549 |
+
return self.scale_factor * z
|
| 550 |
+
|
| 551 |
+
def get_learned_conditioning(self, c):
|
| 552 |
+
if self.cond_stage_forward is None:
|
| 553 |
+
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
| 554 |
+
c = self.cond_stage_model.encode(c)
|
| 555 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
| 556 |
+
c = c.mode()
|
| 557 |
+
else:
|
| 558 |
+
c = self.cond_stage_model(c)
|
| 559 |
+
else:
|
| 560 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
| 561 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
| 562 |
+
return c
|
| 563 |
+
|
| 564 |
+
def meshgrid(self, h, w):
|
| 565 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
| 566 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
| 567 |
+
|
| 568 |
+
arr = torch.cat([y, x], dim=-1)
|
| 569 |
+
return arr
|
| 570 |
+
|
| 571 |
+
def delta_border(self, h, w):
|
| 572 |
+
"""
|
| 573 |
+
:param h: height
|
| 574 |
+
:param w: width
|
| 575 |
+
:return: normalized distance to image border,
|
| 576 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
| 577 |
+
"""
|
| 578 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
| 579 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
| 580 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
| 581 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
| 582 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
| 583 |
+
return edge_dist
|
| 584 |
+
|
| 585 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
| 586 |
+
weighting = self.delta_border(h, w)
|
| 587 |
+
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
|
| 588 |
+
self.split_input_params["clip_max_weight"], )
|
| 589 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
| 590 |
+
|
| 591 |
+
if self.split_input_params["tie_braker"]:
|
| 592 |
+
L_weighting = self.delta_border(Ly, Lx)
|
| 593 |
+
L_weighting = torch.clip(L_weighting,
|
| 594 |
+
self.split_input_params["clip_min_tie_weight"],
|
| 595 |
+
self.split_input_params["clip_max_tie_weight"])
|
| 596 |
+
|
| 597 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
| 598 |
+
weighting = weighting * L_weighting
|
| 599 |
+
return weighting
|
| 600 |
+
|
| 601 |
+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
| 602 |
+
"""
|
| 603 |
+
:param x: img of size (bs, c, h, w)
|
| 604 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
| 605 |
+
"""
|
| 606 |
+
bs, nc, h, w = x.shape
|
| 607 |
+
|
| 608 |
+
# number of crops in image
|
| 609 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
| 610 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
| 611 |
+
|
| 612 |
+
if uf == 1 and df == 1:
|
| 613 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 614 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 615 |
+
|
| 616 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
| 617 |
+
|
| 618 |
+
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
| 619 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
| 620 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
| 621 |
+
|
| 622 |
+
elif uf > 1 and df == 1:
|
| 623 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 624 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 625 |
+
|
| 626 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
| 627 |
+
dilation=1, padding=0,
|
| 628 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
| 629 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
| 630 |
+
|
| 631 |
+
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
| 632 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
| 633 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
| 634 |
+
|
| 635 |
+
elif df > 1 and uf == 1:
|
| 636 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
| 637 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 638 |
+
|
| 639 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
| 640 |
+
dilation=1, padding=0,
|
| 641 |
+
stride=(stride[0] // df, stride[1] // df))
|
| 642 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
| 643 |
+
|
| 644 |
+
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
| 645 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
| 646 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
| 647 |
+
|
| 648 |
+
else:
|
| 649 |
+
raise NotImplementedError
|
| 650 |
+
|
| 651 |
+
return fold, unfold, normalization, weighting
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
| 655 |
+
cond_key=None, return_original_cond=False, bs=None):
|
| 656 |
+
x = super().get_input(batch, k)
|
| 657 |
+
if bs is not None:
|
| 658 |
+
x = x[:bs]
|
| 659 |
+
x = x.to(self.device)
|
| 660 |
+
encoder_posterior = self.encode_first_stage(x)
|
| 661 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 662 |
+
|
| 663 |
+
if self.model.conditioning_key is not None:
|
| 664 |
+
if cond_key is None:
|
| 665 |
+
cond_key = self.cond_stage_key
|
| 666 |
+
if cond_key != self.first_stage_key:
|
| 667 |
+
if cond_key in ['caption', 'coordinates_bbox']:
|
| 668 |
+
xc = batch[cond_key]
|
| 669 |
+
elif cond_key == 'class_label':
|
| 670 |
+
xc = batch
|
| 671 |
+
else:
|
| 672 |
+
xc = super().get_input(batch, cond_key).to(self.device)
|
| 673 |
+
else:
|
| 674 |
+
xc = x
|
| 675 |
+
if not self.cond_stage_trainable or force_c_encode:
|
| 676 |
+
if isinstance(xc, dict) or isinstance(xc, list):
|
| 677 |
+
# import pudb; pudb.set_trace()
|
| 678 |
+
c = self.get_learned_conditioning(xc)
|
| 679 |
+
else:
|
| 680 |
+
c = self.get_learned_conditioning(xc.to(self.device))
|
| 681 |
+
else:
|
| 682 |
+
c = xc
|
| 683 |
+
if bs is not None:
|
| 684 |
+
c = c[:bs]
|
| 685 |
+
|
| 686 |
+
if self.use_positional_encodings:
|
| 687 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
| 688 |
+
ckey = __conditioning_keys__[self.model.conditioning_key]
|
| 689 |
+
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
|
| 690 |
+
|
| 691 |
+
else:
|
| 692 |
+
c = None
|
| 693 |
+
xc = None
|
| 694 |
+
if self.use_positional_encodings:
|
| 695 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
| 696 |
+
c = {'pos_x': pos_x, 'pos_y': pos_y}
|
| 697 |
+
out = [z, c]
|
| 698 |
+
if return_first_stage_outputs:
|
| 699 |
+
xrec = self.decode_first_stage(z)
|
| 700 |
+
out.extend([x, xrec])
|
| 701 |
+
if return_original_cond:
|
| 702 |
+
out.append(xc)
|
| 703 |
+
return out
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 707 |
+
if predict_cids:
|
| 708 |
+
if z.dim() == 4:
|
| 709 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 710 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 711 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
| 712 |
+
|
| 713 |
+
z = 1. / self.scale_factor * z
|
| 714 |
+
|
| 715 |
+
if hasattr(self, "split_input_params"):
|
| 716 |
+
if self.split_input_params["patch_distributed_vq"]:
|
| 717 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 718 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 719 |
+
uf = self.split_input_params["vqf"]
|
| 720 |
+
bs, nc, h, w = z.shape
|
| 721 |
+
if ks[0] > h or ks[1] > w:
|
| 722 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
| 723 |
+
print("reducing Kernel")
|
| 724 |
+
|
| 725 |
+
if stride[0] > h or stride[1] > w:
|
| 726 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
| 727 |
+
print("reducing stride")
|
| 728 |
+
|
| 729 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
| 730 |
+
|
| 731 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
| 732 |
+
# 1. Reshape to img shape
|
| 733 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 734 |
+
|
| 735 |
+
# 2. apply model loop over last dim
|
| 736 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 737 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
| 738 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
| 739 |
+
for i in range(z.shape[-1])]
|
| 740 |
+
else:
|
| 741 |
+
|
| 742 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
| 743 |
+
for i in range(z.shape[-1])]
|
| 744 |
+
|
| 745 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
| 746 |
+
o = o * weighting
|
| 747 |
+
# Reverse 1. reshape to img shape
|
| 748 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 749 |
+
# stitch crops together
|
| 750 |
+
decoded = fold(o)
|
| 751 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
| 752 |
+
return decoded
|
| 753 |
+
else:
|
| 754 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 755 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 756 |
+
else:
|
| 757 |
+
return self.first_stage_model.decode(z)
|
| 758 |
+
|
| 759 |
+
else:
|
| 760 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 761 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 762 |
+
else:
|
| 763 |
+
return self.first_stage_model.decode(z)
|
| 764 |
+
|
| 765 |
+
# same as above but without decorator
|
| 766 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 767 |
+
if predict_cids:
|
| 768 |
+
if z.dim() == 4:
|
| 769 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 770 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 771 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
| 772 |
+
|
| 773 |
+
z = 1. / self.scale_factor * z
|
| 774 |
+
|
| 775 |
+
if hasattr(self, "split_input_params"):
|
| 776 |
+
if self.split_input_params["patch_distributed_vq"]:
|
| 777 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 778 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 779 |
+
uf = self.split_input_params["vqf"]
|
| 780 |
+
bs, nc, h, w = z.shape
|
| 781 |
+
if ks[0] > h or ks[1] > w:
|
| 782 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
| 783 |
+
print("reducing Kernel")
|
| 784 |
+
|
| 785 |
+
if stride[0] > h or stride[1] > w:
|
| 786 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
| 787 |
+
print("reducing stride")
|
| 788 |
+
|
| 789 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
| 790 |
+
|
| 791 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
| 792 |
+
# 1. Reshape to img shape
|
| 793 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 794 |
+
|
| 795 |
+
# 2. apply model loop over last dim
|
| 796 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 797 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
| 798 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
| 799 |
+
for i in range(z.shape[-1])]
|
| 800 |
+
else:
|
| 801 |
+
|
| 802 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
| 803 |
+
for i in range(z.shape[-1])]
|
| 804 |
+
|
| 805 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
| 806 |
+
o = o * weighting
|
| 807 |
+
# Reverse 1. reshape to img shape
|
| 808 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 809 |
+
# stitch crops together
|
| 810 |
+
decoded = fold(o)
|
| 811 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
| 812 |
+
return decoded
|
| 813 |
+
else:
|
| 814 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 815 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 816 |
+
else:
|
| 817 |
+
return self.first_stage_model.decode(z)
|
| 818 |
+
|
| 819 |
+
else:
|
| 820 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
| 821 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
| 822 |
+
else:
|
| 823 |
+
return self.first_stage_model.decode(z)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def encode_first_stage(self, x):
|
| 827 |
+
if hasattr(self, "split_input_params"):
|
| 828 |
+
if self.split_input_params["patch_distributed_vq"]:
|
| 829 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 830 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 831 |
+
df = self.split_input_params["vqf"]
|
| 832 |
+
self.split_input_params['original_image_size'] = x.shape[-2:]
|
| 833 |
+
bs, nc, h, w = x.shape
|
| 834 |
+
if ks[0] > h or ks[1] > w:
|
| 835 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
| 836 |
+
print("reducing Kernel")
|
| 837 |
+
|
| 838 |
+
if stride[0] > h or stride[1] > w:
|
| 839 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
| 840 |
+
print("reducing stride")
|
| 841 |
+
|
| 842 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
| 843 |
+
z = unfold(x) # (bn, nc * prod(**ks), L)
|
| 844 |
+
# Reshape to img shape
|
| 845 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 846 |
+
|
| 847 |
+
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
| 848 |
+
for i in range(z.shape[-1])]
|
| 849 |
+
|
| 850 |
+
o = torch.stack(output_list, axis=-1)
|
| 851 |
+
o = o * weighting
|
| 852 |
+
|
| 853 |
+
# Reverse reshape to img shape
|
| 854 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 855 |
+
# stitch crops together
|
| 856 |
+
decoded = fold(o)
|
| 857 |
+
decoded = decoded / normalization
|
| 858 |
+
return decoded
|
| 859 |
+
|
| 860 |
+
else:
|
| 861 |
+
return self.first_stage_model.encode(x)
|
| 862 |
+
else:
|
| 863 |
+
return self.first_stage_model.encode(x)
|
| 864 |
+
|
| 865 |
+
def shared_step(self, batch, **kwargs):
|
| 866 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
| 867 |
+
loss = self(x, c)
|
| 868 |
+
return loss
|
| 869 |
+
|
| 870 |
+
def forward(self, x, c, *args, **kwargs):
|
| 871 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 872 |
+
if self.model.conditioning_key is not None:
|
| 873 |
+
assert c is not None
|
| 874 |
+
if self.cond_stage_trainable:
|
| 875 |
+
c = self.get_learned_conditioning(c)
|
| 876 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
| 877 |
+
tc = self.cond_ids[t].to(self.device)
|
| 878 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
| 879 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
| 880 |
+
|
| 881 |
+
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
| 882 |
+
def rescale_bbox(bbox):
|
| 883 |
+
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
| 884 |
+
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
| 885 |
+
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
| 886 |
+
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
| 887 |
+
return x0, y0, w, h
|
| 888 |
+
|
| 889 |
+
return [rescale_bbox(b) for b in bboxes]
|
| 890 |
+
|
| 891 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
| 892 |
+
|
| 893 |
+
if isinstance(cond, dict):
|
| 894 |
+
# hybrid case, cond is exptected to be a dict
|
| 895 |
+
pass
|
| 896 |
+
else:
|
| 897 |
+
if not isinstance(cond, list):
|
| 898 |
+
cond = [cond]
|
| 899 |
+
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
| 900 |
+
cond = {key: cond}
|
| 901 |
+
|
| 902 |
+
if hasattr(self, "split_input_params"):
|
| 903 |
+
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
| 904 |
+
assert not return_ids
|
| 905 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
| 906 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
| 907 |
+
|
| 908 |
+
h, w = x_noisy.shape[-2:]
|
| 909 |
+
|
| 910 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
| 911 |
+
|
| 912 |
+
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
| 913 |
+
# Reshape to img shape
|
| 914 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 915 |
+
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
| 916 |
+
|
| 917 |
+
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
| 918 |
+
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
| 919 |
+
c_key = next(iter(cond.keys())) # get key
|
| 920 |
+
c = next(iter(cond.values())) # get value
|
| 921 |
+
assert (len(c) == 1) # todo extend to list with more than one elem
|
| 922 |
+
c = c[0] # get element
|
| 923 |
+
|
| 924 |
+
c = unfold(c)
|
| 925 |
+
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
| 926 |
+
|
| 927 |
+
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
| 928 |
+
|
| 929 |
+
elif self.cond_stage_key == 'coordinates_bbox':
|
| 930 |
+
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
| 931 |
+
|
| 932 |
+
# assuming padding of unfold is always 0 and its dilation is always 1
|
| 933 |
+
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
| 934 |
+
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
| 935 |
+
# as we are operating on latents, we need the factor from the original image size to the
|
| 936 |
+
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
| 937 |
+
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
| 938 |
+
rescale_latent = 2 ** (num_downs)
|
| 939 |
+
|
| 940 |
+
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
| 941 |
+
# need to rescale the tl patch coordinates to be in between (0,1)
|
| 942 |
+
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
| 943 |
+
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
| 944 |
+
for patch_nr in range(z.shape[-1])]
|
| 945 |
+
|
| 946 |
+
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
| 947 |
+
patch_limits = [(x_tl, y_tl,
|
| 948 |
+
rescale_latent * ks[0] / full_img_w,
|
| 949 |
+
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
| 950 |
+
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
| 951 |
+
|
| 952 |
+
# tokenize crop coordinates for the bounding boxes of the respective patches
|
| 953 |
+
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
| 954 |
+
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
| 955 |
+
print(patch_limits_tknzd[0].shape)
|
| 956 |
+
# cut tknzd crop position from conditioning
|
| 957 |
+
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
| 958 |
+
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
| 959 |
+
print(cut_cond.shape)
|
| 960 |
+
|
| 961 |
+
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
| 962 |
+
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
| 963 |
+
print(adapted_cond.shape)
|
| 964 |
+
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
| 965 |
+
print(adapted_cond.shape)
|
| 966 |
+
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
| 967 |
+
print(adapted_cond.shape)
|
| 968 |
+
|
| 969 |
+
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
| 970 |
+
|
| 971 |
+
else:
|
| 972 |
+
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
| 973 |
+
|
| 974 |
+
# apply model by loop over crops
|
| 975 |
+
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
| 976 |
+
assert not isinstance(output_list[0],
|
| 977 |
+
tuple) # todo cant deal with multiple model outputs check this never happens
|
| 978 |
+
|
| 979 |
+
o = torch.stack(output_list, axis=-1)
|
| 980 |
+
o = o * weighting
|
| 981 |
+
# Reverse reshape to img shape
|
| 982 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
| 983 |
+
# stitch crops together
|
| 984 |
+
x_recon = fold(o) / normalization
|
| 985 |
+
|
| 986 |
+
else:
|
| 987 |
+
x_recon = self.model(x_noisy, t, **cond)
|
| 988 |
+
|
| 989 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
| 990 |
+
return x_recon[0]
|
| 991 |
+
else:
|
| 992 |
+
return x_recon
|
| 993 |
+
|
| 994 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 995 |
+
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
| 996 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 997 |
+
|
| 998 |
+
def _prior_bpd(self, x_start):
|
| 999 |
+
"""
|
| 1000 |
+
Get the prior KL term for the variational lower-bound, measured in
|
| 1001 |
+
bits-per-dim.
|
| 1002 |
+
This term can't be optimized, as it only depends on the encoder.
|
| 1003 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 1004 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
| 1005 |
+
"""
|
| 1006 |
+
batch_size = x_start.shape[0]
|
| 1007 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
| 1008 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 1009 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
| 1010 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
| 1011 |
+
|
| 1012 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
| 1013 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 1014 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 1015 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
| 1016 |
+
|
| 1017 |
+
loss_dict = {}
|
| 1018 |
+
prefix = 'train' if self.training else 'val'
|
| 1019 |
+
|
| 1020 |
+
if self.parameterization == "x0":
|
| 1021 |
+
target = x_start
|
| 1022 |
+
elif self.parameterization == "eps":
|
| 1023 |
+
target = noise
|
| 1024 |
+
else:
|
| 1025 |
+
raise NotImplementedError()
|
| 1026 |
+
|
| 1027 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
| 1028 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
| 1029 |
+
|
| 1030 |
+
logvar_t = self.logvar[t].to(self.device)
|
| 1031 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
| 1032 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
| 1033 |
+
if self.learn_logvar:
|
| 1034 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
| 1035 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
| 1036 |
+
|
| 1037 |
+
loss = self.l_simple_weight * loss.mean()
|
| 1038 |
+
|
| 1039 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
| 1040 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
| 1041 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
| 1042 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
| 1043 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
| 1044 |
+
|
| 1045 |
+
return loss, loss_dict
|
| 1046 |
+
|
| 1047 |
+
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
| 1048 |
+
return_x0=False, score_corrector=None, corrector_kwargs=None):
|
| 1049 |
+
t_in = t
|
| 1050 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
| 1051 |
+
|
| 1052 |
+
if score_corrector is not None:
|
| 1053 |
+
assert self.parameterization == "eps"
|
| 1054 |
+
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
|
| 1055 |
+
|
| 1056 |
+
if return_codebook_ids:
|
| 1057 |
+
model_out, logits = model_out
|
| 1058 |
+
|
| 1059 |
+
if self.parameterization == "eps":
|
| 1060 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
| 1061 |
+
elif self.parameterization == "x0":
|
| 1062 |
+
x_recon = model_out
|
| 1063 |
+
else:
|
| 1064 |
+
raise NotImplementedError()
|
| 1065 |
+
|
| 1066 |
+
if clip_denoised:
|
| 1067 |
+
x_recon.clamp_(-1., 1.)
|
| 1068 |
+
if quantize_denoised:
|
| 1069 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
| 1070 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
| 1071 |
+
if return_codebook_ids:
|
| 1072 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
| 1073 |
+
elif return_x0:
|
| 1074 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
| 1075 |
+
else:
|
| 1076 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
|
| 1080 |
+
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
|
| 1081 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
| 1082 |
+
b, *_, device = *x.shape, x.device
|
| 1083 |
+
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
|
| 1084 |
+
return_codebook_ids=return_codebook_ids,
|
| 1085 |
+
quantize_denoised=quantize_denoised,
|
| 1086 |
+
return_x0=return_x0,
|
| 1087 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
| 1088 |
+
if return_codebook_ids:
|
| 1089 |
+
raise DeprecationWarning("Support dropped.")
|
| 1090 |
+
model_mean, _, model_log_variance, logits = outputs
|
| 1091 |
+
elif return_x0:
|
| 1092 |
+
model_mean, _, model_log_variance, x0 = outputs
|
| 1093 |
+
else:
|
| 1094 |
+
model_mean, _, model_log_variance = outputs
|
| 1095 |
+
|
| 1096 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
| 1097 |
+
if noise_dropout > 0.:
|
| 1098 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 1099 |
+
# no noise when t == 0
|
| 1100 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 1101 |
+
|
| 1102 |
+
if return_codebook_ids:
|
| 1103 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
|
| 1104 |
+
if return_x0:
|
| 1105 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
| 1106 |
+
else:
|
| 1107 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
|
| 1111 |
+
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
|
| 1112 |
+
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
|
| 1113 |
+
log_every_t=None):
|
| 1114 |
+
if not log_every_t:
|
| 1115 |
+
log_every_t = self.log_every_t
|
| 1116 |
+
timesteps = self.num_timesteps
|
| 1117 |
+
if batch_size is not None:
|
| 1118 |
+
b = batch_size if batch_size is not None else shape[0]
|
| 1119 |
+
shape = [batch_size] + list(shape)
|
| 1120 |
+
else:
|
| 1121 |
+
b = batch_size = shape[0]
|
| 1122 |
+
if x_T is None:
|
| 1123 |
+
img = torch.randn(shape, device=self.device)
|
| 1124 |
+
else:
|
| 1125 |
+
img = x_T
|
| 1126 |
+
intermediates = []
|
| 1127 |
+
if cond is not None:
|
| 1128 |
+
if isinstance(cond, dict):
|
| 1129 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
| 1130 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
| 1131 |
+
else:
|
| 1132 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
| 1133 |
+
|
| 1134 |
+
if start_T is not None:
|
| 1135 |
+
timesteps = min(timesteps, start_T)
|
| 1136 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
|
| 1137 |
+
total=timesteps) if verbose else reversed(
|
| 1138 |
+
range(0, timesteps))
|
| 1139 |
+
if type(temperature) == float:
|
| 1140 |
+
temperature = [temperature] * timesteps
|
| 1141 |
+
|
| 1142 |
+
for i in iterator:
|
| 1143 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
| 1144 |
+
if self.shorten_cond_schedule:
|
| 1145 |
+
assert self.model.conditioning_key != 'hybrid'
|
| 1146 |
+
tc = self.cond_ids[ts].to(cond.device)
|
| 1147 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
| 1148 |
+
|
| 1149 |
+
img, x0_partial = self.p_sample(img, cond, ts,
|
| 1150 |
+
clip_denoised=self.clip_denoised,
|
| 1151 |
+
quantize_denoised=quantize_denoised, return_x0=True,
|
| 1152 |
+
temperature=temperature[i], noise_dropout=noise_dropout,
|
| 1153 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
| 1154 |
+
if mask is not None:
|
| 1155 |
+
assert x0 is not None
|
| 1156 |
+
img_orig = self.q_sample(x0, ts)
|
| 1157 |
+
img = img_orig * mask + (1. - mask) * img
|
| 1158 |
+
|
| 1159 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
| 1160 |
+
intermediates.append(x0_partial)
|
| 1161 |
+
if callback: callback(i)
|
| 1162 |
+
if img_callback: img_callback(img, i)
|
| 1163 |
+
return img, intermediates
|
| 1164 |
+
|
| 1165 |
+
|
| 1166 |
+
def p_sample_loop(self, cond, shape, return_intermediates=False,
|
| 1167 |
+
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
|
| 1168 |
+
mask=None, x0=None, img_callback=None, start_T=None,
|
| 1169 |
+
log_every_t=None):
|
| 1170 |
+
|
| 1171 |
+
if not log_every_t:
|
| 1172 |
+
log_every_t = self.log_every_t
|
| 1173 |
+
device = self.betas.device
|
| 1174 |
+
b = shape[0]
|
| 1175 |
+
if x_T is None:
|
| 1176 |
+
img = torch.randn(shape, device=device)
|
| 1177 |
+
else:
|
| 1178 |
+
img = x_T
|
| 1179 |
+
|
| 1180 |
+
intermediates = [img]
|
| 1181 |
+
if timesteps is None:
|
| 1182 |
+
timesteps = self.num_timesteps
|
| 1183 |
+
|
| 1184 |
+
if start_T is not None:
|
| 1185 |
+
timesteps = min(timesteps, start_T)
|
| 1186 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
| 1187 |
+
range(0, timesteps))
|
| 1188 |
+
|
| 1189 |
+
if mask is not None:
|
| 1190 |
+
assert x0 is not None
|
| 1191 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
| 1192 |
+
|
| 1193 |
+
for i in iterator:
|
| 1194 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
| 1195 |
+
if self.shorten_cond_schedule:
|
| 1196 |
+
assert self.model.conditioning_key != 'hybrid'
|
| 1197 |
+
tc = self.cond_ids[ts].to(cond.device)
|
| 1198 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
| 1199 |
+
|
| 1200 |
+
img = self.p_sample(img, cond, ts,
|
| 1201 |
+
clip_denoised=self.clip_denoised,
|
| 1202 |
+
quantize_denoised=quantize_denoised)
|
| 1203 |
+
if mask is not None:
|
| 1204 |
+
img_orig = self.q_sample(x0, ts)
|
| 1205 |
+
img = img_orig * mask + (1. - mask) * img
|
| 1206 |
+
|
| 1207 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
| 1208 |
+
intermediates.append(img)
|
| 1209 |
+
if callback: callback(i)
|
| 1210 |
+
if img_callback: img_callback(img, i)
|
| 1211 |
+
|
| 1212 |
+
if return_intermediates:
|
| 1213 |
+
return img, intermediates
|
| 1214 |
+
return img
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
| 1218 |
+
verbose=True, timesteps=None, quantize_denoised=False,
|
| 1219 |
+
mask=None, x0=None, shape=None,**kwargs):
|
| 1220 |
+
if shape is None:
|
| 1221 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
| 1222 |
+
if cond is not None:
|
| 1223 |
+
if isinstance(cond, dict):
|
| 1224 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
| 1225 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
| 1226 |
+
else:
|
| 1227 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
| 1228 |
+
return self.p_sample_loop(cond,
|
| 1229 |
+
shape,
|
| 1230 |
+
return_intermediates=return_intermediates, x_T=x_T,
|
| 1231 |
+
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
|
| 1232 |
+
mask=mask, x0=x0)
|
| 1233 |
+
|
| 1234 |
+
|
| 1235 |
+
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
| 1236 |
+
|
| 1237 |
+
if ddim:
|
| 1238 |
+
ddim_sampler = DDIMSampler(self)
|
| 1239 |
+
shape = (self.channels, self.image_size, self.image_size)
|
| 1240 |
+
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
| 1241 |
+
shape,cond,verbose=False,**kwargs)
|
| 1242 |
+
|
| 1243 |
+
else:
|
| 1244 |
+
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
| 1245 |
+
return_intermediates=True,**kwargs)
|
| 1246 |
+
|
| 1247 |
+
return samples, intermediates
|
| 1248 |
+
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
| 1252 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
| 1253 |
+
plot_diffusion_rows=True, **kwargs):
|
| 1254 |
+
|
| 1255 |
+
use_ddim = ddim_steps is not None
|
| 1256 |
+
|
| 1257 |
+
log = dict()
|
| 1258 |
+
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
| 1259 |
+
return_first_stage_outputs=True,
|
| 1260 |
+
force_c_encode=True,
|
| 1261 |
+
return_original_cond=True,
|
| 1262 |
+
bs=N)
|
| 1263 |
+
N = min(x.shape[0], N)
|
| 1264 |
+
n_row = min(x.shape[0], n_row)
|
| 1265 |
+
log["inputs"] = x
|
| 1266 |
+
log["reconstruction"] = xrec
|
| 1267 |
+
if self.model.conditioning_key is not None:
|
| 1268 |
+
if hasattr(self.cond_stage_model, "decode"):
|
| 1269 |
+
xc = self.cond_stage_model.decode(c)
|
| 1270 |
+
log["conditioning"] = xc
|
| 1271 |
+
elif self.cond_stage_key in ["caption"]:
|
| 1272 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
| 1273 |
+
log["conditioning"] = xc
|
| 1274 |
+
elif self.cond_stage_key == 'class_label':
|
| 1275 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
| 1276 |
+
log['conditioning'] = xc
|
| 1277 |
+
elif isimage(xc):
|
| 1278 |
+
log["conditioning"] = xc
|
| 1279 |
+
if ismap(xc):
|
| 1280 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
| 1281 |
+
|
| 1282 |
+
if plot_diffusion_rows:
|
| 1283 |
+
# get diffusion row
|
| 1284 |
+
diffusion_row = list()
|
| 1285 |
+
z_start = z[:n_row]
|
| 1286 |
+
for t in range(self.num_timesteps):
|
| 1287 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 1288 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 1289 |
+
t = t.to(self.device).long()
|
| 1290 |
+
noise = torch.randn_like(z_start)
|
| 1291 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 1292 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 1293 |
+
|
| 1294 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 1295 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
| 1296 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
| 1297 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 1298 |
+
log["diffusion_row"] = diffusion_grid
|
| 1299 |
+
|
| 1300 |
+
if sample:
|
| 1301 |
+
# get denoise row
|
| 1302 |
+
with self.ema_scope("Plotting"):
|
| 1303 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
| 1304 |
+
ddim_steps=ddim_steps,eta=ddim_eta)
|
| 1305 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
| 1306 |
+
x_samples = self.decode_first_stage(samples)
|
| 1307 |
+
log["samples"] = x_samples
|
| 1308 |
+
if plot_denoise_rows:
|
| 1309 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 1310 |
+
log["denoise_row"] = denoise_grid
|
| 1311 |
+
|
| 1312 |
+
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
| 1313 |
+
self.first_stage_model, IdentityFirstStage):
|
| 1314 |
+
# also display when quantizing x0 while sampling
|
| 1315 |
+
with self.ema_scope("Plotting Quantized Denoised"):
|
| 1316 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
| 1317 |
+
ddim_steps=ddim_steps,eta=ddim_eta,
|
| 1318 |
+
quantize_denoised=True)
|
| 1319 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
| 1320 |
+
# quantize_denoised=True)
|
| 1321 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1322 |
+
log["samples_x0_quantized"] = x_samples
|
| 1323 |
+
|
| 1324 |
+
if inpaint:
|
| 1325 |
+
# make a simple center square
|
| 1326 |
+
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
| 1327 |
+
mask = torch.ones(N, h, w).to(self.device)
|
| 1328 |
+
# zeros will be filled in
|
| 1329 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
| 1330 |
+
mask = mask[:, None, ...]
|
| 1331 |
+
with self.ema_scope("Plotting Inpaint"):
|
| 1332 |
+
|
| 1333 |
+
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
| 1334 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
| 1335 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1336 |
+
log["samples_inpainting"] = x_samples
|
| 1337 |
+
log["mask"] = mask
|
| 1338 |
+
|
| 1339 |
+
# outpaint
|
| 1340 |
+
with self.ema_scope("Plotting Outpaint"):
|
| 1341 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
| 1342 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
| 1343 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1344 |
+
log["samples_outpainting"] = x_samples
|
| 1345 |
+
|
| 1346 |
+
if plot_progressive_rows:
|
| 1347 |
+
with self.ema_scope("Plotting Progressives"):
|
| 1348 |
+
img, progressives = self.progressive_denoising(c,
|
| 1349 |
+
shape=(self.channels, self.image_size, self.image_size),
|
| 1350 |
+
batch_size=N)
|
| 1351 |
+
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
| 1352 |
+
log["progressive_row"] = prog_row
|
| 1353 |
+
|
| 1354 |
+
if return_keys:
|
| 1355 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
| 1356 |
+
return log
|
| 1357 |
+
else:
|
| 1358 |
+
return {key: log[key] for key in return_keys}
|
| 1359 |
+
return log
|
| 1360 |
+
|
| 1361 |
+
def configure_optimizers(self):
|
| 1362 |
+
lr = self.learning_rate
|
| 1363 |
+
params = list(self.model.parameters())
|
| 1364 |
+
if self.cond_stage_trainable:
|
| 1365 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
| 1366 |
+
params = params + list(self.cond_stage_model.parameters())
|
| 1367 |
+
if self.learn_logvar:
|
| 1368 |
+
print('Diffusion model optimizing logvar')
|
| 1369 |
+
params.append(self.logvar)
|
| 1370 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 1371 |
+
if self.use_scheduler:
|
| 1372 |
+
assert 'target' in self.scheduler_config
|
| 1373 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 1374 |
+
|
| 1375 |
+
print("Setting up LambdaLR scheduler...")
|
| 1376 |
+
scheduler = [
|
| 1377 |
+
{
|
| 1378 |
+
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
|
| 1379 |
+
'interval': 'step',
|
| 1380 |
+
'frequency': 1
|
| 1381 |
+
}]
|
| 1382 |
+
return [opt], scheduler
|
| 1383 |
+
return opt
|
| 1384 |
+
|
| 1385 |
+
|
| 1386 |
+
def to_rgb(self, x):
|
| 1387 |
+
x = x.float()
|
| 1388 |
+
if not hasattr(self, "colorize"):
|
| 1389 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
| 1390 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
| 1391 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
| 1392 |
+
return x
|
| 1393 |
+
|
| 1394 |
+
|
| 1395 |
+
class DiffusionWrapper(pl.LightningModule):
|
| 1396 |
+
def __init__(self, diff_model_config, conditioning_key):
|
| 1397 |
+
super().__init__()
|
| 1398 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
| 1399 |
+
self.conditioning_key = conditioning_key
|
| 1400 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
| 1401 |
+
|
| 1402 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
| 1403 |
+
if self.conditioning_key is None:
|
| 1404 |
+
out = self.diffusion_model(x, t)
|
| 1405 |
+
elif self.conditioning_key == 'concat':
|
| 1406 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1407 |
+
out = self.diffusion_model(xc, t)
|
| 1408 |
+
elif self.conditioning_key == 'crossattn':
|
| 1409 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1410 |
+
out = self.diffusion_model(x, t, context=cc)
|
| 1411 |
+
elif self.conditioning_key == 'hybrid':
|
| 1412 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1413 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1414 |
+
out = self.diffusion_model(xc, t, context=cc)
|
| 1415 |
+
elif self.conditioning_key == 'adm':
|
| 1416 |
+
cc = c_crossattn[0]
|
| 1417 |
+
out = self.diffusion_model(x, t, y=cc)
|
| 1418 |
+
else:
|
| 1419 |
+
raise NotImplementedError()
|
| 1420 |
+
|
| 1421 |
+
return out
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
class Layout2ImgDiffusion(LatentDiffusion):
|
| 1425 |
+
# TODO: move all layout-specific hacks to this class
|
| 1426 |
+
def __init__(self, cond_stage_key, *args, **kwargs):
|
| 1427 |
+
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
| 1428 |
+
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
| 1429 |
+
|
| 1430 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
| 1431 |
+
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
| 1432 |
+
|
| 1433 |
+
key = 'train' if self.training else 'validation'
|
| 1434 |
+
dset = self.trainer.datamodule.datasets[key]
|
| 1435 |
+
mapper = dset.conditional_builders[self.cond_stage_key]
|
| 1436 |
+
|
| 1437 |
+
bbox_imgs = []
|
| 1438 |
+
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
|
| 1439 |
+
for tknzd_bbox in batch[self.cond_stage_key][:N]:
|
| 1440 |
+
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
|
| 1441 |
+
bbox_imgs.append(bboximg)
|
| 1442 |
+
|
| 1443 |
+
cond_img = torch.stack(bbox_imgs, dim=0)
|
| 1444 |
+
logs['bbox_image'] = cond_img
|
| 1445 |
+
return logs
|
ldm/models/diffusion/dpm_solver/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampler import DPMSolverSampler
|
ldm/models/diffusion/dpm_solver/dpm_solver.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NoiseScheduleVP:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
schedule="discrete",
|
| 10 |
+
betas=None,
|
| 11 |
+
alphas_cumprod=None,
|
| 12 |
+
continuous_beta_0=0.1,
|
| 13 |
+
continuous_beta_1=20.0,
|
| 14 |
+
):
|
| 15 |
+
|
| 16 |
+
if schedule not in ["discrete", "linear", "cosine"]:
|
| 17 |
+
raise ValueError(
|
| 18 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
| 19 |
+
schedule
|
| 20 |
+
)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.schedule = schedule
|
| 24 |
+
if schedule == "discrete":
|
| 25 |
+
if betas is not None:
|
| 26 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 27 |
+
else:
|
| 28 |
+
assert alphas_cumprod is not None
|
| 29 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 30 |
+
self.total_N = len(log_alphas)
|
| 31 |
+
self.T = 1.0
|
| 32 |
+
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
|
| 33 |
+
self.log_alpha_array = log_alphas.reshape(
|
| 34 |
+
(
|
| 35 |
+
1,
|
| 36 |
+
-1,
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
self.total_N = 1000
|
| 41 |
+
self.beta_0 = continuous_beta_0
|
| 42 |
+
self.beta_1 = continuous_beta_1
|
| 43 |
+
self.cosine_s = 0.008
|
| 44 |
+
self.cosine_beta_max = 999.0
|
| 45 |
+
self.cosine_t_max = (
|
| 46 |
+
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
|
| 47 |
+
* 2.0
|
| 48 |
+
* (1.0 + self.cosine_s)
|
| 49 |
+
/ math.pi
|
| 50 |
+
- self.cosine_s
|
| 51 |
+
)
|
| 52 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 53 |
+
self.schedule = schedule
|
| 54 |
+
if schedule == "cosine":
|
| 55 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 56 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 57 |
+
self.T = 0.9946
|
| 58 |
+
else:
|
| 59 |
+
self.T = 1.0
|
| 60 |
+
|
| 61 |
+
def marginal_log_mean_coeff(self, t):
|
| 62 |
+
"""
|
| 63 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 64 |
+
"""
|
| 65 |
+
if self.schedule == "discrete":
|
| 66 |
+
return interpolate_fn(
|
| 67 |
+
t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
|
| 68 |
+
).reshape((-1))
|
| 69 |
+
elif self.schedule == "linear":
|
| 70 |
+
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 71 |
+
elif self.schedule == "cosine":
|
| 72 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 73 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 74 |
+
return log_alpha_t
|
| 75 |
+
|
| 76 |
+
def marginal_alpha(self, t):
|
| 77 |
+
"""
|
| 78 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 79 |
+
"""
|
| 80 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 81 |
+
|
| 82 |
+
def marginal_std(self, t):
|
| 83 |
+
"""
|
| 84 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 85 |
+
"""
|
| 86 |
+
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
|
| 87 |
+
|
| 88 |
+
def marginal_lambda(self, t):
|
| 89 |
+
"""
|
| 90 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 91 |
+
"""
|
| 92 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 93 |
+
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
|
| 94 |
+
return log_mean_coeff - log_std
|
| 95 |
+
|
| 96 |
+
def inverse_lambda(self, lamb, return_scalar=False):
|
| 97 |
+
"""
|
| 98 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 99 |
+
"""
|
| 100 |
+
if self.schedule == "linear":
|
| 101 |
+
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 102 |
+
Delta = self.beta_0**2 + tmp
|
| 103 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 104 |
+
elif self.schedule == "discrete":
|
| 105 |
+
# check if lamb is a scalar
|
| 106 |
+
if not isinstance(lamb, torch.Tensor):
|
| 107 |
+
lamb = torch.tensor(lamb)
|
| 108 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
|
| 109 |
+
t = interpolate_fn(
|
| 110 |
+
log_alpha.reshape((-1, 1)),
|
| 111 |
+
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
| 112 |
+
torch.flip(self.t_array.to(lamb.device), [1]),
|
| 113 |
+
)
|
| 114 |
+
if return_scalar:
|
| 115 |
+
return t.reshape((-1,)).item()
|
| 116 |
+
return t.reshape((-1,))
|
| 117 |
+
else:
|
| 118 |
+
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 119 |
+
t_fn = (
|
| 120 |
+
lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
|
| 121 |
+
* 2.0
|
| 122 |
+
* (1.0 + self.cosine_s)
|
| 123 |
+
/ math.pi
|
| 124 |
+
- self.cosine_s
|
| 125 |
+
)
|
| 126 |
+
t = t_fn(log_alpha)
|
| 127 |
+
return t
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def model_wrapper(
|
| 131 |
+
model,
|
| 132 |
+
noise_schedule,
|
| 133 |
+
model_type="noise",
|
| 134 |
+
model_kwargs={},
|
| 135 |
+
guidance_type="uncond",
|
| 136 |
+
condition=None,
|
| 137 |
+
unconditional_condition=None,
|
| 138 |
+
guidance_scale=1.0,
|
| 139 |
+
classifier_fn=None,
|
| 140 |
+
classifier_kwargs={},
|
| 141 |
+
):
|
| 142 |
+
|
| 143 |
+
def get_model_input_time(t_continuous):
|
| 144 |
+
"""
|
| 145 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 146 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 147 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 148 |
+
"""
|
| 149 |
+
if noise_schedule.schedule == "discrete":
|
| 150 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
|
| 151 |
+
else:
|
| 152 |
+
return t_continuous
|
| 153 |
+
|
| 154 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 155 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 156 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 157 |
+
t_input = get_model_input_time(t_continuous)
|
| 158 |
+
if cond is None:
|
| 159 |
+
output = model(x, t_input, None, **model_kwargs)
|
| 160 |
+
else:
|
| 161 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 162 |
+
if model_type == "noise":
|
| 163 |
+
return output
|
| 164 |
+
elif model_type == "x_start":
|
| 165 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 166 |
+
dims = x.dim()
|
| 167 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
| 168 |
+
elif model_type == "v":
|
| 169 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 170 |
+
dims = x.dim()
|
| 171 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 172 |
+
elif model_type == "score":
|
| 173 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 174 |
+
dims = x.dim()
|
| 175 |
+
return -expand_dims(sigma_t, dims) * output
|
| 176 |
+
|
| 177 |
+
def cond_grad_fn(x, t_input):
|
| 178 |
+
"""
|
| 179 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 180 |
+
"""
|
| 181 |
+
with torch.enable_grad():
|
| 182 |
+
x_in = x.detach().requires_grad_(True)
|
| 183 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 184 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 185 |
+
|
| 186 |
+
def model_fn(x, t_continuous):
|
| 187 |
+
"""
|
| 188 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 189 |
+
"""
|
| 190 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 191 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 192 |
+
if guidance_type == "uncond":
|
| 193 |
+
return noise_pred_fn(x, t_continuous)
|
| 194 |
+
elif guidance_type == "classifier":
|
| 195 |
+
assert classifier_fn is not None
|
| 196 |
+
t_input = get_model_input_time(t_continuous)
|
| 197 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 198 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 199 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 200 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
| 201 |
+
elif guidance_type == "classifier-free":
|
| 202 |
+
if guidance_scale == 1.0 or unconditional_condition is None:
|
| 203 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 204 |
+
else:
|
| 205 |
+
x_in = torch.cat([x] * 2)
|
| 206 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 207 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 208 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 209 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 210 |
+
|
| 211 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 212 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 213 |
+
return model_fn
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class DPM_Solver:
|
| 217 |
+
def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0):
|
| 218 |
+
"""Construct a DPM-Solver.
|
| 219 |
+
|
| 220 |
+
We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
|
| 221 |
+
If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
|
| 222 |
+
If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
|
| 223 |
+
In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
|
| 224 |
+
The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
| 228 |
+
``
|
| 229 |
+
def model_fn(x, t_continuous):
|
| 230 |
+
return noise
|
| 231 |
+
``
|
| 232 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 233 |
+
predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
|
| 234 |
+
thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
|
| 235 |
+
max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
|
| 236 |
+
|
| 237 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
| 238 |
+
"""
|
| 239 |
+
self.model = model_fn
|
| 240 |
+
self.noise_schedule = noise_schedule
|
| 241 |
+
self.predict_x0 = predict_x0
|
| 242 |
+
self.thresholding = thresholding
|
| 243 |
+
self.max_val = max_val
|
| 244 |
+
|
| 245 |
+
def noise_prediction_fn(self, x, t):
|
| 246 |
+
"""
|
| 247 |
+
Return the noise prediction model.
|
| 248 |
+
"""
|
| 249 |
+
return self.model(x, t)
|
| 250 |
+
|
| 251 |
+
def data_prediction_fn(self, x, t):
|
| 252 |
+
"""
|
| 253 |
+
Return the data prediction model (with thresholding).
|
| 254 |
+
"""
|
| 255 |
+
noise = self.noise_prediction_fn(x, t)
|
| 256 |
+
dims = x.dim()
|
| 257 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
| 258 |
+
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
| 259 |
+
if self.thresholding:
|
| 260 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
| 261 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 262 |
+
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
| 263 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 264 |
+
return x0
|
| 265 |
+
|
| 266 |
+
def model_fn(self, x, t):
|
| 267 |
+
"""
|
| 268 |
+
Convert the model to the noise prediction model or the data prediction model.
|
| 269 |
+
"""
|
| 270 |
+
if self.predict_x0:
|
| 271 |
+
return self.data_prediction_fn(x, t)
|
| 272 |
+
else:
|
| 273 |
+
return self.noise_prediction_fn(x, t)
|
| 274 |
+
|
| 275 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 276 |
+
"""Compute the intermediate time steps for sampling.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 280 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 281 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 282 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 283 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 284 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 285 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 286 |
+
device: A torch device.
|
| 287 |
+
Returns:
|
| 288 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 289 |
+
"""
|
| 290 |
+
if skip_type == "logSNR":
|
| 291 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 292 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 293 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 294 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 295 |
+
elif skip_type == "time_uniform":
|
| 296 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 297 |
+
elif skip_type == "time_quadratic":
|
| 298 |
+
t_order = 2
|
| 299 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 300 |
+
return t
|
| 301 |
+
else:
|
| 302 |
+
raise ValueError(
|
| 303 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
| 307 |
+
"""
|
| 308 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
| 309 |
+
|
| 310 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
| 311 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
| 312 |
+
- If order == 1:
|
| 313 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
| 314 |
+
- If order == 2:
|
| 315 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
| 316 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
| 317 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 318 |
+
- If order == 3:
|
| 319 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
| 320 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 321 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
| 322 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
| 323 |
+
|
| 324 |
+
============================================
|
| 325 |
+
Args:
|
| 326 |
+
order: A `int`. The max order for the solver (2 or 3).
|
| 327 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
| 328 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 329 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 330 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 331 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 332 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 333 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 334 |
+
device: A torch device.
|
| 335 |
+
Returns:
|
| 336 |
+
orders: A list of the solver order of each step.
|
| 337 |
+
"""
|
| 338 |
+
if order == 3:
|
| 339 |
+
K = steps // 3 + 1
|
| 340 |
+
if steps % 3 == 0:
|
| 341 |
+
orders = [
|
| 342 |
+
3,
|
| 343 |
+
] * (
|
| 344 |
+
K - 2
|
| 345 |
+
) + [2, 1]
|
| 346 |
+
elif steps % 3 == 1:
|
| 347 |
+
orders = [
|
| 348 |
+
3,
|
| 349 |
+
] * (
|
| 350 |
+
K - 1
|
| 351 |
+
) + [1]
|
| 352 |
+
else:
|
| 353 |
+
orders = [
|
| 354 |
+
3,
|
| 355 |
+
] * (
|
| 356 |
+
K - 1
|
| 357 |
+
) + [2]
|
| 358 |
+
elif order == 2:
|
| 359 |
+
if steps % 2 == 0:
|
| 360 |
+
K = steps // 2
|
| 361 |
+
orders = [
|
| 362 |
+
2,
|
| 363 |
+
] * K
|
| 364 |
+
else:
|
| 365 |
+
K = steps // 2 + 1
|
| 366 |
+
orders = [
|
| 367 |
+
2,
|
| 368 |
+
] * (
|
| 369 |
+
K - 1
|
| 370 |
+
) + [1]
|
| 371 |
+
elif order == 1:
|
| 372 |
+
K = 1
|
| 373 |
+
orders = [
|
| 374 |
+
1,
|
| 375 |
+
] * steps
|
| 376 |
+
else:
|
| 377 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
| 378 |
+
if skip_type == "logSNR":
|
| 379 |
+
# To reproduce the results in DPM-Solver paper
|
| 380 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
| 381 |
+
else:
|
| 382 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
| 383 |
+
torch.cumsum(
|
| 384 |
+
torch.tensor(
|
| 385 |
+
[
|
| 386 |
+
0,
|
| 387 |
+
]
|
| 388 |
+
+ orders
|
| 389 |
+
),
|
| 390 |
+
dim=0,
|
| 391 |
+
).to(device)
|
| 392 |
+
]
|
| 393 |
+
return timesteps_outer, orders
|
| 394 |
+
|
| 395 |
+
def denoise_to_zero_fn(self, x, s):
|
| 396 |
+
"""
|
| 397 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
| 398 |
+
"""
|
| 399 |
+
return self.data_prediction_fn(x, s)
|
| 400 |
+
|
| 401 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
| 402 |
+
"""
|
| 403 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 407 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
| 408 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 409 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 410 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 411 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
| 412 |
+
Returns:
|
| 413 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 414 |
+
"""
|
| 415 |
+
ns = self.noise_schedule
|
| 416 |
+
dims = x.dim()
|
| 417 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 418 |
+
h = lambda_t - lambda_s
|
| 419 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
| 420 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
| 421 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 422 |
+
|
| 423 |
+
if self.predict_x0:
|
| 424 |
+
phi_1 = torch.expm1(-h)
|
| 425 |
+
if model_s is None:
|
| 426 |
+
model_s = self.model_fn(x, s)
|
| 427 |
+
x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s
|
| 428 |
+
if return_intermediate:
|
| 429 |
+
return x_t, {"model_s": model_s}
|
| 430 |
+
else:
|
| 431 |
+
return x_t
|
| 432 |
+
else:
|
| 433 |
+
phi_1 = torch.expm1(h)
|
| 434 |
+
if model_s is None:
|
| 435 |
+
model_s = self.model_fn(x, s)
|
| 436 |
+
x_t = (
|
| 437 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
| 438 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
| 439 |
+
)
|
| 440 |
+
if return_intermediate:
|
| 441 |
+
return x_t, {"model_s": model_s}
|
| 442 |
+
else:
|
| 443 |
+
return x_t
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
|
| 447 |
+
"""
|
| 448 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 452 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 453 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
| 454 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 455 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 456 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 457 |
+
Returns:
|
| 458 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 459 |
+
"""
|
| 460 |
+
ns = self.noise_schedule
|
| 461 |
+
dims = x.dim()
|
| 462 |
+
model_prev_1, model_prev_0 = model_prev_list[-2:]
|
| 463 |
+
t_prev_1, t_prev_0 = t_prev_list[-2:]
|
| 464 |
+
lambda_prev_1, lambda_prev_0, lambda_t = (
|
| 465 |
+
ns.marginal_lambda(t_prev_1),
|
| 466 |
+
ns.marginal_lambda(t_prev_0),
|
| 467 |
+
ns.marginal_lambda(t),
|
| 468 |
+
)
|
| 469 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 470 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 471 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 472 |
+
|
| 473 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
| 474 |
+
h = lambda_t - lambda_prev_0
|
| 475 |
+
r0 = h_0 / h
|
| 476 |
+
D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
|
| 477 |
+
if self.predict_x0:
|
| 478 |
+
if solver_type == "dpm_solver" or solver_type == "dpmsolver":
|
| 479 |
+
x_t = (
|
| 480 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 481 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
| 482 |
+
- 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0
|
| 483 |
+
)
|
| 484 |
+
elif solver_type == "taylor":
|
| 485 |
+
x_t = (
|
| 486 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 487 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
| 488 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0
|
| 489 |
+
)
|
| 490 |
+
else:
|
| 491 |
+
if solver_type == "dpm_solver" or solver_type == "dpmsolver":
|
| 492 |
+
x_t = (
|
| 493 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 494 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
| 495 |
+
- 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0
|
| 496 |
+
)
|
| 497 |
+
elif solver_type == "taylor":
|
| 498 |
+
x_t = (
|
| 499 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 500 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
| 501 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0
|
| 502 |
+
)
|
| 503 |
+
return x_t
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
|
| 507 |
+
"""
|
| 508 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 512 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 513 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
| 514 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 515 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 516 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 517 |
+
Returns:
|
| 518 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 519 |
+
"""
|
| 520 |
+
ns = self.noise_schedule
|
| 521 |
+
dims = x.dim()
|
| 522 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
| 523 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
| 524 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
|
| 525 |
+
ns.marginal_lambda(t_prev_2),
|
| 526 |
+
ns.marginal_lambda(t_prev_1),
|
| 527 |
+
ns.marginal_lambda(t_prev_0),
|
| 528 |
+
ns.marginal_lambda(t),
|
| 529 |
+
)
|
| 530 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 531 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 532 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 533 |
+
|
| 534 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
| 535 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
| 536 |
+
h = lambda_t - lambda_prev_0
|
| 537 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 538 |
+
D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
|
| 539 |
+
D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)
|
| 540 |
+
D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
|
| 541 |
+
D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)
|
| 542 |
+
if self.predict_x0:
|
| 543 |
+
x_t = (
|
| 544 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 545 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
| 546 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
|
| 547 |
+
- expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims) * D2
|
| 548 |
+
)
|
| 549 |
+
else:
|
| 550 |
+
x_t = (
|
| 551 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 552 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
| 553 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
|
| 554 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2
|
| 555 |
+
)
|
| 556 |
+
return x_t
|
| 557 |
+
|
| 558 |
+
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"):
|
| 559 |
+
"""
|
| 560 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 564 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 565 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
| 566 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 567 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
| 568 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 569 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 570 |
+
Returns:
|
| 571 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 572 |
+
"""
|
| 573 |
+
if order == 1:
|
| 574 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
| 575 |
+
elif order == 2:
|
| 576 |
+
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
| 577 |
+
elif order == 3:
|
| 578 |
+
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
| 579 |
+
else:
|
| 580 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
| 581 |
+
|
| 582 |
+
def dpm_solver_adaptive(
|
| 583 |
+
self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpm_solver"
|
| 584 |
+
):
|
| 585 |
+
"""
|
| 586 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
| 590 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
| 591 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 592 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 593 |
+
h_init: A `float`. The initial step size (for logSNR).
|
| 594 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
| 595 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
| 596 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
| 597 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
| 598 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
| 599 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 600 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 601 |
+
Returns:
|
| 602 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
| 603 |
+
|
| 604 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
| 605 |
+
"""
|
| 606 |
+
ns = self.noise_schedule
|
| 607 |
+
s = t_T * torch.ones((x.shape[0],)).to(x)
|
| 608 |
+
lambda_s = ns.marginal_lambda(s)
|
| 609 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
| 610 |
+
h = h_init * torch.ones_like(s).to(x)
|
| 611 |
+
x_prev = x
|
| 612 |
+
nfe = 0
|
| 613 |
+
if order == 2:
|
| 614 |
+
r1 = 0.5
|
| 615 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
| 616 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
|
| 617 |
+
x, s, t, r1=r1, solver_type=solver_type, **kwargs
|
| 618 |
+
)
|
| 619 |
+
elif order == 3:
|
| 620 |
+
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
|
| 621 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
|
| 622 |
+
x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
|
| 623 |
+
)
|
| 624 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
|
| 625 |
+
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
|
| 626 |
+
)
|
| 627 |
+
else:
|
| 628 |
+
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
| 629 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
| 630 |
+
t = ns.inverse_lambda(lambda_s + h)
|
| 631 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
| 632 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
| 633 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
| 634 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
| 635 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
| 636 |
+
if torch.all(E <= 1.0):
|
| 637 |
+
x = x_higher
|
| 638 |
+
s = t
|
| 639 |
+
x_prev = x_lower
|
| 640 |
+
lambda_s = ns.marginal_lambda(s)
|
| 641 |
+
h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
|
| 642 |
+
nfe += order
|
| 643 |
+
print("adaptive solver nfe", nfe)
|
| 644 |
+
return x
|
| 645 |
+
|
| 646 |
+
def sample(
|
| 647 |
+
self,
|
| 648 |
+
x,
|
| 649 |
+
steps=20,
|
| 650 |
+
t_start=None,
|
| 651 |
+
t_end=None,
|
| 652 |
+
order=3,
|
| 653 |
+
skip_type="time_uniform",
|
| 654 |
+
method="singlestep",
|
| 655 |
+
lower_order_final=True,
|
| 656 |
+
denoise_to_zero=False,
|
| 657 |
+
solver_type="dpm_solver",
|
| 658 |
+
atol=0.0078,
|
| 659 |
+
rtol=0.05,
|
| 660 |
+
flags=None,
|
| 661 |
+
):
|
| 662 |
+
device = x.device
|
| 663 |
+
with torch.no_grad():
|
| 664 |
+
if flags.learn:
|
| 665 |
+
load_from = f"{flags.log_path}/NFE-{steps}-256LSUN-dpmsolver++-{order}-decode/best.pt"
|
| 666 |
+
timesteps = torch.load(load_from)['best_t_steps'].to(x.device)
|
| 667 |
+
if flags:
|
| 668 |
+
length = timesteps.shape[0] // 2
|
| 669 |
+
timesteps2 = timesteps[length:]
|
| 670 |
+
timesteps = timesteps[:length]
|
| 671 |
+
else:
|
| 672 |
+
timesteps2 = timesteps
|
| 673 |
+
else:
|
| 674 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 675 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 676 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 677 |
+
timesteps2 = timesteps
|
| 678 |
+
assert timesteps.shape[0] - 1 == steps
|
| 679 |
+
|
| 680 |
+
def one_step(t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True):
|
| 681 |
+
x_next = self.multistep_dpm_solver_update(x_next, model_prev_list, t_prev_list, t1, step, solver_type="dpmsolver")
|
| 682 |
+
model_x_next = self.model_fn(x_next, t2)
|
| 683 |
+
update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first)
|
| 684 |
+
return x_next
|
| 685 |
+
|
| 686 |
+
def update_lists(t_list, model_list, t_, model_x, order, first=False):
|
| 687 |
+
if first:
|
| 688 |
+
t_list.append(t_)
|
| 689 |
+
model_list.append(model_x)
|
| 690 |
+
return
|
| 691 |
+
for m in range(order - 1):
|
| 692 |
+
t_list[m] = t_list[m + 1]
|
| 693 |
+
model_list[m] = model_list[m + 1]
|
| 694 |
+
t_list[-1] = t_
|
| 695 |
+
model_list[-1] = model_x
|
| 696 |
+
|
| 697 |
+
timesteps1 = timesteps
|
| 698 |
+
step = 0
|
| 699 |
+
vec_t1 = timesteps1[0].expand((x.shape[0]))
|
| 700 |
+
vec_t2 = timesteps2[0].expand((x.shape[0]))
|
| 701 |
+
t_prev_list = [vec_t1]
|
| 702 |
+
model_prev_list = [self.model_fn(x, vec_t2)]
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
for step in range(1, order):
|
| 706 |
+
vec_t1 = timesteps1[step].expand(x.shape[0])
|
| 707 |
+
vec_t2 = timesteps2[step].expand(x.shape[0])
|
| 708 |
+
x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step, x, order, first=True)
|
| 709 |
+
|
| 710 |
+
for step in range(order, steps + 1):
|
| 711 |
+
step_order = min(order, steps + 1 - step)
|
| 712 |
+
vec_t1 = timesteps1[step].expand(x.shape[0])
|
| 713 |
+
vec_t2 = timesteps2[step].expand(x.shape[0])
|
| 714 |
+
x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step_order, x, order, first=False)
|
| 715 |
+
|
| 716 |
+
return x
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
#############################################################
|
| 720 |
+
# other utility functions
|
| 721 |
+
#############################################################
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def interpolate_fn(x, xp, yp):
|
| 725 |
+
"""
|
| 726 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 727 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 728 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 732 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 733 |
+
yp: PyTorch tensor with shape [C, K].
|
| 734 |
+
Returns:
|
| 735 |
+
The function values f(x), with shape [N, C].
|
| 736 |
+
"""
|
| 737 |
+
N, K = x.shape[0], xp.shape[1]
|
| 738 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 739 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 740 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 741 |
+
cand_start_idx = x_idx - 1
|
| 742 |
+
start_idx = torch.where(
|
| 743 |
+
torch.eq(x_idx, 0),
|
| 744 |
+
torch.tensor(1, device=x.device),
|
| 745 |
+
torch.where(
|
| 746 |
+
torch.eq(x_idx, K),
|
| 747 |
+
torch.tensor(K - 2, device=x.device),
|
| 748 |
+
cand_start_idx,
|
| 749 |
+
),
|
| 750 |
+
)
|
| 751 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 752 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 753 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 754 |
+
start_idx2 = torch.where(
|
| 755 |
+
torch.eq(x_idx, 0),
|
| 756 |
+
torch.tensor(0, device=x.device),
|
| 757 |
+
torch.where(
|
| 758 |
+
torch.eq(x_idx, K),
|
| 759 |
+
torch.tensor(K - 2, device=x.device),
|
| 760 |
+
cand_start_idx,
|
| 761 |
+
),
|
| 762 |
+
)
|
| 763 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 764 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 765 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 766 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 767 |
+
return cand
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def expand_dims(v, dims):
|
| 771 |
+
"""
|
| 772 |
+
Expand the tensor `v` to the dim `dims`.
|
| 773 |
+
|
| 774 |
+
Args:
|
| 775 |
+
`v`: a PyTorch tensor with shape [N].
|
| 776 |
+
`dim`: a `int`.
|
| 777 |
+
Returns:
|
| 778 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 779 |
+
"""
|
| 780 |
+
return v[(...,) + (None,) * (dims - 1)]
|
ldm/models/diffusion/dpm_solver/sampler.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DPMSolverSampler(object):
|
| 9 |
+
def __init__(self, model, **kwargs):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.model = model
|
| 12 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
| 13 |
+
self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
|
| 14 |
+
|
| 15 |
+
def register_buffer(self, name, attr):
|
| 16 |
+
if type(attr) == torch.Tensor:
|
| 17 |
+
if attr.device != torch.device("cuda"):
|
| 18 |
+
attr = attr.to(torch.device("cuda"))
|
| 19 |
+
setattr(self, name, attr)
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def sample(
|
| 23 |
+
self,
|
| 24 |
+
S,
|
| 25 |
+
batch_size,
|
| 26 |
+
shape,
|
| 27 |
+
conditioning=None,
|
| 28 |
+
callback=None,
|
| 29 |
+
normals_sequence=None,
|
| 30 |
+
img_callback=None,
|
| 31 |
+
quantize_x0=False,
|
| 32 |
+
eta=0.0,
|
| 33 |
+
mask=None,
|
| 34 |
+
x0=None,
|
| 35 |
+
temperature=1.0,
|
| 36 |
+
noise_dropout=0.0,
|
| 37 |
+
score_corrector=None,
|
| 38 |
+
corrector_kwargs=None,
|
| 39 |
+
verbose=True,
|
| 40 |
+
x_T=None,
|
| 41 |
+
log_every_t=100,
|
| 42 |
+
unconditional_guidance_scale=1.0,
|
| 43 |
+
unconditional_conditioning=None,
|
| 44 |
+
flags=None,
|
| 45 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
if conditioning is not None:
|
| 49 |
+
if isinstance(conditioning, dict):
|
| 50 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 51 |
+
if cbs != batch_size:
|
| 52 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 53 |
+
else:
|
| 54 |
+
if conditioning.shape[0] != batch_size:
|
| 55 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 56 |
+
|
| 57 |
+
# sampling
|
| 58 |
+
C, H, W = shape
|
| 59 |
+
size = (batch_size, C, H, W)
|
| 60 |
+
|
| 61 |
+
# print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
| 62 |
+
|
| 63 |
+
device = self.model.betas.device
|
| 64 |
+
if x_T is None:
|
| 65 |
+
img = torch.randn(size, device=device)
|
| 66 |
+
else:
|
| 67 |
+
img = x_T
|
| 68 |
+
|
| 69 |
+
ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 70 |
+
|
| 71 |
+
if conditioning is None:
|
| 72 |
+
model_fn = model_wrapper(
|
| 73 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 74 |
+
ns,
|
| 75 |
+
model_type="noise",
|
| 76 |
+
guidance_type="uncond",
|
| 77 |
+
)
|
| 78 |
+
ORDER = 3
|
| 79 |
+
else:
|
| 80 |
+
model_fn = model_wrapper(
|
| 81 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 82 |
+
ns,
|
| 83 |
+
model_type="noise",
|
| 84 |
+
guidance_type="classifier-free",
|
| 85 |
+
condition=conditioning,
|
| 86 |
+
unconditional_condition=unconditional_conditioning,
|
| 87 |
+
guidance_scale=unconditional_guidance_scale,
|
| 88 |
+
)
|
| 89 |
+
ORDER = 2
|
| 90 |
+
|
| 91 |
+
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
| 92 |
+
x = dpm_solver.sample(
|
| 93 |
+
img, steps=S, skip_type=flags.skip_type, method="multistep", order=ORDER, lower_order_final=True, flags=flags
|
| 94 |
+
)
|
| 95 |
+
return x.to(device), None
|
ldm/models/diffusion/dpm_solver_v3/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampler import DPMSolverv3Sampler
|
ldm/models/diffusion/dpm_solver_v3/dpm_solver_v3.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NoiseScheduleVP:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
schedule="discrete",
|
| 12 |
+
betas=None,
|
| 13 |
+
alphas_cumprod=None,
|
| 14 |
+
continuous_beta_0=0.1,
|
| 15 |
+
continuous_beta_1=20.0,
|
| 16 |
+
):
|
| 17 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
| 18 |
+
|
| 19 |
+
***
|
| 20 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
| 21 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
| 22 |
+
***
|
| 23 |
+
|
| 24 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
| 25 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
| 26 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
| 27 |
+
|
| 28 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
| 29 |
+
sigma_t = self.marginal_std(t)
|
| 30 |
+
lambda_t = self.marginal_lambda(t)
|
| 31 |
+
|
| 32 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
| 33 |
+
|
| 34 |
+
t = self.inverse_lambda(lambda_t)
|
| 35 |
+
|
| 36 |
+
===============================================================
|
| 37 |
+
|
| 38 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
| 39 |
+
|
| 40 |
+
1. For discrete-time DPMs:
|
| 41 |
+
|
| 42 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
| 43 |
+
t_i = (i + 1) / N
|
| 44 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
| 45 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
| 49 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
| 50 |
+
|
| 51 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
| 52 |
+
|
| 53 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
| 54 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
| 55 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
| 56 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
| 57 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
| 58 |
+
and
|
| 59 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
2. For continuous-time DPMs:
|
| 63 |
+
|
| 64 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
| 65 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
| 69 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
| 70 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
| 71 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
| 72 |
+
T: A `float` number. The ending time of the forward process.
|
| 73 |
+
|
| 74 |
+
===============================================================
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
| 78 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
| 79 |
+
Returns:
|
| 80 |
+
A wrapper object of the forward SDE (VP type).
|
| 81 |
+
|
| 82 |
+
===============================================================
|
| 83 |
+
|
| 84 |
+
Example:
|
| 85 |
+
|
| 86 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
| 87 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
| 88 |
+
|
| 89 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
| 90 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
| 91 |
+
|
| 92 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
| 93 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
if schedule not in ["discrete", "linear", "cosine"]:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
| 100 |
+
schedule
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.schedule = schedule
|
| 105 |
+
if schedule == "discrete":
|
| 106 |
+
if betas is not None:
|
| 107 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 108 |
+
else:
|
| 109 |
+
assert alphas_cumprod is not None
|
| 110 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 111 |
+
self.total_N = len(log_alphas)
|
| 112 |
+
self.T = 1.0
|
| 113 |
+
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
|
| 114 |
+
self.log_alpha_array = log_alphas.reshape(
|
| 115 |
+
(
|
| 116 |
+
1,
|
| 117 |
+
-1,
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
self.total_N = 1000
|
| 122 |
+
self.beta_0 = continuous_beta_0
|
| 123 |
+
self.beta_1 = continuous_beta_1
|
| 124 |
+
self.cosine_s = 0.008
|
| 125 |
+
self.cosine_beta_max = 999.0
|
| 126 |
+
self.cosine_t_max = (
|
| 127 |
+
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
|
| 128 |
+
* 2.0
|
| 129 |
+
* (1.0 + self.cosine_s)
|
| 130 |
+
/ math.pi
|
| 131 |
+
- self.cosine_s
|
| 132 |
+
)
|
| 133 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 134 |
+
self.schedule = schedule
|
| 135 |
+
if schedule == "cosine":
|
| 136 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 137 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 138 |
+
self.T = 0.9946
|
| 139 |
+
else:
|
| 140 |
+
self.T = 1.0
|
| 141 |
+
|
| 142 |
+
def marginal_log_mean_coeff(self, t):
|
| 143 |
+
"""
|
| 144 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 145 |
+
"""
|
| 146 |
+
if self.schedule == "discrete":
|
| 147 |
+
return interpolate_fn(
|
| 148 |
+
t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
|
| 149 |
+
).reshape((-1))
|
| 150 |
+
elif self.schedule == "linear":
|
| 151 |
+
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 152 |
+
elif self.schedule == "cosine":
|
| 153 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 154 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 155 |
+
return log_alpha_t
|
| 156 |
+
|
| 157 |
+
def marginal_alpha(self, t):
|
| 158 |
+
"""
|
| 159 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 160 |
+
"""
|
| 161 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 162 |
+
|
| 163 |
+
def marginal_std(self, t):
|
| 164 |
+
"""
|
| 165 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 166 |
+
"""
|
| 167 |
+
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
|
| 168 |
+
|
| 169 |
+
def marginal_lambda(self, t):
|
| 170 |
+
"""
|
| 171 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 172 |
+
"""
|
| 173 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 174 |
+
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
|
| 175 |
+
return log_mean_coeff - log_std
|
| 176 |
+
|
| 177 |
+
def inverse_lambda(self, lamb):
|
| 178 |
+
"""
|
| 179 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 180 |
+
"""
|
| 181 |
+
if self.schedule == "linear":
|
| 182 |
+
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 183 |
+
Delta = self.beta_0**2 + tmp
|
| 184 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 185 |
+
elif self.schedule == "discrete":
|
| 186 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
|
| 187 |
+
t = interpolate_fn(
|
| 188 |
+
log_alpha.reshape((-1, 1)),
|
| 189 |
+
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
| 190 |
+
torch.flip(self.t_array.to(lamb.device), [1]),
|
| 191 |
+
)
|
| 192 |
+
return t.reshape((-1,))
|
| 193 |
+
else:
|
| 194 |
+
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 195 |
+
t_fn = (
|
| 196 |
+
lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
|
| 197 |
+
* 2.0
|
| 198 |
+
* (1.0 + self.cosine_s)
|
| 199 |
+
/ math.pi
|
| 200 |
+
- self.cosine_s
|
| 201 |
+
)
|
| 202 |
+
t = t_fn(log_alpha)
|
| 203 |
+
return t
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def model_wrapper(
|
| 207 |
+
model,
|
| 208 |
+
noise_schedule,
|
| 209 |
+
model_type="noise",
|
| 210 |
+
model_kwargs={},
|
| 211 |
+
guidance_type="uncond",
|
| 212 |
+
condition=None,
|
| 213 |
+
unconditional_condition=None,
|
| 214 |
+
guidance_scale=1.0,
|
| 215 |
+
classifier_fn=None,
|
| 216 |
+
classifier_kwargs={},
|
| 217 |
+
):
|
| 218 |
+
"""Create a wrapper function for the noise prediction model.
|
| 219 |
+
|
| 220 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
| 221 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
| 222 |
+
|
| 223 |
+
We support four types of the diffusion model by setting `model_type`:
|
| 224 |
+
|
| 225 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
| 226 |
+
|
| 227 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
| 228 |
+
|
| 229 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
| 230 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
| 231 |
+
|
| 232 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
| 233 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
| 234 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
| 235 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
| 236 |
+
|
| 237 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
| 238 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
| 239 |
+
```
|
| 240 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
| 244 |
+
1. "uncond": unconditional sampling by DPMs.
|
| 245 |
+
The input `model` has the following format:
|
| 246 |
+
``
|
| 247 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 248 |
+
``
|
| 249 |
+
|
| 250 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
| 251 |
+
The input `model` has the following format:
|
| 252 |
+
``
|
| 253 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 254 |
+
``
|
| 255 |
+
|
| 256 |
+
The input `classifier_fn` has the following format:
|
| 257 |
+
``
|
| 258 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
| 259 |
+
``
|
| 260 |
+
|
| 261 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
| 262 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
| 263 |
+
|
| 264 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
| 265 |
+
The input `model` has the following format:
|
| 266 |
+
``
|
| 267 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
| 268 |
+
``
|
| 269 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
| 270 |
+
|
| 271 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
| 272 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
| 276 |
+
or continuous-time labels (i.e. epsilon to T).
|
| 277 |
+
|
| 278 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
| 279 |
+
``
|
| 280 |
+
def model_fn(x, t_continuous) -> noise:
|
| 281 |
+
t_input = get_model_input_time(t_continuous)
|
| 282 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
| 283 |
+
``
|
| 284 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
| 285 |
+
|
| 286 |
+
===============================================================
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
model: A diffusion model with the corresponding format described above.
|
| 290 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 291 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
| 292 |
+
"noise" or "x_start" or "v" or "score".
|
| 293 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
| 294 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
| 295 |
+
"uncond" or "classifier" or "classifier-free".
|
| 296 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
| 297 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
| 298 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
| 299 |
+
Only used for "classifier-free" guidance type.
|
| 300 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
| 301 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
| 302 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
| 303 |
+
Returns:
|
| 304 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def get_model_input_time(t_continuous):
|
| 308 |
+
"""
|
| 309 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 310 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 311 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 312 |
+
"""
|
| 313 |
+
if noise_schedule.schedule == "discrete":
|
| 314 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
|
| 315 |
+
else:
|
| 316 |
+
return t_continuous
|
| 317 |
+
|
| 318 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 319 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 320 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 321 |
+
t_input = get_model_input_time(t_continuous)
|
| 322 |
+
if cond is None:
|
| 323 |
+
output = model(x, t_input, None, **model_kwargs)
|
| 324 |
+
else:
|
| 325 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 326 |
+
if model_type == "noise":
|
| 327 |
+
return output
|
| 328 |
+
elif model_type == "x_start":
|
| 329 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 330 |
+
dims = x.dim()
|
| 331 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
| 332 |
+
elif model_type == "v":
|
| 333 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 334 |
+
dims = x.dim()
|
| 335 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 336 |
+
elif model_type == "score":
|
| 337 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 338 |
+
dims = x.dim()
|
| 339 |
+
return -expand_dims(sigma_t, dims) * output
|
| 340 |
+
|
| 341 |
+
def cond_grad_fn(x, t_input):
|
| 342 |
+
"""
|
| 343 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 344 |
+
"""
|
| 345 |
+
with torch.enable_grad():
|
| 346 |
+
x_in = x.detach().requires_grad_(True)
|
| 347 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 348 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 349 |
+
|
| 350 |
+
def model_fn(x, t_continuous):
|
| 351 |
+
"""
|
| 352 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 353 |
+
"""
|
| 354 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 355 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 356 |
+
if guidance_type == "uncond":
|
| 357 |
+
return noise_pred_fn(x, t_continuous)
|
| 358 |
+
elif guidance_type == "classifier":
|
| 359 |
+
assert classifier_fn is not None
|
| 360 |
+
t_input = get_model_input_time(t_continuous)
|
| 361 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 362 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 363 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 364 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
| 365 |
+
elif guidance_type == "classifier-free":
|
| 366 |
+
if guidance_scale == 1.0 or unconditional_condition is None:
|
| 367 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 368 |
+
else:
|
| 369 |
+
x_in = torch.cat([x] * 2)
|
| 370 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 371 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 372 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 373 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 374 |
+
|
| 375 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 376 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 377 |
+
return model_fn
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def weighted_cumsumexp_trapezoid(a, x, b, cumsum=True):
|
| 381 |
+
# ∫ b*e^a dx
|
| 382 |
+
# Input: a,x,b: shape (N+1,...)
|
| 383 |
+
# Output: y: shape (N+1,...)
|
| 384 |
+
# y_0 = 0
|
| 385 |
+
# y_n = sum_{i=1}^{n} 0.5*(x_{i}-x_{i-1})*(b_{i}*e^{a_{i}}+b_{i-1}*e^{a_{i-1}}) (n from 1 to N)
|
| 386 |
+
|
| 387 |
+
assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
|
| 388 |
+
if b is not None:
|
| 389 |
+
assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
|
| 390 |
+
|
| 391 |
+
a_max = np.amax(a, axis=0, keepdims=True)
|
| 392 |
+
|
| 393 |
+
if b is not None:
|
| 394 |
+
b = np.asarray(b)
|
| 395 |
+
tmp = b * np.exp(a - a_max)
|
| 396 |
+
else:
|
| 397 |
+
tmp = np.exp(a - a_max)
|
| 398 |
+
|
| 399 |
+
out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
|
| 400 |
+
if not cumsum:
|
| 401 |
+
return np.sum(out, axis=0) * np.exp(a_max)
|
| 402 |
+
out = np.cumsum(out, axis=0)
|
| 403 |
+
out *= np.exp(a_max)
|
| 404 |
+
return np.concatenate([np.zeros_like(out[[0]]), out], axis=0)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=True):
|
| 408 |
+
assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
|
| 409 |
+
if b is not None:
|
| 410 |
+
assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
|
| 411 |
+
|
| 412 |
+
a_max = torch.amax(a, dim=0, keepdims=True)
|
| 413 |
+
|
| 414 |
+
if b is not None:
|
| 415 |
+
tmp = b * torch.exp(a - a_max)
|
| 416 |
+
else:
|
| 417 |
+
tmp = torch.exp(a - a_max)
|
| 418 |
+
|
| 419 |
+
out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
|
| 420 |
+
if not cumsum:
|
| 421 |
+
return torch.sum(out, dim=0) * torch.exp(a_max)
|
| 422 |
+
out = torch.cumsum(out, dim=0)
|
| 423 |
+
out *= torch.exp(a_max)
|
| 424 |
+
return torch.concat([torch.zeros_like(out[[0]]), out], dim=0)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def index_list(lst, index):
|
| 428 |
+
new_lst = []
|
| 429 |
+
for i in index:
|
| 430 |
+
new_lst.append(lst[i])
|
| 431 |
+
return new_lst
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class DPM_Solver_v3:
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
statistics_dir,
|
| 438 |
+
noise_schedule,
|
| 439 |
+
steps=10,
|
| 440 |
+
t_start=None,
|
| 441 |
+
t_end=None,
|
| 442 |
+
skip_type="time_uniform",
|
| 443 |
+
degenerated=False,
|
| 444 |
+
device="cuda",
|
| 445 |
+
):
|
| 446 |
+
self.device = device
|
| 447 |
+
self.model = None
|
| 448 |
+
self.noise_schedule = noise_schedule
|
| 449 |
+
self.steps = steps
|
| 450 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 451 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 452 |
+
assert (
|
| 453 |
+
t_0 > 0 and t_T > 0
|
| 454 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 455 |
+
|
| 456 |
+
l = np.load(os.path.join(statistics_dir, "l.npz"))["l"]
|
| 457 |
+
sb = np.load(os.path.join(statistics_dir, "sb.npz"))
|
| 458 |
+
s, b = sb["s"], sb["b"]
|
| 459 |
+
if degenerated:
|
| 460 |
+
l = np.ones_like(l)
|
| 461 |
+
s = np.zeros_like(s)
|
| 462 |
+
b = np.zeros_like(b)
|
| 463 |
+
self.statistics_steps = l.shape[0] - 1
|
| 464 |
+
ts = noise_schedule.marginal_lambda(
|
| 465 |
+
self.get_time_steps("logSNR", t_T, t_0, self.statistics_steps, "cpu")
|
| 466 |
+
).numpy()[:, None, None, None]
|
| 467 |
+
self.ts = torch.from_numpy(ts).cuda()
|
| 468 |
+
self.lambda_T = self.ts[0].cpu().item()
|
| 469 |
+
self.lambda_0 = self.ts[-1].cpu().item()
|
| 470 |
+
z = np.zeros_like(l)
|
| 471 |
+
o = np.ones_like(l)
|
| 472 |
+
L = weighted_cumsumexp_trapezoid(z, ts, l)
|
| 473 |
+
S = weighted_cumsumexp_trapezoid(z, ts, s)
|
| 474 |
+
|
| 475 |
+
I = weighted_cumsumexp_trapezoid(L + S, ts, o)
|
| 476 |
+
B = weighted_cumsumexp_trapezoid(-S, ts, b)
|
| 477 |
+
C = weighted_cumsumexp_trapezoid(L + S, ts, B)
|
| 478 |
+
self.l = torch.from_numpy(l).cuda()
|
| 479 |
+
self.s = torch.from_numpy(s).cuda()
|
| 480 |
+
self.b = torch.from_numpy(b).cuda()
|
| 481 |
+
self.L = torch.from_numpy(L).cuda()
|
| 482 |
+
self.S = torch.from_numpy(S).cuda()
|
| 483 |
+
self.I = torch.from_numpy(I).cuda()
|
| 484 |
+
self.B = torch.from_numpy(B).cuda()
|
| 485 |
+
self.C = torch.from_numpy(C).cuda()
|
| 486 |
+
|
| 487 |
+
# precompute timesteps
|
| 488 |
+
if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic":
|
| 489 |
+
self.timesteps = self.get_time_steps(skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 490 |
+
self.indexes = self.convert_to_indexes(self.timesteps)
|
| 491 |
+
self.timesteps = self.convert_to_timesteps(self.indexes, device)
|
| 492 |
+
elif skip_type == "edm":
|
| 493 |
+
self.indexes, self.timesteps = self.get_timesteps_edm(N=steps, device=device)
|
| 494 |
+
self.timesteps = self.convert_to_timesteps(self.indexes, device)
|
| 495 |
+
else:
|
| 496 |
+
raise ValueError(f"Unsupported timestep strategy {skip_type}")
|
| 497 |
+
|
| 498 |
+
print("Indexes", self.indexes)
|
| 499 |
+
print("Time steps", self.timesteps)
|
| 500 |
+
print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
|
| 501 |
+
|
| 502 |
+
# store high-order exponential coefficients (lazy)
|
| 503 |
+
self.exp_coeffs = {}
|
| 504 |
+
|
| 505 |
+
def noise_prediction_fn(self, x, t):
|
| 506 |
+
"""
|
| 507 |
+
Return the noise prediction model.
|
| 508 |
+
"""
|
| 509 |
+
return self.model(x, t)
|
| 510 |
+
|
| 511 |
+
def convert_to_indexes(self, timesteps):
|
| 512 |
+
logSNR_steps = self.noise_schedule.marginal_lambda(timesteps)
|
| 513 |
+
indexes = list(
|
| 514 |
+
(self.statistics_steps * (logSNR_steps - self.lambda_T) / (self.lambda_0 - self.lambda_T))
|
| 515 |
+
.round()
|
| 516 |
+
.cpu()
|
| 517 |
+
.numpy()
|
| 518 |
+
.astype(np.int64)
|
| 519 |
+
)
|
| 520 |
+
return indexes
|
| 521 |
+
|
| 522 |
+
def convert_to_timesteps(self, indexes, device):
|
| 523 |
+
logSNR_steps = (
|
| 524 |
+
self.lambda_T + (self.lambda_0 - self.lambda_T) * torch.Tensor(indexes).to(device) / self.statistics_steps
|
| 525 |
+
)
|
| 526 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 527 |
+
|
| 528 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 529 |
+
"""Compute the intermediate time steps for sampling.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 533 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 534 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 535 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 536 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 537 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 538 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 539 |
+
device: A torch device.
|
| 540 |
+
Returns:
|
| 541 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 542 |
+
"""
|
| 543 |
+
if skip_type == "logSNR":
|
| 544 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 545 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 546 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 547 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 548 |
+
elif skip_type == "time_uniform":
|
| 549 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 550 |
+
elif skip_type == "time_quadratic":
|
| 551 |
+
t_order = 2
|
| 552 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 553 |
+
return t
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError(
|
| 556 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def get_timesteps_edm(self, N, device):
|
| 560 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 561 |
+
|
| 562 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 563 |
+
|
| 564 |
+
sigma_min: float = np.exp(-self.lambda_0)
|
| 565 |
+
sigma_max: float = np.exp(-self.lambda_T)
|
| 566 |
+
ramp = np.linspace(0, 1, N + 1)
|
| 567 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 568 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 569 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 570 |
+
lambdas = torch.Tensor(-np.log(sigmas)).to(device)
|
| 571 |
+
timesteps = self.noise_schedule.inverse_lambda(lambdas)
|
| 572 |
+
|
| 573 |
+
indexes = list(
|
| 574 |
+
(self.statistics_steps * (lambdas - self.lambda_T) / (self.lambda_0 - self.lambda_T))
|
| 575 |
+
.round()
|
| 576 |
+
.cpu()
|
| 577 |
+
.numpy()
|
| 578 |
+
.astype(np.int64)
|
| 579 |
+
)
|
| 580 |
+
return indexes, timesteps
|
| 581 |
+
|
| 582 |
+
def get_g(self, f_t, i_s, i_t):
|
| 583 |
+
return torch.exp(self.S[i_s] - self.S[i_t]) * f_t - torch.exp(self.S[i_s]) * (self.B[i_t] - self.B[i_s])
|
| 584 |
+
|
| 585 |
+
def compute_exponential_coefficients_high_order(self, i_s, i_t, order=2):
|
| 586 |
+
key = (i_s, i_t, order)
|
| 587 |
+
if key in self.exp_coeffs.keys():
|
| 588 |
+
coeffs = self.exp_coeffs[key]
|
| 589 |
+
else:
|
| 590 |
+
n = order - 1
|
| 591 |
+
a = self.L[i_s : i_t + 1] + self.S[i_s : i_t + 1] - self.L[i_s] - self.S[i_s]
|
| 592 |
+
x = self.ts[i_s : i_t + 1]
|
| 593 |
+
b = (self.ts[i_s : i_t + 1] - self.ts[i_s]) ** n / math.factorial(n)
|
| 594 |
+
coeffs = weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=False)
|
| 595 |
+
self.exp_coeffs[key] = coeffs
|
| 596 |
+
return coeffs
|
| 597 |
+
|
| 598 |
+
def compute_high_order_derivatives(self, n, lambda_0n, g_0n, pseudo=False):
|
| 599 |
+
# return g^(1), ..., g^(n)
|
| 600 |
+
if pseudo:
|
| 601 |
+
D = [[] for _ in range(n + 1)]
|
| 602 |
+
D[0] = g_0n
|
| 603 |
+
for i in range(1, n + 1):
|
| 604 |
+
for j in range(n - i + 1):
|
| 605 |
+
D[i].append((D[i - 1][j] - D[i - 1][j + 1]) / (lambda_0n[j] - lambda_0n[i + j]))
|
| 606 |
+
|
| 607 |
+
return [D[i][0] * math.factorial(i) for i in range(1, n + 1)]
|
| 608 |
+
else:
|
| 609 |
+
R = []
|
| 610 |
+
for i in range(1, n + 1):
|
| 611 |
+
R.append(torch.pow(lambda_0n[1:] - lambda_0n[0], i))
|
| 612 |
+
R = torch.stack(R).t()
|
| 613 |
+
B = (torch.stack(g_0n[1:]) - g_0n[0]).reshape(n, -1)
|
| 614 |
+
shape = g_0n[0].shape
|
| 615 |
+
solution = torch.linalg.inv(R) @ B
|
| 616 |
+
solution = solution.reshape([n] + list(shape))
|
| 617 |
+
return [solution[i - 1] * math.factorial(i) for i in range(1, n + 1)]
|
| 618 |
+
|
| 619 |
+
def multistep_predictor_update(self, x_lst, eps_lst, time_lst, index_lst, t, i_t, order=1, pseudo=False):
|
| 620 |
+
# x_lst: [..., x_s]
|
| 621 |
+
# eps_lst: [..., eps_s]
|
| 622 |
+
# time_lst: [..., time_s]
|
| 623 |
+
ns = self.noise_schedule
|
| 624 |
+
n = order - 1
|
| 625 |
+
indexes = [-i - 1 for i in range(n + 1)]
|
| 626 |
+
x_0n = index_list(x_lst, indexes)
|
| 627 |
+
eps_0n = index_list(eps_lst, indexes)
|
| 628 |
+
time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
|
| 629 |
+
index_0n = index_list(index_lst, indexes)
|
| 630 |
+
lambda_0n = ns.marginal_lambda(time_0n)
|
| 631 |
+
alpha_0n = ns.marginal_alpha(time_0n)
|
| 632 |
+
sigma_0n = ns.marginal_std(time_0n)
|
| 633 |
+
|
| 634 |
+
alpha_s, alpha_t = alpha_0n[0], ns.marginal_alpha(t)
|
| 635 |
+
i_s = index_0n[0]
|
| 636 |
+
x_s = x_0n[0]
|
| 637 |
+
g_0n = []
|
| 638 |
+
for i in range(n + 1):
|
| 639 |
+
f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
|
| 640 |
+
g_i = self.get_g(f_i, index_0n[0], index_0n[i])
|
| 641 |
+
g_0n.append(g_i)
|
| 642 |
+
g_0 = g_0n[0]
|
| 643 |
+
x_t = (
|
| 644 |
+
alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
|
| 645 |
+
- alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
|
| 646 |
+
- alpha_t
|
| 647 |
+
* torch.exp(-self.L[i_t])
|
| 648 |
+
* (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
|
| 649 |
+
)
|
| 650 |
+
if order > 1:
|
| 651 |
+
g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
|
| 652 |
+
for i in range(order - 1):
|
| 653 |
+
x_t = (
|
| 654 |
+
x_t
|
| 655 |
+
- alpha_t
|
| 656 |
+
* torch.exp(self.L[i_s] - self.L[i_t])
|
| 657 |
+
* self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
|
| 658 |
+
* g_d[i]
|
| 659 |
+
)
|
| 660 |
+
return x_t
|
| 661 |
+
|
| 662 |
+
def multistep_corrector_update(self, x_lst, eps_lst, time_lst, index_lst, order=1, pseudo=False):
|
| 663 |
+
# x_lst: [..., x_s, x_t]
|
| 664 |
+
# eps_lst: [..., eps_s, eps_t]
|
| 665 |
+
# lambda_lst: [..., lambda_s, lambda_t]
|
| 666 |
+
ns = self.noise_schedule
|
| 667 |
+
n = order - 1
|
| 668 |
+
indexes = [-i - 1 for i in range(n + 1)]
|
| 669 |
+
indexes[0] = -2
|
| 670 |
+
indexes[1] = -1
|
| 671 |
+
x_0n = index_list(x_lst, indexes)
|
| 672 |
+
eps_0n = index_list(eps_lst, indexes)
|
| 673 |
+
time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
|
| 674 |
+
index_0n = index_list(index_lst, indexes)
|
| 675 |
+
lambda_0n = ns.marginal_lambda(time_0n)
|
| 676 |
+
alpha_0n = ns.marginal_alpha(time_0n)
|
| 677 |
+
sigma_0n = ns.marginal_std(time_0n)
|
| 678 |
+
|
| 679 |
+
alpha_s, alpha_t = alpha_0n[0], alpha_0n[1]
|
| 680 |
+
i_s, i_t = index_0n[0], index_0n[1]
|
| 681 |
+
x_s = x_0n[0]
|
| 682 |
+
g_0n = []
|
| 683 |
+
for i in range(n + 1):
|
| 684 |
+
f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
|
| 685 |
+
g_i = self.get_g(f_i, index_0n[0], index_0n[i])
|
| 686 |
+
g_0n.append(g_i)
|
| 687 |
+
g_0 = g_0n[0]
|
| 688 |
+
x_t_new = (
|
| 689 |
+
alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
|
| 690 |
+
- alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
|
| 691 |
+
- alpha_t
|
| 692 |
+
* torch.exp(-self.L[i_t])
|
| 693 |
+
* (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
|
| 694 |
+
)
|
| 695 |
+
if order > 1:
|
| 696 |
+
g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
|
| 697 |
+
for i in range(order - 1):
|
| 698 |
+
x_t_new = (
|
| 699 |
+
x_t_new
|
| 700 |
+
- alpha_t
|
| 701 |
+
* torch.exp(self.L[i_s] - self.L[i_t])
|
| 702 |
+
* self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
|
| 703 |
+
* g_d[i]
|
| 704 |
+
)
|
| 705 |
+
return x_t_new
|
| 706 |
+
|
| 707 |
+
def sample(
|
| 708 |
+
self,
|
| 709 |
+
x,
|
| 710 |
+
model_fn,
|
| 711 |
+
order,
|
| 712 |
+
p_pseudo,
|
| 713 |
+
use_corrector,
|
| 714 |
+
c_pseudo,
|
| 715 |
+
lower_order_final,
|
| 716 |
+
half=False,
|
| 717 |
+
return_intermediate=False,
|
| 718 |
+
):
|
| 719 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 720 |
+
steps = self.steps
|
| 721 |
+
cached_x = []
|
| 722 |
+
cached_model_output = []
|
| 723 |
+
cached_time = []
|
| 724 |
+
cached_index = []
|
| 725 |
+
indexes, timesteps = self.indexes, self.timesteps
|
| 726 |
+
step_p_order = 0
|
| 727 |
+
|
| 728 |
+
for step in range(1, steps + 1):
|
| 729 |
+
cached_x.append(x)
|
| 730 |
+
cached_model_output.append(self.noise_prediction_fn(x, timesteps[step - 1]))
|
| 731 |
+
cached_time.append(timesteps[step - 1])
|
| 732 |
+
cached_index.append(indexes[step - 1])
|
| 733 |
+
if use_corrector and (timesteps[step - 1] > 0.5 or not half):
|
| 734 |
+
step_c_order = step_p_order + c_pseudo
|
| 735 |
+
if step_c_order > 1:
|
| 736 |
+
x_new = self.multistep_corrector_update(
|
| 737 |
+
cached_x, cached_model_output, cached_time, cached_index, order=step_c_order, pseudo=c_pseudo
|
| 738 |
+
)
|
| 739 |
+
sigma_t = self.noise_schedule.marginal_std(cached_time[-1])
|
| 740 |
+
l_t = self.l[cached_index[-1]]
|
| 741 |
+
N_old = sigma_t * cached_model_output[-1] - l_t * cached_x[-1]
|
| 742 |
+
cached_x[-1] = x_new
|
| 743 |
+
cached_model_output[-1] = (N_old + l_t * cached_x[-1]) / sigma_t
|
| 744 |
+
if step < order:
|
| 745 |
+
step_p_order = step
|
| 746 |
+
else:
|
| 747 |
+
step_p_order = order
|
| 748 |
+
if lower_order_final:
|
| 749 |
+
step_p_order = min(step_p_order, steps + 1 - step)
|
| 750 |
+
t = timesteps[step]
|
| 751 |
+
i_t = indexes[step]
|
| 752 |
+
|
| 753 |
+
x = self.multistep_predictor_update(
|
| 754 |
+
cached_x, cached_model_output, cached_time, cached_index, t, i_t, order=step_p_order, pseudo=p_pseudo
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
if return_intermediate:
|
| 758 |
+
return x, cached_x
|
| 759 |
+
else:
|
| 760 |
+
return x
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
#############################################################
|
| 764 |
+
# other utility functions
|
| 765 |
+
#############################################################
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def interpolate_fn(x, xp, yp):
|
| 769 |
+
"""
|
| 770 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 771 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 772 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 773 |
+
|
| 774 |
+
Args:
|
| 775 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 776 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 777 |
+
yp: PyTorch tensor with shape [C, K].
|
| 778 |
+
Returns:
|
| 779 |
+
The function values f(x), with shape [N, C].
|
| 780 |
+
"""
|
| 781 |
+
N, K = x.shape[0], xp.shape[1]
|
| 782 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 783 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 784 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 785 |
+
cand_start_idx = x_idx - 1
|
| 786 |
+
start_idx = torch.where(
|
| 787 |
+
torch.eq(x_idx, 0),
|
| 788 |
+
torch.tensor(1, device=x.device),
|
| 789 |
+
torch.where(
|
| 790 |
+
torch.eq(x_idx, K),
|
| 791 |
+
torch.tensor(K - 2, device=x.device),
|
| 792 |
+
cand_start_idx,
|
| 793 |
+
),
|
| 794 |
+
)
|
| 795 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 796 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 797 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 798 |
+
start_idx2 = torch.where(
|
| 799 |
+
torch.eq(x_idx, 0),
|
| 800 |
+
torch.tensor(0, device=x.device),
|
| 801 |
+
torch.where(
|
| 802 |
+
torch.eq(x_idx, K),
|
| 803 |
+
torch.tensor(K - 2, device=x.device),
|
| 804 |
+
cand_start_idx,
|
| 805 |
+
),
|
| 806 |
+
)
|
| 807 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 808 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 809 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 810 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 811 |
+
return cand
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def expand_dims(v, dims):
|
| 815 |
+
"""
|
| 816 |
+
Expand the tensor `v` to the dim `dims`.
|
| 817 |
+
|
| 818 |
+
Args:
|
| 819 |
+
`v`: a PyTorch tensor with shape [N].
|
| 820 |
+
`dim`: a `int`.
|
| 821 |
+
Returns:
|
| 822 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 823 |
+
"""
|
| 824 |
+
return v[(...,) + (None,) * (dims - 1)]
|
ldm/models/diffusion/dpm_solver_v3/sampler.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DPMSolverv3Sampler:
|
| 9 |
+
def __init__(self, ckp_path, stats_dir, model, steps, guidance_scale, **kwargs):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.model = model
|
| 12 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
| 13 |
+
self.alphas_cumprod = to_torch(model.alphas_cumprod)
|
| 14 |
+
self.device = self.model.betas.device
|
| 15 |
+
self.guidance_scale = guidance_scale
|
| 16 |
+
|
| 17 |
+
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 18 |
+
|
| 19 |
+
assert stats_dir is not None, f"No statistics file found in {stats_base}."
|
| 20 |
+
print("Use statistics", stats_dir)
|
| 21 |
+
self.dpm_solver_v3 = DPM_Solver_v3(
|
| 22 |
+
statistics_dir=stats_dir,
|
| 23 |
+
noise_schedule=self.ns,
|
| 24 |
+
steps=steps,
|
| 25 |
+
t_start=None,
|
| 26 |
+
t_end=None,
|
| 27 |
+
skip_type="time_uniform",
|
| 28 |
+
degenerated=False,
|
| 29 |
+
device=self.device,
|
| 30 |
+
)
|
| 31 |
+
self.steps = steps
|
| 32 |
+
|
| 33 |
+
@torch.no_grad()
|
| 34 |
+
def sample(
|
| 35 |
+
self,
|
| 36 |
+
batch_size,
|
| 37 |
+
shape,
|
| 38 |
+
conditioning=None,
|
| 39 |
+
x_T=None,
|
| 40 |
+
unconditional_conditioning=None,
|
| 41 |
+
use_corrector=False,
|
| 42 |
+
half=False,
|
| 43 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
if conditioning is not None:
|
| 47 |
+
if isinstance(conditioning, dict):
|
| 48 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 49 |
+
if cbs != batch_size:
|
| 50 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 51 |
+
else:
|
| 52 |
+
if conditioning.shape[0] != batch_size:
|
| 53 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 54 |
+
|
| 55 |
+
# sampling
|
| 56 |
+
C, H, W = shape
|
| 57 |
+
size = (batch_size, C, H, W)
|
| 58 |
+
|
| 59 |
+
if x_T is None:
|
| 60 |
+
img = torch.randn(size, device=self.device)
|
| 61 |
+
else:
|
| 62 |
+
img = x_T
|
| 63 |
+
|
| 64 |
+
if conditioning is None:
|
| 65 |
+
model_fn = model_wrapper(
|
| 66 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 67 |
+
self.ns,
|
| 68 |
+
model_type="noise",
|
| 69 |
+
guidance_type="uncond",
|
| 70 |
+
)
|
| 71 |
+
ORDER = 3
|
| 72 |
+
else:
|
| 73 |
+
model_fn = model_wrapper(
|
| 74 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 75 |
+
self.ns,
|
| 76 |
+
model_type="noise",
|
| 77 |
+
guidance_type="classifier-free",
|
| 78 |
+
condition=conditioning,
|
| 79 |
+
unconditional_condition=unconditional_conditioning,
|
| 80 |
+
guidance_scale=self.guidance_scale,
|
| 81 |
+
)
|
| 82 |
+
ORDER = 2
|
| 83 |
+
|
| 84 |
+
x = self.dpm_solver_v3.sample(
|
| 85 |
+
img,
|
| 86 |
+
model_fn,
|
| 87 |
+
order=ORDER,
|
| 88 |
+
p_pseudo=False,
|
| 89 |
+
c_pseudo=True,
|
| 90 |
+
lower_order_final=True,
|
| 91 |
+
use_corrector=use_corrector,
|
| 92 |
+
half=half,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return x.to(self.device), None
|
ldm/models/diffusion/plms.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PLMSSampler(object):
|
| 12 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = model
|
| 15 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 16 |
+
self.schedule = schedule
|
| 17 |
+
|
| 18 |
+
def register_buffer(self, name, attr):
|
| 19 |
+
if type(attr) == torch.Tensor:
|
| 20 |
+
if attr.device != torch.device("cuda"):
|
| 21 |
+
attr = attr.to(torch.device("cuda"))
|
| 22 |
+
setattr(self, name, attr)
|
| 23 |
+
|
| 24 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
| 25 |
+
if ddim_eta != 0:
|
| 26 |
+
raise ValueError('ddim_eta must be 0 for PLMS')
|
| 27 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
| 28 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
| 29 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 30 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 31 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 32 |
+
|
| 33 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
| 34 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 35 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
| 36 |
+
|
| 37 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 38 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
| 39 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
| 40 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
| 41 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
| 42 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
| 43 |
+
|
| 44 |
+
# ddim sampling parameters
|
| 45 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
| 46 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 47 |
+
eta=ddim_eta,verbose=verbose)
|
| 48 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
| 49 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
| 50 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
| 51 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
| 52 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 53 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
| 54 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
| 55 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def sample(self,
|
| 59 |
+
S,
|
| 60 |
+
batch_size,
|
| 61 |
+
shape,
|
| 62 |
+
conditioning=None,
|
| 63 |
+
callback=None,
|
| 64 |
+
normals_sequence=None,
|
| 65 |
+
img_callback=None,
|
| 66 |
+
quantize_x0=False,
|
| 67 |
+
eta=0.,
|
| 68 |
+
mask=None,
|
| 69 |
+
x0=None,
|
| 70 |
+
temperature=1.,
|
| 71 |
+
noise_dropout=0.,
|
| 72 |
+
score_corrector=None,
|
| 73 |
+
corrector_kwargs=None,
|
| 74 |
+
verbose=True,
|
| 75 |
+
x_T=None,
|
| 76 |
+
log_every_t=100,
|
| 77 |
+
unconditional_guidance_scale=1.,
|
| 78 |
+
unconditional_conditioning=None,
|
| 79 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 80 |
+
**kwargs
|
| 81 |
+
):
|
| 82 |
+
if conditioning is not None:
|
| 83 |
+
if isinstance(conditioning, dict):
|
| 84 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 85 |
+
if cbs != batch_size:
|
| 86 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 87 |
+
else:
|
| 88 |
+
if conditioning.shape[0] != batch_size:
|
| 89 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 90 |
+
|
| 91 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 92 |
+
# sampling
|
| 93 |
+
C, H, W = shape
|
| 94 |
+
size = (batch_size, C, H, W)
|
| 95 |
+
print(f'Data shape for PLMS sampling is {size}')
|
| 96 |
+
|
| 97 |
+
samples, intermediates = self.plms_sampling(conditioning, size,
|
| 98 |
+
callback=callback,
|
| 99 |
+
img_callback=img_callback,
|
| 100 |
+
quantize_denoised=quantize_x0,
|
| 101 |
+
mask=mask, x0=x0,
|
| 102 |
+
ddim_use_original_steps=False,
|
| 103 |
+
noise_dropout=noise_dropout,
|
| 104 |
+
temperature=temperature,
|
| 105 |
+
score_corrector=score_corrector,
|
| 106 |
+
corrector_kwargs=corrector_kwargs,
|
| 107 |
+
x_T=x_T,
|
| 108 |
+
log_every_t=log_every_t,
|
| 109 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 110 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 111 |
+
)
|
| 112 |
+
return samples, intermediates
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def plms_sampling(self, cond, shape,
|
| 116 |
+
x_T=None, ddim_use_original_steps=False,
|
| 117 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
| 118 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
| 119 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 120 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
| 121 |
+
device = self.model.betas.device
|
| 122 |
+
b = shape[0]
|
| 123 |
+
if x_T is None:
|
| 124 |
+
img = torch.randn(shape, device=device)
|
| 125 |
+
else:
|
| 126 |
+
img = x_T
|
| 127 |
+
|
| 128 |
+
if timesteps is None:
|
| 129 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 130 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 131 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 132 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 133 |
+
|
| 134 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 135 |
+
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
| 136 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 137 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
| 138 |
+
|
| 139 |
+
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
| 140 |
+
old_eps = []
|
| 141 |
+
|
| 142 |
+
for i, step in enumerate(iterator):
|
| 143 |
+
index = total_steps - i - 1
|
| 144 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 145 |
+
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
| 146 |
+
|
| 147 |
+
if mask is not None:
|
| 148 |
+
assert x0 is not None
|
| 149 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
| 150 |
+
img = img_orig * mask + (1. - mask) * img
|
| 151 |
+
|
| 152 |
+
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 153 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 154 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 155 |
+
corrector_kwargs=corrector_kwargs,
|
| 156 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 157 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 158 |
+
old_eps=old_eps, t_next=ts_next)
|
| 159 |
+
img, pred_x0, e_t = outs
|
| 160 |
+
old_eps.append(e_t)
|
| 161 |
+
if len(old_eps) >= 4:
|
| 162 |
+
old_eps.pop(0)
|
| 163 |
+
if callback: callback(i)
|
| 164 |
+
if img_callback: img_callback(pred_x0, i)
|
| 165 |
+
|
| 166 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 167 |
+
intermediates['x_inter'].append(img)
|
| 168 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 169 |
+
|
| 170 |
+
return img, intermediates
|
| 171 |
+
|
| 172 |
+
@torch.no_grad()
|
| 173 |
+
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
| 174 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 175 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
| 176 |
+
b, *_, device = *x.shape, x.device
|
| 177 |
+
|
| 178 |
+
def get_model_output(x, t):
|
| 179 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 180 |
+
e_t = self.model.apply_model(x, t, c)
|
| 181 |
+
else:
|
| 182 |
+
x_in = torch.cat([x] * 2)
|
| 183 |
+
t_in = torch.cat([t] * 2)
|
| 184 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
| 185 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
| 186 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 187 |
+
|
| 188 |
+
if score_corrector is not None:
|
| 189 |
+
assert self.model.parameterization == "eps"
|
| 190 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
| 191 |
+
|
| 192 |
+
return e_t
|
| 193 |
+
|
| 194 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 195 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 196 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 197 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 198 |
+
|
| 199 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
| 200 |
+
# select parameters corresponding to the currently considered timestep
|
| 201 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 202 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 203 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 204 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 205 |
+
|
| 206 |
+
# current prediction for x_0
|
| 207 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 208 |
+
if quantize_denoised:
|
| 209 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 210 |
+
# direction pointing to x_t
|
| 211 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 212 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 213 |
+
if noise_dropout > 0.:
|
| 214 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 215 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 216 |
+
return x_prev, pred_x0
|
| 217 |
+
|
| 218 |
+
e_t = get_model_output(x, t)
|
| 219 |
+
if len(old_eps) == 0:
|
| 220 |
+
# Pseudo Improved Euler (2nd order)
|
| 221 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
| 222 |
+
e_t_next = get_model_output(x_prev, t_next)
|
| 223 |
+
e_t_prime = (e_t + e_t_next) / 2
|
| 224 |
+
elif len(old_eps) == 1:
|
| 225 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
| 226 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
| 227 |
+
elif len(old_eps) == 2:
|
| 228 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
| 229 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
| 230 |
+
elif len(old_eps) >= 3:
|
| 231 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
| 232 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
| 233 |
+
|
| 234 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
| 235 |
+
|
| 236 |
+
return x_prev, pred_x0, e_t
|
ldm/models/diffusion/uni_pc/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampler import UniPCSampler
|
ldm/models/diffusion/uni_pc/sampler.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class UniPCSampler(object):
|
| 9 |
+
def __init__(self, model, **kwargs):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.model = model
|
| 12 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
| 13 |
+
self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
|
| 14 |
+
|
| 15 |
+
def register_buffer(self, name, attr):
|
| 16 |
+
if type(attr) == torch.Tensor:
|
| 17 |
+
if attr.device != torch.device("cuda"):
|
| 18 |
+
attr = attr.to(torch.device("cuda"))
|
| 19 |
+
setattr(self, name, attr)
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def sample(
|
| 23 |
+
self,
|
| 24 |
+
S,
|
| 25 |
+
batch_size,
|
| 26 |
+
shape,
|
| 27 |
+
conditioning=None,
|
| 28 |
+
x_T=None,
|
| 29 |
+
unconditional_guidance_scale=1.0,
|
| 30 |
+
unconditional_conditioning=None,
|
| 31 |
+
flags=None,
|
| 32 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
if conditioning is not None:
|
| 36 |
+
if isinstance(conditioning, dict):
|
| 37 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 38 |
+
if cbs != batch_size:
|
| 39 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 40 |
+
else:
|
| 41 |
+
if conditioning.shape[0] != batch_size:
|
| 42 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 43 |
+
|
| 44 |
+
# sampling
|
| 45 |
+
C, H, W = shape
|
| 46 |
+
size = (batch_size, C, H, W)
|
| 47 |
+
|
| 48 |
+
device = self.model.betas.device
|
| 49 |
+
if x_T is None:
|
| 50 |
+
img = torch.randn(size, device=device)
|
| 51 |
+
else:
|
| 52 |
+
img = x_T
|
| 53 |
+
|
| 54 |
+
ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 55 |
+
|
| 56 |
+
if conditioning is None:
|
| 57 |
+
model_fn = model_wrapper(
|
| 58 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 59 |
+
ns,
|
| 60 |
+
model_type="noise",
|
| 61 |
+
guidance_type="uncond",
|
| 62 |
+
)
|
| 63 |
+
ORDER = 3
|
| 64 |
+
else:
|
| 65 |
+
model_fn = model_wrapper(
|
| 66 |
+
lambda x, t, c: self.model.apply_model(x, t, c),
|
| 67 |
+
ns,
|
| 68 |
+
model_type="noise",
|
| 69 |
+
guidance_type="classifier-free",
|
| 70 |
+
condition=conditioning,
|
| 71 |
+
unconditional_condition=unconditional_conditioning,
|
| 72 |
+
guidance_scale=unconditional_guidance_scale, # 7.5
|
| 73 |
+
)
|
| 74 |
+
ORDER = 2
|
| 75 |
+
|
| 76 |
+
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant="bh2")
|
| 77 |
+
x = uni_pc.sample(
|
| 78 |
+
img, steps=S, skip_type=flags.skip_type, method="multistep", order=ORDER, lower_order_final=True, flags=flags
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return x.to(device), None
|
| 82 |
+
|
| 83 |
+
|
ldm/models/diffusion/uni_pc/uni_pc.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NoiseScheduleVP:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
schedule="discrete",
|
| 10 |
+
betas=None,
|
| 11 |
+
alphas_cumprod=None,
|
| 12 |
+
continuous_beta_0=0.1,
|
| 13 |
+
continuous_beta_1=20.0,
|
| 14 |
+
):
|
| 15 |
+
|
| 16 |
+
if schedule not in ["discrete", "linear", "cosine"]:
|
| 17 |
+
raise ValueError(
|
| 18 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
| 19 |
+
schedule
|
| 20 |
+
)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.schedule = schedule
|
| 24 |
+
if schedule == "discrete":
|
| 25 |
+
if betas is not None:
|
| 26 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 27 |
+
else:
|
| 28 |
+
assert alphas_cumprod is not None
|
| 29 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 30 |
+
self.total_N = len(log_alphas)
|
| 31 |
+
self.T = 1.0
|
| 32 |
+
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
|
| 33 |
+
self.log_alpha_array = log_alphas.reshape(
|
| 34 |
+
(
|
| 35 |
+
1,
|
| 36 |
+
-1,
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
self.total_N = 1000
|
| 41 |
+
self.beta_0 = continuous_beta_0
|
| 42 |
+
self.beta_1 = continuous_beta_1
|
| 43 |
+
self.cosine_s = 0.008
|
| 44 |
+
self.cosine_beta_max = 999.0
|
| 45 |
+
self.cosine_t_max = (
|
| 46 |
+
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
|
| 47 |
+
* 2.0
|
| 48 |
+
* (1.0 + self.cosine_s)
|
| 49 |
+
/ math.pi
|
| 50 |
+
- self.cosine_s
|
| 51 |
+
)
|
| 52 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 53 |
+
self.schedule = schedule
|
| 54 |
+
if schedule == "cosine":
|
| 55 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 56 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 57 |
+
self.T = 0.9946
|
| 58 |
+
else:
|
| 59 |
+
self.T = 1.0
|
| 60 |
+
|
| 61 |
+
def marginal_log_mean_coeff(self, t):
|
| 62 |
+
"""
|
| 63 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 64 |
+
"""
|
| 65 |
+
if self.schedule == "discrete":
|
| 66 |
+
return interpolate_fn(
|
| 67 |
+
t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
|
| 68 |
+
).reshape((-1))
|
| 69 |
+
elif self.schedule == "linear":
|
| 70 |
+
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 71 |
+
elif self.schedule == "cosine":
|
| 72 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 73 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 74 |
+
return log_alpha_t
|
| 75 |
+
|
| 76 |
+
def marginal_alpha(self, t):
|
| 77 |
+
"""
|
| 78 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 79 |
+
"""
|
| 80 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 81 |
+
|
| 82 |
+
def marginal_std(self, t):
|
| 83 |
+
"""
|
| 84 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 85 |
+
"""
|
| 86 |
+
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
|
| 87 |
+
|
| 88 |
+
def marginal_lambda(self, t):
|
| 89 |
+
"""
|
| 90 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 91 |
+
"""
|
| 92 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 93 |
+
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
|
| 94 |
+
return log_mean_coeff - log_std
|
| 95 |
+
|
| 96 |
+
def inverse_lambda(self, lamb):
|
| 97 |
+
"""
|
| 98 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 99 |
+
"""
|
| 100 |
+
if self.schedule == "linear":
|
| 101 |
+
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 102 |
+
Delta = self.beta_0**2 + tmp
|
| 103 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 104 |
+
elif self.schedule == "discrete":
|
| 105 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
|
| 106 |
+
t = interpolate_fn(
|
| 107 |
+
log_alpha.reshape((-1, 1)),
|
| 108 |
+
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
| 109 |
+
torch.flip(self.t_array.to(lamb.device), [1]),
|
| 110 |
+
)
|
| 111 |
+
return t.reshape((-1,))
|
| 112 |
+
else:
|
| 113 |
+
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 114 |
+
t_fn = (
|
| 115 |
+
lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
|
| 116 |
+
* 2.0
|
| 117 |
+
* (1.0 + self.cosine_s)
|
| 118 |
+
/ math.pi
|
| 119 |
+
- self.cosine_s
|
| 120 |
+
)
|
| 121 |
+
t = t_fn(log_alpha)
|
| 122 |
+
return t
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def model_wrapper(
|
| 126 |
+
model,
|
| 127 |
+
noise_schedule,
|
| 128 |
+
model_type="noise",
|
| 129 |
+
model_kwargs={},
|
| 130 |
+
guidance_type="uncond",
|
| 131 |
+
condition=None,
|
| 132 |
+
unconditional_condition=None,
|
| 133 |
+
guidance_scale=1.0,
|
| 134 |
+
classifier_fn=None,
|
| 135 |
+
classifier_kwargs={},
|
| 136 |
+
):
|
| 137 |
+
|
| 138 |
+
def get_model_input_time(t_continuous):
|
| 139 |
+
"""
|
| 140 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 141 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 142 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 143 |
+
"""
|
| 144 |
+
if noise_schedule.schedule == "discrete":
|
| 145 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
|
| 146 |
+
else:
|
| 147 |
+
return t_continuous
|
| 148 |
+
|
| 149 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 150 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 151 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 152 |
+
t_input = get_model_input_time(t_continuous)
|
| 153 |
+
if cond is None:
|
| 154 |
+
output = model(x, t_input, None, **model_kwargs)
|
| 155 |
+
else:
|
| 156 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 157 |
+
if model_type == "noise":
|
| 158 |
+
return output
|
| 159 |
+
elif model_type == "x_start":
|
| 160 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 161 |
+
dims = x.dim()
|
| 162 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
| 163 |
+
elif model_type == "v":
|
| 164 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 165 |
+
dims = x.dim()
|
| 166 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 167 |
+
elif model_type == "score":
|
| 168 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 169 |
+
dims = x.dim()
|
| 170 |
+
return -expand_dims(sigma_t, dims) * output
|
| 171 |
+
|
| 172 |
+
def cond_grad_fn(x, t_input):
|
| 173 |
+
"""
|
| 174 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 175 |
+
"""
|
| 176 |
+
with torch.enable_grad():
|
| 177 |
+
x_in = x.detach().requires_grad_(True)
|
| 178 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 179 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 180 |
+
|
| 181 |
+
def model_fn(x, t_continuous):
|
| 182 |
+
"""
|
| 183 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 184 |
+
"""
|
| 185 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 186 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 187 |
+
if guidance_type == "uncond":
|
| 188 |
+
return noise_pred_fn(x, t_continuous)
|
| 189 |
+
elif guidance_type == "classifier":
|
| 190 |
+
assert classifier_fn is not None
|
| 191 |
+
t_input = get_model_input_time(t_continuous)
|
| 192 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 193 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 194 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 195 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
| 196 |
+
elif guidance_type == "classifier-free":
|
| 197 |
+
if guidance_scale == 1.0 or unconditional_condition is None:
|
| 198 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 199 |
+
else:
|
| 200 |
+
x_in = torch.cat([x] * 2)
|
| 201 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 202 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 203 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 204 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 205 |
+
|
| 206 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 207 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 208 |
+
return model_fn
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class UniPC:
|
| 212 |
+
def __init__(self, model_fn, noise_schedule, predict_x0=True, thresholding=False, max_val=1.0, variant="bh1"):
|
| 213 |
+
"""Construct a UniPC.
|
| 214 |
+
|
| 215 |
+
We support both data_prediction and noise_prediction.
|
| 216 |
+
"""
|
| 217 |
+
self.model = model_fn
|
| 218 |
+
self.noise_schedule = noise_schedule
|
| 219 |
+
self.variant = variant
|
| 220 |
+
self.predict_x0 = predict_x0
|
| 221 |
+
self.thresholding = thresholding
|
| 222 |
+
self.max_val = max_val
|
| 223 |
+
|
| 224 |
+
def dynamic_thresholding_fn(self, x0, t=None):
|
| 225 |
+
"""
|
| 226 |
+
The dynamic thresholding method.
|
| 227 |
+
"""
|
| 228 |
+
dims = x0.dim()
|
| 229 |
+
p = self.dynamic_thresholding_ratio
|
| 230 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 231 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
| 232 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 233 |
+
return x0
|
| 234 |
+
|
| 235 |
+
def noise_prediction_fn(self, x, t):
|
| 236 |
+
"""
|
| 237 |
+
Return the noise prediction model.
|
| 238 |
+
"""
|
| 239 |
+
return self.model(x, t)
|
| 240 |
+
|
| 241 |
+
def data_prediction_fn(self, x, t):
|
| 242 |
+
"""
|
| 243 |
+
Return the data prediction model (with thresholding).
|
| 244 |
+
"""
|
| 245 |
+
noise = self.noise_prediction_fn(x, t)
|
| 246 |
+
dims = x.dim()
|
| 247 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
| 248 |
+
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
| 249 |
+
if self.thresholding:
|
| 250 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
| 251 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 252 |
+
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
| 253 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 254 |
+
return x0
|
| 255 |
+
|
| 256 |
+
def model_fn(self, x, t):
|
| 257 |
+
"""
|
| 258 |
+
Convert the model to the noise prediction model or the data prediction model.
|
| 259 |
+
"""
|
| 260 |
+
if self.predict_x0:
|
| 261 |
+
return self.data_prediction_fn(x, t)
|
| 262 |
+
else:
|
| 263 |
+
return self.noise_prediction_fn(x, t)
|
| 264 |
+
|
| 265 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 266 |
+
"""Compute the intermediate time steps for sampling."""
|
| 267 |
+
if skip_type == "logSNR":
|
| 268 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 269 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 270 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 271 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 272 |
+
elif skip_type == "time_uniform":
|
| 273 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 274 |
+
elif skip_type == "time_quadratic":
|
| 275 |
+
t_order = 2
|
| 276 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 277 |
+
return t
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(
|
| 280 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, t2, order, **kwargs):
|
| 284 |
+
if len(t.shape) == 0:
|
| 285 |
+
t = t.view(-1)
|
| 286 |
+
if "bh" in self.variant:
|
| 287 |
+
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, t2, order, **kwargs)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, t2, order, x_t=None, use_corrector=True):
|
| 291 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
| 292 |
+
ns = self.noise_schedule
|
| 293 |
+
assert order <= len(model_prev_list)
|
| 294 |
+
dims = x.dim()
|
| 295 |
+
|
| 296 |
+
# first compute rks
|
| 297 |
+
t_prev_0 = t_prev_list[-1]
|
| 298 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 299 |
+
lambda_t = ns.marginal_lambda(t)
|
| 300 |
+
model_prev_0 = model_prev_list[-1]
|
| 301 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 302 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 303 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 304 |
+
|
| 305 |
+
h = lambda_t - lambda_prev_0
|
| 306 |
+
|
| 307 |
+
rks = []
|
| 308 |
+
D1s = []
|
| 309 |
+
for i in range(1, order):
|
| 310 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 311 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 312 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 313 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 314 |
+
rks.append(rk)
|
| 315 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 316 |
+
|
| 317 |
+
rks.append(1.0)
|
| 318 |
+
rks = torch.tensor(rks, device=x.device)
|
| 319 |
+
|
| 320 |
+
R = []
|
| 321 |
+
b = []
|
| 322 |
+
|
| 323 |
+
hh = -h[0] if self.predict_x0 else h[0]
|
| 324 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 325 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 326 |
+
|
| 327 |
+
factorial_i = 1
|
| 328 |
+
|
| 329 |
+
if self.variant == "bh1":
|
| 330 |
+
B_h = hh
|
| 331 |
+
elif self.variant == "bh2":
|
| 332 |
+
B_h = torch.expm1(hh)
|
| 333 |
+
else:
|
| 334 |
+
raise NotImplementedError()
|
| 335 |
+
|
| 336 |
+
for i in range(1, order + 1):
|
| 337 |
+
R.append(torch.pow(rks, i - 1))
|
| 338 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 339 |
+
factorial_i *= i + 1
|
| 340 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 341 |
+
|
| 342 |
+
R = torch.stack(R)
|
| 343 |
+
b = torch.tensor(b, device=x.device)
|
| 344 |
+
|
| 345 |
+
# now predictor
|
| 346 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
| 347 |
+
if len(D1s) > 0:
|
| 348 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 349 |
+
if x_t is None:
|
| 350 |
+
# for order 2, we use a simplified version
|
| 351 |
+
if order == 2:
|
| 352 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 353 |
+
else:
|
| 354 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 355 |
+
else:
|
| 356 |
+
D1s = None
|
| 357 |
+
|
| 358 |
+
if use_corrector:
|
| 359 |
+
# print('using corrector')
|
| 360 |
+
# for order 1, we use a simplified version
|
| 361 |
+
if order == 1:
|
| 362 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 363 |
+
else:
|
| 364 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 365 |
+
|
| 366 |
+
model_t = None
|
| 367 |
+
if self.predict_x0:
|
| 368 |
+
x_t_ = expand_dims(sigma_t / sigma_prev_0, dims) * x - expand_dims(alpha_t * h_phi_1, dims) * model_prev_0
|
| 369 |
+
|
| 370 |
+
if x_t is None:
|
| 371 |
+
if use_predictor:
|
| 372 |
+
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
| 373 |
+
else:
|
| 374 |
+
pred_res = 0
|
| 375 |
+
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
| 376 |
+
|
| 377 |
+
if use_corrector:
|
| 378 |
+
model_t = self.model_fn(x_t, t2)
|
| 379 |
+
if D1s is not None:
|
| 380 |
+
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
| 381 |
+
else:
|
| 382 |
+
corr_res = 0
|
| 383 |
+
D1_t = model_t - model_prev_0
|
| 384 |
+
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 385 |
+
else:
|
| 386 |
+
x_t_ = (
|
| 387 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 388 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
| 389 |
+
)
|
| 390 |
+
if x_t is None:
|
| 391 |
+
if use_predictor:
|
| 392 |
+
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
| 393 |
+
else:
|
| 394 |
+
pred_res = 0
|
| 395 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
| 396 |
+
|
| 397 |
+
if use_corrector:
|
| 398 |
+
model_t = self.model_fn(x_t, t2)
|
| 399 |
+
if D1s is not None:
|
| 400 |
+
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
| 401 |
+
else:
|
| 402 |
+
corr_res = 0
|
| 403 |
+
D1_t = model_t - model_prev_0
|
| 404 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 405 |
+
return x_t, model_t
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def sample(
|
| 409 |
+
self,
|
| 410 |
+
x,
|
| 411 |
+
steps=20,
|
| 412 |
+
t_start=None,
|
| 413 |
+
t_end=None,
|
| 414 |
+
order=3,
|
| 415 |
+
skip_type="time_uniform",
|
| 416 |
+
method="singlestep",
|
| 417 |
+
lower_order_final=True,
|
| 418 |
+
denoise_to_zero=False,
|
| 419 |
+
solver_type="dpm_solver",
|
| 420 |
+
atol=0.0078,
|
| 421 |
+
rtol=0.05,
|
| 422 |
+
corrector=False,
|
| 423 |
+
flags=None,
|
| 424 |
+
):
|
| 425 |
+
|
| 426 |
+
device = x.device
|
| 427 |
+
assert steps >= order
|
| 428 |
+
with torch.no_grad():
|
| 429 |
+
if flags.learn:
|
| 430 |
+
load_from = f"{flags.log_path}/NFE-{steps}-256LSUN-uni_pc-{order}-decode/best.pt"
|
| 431 |
+
timesteps = torch.load(load_from)['best_t_steps'].to(x.device)
|
| 432 |
+
if flags.vs:
|
| 433 |
+
length = timesteps.shape[0] // 2
|
| 434 |
+
timesteps2 = timesteps[length:]
|
| 435 |
+
timesteps = timesteps[:length]
|
| 436 |
+
else:
|
| 437 |
+
timesteps2 = timesteps
|
| 438 |
+
else:
|
| 439 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 440 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 441 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 442 |
+
timesteps2 = timesteps
|
| 443 |
+
assert timesteps.shape[0] - 1 == steps
|
| 444 |
+
|
| 445 |
+
def one_step(t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True, use_corrector=True):
|
| 446 |
+
x_next, model_x_next = self.multistep_uni_pc_update(x_next, model_prev_list, t_prev_list, t1, t2, step, use_corrector=use_corrector)
|
| 447 |
+
if model_x_next is None:
|
| 448 |
+
model_x_next = self.model_fn(x_next, t2)
|
| 449 |
+
update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first)
|
| 450 |
+
return x_next
|
| 451 |
+
|
| 452 |
+
def update_lists(t_list, model_list, t_, model_x, order, first=False):
|
| 453 |
+
if first:
|
| 454 |
+
t_list.append(t_)
|
| 455 |
+
model_list.append(model_x)
|
| 456 |
+
return
|
| 457 |
+
for m in range(order - 1):
|
| 458 |
+
t_list[m] = t_list[m + 1]
|
| 459 |
+
model_list[m] = model_list[m + 1]
|
| 460 |
+
t_list[-1] = t_
|
| 461 |
+
model_list[-1] = model_x
|
| 462 |
+
|
| 463 |
+
timesteps1 = timesteps
|
| 464 |
+
step = 0
|
| 465 |
+
vec_t1 = timesteps1[0].expand((x.shape[0])) # bs
|
| 466 |
+
vec_t2 = timesteps2[0].expand((x.shape[0])) # bs
|
| 467 |
+
t_prev_list = [vec_t1]
|
| 468 |
+
model_prev_list = [self.model_fn(x, vec_t2)]
|
| 469 |
+
|
| 470 |
+
for step in range(1, order):
|
| 471 |
+
vec_t1 = timesteps1[step].expand((x.shape[0]))
|
| 472 |
+
vec_t2 = timesteps2[step].expand((x.shape[0]))
|
| 473 |
+
x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step, x, order, first=True)
|
| 474 |
+
|
| 475 |
+
for step in range(order, steps + 1):
|
| 476 |
+
step_order = min(order, steps + 1 - step)
|
| 477 |
+
vec_t1 = timesteps1[step].expand((x.shape[0]))
|
| 478 |
+
vec_t2 = timesteps2[step].expand((x.shape[0]))
|
| 479 |
+
use_corrector = True
|
| 480 |
+
if step == steps:
|
| 481 |
+
use_corrector = False
|
| 482 |
+
x = one_step(vec_t1, vec_t2, t_prev_list, model_prev_list, step_order, x, order, first=False, use_corrector=use_corrector)
|
| 483 |
+
return x
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
#############################################################
|
| 487 |
+
# other utility functions
|
| 488 |
+
#############################################################
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def interpolate_fn(x, xp, yp):
|
| 492 |
+
"""
|
| 493 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 494 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 495 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 499 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 500 |
+
yp: PyTorch tensor with shape [C, K].
|
| 501 |
+
Returns:
|
| 502 |
+
The function values f(x), with shape [N, C].
|
| 503 |
+
"""
|
| 504 |
+
N, K = x.shape[0], xp.shape[1]
|
| 505 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 506 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 507 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 508 |
+
cand_start_idx = x_idx - 1
|
| 509 |
+
start_idx = torch.where(
|
| 510 |
+
torch.eq(x_idx, 0),
|
| 511 |
+
torch.tensor(1, device=x.device),
|
| 512 |
+
torch.where(
|
| 513 |
+
torch.eq(x_idx, K),
|
| 514 |
+
torch.tensor(K - 2, device=x.device),
|
| 515 |
+
cand_start_idx,
|
| 516 |
+
),
|
| 517 |
+
)
|
| 518 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 519 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 520 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 521 |
+
start_idx2 = torch.where(
|
| 522 |
+
torch.eq(x_idx, 0),
|
| 523 |
+
torch.tensor(0, device=x.device),
|
| 524 |
+
torch.where(
|
| 525 |
+
torch.eq(x_idx, K),
|
| 526 |
+
torch.tensor(K - 2, device=x.device),
|
| 527 |
+
cand_start_idx,
|
| 528 |
+
),
|
| 529 |
+
)
|
| 530 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 531 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 532 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 533 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 534 |
+
return cand
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def expand_dims(v, dims):
|
| 538 |
+
"""
|
| 539 |
+
Expand the tensor `v` to the dim `dims`.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
`v`: a PyTorch tensor with shape [N].
|
| 543 |
+
`dim`: a `int`.
|
| 544 |
+
Returns:
|
| 545 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 546 |
+
"""
|
| 547 |
+
return v[(...,) + (None,) * (dims - 1)]
|
ldm/modules/attention.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inspect import isfunction
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn, einsum
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
|
| 8 |
+
from ldm.modules.diffusionmodules.util import checkpoint
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def exists(val):
|
| 12 |
+
return val is not None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def uniq(arr):
|
| 16 |
+
return {el: True for el in arr}.keys()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def default(val, d):
|
| 20 |
+
if exists(val):
|
| 21 |
+
return val
|
| 22 |
+
return d() if isfunction(d) else d
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def max_neg_value(t):
|
| 26 |
+
return -torch.finfo(t.dtype).max
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def init_(tensor):
|
| 30 |
+
dim = tensor.shape[-1]
|
| 31 |
+
std = 1 / math.sqrt(dim)
|
| 32 |
+
tensor.uniform_(-std, std)
|
| 33 |
+
return tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# feedforward
|
| 37 |
+
class GEGLU(nn.Module):
|
| 38 |
+
def __init__(self, dim_in, dim_out):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 44 |
+
return x * F.gelu(gate)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class FeedForward(nn.Module):
|
| 48 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
| 49 |
+
super().__init__()
|
| 50 |
+
inner_dim = int(dim * mult)
|
| 51 |
+
dim_out = default(dim_out, dim)
|
| 52 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
| 53 |
+
|
| 54 |
+
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
return self.net(x)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def zero_module(module):
|
| 61 |
+
"""
|
| 62 |
+
Zero out the parameters of a module and return it.
|
| 63 |
+
"""
|
| 64 |
+
for p in module.parameters():
|
| 65 |
+
p.detach().zero_()
|
| 66 |
+
return module
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def Normalize(in_channels):
|
| 70 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LinearAttention(nn.Module):
|
| 74 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.heads = heads
|
| 77 |
+
hidden_dim = dim_head * heads
|
| 78 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
| 79 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
b, c, h, w = x.shape
|
| 83 |
+
qkv = self.to_qkv(x)
|
| 84 |
+
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
| 85 |
+
k = k.softmax(dim=-1)
|
| 86 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
| 87 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
| 88 |
+
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
| 89 |
+
return self.to_out(out)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SpatialSelfAttention(nn.Module):
|
| 93 |
+
def __init__(self, in_channels):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.in_channels = in_channels
|
| 96 |
+
|
| 97 |
+
self.norm = Normalize(in_channels)
|
| 98 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 99 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 100 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 101 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
h_ = x
|
| 105 |
+
h_ = self.norm(h_)
|
| 106 |
+
q = self.q(h_)
|
| 107 |
+
k = self.k(h_)
|
| 108 |
+
v = self.v(h_)
|
| 109 |
+
|
| 110 |
+
# compute attention
|
| 111 |
+
b, c, h, w = q.shape
|
| 112 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
| 113 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
| 114 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
| 115 |
+
|
| 116 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 117 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 118 |
+
|
| 119 |
+
# attend to values
|
| 120 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
| 121 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
| 122 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
| 123 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
| 124 |
+
h_ = self.proj_out(h_)
|
| 125 |
+
|
| 126 |
+
return x + h_
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class CrossAttention(nn.Module):
|
| 130 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
| 131 |
+
super().__init__()
|
| 132 |
+
inner_dim = dim_head * heads
|
| 133 |
+
context_dim = default(context_dim, query_dim)
|
| 134 |
+
|
| 135 |
+
self.scale = dim_head ** -0.5
|
| 136 |
+
self.heads = heads
|
| 137 |
+
|
| 138 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 139 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
| 140 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
| 141 |
+
|
| 142 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 143 |
+
|
| 144 |
+
def forward(self, x, context=None, mask=None):
|
| 145 |
+
h = self.heads
|
| 146 |
+
|
| 147 |
+
q = self.to_q(x)
|
| 148 |
+
context = default(context, x)
|
| 149 |
+
k = self.to_k(context)
|
| 150 |
+
v = self.to_v(context)
|
| 151 |
+
|
| 152 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
| 153 |
+
|
| 154 |
+
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
| 155 |
+
|
| 156 |
+
if exists(mask):
|
| 157 |
+
mask = rearrange(mask, "b ... -> b (...)")
|
| 158 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 159 |
+
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
| 160 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 161 |
+
|
| 162 |
+
# attention, what we cannot get enough of
|
| 163 |
+
attn = sim.softmax(dim=-1)
|
| 164 |
+
|
| 165 |
+
out = einsum("b i j, b j d -> b i d", attn, v)
|
| 166 |
+
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
| 167 |
+
return self.to_out(out)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class BasicTransformerBlock(nn.Module):
|
| 171 |
+
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=False):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
| 174 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 175 |
+
self.attn2 = CrossAttention(
|
| 176 |
+
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
| 177 |
+
) # is self-attn if context is none
|
| 178 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 179 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 180 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 181 |
+
self.checkpoint = checkpoint
|
| 182 |
+
|
| 183 |
+
def forward(self, x, context=None):
|
| 184 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
| 185 |
+
|
| 186 |
+
def _forward(self, x, context=None):
|
| 187 |
+
x = self.attn1(self.norm1(x)) + x
|
| 188 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
| 189 |
+
x = self.ff(self.norm3(x)) + x
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class SpatialTransformer(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
Transformer block for image-like data.
|
| 196 |
+
First, project the input (aka embedding)
|
| 197 |
+
and reshape to b, t, d.
|
| 198 |
+
Then apply standard transformer action.
|
| 199 |
+
Finally, reshape to image
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.in_channels = in_channels
|
| 205 |
+
inner_dim = n_heads * d_head
|
| 206 |
+
self.norm = Normalize(in_channels)
|
| 207 |
+
|
| 208 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 209 |
+
|
| 210 |
+
self.transformer_blocks = nn.ModuleList(
|
| 211 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth)]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
| 215 |
+
|
| 216 |
+
def forward(self, x, context=None):
|
| 217 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 218 |
+
b, c, h, w = x.shape
|
| 219 |
+
x_in = x
|
| 220 |
+
x = self.norm(x)
|
| 221 |
+
x = self.proj_in(x)
|
| 222 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 223 |
+
for block in self.transformer_blocks:
|
| 224 |
+
x = block(x, context=context)
|
| 225 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 226 |
+
x = self.proj_out(x)
|
| 227 |
+
return x + x_in
|
ldm/modules/diffusionmodules/__init__.py
ADDED
|
File without changes
|
ldm/modules/diffusionmodules/model.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytorch_diffusion + derived encoder decoder
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
from ldm.util import instantiate_from_config
|
| 9 |
+
from ldm.modules.attention import LinearAttention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
| 13 |
+
"""
|
| 14 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
| 15 |
+
From Fairseq.
|
| 16 |
+
Build sinusoidal embeddings.
|
| 17 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 18 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 19 |
+
"""
|
| 20 |
+
assert len(timesteps.shape) == 1
|
| 21 |
+
|
| 22 |
+
half_dim = embedding_dim // 2
|
| 23 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 24 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
| 25 |
+
emb = emb.to(device=timesteps.device)
|
| 26 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
| 27 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 28 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 29 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
| 30 |
+
return emb
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def nonlinearity(x):
|
| 34 |
+
# swish
|
| 35 |
+
return x*torch.sigmoid(x)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def Normalize(in_channels, num_groups=32):
|
| 39 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Upsample(nn.Module):
|
| 43 |
+
def __init__(self, in_channels, with_conv):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.with_conv = with_conv
|
| 46 |
+
if self.with_conv:
|
| 47 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 48 |
+
in_channels,
|
| 49 |
+
kernel_size=3,
|
| 50 |
+
stride=1,
|
| 51 |
+
padding=1)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 55 |
+
if self.with_conv:
|
| 56 |
+
x = self.conv(x)
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Downsample(nn.Module):
|
| 61 |
+
def __init__(self, in_channels, with_conv):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.with_conv = with_conv
|
| 64 |
+
if self.with_conv:
|
| 65 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 66 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 67 |
+
in_channels,
|
| 68 |
+
kernel_size=3,
|
| 69 |
+
stride=2,
|
| 70 |
+
padding=0)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
if self.with_conv:
|
| 74 |
+
pad = (0,1,0,1)
|
| 75 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 76 |
+
x = self.conv(x)
|
| 77 |
+
else:
|
| 78 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ResnetBlock(nn.Module):
|
| 83 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 84 |
+
dropout, temb_channels=512):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.in_channels = in_channels
|
| 87 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 88 |
+
self.out_channels = out_channels
|
| 89 |
+
self.use_conv_shortcut = conv_shortcut
|
| 90 |
+
|
| 91 |
+
self.norm1 = Normalize(in_channels)
|
| 92 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 93 |
+
out_channels,
|
| 94 |
+
kernel_size=3,
|
| 95 |
+
stride=1,
|
| 96 |
+
padding=1)
|
| 97 |
+
if temb_channels > 0:
|
| 98 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 99 |
+
out_channels)
|
| 100 |
+
self.norm2 = Normalize(out_channels)
|
| 101 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 102 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 103 |
+
out_channels,
|
| 104 |
+
kernel_size=3,
|
| 105 |
+
stride=1,
|
| 106 |
+
padding=1)
|
| 107 |
+
if self.in_channels != self.out_channels:
|
| 108 |
+
if self.use_conv_shortcut:
|
| 109 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 110 |
+
out_channels,
|
| 111 |
+
kernel_size=3,
|
| 112 |
+
stride=1,
|
| 113 |
+
padding=1)
|
| 114 |
+
else:
|
| 115 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 116 |
+
out_channels,
|
| 117 |
+
kernel_size=1,
|
| 118 |
+
stride=1,
|
| 119 |
+
padding=0)
|
| 120 |
+
|
| 121 |
+
def forward(self, x, temb):
|
| 122 |
+
h = x
|
| 123 |
+
h = self.norm1(h)
|
| 124 |
+
h = nonlinearity(h)
|
| 125 |
+
h = self.conv1(h)
|
| 126 |
+
|
| 127 |
+
if temb is not None:
|
| 128 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
| 129 |
+
|
| 130 |
+
h = self.norm2(h)
|
| 131 |
+
h = nonlinearity(h)
|
| 132 |
+
h = self.dropout(h)
|
| 133 |
+
h = self.conv2(h)
|
| 134 |
+
|
| 135 |
+
if self.in_channels != self.out_channels:
|
| 136 |
+
if self.use_conv_shortcut:
|
| 137 |
+
x = self.conv_shortcut(x)
|
| 138 |
+
else:
|
| 139 |
+
x = self.nin_shortcut(x)
|
| 140 |
+
|
| 141 |
+
return x+h
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class LinAttnBlock(LinearAttention):
|
| 145 |
+
"""to match AttnBlock usage"""
|
| 146 |
+
def __init__(self, in_channels):
|
| 147 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class AttnBlock(nn.Module):
|
| 151 |
+
def __init__(self, in_channels):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.in_channels = in_channels
|
| 154 |
+
|
| 155 |
+
self.norm = Normalize(in_channels)
|
| 156 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 157 |
+
in_channels,
|
| 158 |
+
kernel_size=1,
|
| 159 |
+
stride=1,
|
| 160 |
+
padding=0)
|
| 161 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 162 |
+
in_channels,
|
| 163 |
+
kernel_size=1,
|
| 164 |
+
stride=1,
|
| 165 |
+
padding=0)
|
| 166 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 167 |
+
in_channels,
|
| 168 |
+
kernel_size=1,
|
| 169 |
+
stride=1,
|
| 170 |
+
padding=0)
|
| 171 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 172 |
+
in_channels,
|
| 173 |
+
kernel_size=1,
|
| 174 |
+
stride=1,
|
| 175 |
+
padding=0)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
h_ = x
|
| 180 |
+
h_ = self.norm(h_)
|
| 181 |
+
q = self.q(h_)
|
| 182 |
+
k = self.k(h_)
|
| 183 |
+
v = self.v(h_)
|
| 184 |
+
|
| 185 |
+
# compute attention
|
| 186 |
+
b,c,h,w = q.shape
|
| 187 |
+
q = q.reshape(b,c,h*w)
|
| 188 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 189 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 190 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 191 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 192 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 193 |
+
|
| 194 |
+
# attend to values
|
| 195 |
+
v = v.reshape(b,c,h*w)
|
| 196 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 197 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 198 |
+
h_ = h_.reshape(b,c,h,w)
|
| 199 |
+
|
| 200 |
+
h_ = self.proj_out(h_)
|
| 201 |
+
|
| 202 |
+
return x+h_
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
| 206 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
| 207 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
| 208 |
+
if attn_type == "vanilla":
|
| 209 |
+
return AttnBlock(in_channels)
|
| 210 |
+
elif attn_type == "none":
|
| 211 |
+
return nn.Identity(in_channels)
|
| 212 |
+
else:
|
| 213 |
+
return LinAttnBlock(in_channels)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class Model(nn.Module):
|
| 217 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 218 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 219 |
+
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
| 220 |
+
super().__init__()
|
| 221 |
+
if use_linear_attn: attn_type = "linear"
|
| 222 |
+
self.ch = ch
|
| 223 |
+
self.temb_ch = self.ch*4
|
| 224 |
+
self.num_resolutions = len(ch_mult)
|
| 225 |
+
self.num_res_blocks = num_res_blocks
|
| 226 |
+
self.resolution = resolution
|
| 227 |
+
self.in_channels = in_channels
|
| 228 |
+
|
| 229 |
+
self.use_timestep = use_timestep
|
| 230 |
+
if self.use_timestep:
|
| 231 |
+
# timestep embedding
|
| 232 |
+
self.temb = nn.Module()
|
| 233 |
+
self.temb.dense = nn.ModuleList([
|
| 234 |
+
torch.nn.Linear(self.ch,
|
| 235 |
+
self.temb_ch),
|
| 236 |
+
torch.nn.Linear(self.temb_ch,
|
| 237 |
+
self.temb_ch),
|
| 238 |
+
])
|
| 239 |
+
|
| 240 |
+
# downsampling
|
| 241 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 242 |
+
self.ch,
|
| 243 |
+
kernel_size=3,
|
| 244 |
+
stride=1,
|
| 245 |
+
padding=1)
|
| 246 |
+
|
| 247 |
+
curr_res = resolution
|
| 248 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 249 |
+
self.down = nn.ModuleList()
|
| 250 |
+
for i_level in range(self.num_resolutions):
|
| 251 |
+
block = nn.ModuleList()
|
| 252 |
+
attn = nn.ModuleList()
|
| 253 |
+
block_in = ch*in_ch_mult[i_level]
|
| 254 |
+
block_out = ch*ch_mult[i_level]
|
| 255 |
+
for i_block in range(self.num_res_blocks):
|
| 256 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 257 |
+
out_channels=block_out,
|
| 258 |
+
temb_channels=self.temb_ch,
|
| 259 |
+
dropout=dropout))
|
| 260 |
+
block_in = block_out
|
| 261 |
+
if curr_res in attn_resolutions:
|
| 262 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 263 |
+
down = nn.Module()
|
| 264 |
+
down.block = block
|
| 265 |
+
down.attn = attn
|
| 266 |
+
if i_level != self.num_resolutions-1:
|
| 267 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 268 |
+
curr_res = curr_res // 2
|
| 269 |
+
self.down.append(down)
|
| 270 |
+
|
| 271 |
+
# middle
|
| 272 |
+
self.mid = nn.Module()
|
| 273 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 274 |
+
out_channels=block_in,
|
| 275 |
+
temb_channels=self.temb_ch,
|
| 276 |
+
dropout=dropout)
|
| 277 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 278 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 279 |
+
out_channels=block_in,
|
| 280 |
+
temb_channels=self.temb_ch,
|
| 281 |
+
dropout=dropout)
|
| 282 |
+
|
| 283 |
+
# upsampling
|
| 284 |
+
self.up = nn.ModuleList()
|
| 285 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 286 |
+
block = nn.ModuleList()
|
| 287 |
+
attn = nn.ModuleList()
|
| 288 |
+
block_out = ch*ch_mult[i_level]
|
| 289 |
+
skip_in = ch*ch_mult[i_level]
|
| 290 |
+
for i_block in range(self.num_res_blocks+1):
|
| 291 |
+
if i_block == self.num_res_blocks:
|
| 292 |
+
skip_in = ch*in_ch_mult[i_level]
|
| 293 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
| 294 |
+
out_channels=block_out,
|
| 295 |
+
temb_channels=self.temb_ch,
|
| 296 |
+
dropout=dropout))
|
| 297 |
+
block_in = block_out
|
| 298 |
+
if curr_res in attn_resolutions:
|
| 299 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 300 |
+
up = nn.Module()
|
| 301 |
+
up.block = block
|
| 302 |
+
up.attn = attn
|
| 303 |
+
if i_level != 0:
|
| 304 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 305 |
+
curr_res = curr_res * 2
|
| 306 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 307 |
+
|
| 308 |
+
# end
|
| 309 |
+
self.norm_out = Normalize(block_in)
|
| 310 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 311 |
+
out_ch,
|
| 312 |
+
kernel_size=3,
|
| 313 |
+
stride=1,
|
| 314 |
+
padding=1)
|
| 315 |
+
|
| 316 |
+
def forward(self, x, t=None, context=None):
|
| 317 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
| 318 |
+
if context is not None:
|
| 319 |
+
# assume aligned context, cat along channel axis
|
| 320 |
+
x = torch.cat((x, context), dim=1)
|
| 321 |
+
if self.use_timestep:
|
| 322 |
+
# timestep embedding
|
| 323 |
+
assert t is not None
|
| 324 |
+
temb = get_timestep_embedding(t, self.ch)
|
| 325 |
+
temb = self.temb.dense[0](temb)
|
| 326 |
+
temb = nonlinearity(temb)
|
| 327 |
+
temb = self.temb.dense[1](temb)
|
| 328 |
+
else:
|
| 329 |
+
temb = None
|
| 330 |
+
|
| 331 |
+
# downsampling
|
| 332 |
+
hs = [self.conv_in(x)]
|
| 333 |
+
for i_level in range(self.num_resolutions):
|
| 334 |
+
for i_block in range(self.num_res_blocks):
|
| 335 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 336 |
+
if len(self.down[i_level].attn) > 0:
|
| 337 |
+
h = self.down[i_level].attn[i_block](h)
|
| 338 |
+
hs.append(h)
|
| 339 |
+
if i_level != self.num_resolutions-1:
|
| 340 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 341 |
+
|
| 342 |
+
# middle
|
| 343 |
+
h = hs[-1]
|
| 344 |
+
h = self.mid.block_1(h, temb)
|
| 345 |
+
h = self.mid.attn_1(h)
|
| 346 |
+
h = self.mid.block_2(h, temb)
|
| 347 |
+
|
| 348 |
+
# upsampling
|
| 349 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 350 |
+
for i_block in range(self.num_res_blocks+1):
|
| 351 |
+
h = self.up[i_level].block[i_block](
|
| 352 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
| 353 |
+
if len(self.up[i_level].attn) > 0:
|
| 354 |
+
h = self.up[i_level].attn[i_block](h)
|
| 355 |
+
if i_level != 0:
|
| 356 |
+
h = self.up[i_level].upsample(h)
|
| 357 |
+
|
| 358 |
+
# end
|
| 359 |
+
h = self.norm_out(h)
|
| 360 |
+
h = nonlinearity(h)
|
| 361 |
+
h = self.conv_out(h)
|
| 362 |
+
return h
|
| 363 |
+
|
| 364 |
+
def get_last_layer(self):
|
| 365 |
+
return self.conv_out.weight
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class Encoder(nn.Module):
|
| 369 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 370 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 371 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
| 372 |
+
**ignore_kwargs):
|
| 373 |
+
super().__init__()
|
| 374 |
+
if use_linear_attn: attn_type = "linear"
|
| 375 |
+
self.ch = ch
|
| 376 |
+
self.temb_ch = 0
|
| 377 |
+
self.num_resolutions = len(ch_mult)
|
| 378 |
+
self.num_res_blocks = num_res_blocks
|
| 379 |
+
self.resolution = resolution
|
| 380 |
+
self.in_channels = in_channels
|
| 381 |
+
|
| 382 |
+
# downsampling
|
| 383 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 384 |
+
self.ch,
|
| 385 |
+
kernel_size=3,
|
| 386 |
+
stride=1,
|
| 387 |
+
padding=1)
|
| 388 |
+
|
| 389 |
+
curr_res = resolution
|
| 390 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 391 |
+
self.in_ch_mult = in_ch_mult
|
| 392 |
+
self.down = nn.ModuleList()
|
| 393 |
+
for i_level in range(self.num_resolutions):
|
| 394 |
+
block = nn.ModuleList()
|
| 395 |
+
attn = nn.ModuleList()
|
| 396 |
+
block_in = ch*in_ch_mult[i_level]
|
| 397 |
+
block_out = ch*ch_mult[i_level]
|
| 398 |
+
for i_block in range(self.num_res_blocks):
|
| 399 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 400 |
+
out_channels=block_out,
|
| 401 |
+
temb_channels=self.temb_ch,
|
| 402 |
+
dropout=dropout))
|
| 403 |
+
block_in = block_out
|
| 404 |
+
if curr_res in attn_resolutions:
|
| 405 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 406 |
+
down = nn.Module()
|
| 407 |
+
down.block = block
|
| 408 |
+
down.attn = attn
|
| 409 |
+
if i_level != self.num_resolutions-1:
|
| 410 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 411 |
+
curr_res = curr_res // 2
|
| 412 |
+
self.down.append(down)
|
| 413 |
+
|
| 414 |
+
# middle
|
| 415 |
+
self.mid = nn.Module()
|
| 416 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 417 |
+
out_channels=block_in,
|
| 418 |
+
temb_channels=self.temb_ch,
|
| 419 |
+
dropout=dropout)
|
| 420 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 421 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 422 |
+
out_channels=block_in,
|
| 423 |
+
temb_channels=self.temb_ch,
|
| 424 |
+
dropout=dropout)
|
| 425 |
+
|
| 426 |
+
# end
|
| 427 |
+
self.norm_out = Normalize(block_in)
|
| 428 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 429 |
+
2*z_channels if double_z else z_channels,
|
| 430 |
+
kernel_size=3,
|
| 431 |
+
stride=1,
|
| 432 |
+
padding=1)
|
| 433 |
+
|
| 434 |
+
def forward(self, x):
|
| 435 |
+
# timestep embedding
|
| 436 |
+
temb = None
|
| 437 |
+
|
| 438 |
+
# downsampling
|
| 439 |
+
hs = [self.conv_in(x)]
|
| 440 |
+
for i_level in range(self.num_resolutions):
|
| 441 |
+
for i_block in range(self.num_res_blocks):
|
| 442 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 443 |
+
if len(self.down[i_level].attn) > 0:
|
| 444 |
+
h = self.down[i_level].attn[i_block](h)
|
| 445 |
+
hs.append(h)
|
| 446 |
+
if i_level != self.num_resolutions-1:
|
| 447 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 448 |
+
|
| 449 |
+
# middle
|
| 450 |
+
h = hs[-1]
|
| 451 |
+
h = self.mid.block_1(h, temb)
|
| 452 |
+
h = self.mid.attn_1(h)
|
| 453 |
+
h = self.mid.block_2(h, temb)
|
| 454 |
+
|
| 455 |
+
# end
|
| 456 |
+
h = self.norm_out(h)
|
| 457 |
+
h = nonlinearity(h)
|
| 458 |
+
h = self.conv_out(h)
|
| 459 |
+
return h
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class Decoder(nn.Module):
|
| 463 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 464 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 465 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
| 466 |
+
attn_type="vanilla", **ignorekwargs):
|
| 467 |
+
super().__init__()
|
| 468 |
+
if use_linear_attn: attn_type = "linear"
|
| 469 |
+
self.ch = ch
|
| 470 |
+
self.temb_ch = 0
|
| 471 |
+
self.num_resolutions = len(ch_mult)
|
| 472 |
+
self.num_res_blocks = num_res_blocks
|
| 473 |
+
self.resolution = resolution
|
| 474 |
+
self.in_channels = in_channels
|
| 475 |
+
self.give_pre_end = give_pre_end
|
| 476 |
+
self.tanh_out = tanh_out
|
| 477 |
+
|
| 478 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 479 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 480 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 481 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
| 482 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
| 483 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
| 484 |
+
self.z_shape, np.prod(self.z_shape)))
|
| 485 |
+
|
| 486 |
+
# z to block_in
|
| 487 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
| 488 |
+
block_in,
|
| 489 |
+
kernel_size=3,
|
| 490 |
+
stride=1,
|
| 491 |
+
padding=1)
|
| 492 |
+
|
| 493 |
+
# middle
|
| 494 |
+
self.mid = nn.Module()
|
| 495 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 496 |
+
out_channels=block_in,
|
| 497 |
+
temb_channels=self.temb_ch,
|
| 498 |
+
dropout=dropout)
|
| 499 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 500 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 501 |
+
out_channels=block_in,
|
| 502 |
+
temb_channels=self.temb_ch,
|
| 503 |
+
dropout=dropout)
|
| 504 |
+
|
| 505 |
+
# upsampling
|
| 506 |
+
self.up = nn.ModuleList()
|
| 507 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 508 |
+
block = nn.ModuleList()
|
| 509 |
+
attn = nn.ModuleList()
|
| 510 |
+
block_out = ch*ch_mult[i_level]
|
| 511 |
+
for i_block in range(self.num_res_blocks+1):
|
| 512 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 513 |
+
out_channels=block_out,
|
| 514 |
+
temb_channels=self.temb_ch,
|
| 515 |
+
dropout=dropout))
|
| 516 |
+
block_in = block_out
|
| 517 |
+
if curr_res in attn_resolutions:
|
| 518 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 519 |
+
up = nn.Module()
|
| 520 |
+
up.block = block
|
| 521 |
+
up.attn = attn
|
| 522 |
+
if i_level != 0:
|
| 523 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 524 |
+
curr_res = curr_res * 2
|
| 525 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 526 |
+
|
| 527 |
+
# end
|
| 528 |
+
self.norm_out = Normalize(block_in)
|
| 529 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 530 |
+
out_ch,
|
| 531 |
+
kernel_size=3,
|
| 532 |
+
stride=1,
|
| 533 |
+
padding=1)
|
| 534 |
+
|
| 535 |
+
def forward(self, z):
|
| 536 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 537 |
+
self.last_z_shape = z.shape
|
| 538 |
+
|
| 539 |
+
# timestep embedding
|
| 540 |
+
temb = None
|
| 541 |
+
|
| 542 |
+
# z to block_in
|
| 543 |
+
h = self.conv_in(z)
|
| 544 |
+
|
| 545 |
+
# middle
|
| 546 |
+
h = self.mid.block_1(h, temb)
|
| 547 |
+
h = self.mid.attn_1(h)
|
| 548 |
+
h = self.mid.block_2(h, temb)
|
| 549 |
+
|
| 550 |
+
# upsampling
|
| 551 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 552 |
+
for i_block in range(self.num_res_blocks+1):
|
| 553 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 554 |
+
if len(self.up[i_level].attn) > 0:
|
| 555 |
+
h = self.up[i_level].attn[i_block](h)
|
| 556 |
+
if i_level != 0:
|
| 557 |
+
h = self.up[i_level].upsample(h)
|
| 558 |
+
|
| 559 |
+
# end
|
| 560 |
+
if self.give_pre_end:
|
| 561 |
+
return h
|
| 562 |
+
|
| 563 |
+
h = self.norm_out(h)
|
| 564 |
+
h = nonlinearity(h)
|
| 565 |
+
h = self.conv_out(h)
|
| 566 |
+
if self.tanh_out:
|
| 567 |
+
h = torch.tanh(h)
|
| 568 |
+
return h
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class SimpleDecoder(nn.Module):
|
| 572 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
| 573 |
+
super().__init__()
|
| 574 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
| 575 |
+
ResnetBlock(in_channels=in_channels,
|
| 576 |
+
out_channels=2 * in_channels,
|
| 577 |
+
temb_channels=0, dropout=0.0),
|
| 578 |
+
ResnetBlock(in_channels=2 * in_channels,
|
| 579 |
+
out_channels=4 * in_channels,
|
| 580 |
+
temb_channels=0, dropout=0.0),
|
| 581 |
+
ResnetBlock(in_channels=4 * in_channels,
|
| 582 |
+
out_channels=2 * in_channels,
|
| 583 |
+
temb_channels=0, dropout=0.0),
|
| 584 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
| 585 |
+
Upsample(in_channels, with_conv=True)])
|
| 586 |
+
# end
|
| 587 |
+
self.norm_out = Normalize(in_channels)
|
| 588 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
| 589 |
+
out_channels,
|
| 590 |
+
kernel_size=3,
|
| 591 |
+
stride=1,
|
| 592 |
+
padding=1)
|
| 593 |
+
|
| 594 |
+
def forward(self, x):
|
| 595 |
+
for i, layer in enumerate(self.model):
|
| 596 |
+
if i in [1,2,3]:
|
| 597 |
+
x = layer(x, None)
|
| 598 |
+
else:
|
| 599 |
+
x = layer(x)
|
| 600 |
+
|
| 601 |
+
h = self.norm_out(x)
|
| 602 |
+
h = nonlinearity(h)
|
| 603 |
+
x = self.conv_out(h)
|
| 604 |
+
return x
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class UpsampleDecoder(nn.Module):
|
| 608 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
| 609 |
+
ch_mult=(2,2), dropout=0.0):
|
| 610 |
+
super().__init__()
|
| 611 |
+
# upsampling
|
| 612 |
+
self.temb_ch = 0
|
| 613 |
+
self.num_resolutions = len(ch_mult)
|
| 614 |
+
self.num_res_blocks = num_res_blocks
|
| 615 |
+
block_in = in_channels
|
| 616 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 617 |
+
self.res_blocks = nn.ModuleList()
|
| 618 |
+
self.upsample_blocks = nn.ModuleList()
|
| 619 |
+
for i_level in range(self.num_resolutions):
|
| 620 |
+
res_block = []
|
| 621 |
+
block_out = ch * ch_mult[i_level]
|
| 622 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 623 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
| 624 |
+
out_channels=block_out,
|
| 625 |
+
temb_channels=self.temb_ch,
|
| 626 |
+
dropout=dropout))
|
| 627 |
+
block_in = block_out
|
| 628 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
| 629 |
+
if i_level != self.num_resolutions - 1:
|
| 630 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
| 631 |
+
curr_res = curr_res * 2
|
| 632 |
+
|
| 633 |
+
# end
|
| 634 |
+
self.norm_out = Normalize(block_in)
|
| 635 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 636 |
+
out_channels,
|
| 637 |
+
kernel_size=3,
|
| 638 |
+
stride=1,
|
| 639 |
+
padding=1)
|
| 640 |
+
|
| 641 |
+
def forward(self, x):
|
| 642 |
+
# upsampling
|
| 643 |
+
h = x
|
| 644 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
| 645 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 646 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
| 647 |
+
if i_level != self.num_resolutions - 1:
|
| 648 |
+
h = self.upsample_blocks[k](h)
|
| 649 |
+
h = self.norm_out(h)
|
| 650 |
+
h = nonlinearity(h)
|
| 651 |
+
h = self.conv_out(h)
|
| 652 |
+
return h
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class LatentRescaler(nn.Module):
|
| 656 |
+
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
| 657 |
+
super().__init__()
|
| 658 |
+
# residual block, interpolate, residual block
|
| 659 |
+
self.factor = factor
|
| 660 |
+
self.conv_in = nn.Conv2d(in_channels,
|
| 661 |
+
mid_channels,
|
| 662 |
+
kernel_size=3,
|
| 663 |
+
stride=1,
|
| 664 |
+
padding=1)
|
| 665 |
+
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
| 666 |
+
out_channels=mid_channels,
|
| 667 |
+
temb_channels=0,
|
| 668 |
+
dropout=0.0) for _ in range(depth)])
|
| 669 |
+
self.attn = AttnBlock(mid_channels)
|
| 670 |
+
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
| 671 |
+
out_channels=mid_channels,
|
| 672 |
+
temb_channels=0,
|
| 673 |
+
dropout=0.0) for _ in range(depth)])
|
| 674 |
+
|
| 675 |
+
self.conv_out = nn.Conv2d(mid_channels,
|
| 676 |
+
out_channels,
|
| 677 |
+
kernel_size=1,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
def forward(self, x):
|
| 681 |
+
x = self.conv_in(x)
|
| 682 |
+
for block in self.res_block1:
|
| 683 |
+
x = block(x, None)
|
| 684 |
+
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
| 685 |
+
x = self.attn(x)
|
| 686 |
+
for block in self.res_block2:
|
| 687 |
+
x = block(x, None)
|
| 688 |
+
x = self.conv_out(x)
|
| 689 |
+
return x
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class MergedRescaleEncoder(nn.Module):
|
| 693 |
+
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
| 694 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
| 695 |
+
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
| 696 |
+
super().__init__()
|
| 697 |
+
intermediate_chn = ch * ch_mult[-1]
|
| 698 |
+
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
| 699 |
+
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
| 700 |
+
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
| 701 |
+
out_ch=None)
|
| 702 |
+
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
| 703 |
+
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
| 704 |
+
|
| 705 |
+
def forward(self, x):
|
| 706 |
+
x = self.encoder(x)
|
| 707 |
+
x = self.rescaler(x)
|
| 708 |
+
return x
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
class MergedRescaleDecoder(nn.Module):
|
| 712 |
+
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
| 713 |
+
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
| 714 |
+
super().__init__()
|
| 715 |
+
tmp_chn = z_channels*ch_mult[-1]
|
| 716 |
+
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
| 717 |
+
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
| 718 |
+
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
| 719 |
+
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
| 720 |
+
out_channels=tmp_chn, depth=rescale_module_depth)
|
| 721 |
+
|
| 722 |
+
def forward(self, x):
|
| 723 |
+
x = self.rescaler(x)
|
| 724 |
+
x = self.decoder(x)
|
| 725 |
+
return x
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class Upsampler(nn.Module):
|
| 729 |
+
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
| 730 |
+
super().__init__()
|
| 731 |
+
assert out_size >= in_size
|
| 732 |
+
num_blocks = int(np.log2(out_size//in_size))+1
|
| 733 |
+
factor_up = 1.+ (out_size % in_size)
|
| 734 |
+
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
| 735 |
+
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
| 736 |
+
out_channels=in_channels)
|
| 737 |
+
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
| 738 |
+
attn_resolutions=[], in_channels=None, ch=in_channels,
|
| 739 |
+
ch_mult=[ch_mult for _ in range(num_blocks)])
|
| 740 |
+
|
| 741 |
+
def forward(self, x):
|
| 742 |
+
x = self.rescaler(x)
|
| 743 |
+
x = self.decoder(x)
|
| 744 |
+
return x
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class Resize(nn.Module):
|
| 748 |
+
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
| 749 |
+
super().__init__()
|
| 750 |
+
self.with_conv = learned
|
| 751 |
+
self.mode = mode
|
| 752 |
+
if self.with_conv:
|
| 753 |
+
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
| 754 |
+
raise NotImplementedError()
|
| 755 |
+
assert in_channels is not None
|
| 756 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 757 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 758 |
+
in_channels,
|
| 759 |
+
kernel_size=4,
|
| 760 |
+
stride=2,
|
| 761 |
+
padding=1)
|
| 762 |
+
|
| 763 |
+
def forward(self, x, scale_factor=1.0):
|
| 764 |
+
if scale_factor==1.0:
|
| 765 |
+
return x
|
| 766 |
+
else:
|
| 767 |
+
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
| 768 |
+
return x
|
| 769 |
+
|
| 770 |
+
class FirstStagePostProcessor(nn.Module):
|
| 771 |
+
|
| 772 |
+
def __init__(self, ch_mult:list, in_channels,
|
| 773 |
+
pretrained_model:nn.Module=None,
|
| 774 |
+
reshape=False,
|
| 775 |
+
n_channels=None,
|
| 776 |
+
dropout=0.,
|
| 777 |
+
pretrained_config=None):
|
| 778 |
+
super().__init__()
|
| 779 |
+
if pretrained_config is None:
|
| 780 |
+
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
| 781 |
+
self.pretrained_model = pretrained_model
|
| 782 |
+
else:
|
| 783 |
+
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
| 784 |
+
self.instantiate_pretrained(pretrained_config)
|
| 785 |
+
|
| 786 |
+
self.do_reshape = reshape
|
| 787 |
+
|
| 788 |
+
if n_channels is None:
|
| 789 |
+
n_channels = self.pretrained_model.encoder.ch
|
| 790 |
+
|
| 791 |
+
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
| 792 |
+
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
| 793 |
+
stride=1,padding=1)
|
| 794 |
+
|
| 795 |
+
blocks = []
|
| 796 |
+
downs = []
|
| 797 |
+
ch_in = n_channels
|
| 798 |
+
for m in ch_mult:
|
| 799 |
+
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
| 800 |
+
ch_in = m * n_channels
|
| 801 |
+
downs.append(Downsample(ch_in, with_conv=False))
|
| 802 |
+
|
| 803 |
+
self.model = nn.ModuleList(blocks)
|
| 804 |
+
self.downsampler = nn.ModuleList(downs)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def instantiate_pretrained(self, config):
|
| 808 |
+
model = instantiate_from_config(config)
|
| 809 |
+
self.pretrained_model = model.eval()
|
| 810 |
+
# self.pretrained_model.train = False
|
| 811 |
+
for param in self.pretrained_model.parameters():
|
| 812 |
+
param.requires_grad = False
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
@torch.no_grad()
|
| 816 |
+
def encode_with_pretrained(self,x):
|
| 817 |
+
c = self.pretrained_model.encode(x)
|
| 818 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
| 819 |
+
c = c.mode()
|
| 820 |
+
return c
|
| 821 |
+
|
| 822 |
+
def forward(self,x):
|
| 823 |
+
z_fs = self.encode_with_pretrained(x)
|
| 824 |
+
z = self.proj_norm(z_fs)
|
| 825 |
+
z = self.proj(z)
|
| 826 |
+
z = nonlinearity(z)
|
| 827 |
+
|
| 828 |
+
for submodel, downmodel in zip(self.model,self.downsampler):
|
| 829 |
+
z = submodel(z,temb=None)
|
| 830 |
+
z = downmodel(z)
|
| 831 |
+
|
| 832 |
+
if self.do_reshape:
|
| 833 |
+
z = rearrange(z,'b c h w -> b (h w) c')
|
| 834 |
+
return z
|
| 835 |
+
|