nvan13 commited on
Commit
1c8e113
·
verified ·
1 Parent(s): b7eae34

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +283 -0
  2. assets/OHRFT_scheme.png +0 -0
  3. assets/figure_nlp.png +0 -0
  4. generation/control/ControlNet/.gitignore +143 -0
  5. generation/control/ControlNet/LICENSE +201 -0
  6. generation/control/ControlNet/README.md +348 -0
  7. generation/control/ControlNet/cldm/cldm.py +435 -0
  8. generation/control/ControlNet/cldm/ddim_hacked.py +317 -0
  9. generation/control/ControlNet/cldm/hack.py +111 -0
  10. generation/control/ControlNet/cldm/logger.py +76 -0
  11. generation/control/ControlNet/cldm/model.py +28 -0
  12. generation/control/ControlNet/config.py +1 -0
  13. generation/control/ControlNet/docs/annotator.md +49 -0
  14. generation/control/ControlNet/docs/faq.md +21 -0
  15. generation/control/ControlNet/docs/low_vram.md +15 -0
  16. generation/control/ControlNet/docs/train.md +276 -0
  17. generation/control/ControlNet/environment.yaml +35 -0
  18. generation/control/ControlNet/gradio_annotator.py +160 -0
  19. generation/control/ControlNet/gradio_canny2image.py +97 -0
  20. generation/control/ControlNet/gradio_depth2image.py +98 -0
  21. generation/control/ControlNet/gradio_fake_scribble2image.py +102 -0
  22. generation/control/ControlNet/gradio_hed2image.py +98 -0
  23. generation/control/ControlNet/gradio_hough2image.py +100 -0
  24. generation/control/ControlNet/gradio_normal2image.py +99 -0
  25. generation/control/ControlNet/gradio_pose2image.py +98 -0
  26. generation/control/ControlNet/gradio_scribble2image.py +92 -0
  27. generation/control/ControlNet/gradio_scribble2image_interactive.py +102 -0
  28. generation/control/ControlNet/gradio_seg2image.py +97 -0
  29. generation/control/ControlNet/ldm/data/__init__.py +0 -0
  30. generation/control/ControlNet/ldm/models/autoencoder.py +219 -0
  31. generation/control/ControlNet/ldm/models/diffusion/__init__.py +0 -0
  32. generation/control/ControlNet/ldm/models/diffusion/ddim.py +336 -0
  33. generation/control/ControlNet/ldm/util.py +197 -0
  34. generation/control/ControlNet/share.py +8 -0
  35. generation/control/ControlNet/tool_add_control.py +50 -0
  36. generation/control/ControlNet/tool_add_control_sd21.py +50 -0
  37. generation/control/ControlNet/tool_transfer_control.py +59 -0
  38. generation/control/ControlNet/tutorial_dataset.py +39 -0
  39. generation/control/ControlNet/tutorial_dataset_test.py +12 -0
  40. generation/control/ControlNet/tutorial_train.py +35 -0
  41. generation/control/ControlNet/tutorial_train_sd21.py +35 -0
  42. generation/control/download_ade20k.sh +10 -0
  43. generation/control/download_celebhq.sh +10 -0
  44. generation/control/eval_canny.py +130 -0
  45. generation/control/eval_landmark.py +127 -0
  46. generation/control/generation.py +238 -0
  47. generation/control/hra.py +254 -0
  48. generation/control/tool_add_hra.py +81 -0
  49. generation/control/train.py +149 -0
  50. generation/env.yml +172 -0
README.md ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align=center>
2
+
3
+ # [NeurIPS 2024 Spotlight] Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2502.14637-b31b1b?style=flat&logo=arxiv)](https://arxiv.org/pdf/2405.17484)
6
+ [![Hugging Face](https://img.shields.io/badge/Hugging%20Face-Peft-orange?style=flat&logo=huggingface)](https://huggingface.co/docs/peft/en/package_reference/hra)
7
+
8
+ </div>
9
+
10
+ <div align="center">
11
+ <img src="assets/OHRFT_scheme.png" width="1100"/>
12
+ </div>
13
+
14
+
15
+
16
+ ## Introduction
17
+
18
+ This repository includes the official implementation of [HRA](https://arxiv.org/pdf/2405.17484).
19
+ We propose a simple yet effective adapter-based orthogonal fine-tuning method, HRA.
20
+ Given a pre-trained model, our method fine-tunes its layers by multiplying each frozen weight matrix with an orthogonal matrix constructed by a chain of learnable Householder reflections (HRs).
21
+
22
+ ## Usage
23
+
24
+ ### Subject-driven Generation
25
+
26
+ <div align="center">
27
+ <img src="assets/subject.png" width="600"/>
28
+ </div>
29
+
30
+ Given several images of a specific subject and a textual prompt, subject-driven generation aims to generate images of the same subject in a context aligning with the prompt.
31
+
32
+ #### Environment Setup
33
+
34
+ ```bash
35
+ cd generation
36
+ conda env create -f env.yml
37
+ ```
38
+
39
+ #### Prepare Dataset
40
+
41
+ Download [dreambooth](https://github.com/google/dreambooth) dataset by running this script.
42
+
43
+ ```bash
44
+ cd subject
45
+ bash download_dreambooth.sh
46
+ ```
47
+
48
+ After downloading the data, your directory structure should look like this:
49
+
50
+ ```
51
+ dreambooth
52
+ ├── dataset
53
+ │ ├── backpack
54
+ │ └── backpack_dog
55
+ │ ...
56
+ ```
57
+
58
+ You can also put your custom images into `dreambooth/dataset`.
59
+
60
+ #### Finetune
61
+
62
+ ```bash
63
+ prompt_idx=0
64
+ class_idx=0
65
+ ./train_dreambooth.sh $prompt_idx $class_idx
66
+ ```
67
+
68
+ where the `$prompt_idx` corresponds to different prompts ranging from 0 to 24 and the `$class_idx` corresponds to different subjects ranging from 0 to 29.
69
+
70
+ Launch the training script with `accelerate` and pass hyperparameters, as well as LoRa-specific arguments to it such as:
71
+
72
+ - `use_hra`: Enables HRA in the training script.
73
+ - `hra_r`: the number of HRs (i.e., r) across different layers, expressed in `int`.
74
+ As r increases, the number of trainable parameters increases, which generally leads to improved performance.
75
+ However, this also results in higher memory consumption and longer computation times.
76
+ Therefore, r is usually set to 8.
77
+ **Note**, please set r to an even number to avoid potential issues during initialization.
78
+ - `hra_apply_GS`: Applys Gram-Schmidt orthogonalization. Default is `false`.
79
+ - `hra_bias`: specify if the `bias` paramteres should be traind. Can be `none`, `all` or `hra_only`.
80
+
81
+ #### Evaluation
82
+
83
+ ```bash
84
+ python evaluate.py
85
+ python get_result.py
86
+ ```
87
+
88
+ ### Controllable Generation
89
+
90
+ <div align="center">
91
+ <img src="assets/control.png" width="650"/>
92
+ </div>
93
+
94
+ Controllable generation aims to generate images aligning with a textual prompt and additional control signals (such as facial landmark annotations, canny edges, and segmentation maps).
95
+
96
+ #### Prepare Dataset
97
+
98
+ Download ADE20K and CelebA-HQ datasets by running this script.
99
+
100
+ ```bash
101
+ cd control
102
+ bash download_ade20k.sh
103
+ bash download_celebhq.sh
104
+ ```
105
+
106
+ For COCO dataset, we follow [OFT](https://github.com/Zeju1997/oft) to download and preprocess it.
107
+
108
+ After downloading the data, your directory structure should look like this:
109
+
110
+ ```
111
+ data
112
+ ├── ADE20K
113
+ │ ├── train
114
+ │ │ ├── color
115
+ │ │ ├── segm
116
+ │ │ └── prompt_train_blip.json
117
+ │ └── val
118
+ │ │ ├── color
119
+ │ │ ├── segm
120
+ │ │ └── prompt_val_blip.json
121
+ └── COCO
122
+ │ ├── train
123
+ │ │ ├── color
124
+ │ │ ├── depth
125
+ ...
126
+ ```
127
+
128
+ #### Prepare pre-trained model
129
+
130
+ Download the pre-trained model weights [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) and save it in the `models` directory.
131
+
132
+ #### Fine-tuning
133
+
134
+ 1. Create the model with additional **HRA** parameters:
135
+ ```bash
136
+ python tool_add_hra.py \
137
+ --input_path=./models/v1-5-pruned.ckpt \
138
+ --output_path=./models/hra_r_8.ckpt \
139
+ --r=8
140
+ ```
141
+ 2. Specify the control signal and dataset. Train the model specify the same hyperparameters as above:
142
+ ```bash
143
+ python train.py \
144
+ --r=8 \
145
+ --control=segm
146
+ ```
147
+
148
+ #### Generation
149
+ 1. After finetuning with **HRA**, run inference to generate images based on control signal. Because the inference takes some time, to perform large scale evaluation, we split the dataset into different sub-datasets and run inference on multiple gpus:
150
+ ```bash
151
+ python generation.py
152
+ --r=8 \
153
+ --control=segm
154
+ ```
155
+ 1. To evaluate **HRA** results on the three tasks: canny edge to image (C2I) on the COCO dataset, landmark to face (L2F) on the CelebA-HQ dataset, and segmentation map to image (S2I) on the ADE20K dataset, run the following scripts on the generated images.
156
+ ```bash
157
+ python eval_landmark.py
158
+ ```
159
+ ```bash
160
+ python eval_canny.py
161
+ ```
162
+ Note, for evaluating the segmentation map-to-image (S2I) task, please install the [Segformer](https://github.com/NVlabs/SegFormer) repository. Run the following testing command on both the original and generated images.
163
+ ```bash
164
+ python tools/test.py local_configs/segformer/B4/segformer.b4.512x512.ade.160k.py ./weights/segformer.b4.512x512.ade.160k.pth
165
+ ```
166
+
167
+ ### Natural Language Understanding
168
+
169
+ <div align="center">
170
+ <img src="assets/figure_nlp.png" width="300"/>
171
+ </div>
172
+
173
+ We adapt [DeBERTaV3-base](https://arxiv.org/abs/2111.09543) and test the performance of the adapted models on [General Language Understanding Evaluation (GLUE) benchmark](https://gluebenchmark.com/).
174
+
175
+ #### Environment Setup
176
+
177
+ ```bash
178
+ cd nlu
179
+ conda env create -f env.yml
180
+ ```
181
+
182
+ Before fine-tuning, you need to install the dependencies.
183
+
184
+ ```bash
185
+ python setup.py install
186
+ ```
187
+
188
+ #### Prepare Dataset
189
+
190
+ Run this scipt to download glue dataset.
191
+
192
+ ```bash
193
+ cache_dir=/tmp/DeBERTa/
194
+ cd experiments/glue
195
+ ./download_data.sh $cache_dir/glue_tasks
196
+ ```
197
+
198
+ #### Finetune
199
+
200
+ Run tasks.
201
+
202
+ ```bash
203
+ ./mnli.sh
204
+ ./cola.sh
205
+ ./mrpc.sh
206
+ ./qnli.sh
207
+ ./qqp.sh
208
+ ./rte.sh
209
+ ./sst2.sh
210
+ ./stsb.sh
211
+ ```
212
+
213
+ ### Mathematical reasoning
214
+ We have not yet completed the integration of HRA code into PEFT. Before that, if you want to try using the HRA method to fine-tune large models, you can follow the steps below.
215
+
216
+ Go to the llama folder
217
+ ```bash
218
+ cd llama
219
+ ```
220
+
221
+ #### Environment Setup
222
+ We recommend using Python 3.10 for your environment and use the conda to install it.
223
+ ```bash
224
+ conda create -n pytorch python=3.10
225
+ ```
226
+ Then install the required packages with the following command:
227
+ ```bash
228
+ pip install -r requirements.txt
229
+ ```
230
+ Please note that the peft package and transformer package must be downloaded with the versions consistent with those listed in the requirements file.
231
+
232
+ After completing the download, please replace the **oft** folder inside the **peft/tuners** within your running environment's **python/site-packages** with the **oft** folder from the current directory.
233
+
234
+ The path for the oft folder in the environment should be:
235
+
236
+ ```bash
237
+ /your_path/anaconda3/envs/pytorch/lib/python3.10/site-packages/peft/tuners/
238
+ ```
239
+ The **layer.py** in the current oft directory is implemented for when λ is not infinity.
240
+
241
+ If you want to simulate when λ is infinity, please replace **layer.py** with **layer_GS_HRA.py**, and set the hyperparameter λ to 0 during training.
242
+
243
+
244
+ #### Prepare Dataset
245
+ The dataset we use for fine-tuning is MetaMathQA-40K, which can be downloaded through this [link](https://huggingface.co/datasets/meta-math/MetaMathQA-40K).
246
+ #### Prepare model
247
+ The model we use for fine-tuning is llama2. You can choose the model you want to fine-tune.
248
+ #### Finetune
249
+ Run the following code to complete the fine-tuning:
250
+ ```bash
251
+ bash tune.sh
252
+ ```
253
+ Please note that you need to change the dataset path, the path of the pre-trained model, and you can change the parameters according to your needs in tune.sh. That is:
254
+ ```bash
255
+ BASE_MODEL="YOUR_MODEL_PATH"
256
+ DATA_PATH="YOUR_DATA_PATH"
257
+ OUTPUT="YOUR_MODEL_SAVED_PATH"
258
+ ```
259
+ #### Evaluation
260
+ After the training is complete, you can run the following command to test:
261
+ ```bash
262
+ bash test.sh
263
+ ```
264
+ Please note to change the model path in it:
265
+ ```bash
266
+ BASE_MODEL="YOUR_MODEL_PATH"
267
+ OUTPUT="YOUR_MODEL_SAVED_PATH"
268
+ ```
269
+
270
+
271
+
272
+ ## 📌 Citing our work
273
+ If you find our work useful, please cite it:
274
+ ```bibtex
275
+ @inproceedings{yuanbridging,
276
+ title={Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation},
277
+ author={Yuan, Shen and Liu, Haotian and Xu, Hongteng},
278
+ booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
279
+ year={2024}
280
+ }
281
+ ```
282
+
283
+
assets/OHRFT_scheme.png ADDED
assets/figure_nlp.png ADDED
generation/control/ControlNet/.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea/
2
+
3
+ training/
4
+ lightning_logs/
5
+ image_log/
6
+
7
+ *.pth
8
+ *.pt
9
+ *.ckpt
10
+ *.safetensors
11
+
12
+ gradio_pose2image_private.py
13
+ gradio_canny2image_private.py
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ pip-wheel-metadata/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109
+ __pypackages__/
110
+
111
+ # Celery stuff
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
generation/control/ControlNet/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
generation/control/ControlNet/README.md ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # News: A nightly version of ControlNet 1.1 is released!
2
+
3
+ [ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) is released. Those new models will be merged to this repo after we make sure that everything is good.
4
+
5
+ # Below is ControlNet 1.0
6
+
7
+ Official implementation of [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543).
8
+
9
+ ControlNet is a neural network structure to control diffusion models by adding extra conditions.
10
+
11
+ ![img](github_page/he.png)
12
+
13
+ It copys the weights of neural network blocks into a "locked" copy and a "trainable" copy.
14
+
15
+ The "trainable" one learns your condition. The "locked" one preserves your model.
16
+
17
+ Thanks to this, training with small dataset of image pairs will not destroy the production-ready diffusion models.
18
+
19
+ The "zero convolution" is 1×1 convolution with both weight and bias initialized as zeros.
20
+
21
+ Before training, all zero convolutions output zeros, and ControlNet will not cause any distortion.
22
+
23
+ No layer is trained from scratch. You are still fine-tuning. Your original model is safe.
24
+
25
+ This allows training on small-scale or even personal devices.
26
+
27
+ This is also friendly to merge/replacement/offsetting of models/weights/blocks/layers.
28
+
29
+ ### FAQ
30
+
31
+ **Q:** But wait, if the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
32
+
33
+ **A:** This is not true. [See an explanation here](docs/faq.md).
34
+
35
+ # Stable Diffusion + ControlNet
36
+
37
+ By repeating the above simple structure 14 times, we can control stable diffusion in this way:
38
+
39
+ ![img](github_page/sd.png)
40
+
41
+ In this way, the ControlNet can **reuse** the SD encoder as a **deep, strong, robust, and powerful backbone** to learn diverse controls. Many evidences (like [this](https://jerryxu.net/ODISE/) and [this](https://vpd.ivg-research.xyz/)) validate that the SD encoder is an excellent backbone.
42
+
43
+ Note that the way we connect layers is computational efficient. The original SD encoder does not need to store gradients (the locked original SD Encoder Block 1234 and Middle). The required GPU memory is not much larger than original SD, although many layers are added. Great!
44
+
45
+ # Features & News
46
+
47
+ 2023/0/14 - We released [ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly). Those new models will be merged to this repo after we make sure that everything is good.
48
+
49
+ 2023/03/03 - We released a discussion - [Precomputed ControlNet: Speed up ControlNet by 45%, but is it necessary?](https://github.com/lllyasviel/ControlNet/discussions/216)
50
+
51
+ 2023/02/26 - We released a blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188)
52
+
53
+ 2023/02/20 - Implementation for non-prompt mode released. See also [Guess Mode / Non-Prompt Mode](#guess-anchor).
54
+
55
+ 2023/02/12 - Now you can play with any community model by [Transferring the ControlNet](https://github.com/lllyasviel/ControlNet/discussions/12).
56
+
57
+ 2023/02/11 - [Low VRAM mode](docs/low_vram.md) is added. Please use this mode if you are using 8GB GPU(s) or if you want larger batch size.
58
+
59
+ # Production-Ready Pretrained Models
60
+
61
+ First create a new conda environment
62
+
63
+ conda env create -f environment.yaml
64
+ conda activate control
65
+
66
+ All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on.
67
+
68
+ We provide 9 Gradio apps with these models.
69
+
70
+ All test images can be found at the folder "test_imgs".
71
+
72
+ ## ControlNet with Canny Edge
73
+
74
+ Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection)
75
+
76
+ python gradio_canny2image.py
77
+
78
+ The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details.
79
+
80
+ Prompt: "bird"
81
+ ![p](github_page/p1.png)
82
+
83
+ Prompt: "cute dog"
84
+ ![p](github_page/p2.png)
85
+
86
+ ## ControlNet with M-LSD Lines
87
+
88
+ Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection)
89
+
90
+ python gradio_hough2image.py
91
+
92
+ The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details.
93
+
94
+ Prompt: "room"
95
+ ![p](github_page/p3.png)
96
+
97
+ Prompt: "building"
98
+ ![p](github_page/p4.png)
99
+
100
+ ## ControlNet with HED Boundary
101
+
102
+ Stable Diffusion 1.5 + ControlNet (using soft HED Boundary)
103
+
104
+ python gradio_hed2image.py
105
+
106
+ The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details.
107
+
108
+ Prompt: "oil painting of handsome old man, masterpiece"
109
+ ![p](github_page/p5.png)
110
+
111
+ Prompt: "Cyberpunk robot"
112
+ ![p](github_page/p6.png)
113
+
114
+ ## ControlNet with User Scribbles
115
+
116
+ Stable Diffusion 1.5 + ControlNet (using Scribbles)
117
+
118
+ python gradio_scribble2image.py
119
+
120
+ Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio.
121
+
122
+ Prompt: "turtle"
123
+ ![p](github_page/p7.png)
124
+
125
+ Prompt: "hot air balloon"
126
+ ![p](github_page/p8.png)
127
+
128
+ ### Interactive Interface
129
+
130
+ We actually provide an interactive interface
131
+
132
+ python gradio_scribble2image_interactive.py
133
+
134
+ ~~However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.~~ (Now fixed, will update asap)
135
+
136
+ The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase.
137
+
138
+ Prompt: "dog in a room"
139
+ ![p](github_page/p20.png)
140
+
141
+ ## ControlNet with Fake Scribbles
142
+
143
+ Stable Diffusion 1.5 + ControlNet (using fake scribbles)
144
+
145
+ python gradio_fake_scribble2image.py
146
+
147
+ Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images.
148
+
149
+ Prompt: "bag"
150
+ ![p](github_page/p9.png)
151
+
152
+ Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still seems to work.)
153
+ ![p](github_page/p10.png)
154
+
155
+ ## ControlNet with Human Pose
156
+
157
+ Stable Diffusion 1.5 + ControlNet (using human pose)
158
+
159
+ python gradio_pose2image.py
160
+
161
+ Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you.
162
+
163
+ Prompt: "Chief in the kitchen"
164
+ ![p](github_page/p11.png)
165
+
166
+ Prompt: "An astronaut on the moon"
167
+ ![p](github_page/p12.png)
168
+
169
+ ## ControlNet with Semantic Segmentation
170
+
171
+ Stable Diffusion 1.5 + ControlNet (using semantic segmentation)
172
+
173
+ python gradio_seg2image.py
174
+
175
+ This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details.
176
+
177
+ Prompt: "House"
178
+ ![p](github_page/p13.png)
179
+
180
+ Prompt: "River"
181
+ ![p](github_page/p14.png)
182
+
183
+ ## ControlNet with Depth
184
+
185
+ Stable Diffusion 1.5 + ControlNet (using depth map)
186
+
187
+ python gradio_depth2image.py
188
+
189
+ Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2).
190
+
191
+ Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map.
192
+
193
+ This is always a strength because if users do not want to preserve more details, they can simply use another SD to post-process an i2i. But if they want to preserve more details, ControlNet becomes their only choice. Again, SD2 uses 64×64 depth, we use 512×512.
194
+
195
+ Prompt: "Stormtrooper's lecture"
196
+ ![p](github_page/p15.png)
197
+
198
+ ## ControlNet with Normal Map
199
+
200
+ Stable Diffusion 1.5 + ControlNet (using normal map)
201
+
202
+ python gradio_normal2image.py
203
+
204
+ This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling).
205
+
206
+ Prompt: "Cute toy"
207
+ ![p](github_page/p17.png)
208
+
209
+ Prompt: "Plaster statue of Abraham Lincoln"
210
+ ![p](github_page/p18.png)
211
+
212
+ Compared to depth model, this model seems to be a bit better at preserving the geometry. This is intuitive: minor details are not salient in depth maps, but are salient in normal maps. Below is the depth result with same inputs. You can see that the hairstyle of the man in the input image is modified by depth model, but preserved by the normal model.
213
+
214
+ Prompt: "Plaster statue of Abraham Lincoln"
215
+ ![p](github_page/p19.png)
216
+
217
+ ## ControlNet with Anime Line Drawing
218
+
219
+ We also trained a relatively simple ControlNet for anime line drawings. This tool may be useful for artistic creations. (Although the image details in the results is a bit modified, since it still diffuse latent images.)
220
+
221
+ This model is not available right now. We need to evaluate the potential risks before releasing this model. Nevertheless, you may be interested in [transferring the ControlNet to any community model](https://github.com/lllyasviel/ControlNet/discussions/12).
222
+
223
+ ![p](github_page/p21.png)
224
+
225
+ <a id="guess-anchor"></a>
226
+
227
+ # Guess Mode / Non-Prompt Mode
228
+
229
+ The "guess mode" (or called non-prompt mode) will completely unleash all the power of the very powerful ControlNet encoder.
230
+
231
+ See also the blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188)
232
+
233
+ You need to manually check the "Guess Mode" toggle to enable this mode.
234
+
235
+ In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
236
+
237
+ **Let's have fun with some very challenging experimental settings!**
238
+
239
+ **No prompts. No "positive" prompts. No "negative" prompts. No extra caption detector. One single diffusion loop.**
240
+
241
+ For this mode, we recommend to use 50 steps and guidance scale between 3 and 5.
242
+
243
+ ![p](github_page/uc2a.png)
244
+
245
+ No prompts:
246
+
247
+ ![p](github_page/uc2b.png)
248
+
249
+ Note that the below example is 768×768. No prompts. No "positive" prompts. No "negative" prompts.
250
+
251
+ ![p](github_page/uc1.png)
252
+
253
+ By tuning the parameters, you can get some very intereting results like below:
254
+
255
+ ![p](github_page/uc3.png)
256
+
257
+ Because no prompt is available, the ControlNet encoder will "guess" what is in the control map. Sometimes the guess result is really interesting. Because diffusion algorithm can essentially give multiple results, the ControlNet seems able to give multiple guesses, like this:
258
+
259
+ ![p](github_page/uc4.png)
260
+
261
+ Without prompt, the HED seems good at generating images look like paintings when the control strength is relatively low:
262
+
263
+ ![p](github_page/uc6.png)
264
+
265
+ The Guess Mode is also supported in [WebUI Plugin](https://github.com/Mikubill/sd-webui-controlnet):
266
+
267
+ ![p](github_page/uci1.png)
268
+
269
+ No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce.
270
+
271
+ ![p](github_page/uci2.png)
272
+
273
+ Below is another challenging example:
274
+
275
+ ![p](github_page/uci3.png)
276
+
277
+ No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce.
278
+
279
+ ![p](github_page/uci4.png)
280
+
281
+ Note that in the guess mode, you will still be able to input prompts. The only difference is that the model will "try harder" to guess what is in the control map even if you do not provide the prompt. Just try it yourself!
282
+
283
+ Besides, if you write some scripts (like BLIP) to generate image captions from the "guess mode" images, and then use the generated captions as prompts to diffuse again, you will get a SOTA pipeline for fully automatic conditional image generating.
284
+
285
+ # Combining Multiple ControlNets
286
+
287
+ ControlNets are composable: more than one ControlNet can be easily composed to multi-condition control.
288
+
289
+ Right now this feature is in experimental stage in the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet):
290
+
291
+ ![p](github_page/multi2.png)
292
+
293
+ ![p](github_page/multi.png)
294
+
295
+ As long as the models are controlling the same SD, the "boundary" between different research projects does not even exist. This plugin also allows different methods to work together!
296
+
297
+ # Use ControlNet in Any Community Model (SD1.X)
298
+
299
+ This is an experimental feature.
300
+
301
+ [See the steps here](https://github.com/lllyasviel/ControlNet/discussions/12).
302
+
303
+ Or you may want to use the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) which is plug-and-play and does not need manual merging.
304
+
305
+ # Annotate Your Own Data
306
+
307
+ We provide simple python scripts to process images.
308
+
309
+ [See a gradio example here](docs/annotator.md).
310
+
311
+ # Train with Your Own Data
312
+
313
+ Training a ControlNet is as easy as (or even easier than) training a simple pix2pix.
314
+
315
+ [See the steps here](docs/train.md).
316
+
317
+ # Related Resources
318
+
319
+ Special Thank to the great project - [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) !
320
+
321
+ We also thank Hysts for making [Hugging Face Space](https://huggingface.co/spaces/hysts/ControlNet) as well as more than 65 models in that amazing [Colab list](https://github.com/camenduru/controlnet-colab)!
322
+
323
+ Thank haofanwang for making [ControlNet-for-Diffusers](https://github.com/haofanwang/ControlNet-for-Diffusers)!
324
+
325
+ We also thank all authors for making Controlnet DEMOs, including but not limited to [fffiloni](https://huggingface.co/spaces/fffiloni/ControlNet-Video), [other-model](https://huggingface.co/spaces/hysts/ControlNet-with-other-models), [ThereforeGames](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/7784), [RamAnanth1](https://huggingface.co/spaces/RamAnanth1/ControlNet), etc!
326
+
327
+ Besides, you may also want to read these amazing related works:
328
+
329
+ [Composer: Creative and Controllable Image Synthesis with Composable Conditions](https://github.com/damo-vilab/composer): A much bigger model to control diffusion!
330
+
331
+ [T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://github.com/TencentARC/T2I-Adapter): A much smaller model to control stable diffusion!
332
+
333
+ [ControlLoRA: A Light Neural Network To Control Stable Diffusion Spatial Information](https://github.com/HighCWu/ControlLoRA): Implement Controlnet using LORA!
334
+
335
+ And these amazing recent projects: [InstructPix2Pix Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix), [Pix2pix-zero: Zero-shot Image-to-Image Translation](https://github.com/pix2pixzero/pix2pix-zero), [Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation](https://github.com/MichalGeyer/plug-and-play), [MaskSketch: Unpaired Structure-guided Masked Image Generation](https://arxiv.org/abs/2302.05496), [SEGA: Instructing Diffusion using Semantic Dimensions](https://arxiv.org/abs/2301.12247), [Universal Guidance for Diffusion Models](https://github.com/arpitbansal297/Universal-Guided-Diffusion), [Region-Aware Diffusion for Zero-shot Text-driven Image Editing](https://github.com/haha-lisa/RDM-Region-Aware-Diffusion-Model), [Domain Expansion of Image Generators](https://arxiv.org/abs/2301.05225), [Image Mixer](https://twitter.com/LambdaAPI/status/1626327289288957956), [MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://multidiffusion.github.io/)
336
+
337
+ # Citation
338
+
339
+ @misc{zhang2023adding,
340
+ title={Adding Conditional Control to Text-to-Image Diffusion Models},
341
+ author={Lvmin Zhang and Anyi Rao and Maneesh Agrawala},
342
+ booktitle={IEEE International Conference on Computer Vision (ICCV)}
343
+ year={2023},
344
+ }
345
+
346
+ [Arxiv Link](https://arxiv.org/abs/2302.05543)
347
+
348
+ [Supplementary Materials](https://lllyasviel.github.io/misc/202309/cnet_supp.pdf)
generation/control/ControlNet/cldm/cldm.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ from ldm.modules.diffusionmodules.util import (
7
+ conv_nd,
8
+ linear,
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from einops import rearrange, repeat
14
+ from torchvision.utils import make_grid
15
+ from ldm.modules.attention import SpatialTransformer
16
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
17
+ from ldm.models.diffusion.ddpm import LatentDiffusion
18
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
19
+ from ldm.models.diffusion.ddim import DDIMSampler
20
+
21
+
22
+ class ControlledUnetModel(UNetModel):
23
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
24
+ hs = []
25
+ with torch.no_grad():
26
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
27
+ emb = self.time_embed(t_emb)
28
+ h = x.type(self.dtype)
29
+ for module in self.input_blocks:
30
+ h = module(h, emb, context)
31
+ hs.append(h)
32
+ h = self.middle_block(h, emb, context)
33
+
34
+ if control is not None:
35
+ h += control.pop()
36
+
37
+ for i, module in enumerate(self.output_blocks):
38
+ if only_mid_control or control is None:
39
+ h = torch.cat([h, hs.pop()], dim=1)
40
+ else:
41
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
42
+ h = module(h, emb, context)
43
+
44
+ h = h.type(x.dtype)
45
+ return self.out(h)
46
+
47
+
48
+ class ControlNet(nn.Module):
49
+ def __init__(
50
+ self,
51
+ image_size,
52
+ in_channels,
53
+ model_channels,
54
+ hint_channels,
55
+ num_res_blocks,
56
+ attention_resolutions,
57
+ dropout=0,
58
+ channel_mult=(1, 2, 4, 8),
59
+ conv_resample=True,
60
+ dims=2,
61
+ use_checkpoint=False,
62
+ use_fp16=False,
63
+ num_heads=-1,
64
+ num_head_channels=-1,
65
+ num_heads_upsample=-1,
66
+ use_scale_shift_norm=False,
67
+ resblock_updown=False,
68
+ use_new_attention_order=False,
69
+ use_spatial_transformer=False, # custom transformer support
70
+ transformer_depth=1, # custom transformer support
71
+ context_dim=None, # custom transformer support
72
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
73
+ legacy=True,
74
+ disable_self_attentions=None,
75
+ num_attention_blocks=None,
76
+ disable_middle_self_attn=False,
77
+ use_linear_in_transformer=False,
78
+ ):
79
+ super().__init__()
80
+ if use_spatial_transformer:
81
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
82
+
83
+ if context_dim is not None:
84
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
85
+ from omegaconf.listconfig import ListConfig
86
+ if type(context_dim) == ListConfig:
87
+ context_dim = list(context_dim)
88
+
89
+ if num_heads_upsample == -1:
90
+ num_heads_upsample = num_heads
91
+
92
+ if num_heads == -1:
93
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
94
+
95
+ if num_head_channels == -1:
96
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
97
+
98
+ self.dims = dims
99
+ self.image_size = image_size
100
+ self.in_channels = in_channels
101
+ self.model_channels = model_channels
102
+ if isinstance(num_res_blocks, int):
103
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
104
+ else:
105
+ if len(num_res_blocks) != len(channel_mult):
106
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
107
+ "as a list/tuple (per-level) with the same length as channel_mult")
108
+ self.num_res_blocks = num_res_blocks
109
+ if disable_self_attentions is not None:
110
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
111
+ assert len(disable_self_attentions) == len(channel_mult)
112
+ if num_attention_blocks is not None:
113
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
114
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
115
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
116
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
117
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
118
+ f"attention will still not be set.")
119
+
120
+ self.attention_resolutions = attention_resolutions
121
+ self.dropout = dropout
122
+ self.channel_mult = channel_mult
123
+ self.conv_resample = conv_resample
124
+ self.use_checkpoint = use_checkpoint
125
+ self.dtype = th.float16 if use_fp16 else th.float32
126
+ self.num_heads = num_heads
127
+ self.num_head_channels = num_head_channels
128
+ self.num_heads_upsample = num_heads_upsample
129
+ self.predict_codebook_ids = n_embed is not None
130
+
131
+ time_embed_dim = model_channels * 4
132
+ self.time_embed = nn.Sequential(
133
+ linear(model_channels, time_embed_dim),
134
+ nn.SiLU(),
135
+ linear(time_embed_dim, time_embed_dim),
136
+ )
137
+
138
+ self.input_blocks = nn.ModuleList(
139
+ [
140
+ TimestepEmbedSequential(
141
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
142
+ )
143
+ ]
144
+ )
145
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
146
+
147
+ self.input_hint_block = TimestepEmbedSequential(
148
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
149
+ nn.SiLU(),
150
+ conv_nd(dims, 16, 16, 3, padding=1),
151
+ nn.SiLU(),
152
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
153
+ nn.SiLU(),
154
+ conv_nd(dims, 32, 32, 3, padding=1),
155
+ nn.SiLU(),
156
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
157
+ nn.SiLU(),
158
+ conv_nd(dims, 96, 96, 3, padding=1),
159
+ nn.SiLU(),
160
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
161
+ nn.SiLU(),
162
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
163
+ )
164
+
165
+ self._feature_size = model_channels
166
+ input_block_chans = [model_channels]
167
+ ch = model_channels
168
+ ds = 1
169
+ for level, mult in enumerate(channel_mult):
170
+ for nr in range(self.num_res_blocks[level]):
171
+ layers = [
172
+ ResBlock(
173
+ ch,
174
+ time_embed_dim,
175
+ dropout,
176
+ out_channels=mult * model_channels,
177
+ dims=dims,
178
+ use_checkpoint=use_checkpoint,
179
+ use_scale_shift_norm=use_scale_shift_norm,
180
+ )
181
+ ]
182
+ ch = mult * model_channels
183
+ if ds in attention_resolutions:
184
+ if num_head_channels == -1:
185
+ dim_head = ch // num_heads
186
+ else:
187
+ num_heads = ch // num_head_channels
188
+ dim_head = num_head_channels
189
+ if legacy:
190
+ # num_heads = 1
191
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
192
+ if exists(disable_self_attentions):
193
+ disabled_sa = disable_self_attentions[level]
194
+ else:
195
+ disabled_sa = False
196
+
197
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
198
+ layers.append(
199
+ AttentionBlock(
200
+ ch,
201
+ use_checkpoint=use_checkpoint,
202
+ num_heads=num_heads,
203
+ num_head_channels=dim_head,
204
+ use_new_attention_order=use_new_attention_order,
205
+ ) if not use_spatial_transformer else SpatialTransformer(
206
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
207
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
208
+ use_checkpoint=use_checkpoint
209
+ )
210
+ )
211
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
212
+ self.zero_convs.append(self.make_zero_conv(ch))
213
+ self._feature_size += ch
214
+ input_block_chans.append(ch)
215
+ if level != len(channel_mult) - 1:
216
+ out_ch = ch
217
+ self.input_blocks.append(
218
+ TimestepEmbedSequential(
219
+ ResBlock(
220
+ ch,
221
+ time_embed_dim,
222
+ dropout,
223
+ out_channels=out_ch,
224
+ dims=dims,
225
+ use_checkpoint=use_checkpoint,
226
+ use_scale_shift_norm=use_scale_shift_norm,
227
+ down=True,
228
+ )
229
+ if resblock_updown
230
+ else Downsample(
231
+ ch, conv_resample, dims=dims, out_channels=out_ch
232
+ )
233
+ )
234
+ )
235
+ ch = out_ch
236
+ input_block_chans.append(ch)
237
+ self.zero_convs.append(self.make_zero_conv(ch))
238
+ ds *= 2
239
+ self._feature_size += ch
240
+
241
+ if num_head_channels == -1:
242
+ dim_head = ch // num_heads
243
+ else:
244
+ num_heads = ch // num_head_channels
245
+ dim_head = num_head_channels
246
+ if legacy:
247
+ # num_heads = 1
248
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
249
+ self.middle_block = TimestepEmbedSequential(
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ ),
258
+ AttentionBlock(
259
+ ch,
260
+ use_checkpoint=use_checkpoint,
261
+ num_heads=num_heads,
262
+ num_head_channels=dim_head,
263
+ use_new_attention_order=use_new_attention_order,
264
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
265
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
266
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
267
+ use_checkpoint=use_checkpoint
268
+ ),
269
+ ResBlock(
270
+ ch,
271
+ time_embed_dim,
272
+ dropout,
273
+ dims=dims,
274
+ use_checkpoint=use_checkpoint,
275
+ use_scale_shift_norm=use_scale_shift_norm,
276
+ ),
277
+ )
278
+ self.middle_block_out = self.make_zero_conv(ch)
279
+ self._feature_size += ch
280
+
281
+ def make_zero_conv(self, channels):
282
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
283
+
284
+ def forward(self, x, hint, timesteps, context, **kwargs):
285
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
286
+ emb = self.time_embed(t_emb)
287
+
288
+ guided_hint = self.input_hint_block(hint, emb, context)
289
+
290
+ outs = []
291
+
292
+ h = x.type(self.dtype)
293
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
294
+ if guided_hint is not None:
295
+ h = module(h, emb, context)
296
+ h += guided_hint
297
+ guided_hint = None
298
+ else:
299
+ h = module(h, emb, context)
300
+ outs.append(zero_conv(h, emb, context))
301
+
302
+ h = self.middle_block(h, emb, context)
303
+ outs.append(self.middle_block_out(h, emb, context))
304
+
305
+ return outs
306
+
307
+
308
+ class ControlLDM(LatentDiffusion):
309
+
310
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
311
+ super().__init__(*args, **kwargs)
312
+ self.control_model = instantiate_from_config(control_stage_config)
313
+ self.control_key = control_key
314
+ self.only_mid_control = only_mid_control
315
+ self.control_scales = [1.0] * 13
316
+
317
+ @torch.no_grad()
318
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
319
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
320
+ control = batch[self.control_key]
321
+ if bs is not None:
322
+ control = control[:bs]
323
+ control = control.to(self.device)
324
+ control = einops.rearrange(control, 'b h w c -> b c h w')
325
+ control = control.to(memory_format=torch.contiguous_format).float()
326
+ return x, dict(c_crossattn=[c], c_concat=[control])
327
+
328
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
329
+ assert isinstance(cond, dict)
330
+ diffusion_model = self.model.diffusion_model
331
+
332
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
333
+
334
+ if cond['c_concat'] is None:
335
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
336
+ else:
337
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
338
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
339
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
340
+
341
+ return eps
342
+
343
+ @torch.no_grad()
344
+ def get_unconditional_conditioning(self, N):
345
+ return self.get_learned_conditioning([""] * N)
346
+
347
+ @torch.no_grad()
348
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
349
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
350
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
351
+ use_ema_scope=True,
352
+ **kwargs):
353
+ use_ddim = ddim_steps is not None
354
+
355
+ log = dict()
356
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
357
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
358
+ N = min(z.shape[0], N)
359
+ n_row = min(z.shape[0], n_row)
360
+ log["reconstruction"] = self.decode_first_stage(z)
361
+ log["control"] = c_cat * 2.0 - 1.0
362
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
363
+
364
+ if plot_diffusion_rows:
365
+ # get diffusion row
366
+ diffusion_row = list()
367
+ z_start = z[:n_row]
368
+ for t in range(self.num_timesteps):
369
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
370
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
371
+ t = t.to(self.device).long()
372
+ noise = torch.randn_like(z_start)
373
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
374
+ diffusion_row.append(self.decode_first_stage(z_noisy))
375
+
376
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
377
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
378
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
379
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
380
+ log["diffusion_row"] = diffusion_grid
381
+
382
+ if sample:
383
+ # get denoise row
384
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
385
+ batch_size=N, ddim=use_ddim,
386
+ ddim_steps=ddim_steps, eta=ddim_eta)
387
+ x_samples = self.decode_first_stage(samples)
388
+ log["samples"] = x_samples
389
+ if plot_denoise_rows:
390
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
391
+ log["denoise_row"] = denoise_grid
392
+
393
+ if unconditional_guidance_scale > 1.0:
394
+ uc_cross = self.get_unconditional_conditioning(N)
395
+ uc_cat = c_cat # torch.zeros_like(c_cat)
396
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
397
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
398
+ batch_size=N, ddim=use_ddim,
399
+ ddim_steps=ddim_steps, eta=ddim_eta,
400
+ unconditional_guidance_scale=unconditional_guidance_scale,
401
+ unconditional_conditioning=uc_full,
402
+ )
403
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
404
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
405
+
406
+ return log
407
+
408
+ @torch.no_grad()
409
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
410
+ ddim_sampler = DDIMSampler(self)
411
+ b, c, h, w = cond["c_concat"][0].shape
412
+ shape = (self.channels, h // 8, w // 8)
413
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
414
+ return samples, intermediates
415
+
416
+ def configure_optimizers(self):
417
+ lr = self.learning_rate
418
+ params = list(self.control_model.parameters())
419
+ if not self.sd_locked:
420
+ params += list(self.model.diffusion_model.output_blocks.parameters())
421
+ params += list(self.model.diffusion_model.out.parameters())
422
+ opt = torch.optim.AdamW(params, lr=lr)
423
+ return opt
424
+
425
+ def low_vram_shift(self, is_diffusing):
426
+ if is_diffusing:
427
+ self.model = self.model.cuda()
428
+ self.control_model = self.control_model.cuda()
429
+ self.first_stage_model = self.first_stage_model.cpu()
430
+ self.cond_stage_model = self.cond_stage_model.cpu()
431
+ else:
432
+ self.model = self.model.cpu()
433
+ self.control_model = self.control_model.cpu()
434
+ self.first_stage_model = self.first_stage_model.cuda()
435
+ self.cond_stage_model = self.cond_stage_model.cuda()
generation/control/ControlNet/cldm/ddim_hacked.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ model_t = self.model.apply_model(x, t, c)
191
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
192
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
193
+
194
+ if self.model.parameterization == "v":
195
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
196
+ else:
197
+ e_t = model_output
198
+
199
+ if score_corrector is not None:
200
+ assert self.model.parameterization == "eps", 'not implemented'
201
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
202
+
203
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
204
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
205
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
206
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
207
+ # select parameters corresponding to the currently considered timestep
208
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
209
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
210
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
211
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
212
+
213
+ # current prediction for x_0
214
+ if self.model.parameterization != "v":
215
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
216
+ else:
217
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
218
+
219
+ if quantize_denoised:
220
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
221
+
222
+ if dynamic_threshold is not None:
223
+ raise NotImplementedError()
224
+
225
+ # direction pointing to x_t
226
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
227
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
228
+ if noise_dropout > 0.:
229
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
230
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
231
+ return x_prev, pred_x0
232
+
233
+ @torch.no_grad()
234
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
235
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
236
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
237
+ num_reference_steps = timesteps.shape[0]
238
+
239
+ assert t_enc <= num_reference_steps
240
+ num_steps = t_enc
241
+
242
+ if use_original_steps:
243
+ alphas_next = self.alphas_cumprod[:num_steps]
244
+ alphas = self.alphas_cumprod_prev[:num_steps]
245
+ else:
246
+ alphas_next = self.ddim_alphas[:num_steps]
247
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
248
+
249
+ x_next = x0
250
+ intermediates = []
251
+ inter_steps = []
252
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
253
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
254
+ if unconditional_guidance_scale == 1.:
255
+ noise_pred = self.model.apply_model(x_next, t, c)
256
+ else:
257
+ assert unconditional_conditioning is not None
258
+ e_t_uncond, noise_pred = torch.chunk(
259
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
260
+ torch.cat((unconditional_conditioning, c))), 2)
261
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
262
+
263
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
264
+ weighted_noise_pred = alphas_next[i].sqrt() * (
265
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
266
+ x_next = xt_weighted + weighted_noise_pred
267
+ if return_intermediates and i % (
268
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
269
+ intermediates.append(x_next)
270
+ inter_steps.append(i)
271
+ elif return_intermediates and i >= num_steps - 2:
272
+ intermediates.append(x_next)
273
+ inter_steps.append(i)
274
+ if callback: callback(i)
275
+
276
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
277
+ if return_intermediates:
278
+ out.update({'intermediates': intermediates})
279
+ return x_next, out
280
+
281
+ @torch.no_grad()
282
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
283
+ # fast, but does not allow for exact reconstruction
284
+ # t serves as an index to gather the correct alphas
285
+ if use_original_steps:
286
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
287
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
288
+ else:
289
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
290
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
291
+
292
+ if noise is None:
293
+ noise = torch.randn_like(x0)
294
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
295
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
296
+
297
+ @torch.no_grad()
298
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
299
+ use_original_steps=False, callback=None):
300
+
301
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
302
+ timesteps = timesteps[:t_start]
303
+
304
+ time_range = np.flip(timesteps)
305
+ total_steps = timesteps.shape[0]
306
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
307
+
308
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
309
+ x_dec = x_latent
310
+ for i, step in enumerate(iterator):
311
+ index = total_steps - i - 1
312
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
313
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
314
+ unconditional_guidance_scale=unconditional_guidance_scale,
315
+ unconditional_conditioning=unconditional_conditioning)
316
+ if callback: callback(i)
317
+ return x_dec
generation/control/ControlNet/cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
generation/control/ControlNet/cldm/logger.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from PIL import Image
7
+ from pytorch_lightning.callbacks import Callback
8
+ from pytorch_lightning.utilities.distributed import rank_zero_only
9
+
10
+
11
+ class ImageLogger(Callback):
12
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
13
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
14
+ log_images_kwargs=None):
15
+ super().__init__()
16
+ self.rescale = rescale
17
+ self.batch_freq = batch_frequency
18
+ self.max_images = max_images
19
+ if not increase_log_steps:
20
+ self.log_steps = [self.batch_freq]
21
+ self.clamp = clamp
22
+ self.disabled = disabled
23
+ self.log_on_batch_idx = log_on_batch_idx
24
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
25
+ self.log_first_step = log_first_step
26
+
27
+ @rank_zero_only
28
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
29
+ root = os.path.join(save_dir, "image_log", split)
30
+ for k in images:
31
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
32
+ if self.rescale:
33
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
34
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
35
+ grid = grid.numpy()
36
+ grid = (grid * 255).astype(np.uint8)
37
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
38
+ path = os.path.join(root, filename)
39
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
40
+ Image.fromarray(grid).save(path)
41
+
42
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
43
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
44
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
45
+ hasattr(pl_module, "log_images") and
46
+ callable(pl_module.log_images) and
47
+ self.max_images > 0):
48
+ logger = type(pl_module.logger)
49
+
50
+ is_train = pl_module.training
51
+ if is_train:
52
+ pl_module.eval()
53
+
54
+ with torch.no_grad():
55
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
56
+
57
+ for k in images:
58
+ N = min(images[k].shape[0], self.max_images)
59
+ images[k] = images[k][:N]
60
+ if isinstance(images[k], torch.Tensor):
61
+ images[k] = images[k].detach().cpu()
62
+ if self.clamp:
63
+ images[k] = torch.clamp(images[k], -1., 1.)
64
+
65
+ self.log_local(pl_module.logger.save_dir, split, images,
66
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
67
+
68
+ if is_train:
69
+ pl_module.train()
70
+
71
+ def check_frequency(self, check_idx):
72
+ return check_idx % self.batch_freq == 0
73
+
74
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
75
+ if not self.disabled:
76
+ self.log_img(pl_module, batch, batch_idx, split="train")
generation/control/ControlNet/cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
generation/control/ControlNet/config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ save_memory = False
generation/control/ControlNet/docs/annotator.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automatic Annotations
2
+
3
+ We provide gradio examples to obtain annotations that are aligned to our pretrained production-ready models.
4
+
5
+ Just run
6
+
7
+ python gradio_annotator.py
8
+
9
+ Since everyone has different habit to organize their datasets, we do not hard code any scripts for batch processing. But "gradio_annotator.py" is written in a super readable way, and modifying it to annotate your images should be easy.
10
+
11
+ In the gradio UI of "gradio_annotator.py" we have the following interfaces:
12
+
13
+ ### Canny Edge
14
+
15
+ Be careful about "black edge and white background" or "white edge and black background".
16
+
17
+ ![p](../github_page/a1.png)
18
+
19
+ ### HED Edge
20
+
21
+ Be careful about "black edge and white background" or "white edge and black background".
22
+
23
+ ![p](../github_page/a2.png)
24
+
25
+ ### MLSD Edge
26
+
27
+ Be careful about "black edge and white background" or "white edge and black background".
28
+
29
+ ![p](../github_page/a3.png)
30
+
31
+ ### MIDAS Depth and Normal
32
+
33
+ Be careful about RGB or BGR in normal maps.
34
+
35
+ ![p](../github_page/a4.png)
36
+
37
+ ### Openpose
38
+
39
+ Be careful about RGB or BGR in pose maps.
40
+
41
+ For our production-ready model, the hand pose option is turned off.
42
+
43
+ ![p](../github_page/a5.png)
44
+
45
+ ### Uniformer Segmentation
46
+
47
+ Be careful about RGB or BGR in segmentation maps.
48
+
49
+ ![p](../github_page/a6.png)
generation/control/ControlNet/docs/faq.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQs
2
+
3
+ **Q:** If the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
4
+
5
+ **A:** This is wrong. Let us consider a very simple
6
+
7
+ $$y=wx+b$$
8
+
9
+ and we have
10
+
11
+ $$\partial y/\partial w=x, \partial y/\partial x=w, \partial y/\partial b=1$$
12
+
13
+ and if $w=0$ and $x \neq 0$, then
14
+
15
+ $$\partial y/\partial w \neq 0, \partial y/\partial x=0, \partial y/\partial b\neq 0$$
16
+
17
+ which means as long as $x \neq 0$, one gradient descent iteration will make $w$ non-zero. Then
18
+
19
+ $$\partial y/\partial x\neq 0$$
20
+
21
+ so that the zero convolutions will progressively become a common conv layer with non-zero weights.
generation/control/ControlNet/docs/low_vram.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Enable Low VRAM Mode
2
+
3
+ If you are using 8GB GPU card (or if you want larger batch size), please open "config.py", and then set
4
+
5
+ ```python
6
+ save_memory = True
7
+ ```
8
+
9
+ This feature is still being tested - not all graphics cards are guaranteed to succeed.
10
+
11
+ But it should be neat as I can diffuse at a batch size of 12 now.
12
+
13
+ (prompt "man")
14
+
15
+ ![p](../github_page/ram12.jpg)
generation/control/ControlNet/docs/train.md ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train a ControlNet to Control SD
2
+
3
+ You are here because you want to control SD in your own way, maybe you have an idea for your perfect research project, and you will annotate some data or have already annotated your own dataset automatically or manually. Herein, the control can be anything that can be converted to images, such as edges, keypoints, segments, etc.
4
+
5
+ Before moving on to your own dataset, we highly recommend to first try the toy dataset, Fill50K, as a sanity check. This will help you get a "feeling" for the training. You will know how long it will take for the model to converge and whether your device will be able to complete the training in an acceptable amount of time. And what it "feels" like when the model converges.
6
+
7
+ We hope that after you read this page, you will find that training a ControlNet is as easy as (or easier than) training a pix2pix.
8
+
9
+ ## Step 0 - Design your control
10
+
11
+ Let us take a look at a very simple task to control SD to fill color in circles.
12
+
13
+ ![p](../github_page/t1.png)
14
+
15
+ This is simple: we want to control SD to fill a circle with colors, and the prompt contains some description of our target.
16
+
17
+ Stable diffusion is trained on billions of images, and it already knows what is "cyan", what is "circle", what is "pink", and what is "background".
18
+
19
+ But it does not know the meaning of that "Control Image (Source Image)". Our target is to let it know.
20
+
21
+ ## Step 1 - Get a dataset
22
+
23
+ Just download the Fill50K dataset from [our huggingface page](https://huggingface.co/lllyasviel/ControlNet) (training/fill50k.zip, the file is only 200M!). Make sure that the data is decompressed as
24
+
25
+ ControlNet/training/fill50k/prompt.json
26
+ ControlNet/training/fill50k/source/X.png
27
+ ControlNet/training/fill50k/target/X.png
28
+
29
+ In the folder "fill50k/source", you will have 50k images of circle lines.
30
+
31
+ ![p](../github_page/t2.png)
32
+
33
+ In the folder "fill50k/target", you will have 50k images of filled circles.
34
+
35
+ ![p](../github_page/t3.png)
36
+
37
+ In the "fill50k/prompt.json", you will have their filenames and prompts. Each prompt is like "a balabala color circle in some other color background."
38
+
39
+ ![p](../github_page/t4.png)
40
+
41
+ ## Step 2 - Load the dataset
42
+
43
+ Then you need to write a simple script to read this dataset for pytorch. (In fact we have written it for you in "tutorial_dataset.py".)
44
+
45
+ ```python
46
+ import json
47
+ import cv2
48
+ import numpy as np
49
+
50
+ from torch.utils.data import Dataset
51
+
52
+
53
+ class MyDataset(Dataset):
54
+ def __init__(self):
55
+ self.data = []
56
+ with open('./training/fill50k/prompt.json', 'rt') as f:
57
+ for line in f:
58
+ self.data.append(json.loads(line))
59
+
60
+ def __len__(self):
61
+ return len(self.data)
62
+
63
+ def __getitem__(self, idx):
64
+ item = self.data[idx]
65
+
66
+ source_filename = item['source']
67
+ target_filename = item['target']
68
+ prompt = item['prompt']
69
+
70
+ source = cv2.imread('./training/fill50k/' + source_filename)
71
+ target = cv2.imread('./training/fill50k/' + target_filename)
72
+
73
+ # Do not forget that OpenCV read images in BGR order.
74
+ source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
75
+ target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
76
+
77
+ # Normalize source images to [0, 1].
78
+ source = source.astype(np.float32) / 255.0
79
+
80
+ # Normalize target images to [-1, 1].
81
+ target = (target.astype(np.float32) / 127.5) - 1.0
82
+
83
+ return dict(jpg=target, txt=prompt, hint=source)
84
+
85
+ ```
86
+
87
+ This will make your dataset into an array-like object in python. You can test this dataset simply by accessing the array, like this
88
+
89
+ ```python
90
+ from tutorial_dataset import MyDataset
91
+
92
+ dataset = MyDataset()
93
+ print(len(dataset))
94
+
95
+ item = dataset[1234]
96
+ jpg = item['jpg']
97
+ txt = item['txt']
98
+ hint = item['hint']
99
+ print(txt)
100
+ print(jpg.shape)
101
+ print(hint.shape)
102
+
103
+ ```
104
+
105
+ The outputs of this simple test on my machine are
106
+
107
+ 50000
108
+ burly wood circle with orange background
109
+ (512, 512, 3)
110
+ (512, 512, 3)
111
+
112
+ And this code is in "tutorial_dataset_test.py".
113
+
114
+ In this way, the dataset is an array-like object with 50000 items. Each item is a dict with three entry "jpg", "txt", and "hint". The "jpg" is the target image, the "hint" is the control image, and the "txt" is the prompt.
115
+
116
+ Do not ask us why we use these three names - this is related to the dark history of a library called LDM.
117
+
118
+ ## Step 3 - What SD model do you want to control?
119
+
120
+ Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main).
121
+
122
+ (Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2.)
123
+
124
+ Then you need to attach a control net to the SD model. The architecture is
125
+
126
+ ![img](../github_page/sd.png)
127
+
128
+ Note that all weights inside the ControlNet are also copied from SD so that no layer is trained from scratch, and you are still finetuning the entire model.
129
+
130
+ We provide a simple script for you to achieve this easily. If your SD filename is "./models/v1-5-pruned.ckpt" and you want the script to save the processed model (SD+ControlNet) at location "./models/control_sd15_ini.ckpt", you can just run:
131
+
132
+ python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt
133
+
134
+ Or if you are using SD2:
135
+
136
+ python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt
137
+
138
+ You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".
139
+
140
+ This is the correct output from my machine:
141
+
142
+ ![img](../github_page/t5.png)
143
+
144
+ ## Step 4 - Train!
145
+
146
+ Happy! We finally come to the most exciting part: training!
147
+
148
+ The training code in "tutorial_train.py" is actually surprisingly simple:
149
+
150
+ ```python
151
+ import pytorch_lightning as pl
152
+ from torch.utils.data import DataLoader
153
+ from tutorial_dataset import MyDataset
154
+ from cldm.logger import ImageLogger
155
+ from cldm.model import create_model, load_state_dict
156
+
157
+
158
+ # Configs
159
+ resume_path = './models/control_sd15_ini.ckpt'
160
+ batch_size = 4
161
+ logger_freq = 300
162
+ learning_rate = 1e-5
163
+ sd_locked = True
164
+ only_mid_control = False
165
+
166
+
167
+ # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
168
+ model = create_model('./models/cldm_v15.yaml').cpu()
169
+ model.load_state_dict(load_state_dict(resume_path, location='cpu'))
170
+ model.learning_rate = learning_rate
171
+ model.sd_locked = sd_locked
172
+ model.only_mid_control = only_mid_control
173
+
174
+
175
+ # Misc
176
+ dataset = MyDataset()
177
+ dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
178
+ logger = ImageLogger(batch_frequency=logger_freq)
179
+ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
180
+
181
+
182
+ # Train!
183
+ trainer.fit(model, dataloader)
184
+
185
+ ```
186
+ (or "tutorial_train_sd21.py" if you are using SD2)
187
+
188
+ Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.
189
+
190
+ Now, you may take a look at [Pytorch Lightning Official DOC](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#trainer) to find out how to enable many useful features like gradient accumulation, multiple GPU training, accelerated dataset loading, flexible checkpoint saving, etc. All these only need about one line of code. Great!
191
+
192
+ Note that if you find OOM, perhaps you need to enable [Low VRAM mode](low_vram.md), and perhaps you also need to use smaller batch size and gradient accumulation. Or you may also want to use some “advanced” tricks like sliced attention or xformers. For example:
193
+
194
+ ```python
195
+ # Configs
196
+ batch_size = 1
197
+
198
+ # Misc
199
+ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger], accumulate_grad_batches=4) # But this will be 4x slower
200
+ ```
201
+
202
+ Note that training with 8 GB laptop GPU is challenging. We will need some GPU memory optimization at least as good as automatic1111’s UI. This may require expert modifications to the code.
203
+
204
+ ### Screenshots
205
+
206
+ The training is fast. After 4000 steps (batch size 4, learning rate 1e-5, about 50 minutes on PCIE 40G), the results on my machine (in an output folder "image_log") is
207
+
208
+ Control:
209
+
210
+ ![img](../github_page/t/ip.png)
211
+
212
+ Prompt:
213
+
214
+ ![img](../github_page/t/t.png)
215
+
216
+ Prediction:
217
+
218
+ ![img](../github_page/t/op.png)
219
+
220
+ Ground Truth:
221
+
222
+ ![img](../github_page/t/gt.png)
223
+
224
+ Note that the SD's capability is preserved. Even training on this super aligned dataset, it still draws some random textures and those snow decorations. (Besides, note that the ground truth looks a bit modified because it is converted from SD's latent image.)
225
+
226
+ Larger batch size and longer training will further improve this. Adequate training will make the filling perfect.
227
+
228
+ Of course, training SD to fill circles is meaningless, but this is a successful beginning of your story.
229
+
230
+ Let us work together to control large models more and more.
231
+
232
+ ## Other options
233
+
234
+ Beyond standard things, we also provide two important parameters "sd_locked" and "only_mid_control" that you need to know.
235
+
236
+ ### only_mid_control
237
+
238
+ By default, only_mid_control is False. When it is True, you will train the below architecture.
239
+
240
+ ![img](../github_page/t6.png)
241
+
242
+ This can be helpful when your computation power is limited and want to speed up the training, or when you want to facilitate the "global" context learning. Note that sometimes you may pause training, set it to True, resume training, and pause again, and set it again, and resume again.
243
+
244
+ If your computation device is good, perhaps you do not need this. But I also know some artists are willing to train a model on their laptop for a month - in that case, perhaps this option can be useful.
245
+
246
+ ### sd_locked
247
+
248
+ By default, sd_locked is True. When it is False, you will train the below architecture.
249
+
250
+ ![img](../github_page/t7.png)
251
+
252
+ This will unlock some layers in SD and you will train them as a whole.
253
+
254
+ This option is DANGEROUS! If your dataset is not good enough, this may downgrade the capability of your SD model.
255
+
256
+ However, this option is also very useful when you are training on images with some specific style, or when you are training with special datasets (like medical dataset with X-ray images or geographic datasets with lots of Google Maps). You can understand this as simultaneously training the ControlNet and something like a DreamBooth.
257
+
258
+ Also, if your dataset is large, you may want to end the training with a few thousands of steps with those layer unlocked. This usually improve the "problem-specific" solutions a little. You may try it yourself to feel the difference.
259
+
260
+ Also, if you unlock some original layers, you may want a lower learning rate, like 2e-6.
261
+
262
+ ## More Consideration: Sudden Converge Phenomenon and Gradient Accumulation
263
+
264
+ ![img](../github_page/ex1.jpg)
265
+
266
+ Because we use zero convolutions, the SD should always be able to predict meaningful images. (If it cannot, the training has already failed.)
267
+
268
+ You will always find that at some iterations, the model "suddenly" be able to fit some training conditions. This means that you will get a basically usable model at about 3k to 7k steps (future training will improve it, but that model after the first "sudden converge" should be basically functional).
269
+
270
+ Note that 3k to 7k steps is not very large, and you should consider larger batch size rather than more training steps. If you can observe the "sudden converge" at 3k step using batch size 4, then, rather than train it with 300k further steps, a better idea is to use 100× gradient accumulation to re-train that 3k steps with 100× batch size. Note that perhaps we should not do this *too* extremely (perhaps 100x accumulation is too extreme), but you should consider that, since "sudden converge" will *always* happen at that certain point, getting a better converge is more important.
271
+
272
+ Because that "sudden converge" always happens, lets say "sudden converge" will happen at 3k step and our money can optimize 90k step, then we have two options: (1) train 3k steps, sudden converge, then train 87k steps. (2) 30x gradient accumulation, train 3k steps (90k real computation steps), then sudden converge.
273
+
274
+ In my experiments, (2) is usually better than (1). However, in real cases, perhaps you may need to balance the steps before and after the "sudden converge" on your own to find a balance. The training after "sudden converge" is also important.
275
+
276
+ But usually, if your logic batch size is already bigger than 256, then further extending the batch size is not very meaningful. In that case, perhaps a better idea is to train more steps. I tried some "common" logic batch size at 64 or 96 or 128 (by gradient accumulation), it seems that many complicated conditions can be solved very well already.
generation/control/ControlNet/environment.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: control
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.12.1
10
+ - torchvision=0.13.1
11
+ - numpy=1.23.1
12
+ - pip:
13
+ - gradio==3.16.2
14
+ - albumentations==1.3.0
15
+ - opencv-contrib-python==4.3.0.36
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.5.0
19
+ - omegaconf==2.1.1
20
+ - test-tube>=0.7.5
21
+ - streamlit==1.12.1
22
+ - einops==0.3.0
23
+ - transformers==4.19.2
24
+ - webdataset==0.2.5
25
+ - kornia==0.6
26
+ - open_clip_torch==2.0.2
27
+ - invisible-watermark>=0.1.5
28
+ - streamlit-drawable-canvas==0.8.0
29
+ - torchmetrics==0.6.0
30
+ - timm==0.6.12
31
+ - addict==2.4.0
32
+ - yapf==0.32.0
33
+ - prettytable==3.6.0
34
+ - safetensors==0.2.7
35
+ - basicsr==1.4.2
generation/control/ControlNet/gradio_annotator.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from annotator.util import resize_image, HWC3
4
+
5
+
6
+ model_canny = None
7
+
8
+
9
+ def canny(img, res, l, h):
10
+ img = resize_image(HWC3(img), res)
11
+ global model_canny
12
+ if model_canny is None:
13
+ from annotator.canny import CannyDetector
14
+ model_canny = CannyDetector()
15
+ result = model_canny(img, l, h)
16
+ return [result]
17
+
18
+
19
+ model_hed = None
20
+
21
+
22
+ def hed(img, res):
23
+ img = resize_image(HWC3(img), res)
24
+ global model_hed
25
+ if model_hed is None:
26
+ from annotator.hed import HEDdetector
27
+ model_hed = HEDdetector()
28
+ result = model_hed(img)
29
+ return [result]
30
+
31
+
32
+ model_mlsd = None
33
+
34
+
35
+ def mlsd(img, res, thr_v, thr_d):
36
+ img = resize_image(HWC3(img), res)
37
+ global model_mlsd
38
+ if model_mlsd is None:
39
+ from annotator.mlsd import MLSDdetector
40
+ model_mlsd = MLSDdetector()
41
+ result = model_mlsd(img, thr_v, thr_d)
42
+ return [result]
43
+
44
+
45
+ model_midas = None
46
+
47
+
48
+ def midas(img, res, a):
49
+ img = resize_image(HWC3(img), res)
50
+ global model_midas
51
+ if model_midas is None:
52
+ from annotator.midas import MidasDetector
53
+ model_midas = MidasDetector()
54
+ results = model_midas(img, a)
55
+ return results
56
+
57
+
58
+ model_openpose = None
59
+
60
+
61
+ def openpose(img, res, has_hand):
62
+ img = resize_image(HWC3(img), res)
63
+ global model_openpose
64
+ if model_openpose is None:
65
+ from annotator.openpose import OpenposeDetector
66
+ model_openpose = OpenposeDetector()
67
+ result, _ = model_openpose(img, has_hand)
68
+ return [result]
69
+
70
+
71
+ model_uniformer = None
72
+
73
+
74
+ def uniformer(img, res):
75
+ img = resize_image(HWC3(img), res)
76
+ global model_uniformer
77
+ if model_uniformer is None:
78
+ from annotator.uniformer import UniformerDetector
79
+ model_uniformer = UniformerDetector()
80
+ result = model_uniformer(img)
81
+ return [result]
82
+
83
+
84
+ block = gr.Blocks().queue()
85
+ with block:
86
+ with gr.Row():
87
+ gr.Markdown("## Canny Edge")
88
+ with gr.Row():
89
+ with gr.Column():
90
+ input_image = gr.Image(source='upload', type="numpy")
91
+ low_threshold = gr.Slider(label="low_threshold", minimum=1, maximum=255, value=100, step=1)
92
+ high_threshold = gr.Slider(label="high_threshold", minimum=1, maximum=255, value=200, step=1)
93
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
94
+ run_button = gr.Button(label="Run")
95
+ with gr.Column():
96
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
97
+ run_button.click(fn=canny, inputs=[input_image, resolution, low_threshold, high_threshold], outputs=[gallery])
98
+
99
+ with gr.Row():
100
+ gr.Markdown("## HED Edge")
101
+ with gr.Row():
102
+ with gr.Column():
103
+ input_image = gr.Image(source='upload', type="numpy")
104
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
105
+ run_button = gr.Button(label="Run")
106
+ with gr.Column():
107
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
108
+ run_button.click(fn=hed, inputs=[input_image, resolution], outputs=[gallery])
109
+
110
+ with gr.Row():
111
+ gr.Markdown("## MLSD Edge")
112
+ with gr.Row():
113
+ with gr.Column():
114
+ input_image = gr.Image(source='upload', type="numpy")
115
+ value_threshold = gr.Slider(label="value_threshold", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
116
+ distance_threshold = gr.Slider(label="distance_threshold", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
117
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64)
118
+ run_button = gr.Button(label="Run")
119
+ with gr.Column():
120
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
121
+ run_button.click(fn=mlsd, inputs=[input_image, resolution, value_threshold, distance_threshold], outputs=[gallery])
122
+
123
+ with gr.Row():
124
+ gr.Markdown("## MIDAS Depth and Normal")
125
+ with gr.Row():
126
+ with gr.Column():
127
+ input_image = gr.Image(source='upload', type="numpy")
128
+ alpha = gr.Slider(label="alpha", minimum=0.1, maximum=20.0, value=6.2, step=0.01)
129
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64)
130
+ run_button = gr.Button(label="Run")
131
+ with gr.Column():
132
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
133
+ run_button.click(fn=midas, inputs=[input_image, resolution, alpha], outputs=[gallery])
134
+
135
+ with gr.Row():
136
+ gr.Markdown("## Openpose")
137
+ with gr.Row():
138
+ with gr.Column():
139
+ input_image = gr.Image(source='upload', type="numpy")
140
+ hand = gr.Checkbox(label='detect hand', value=False)
141
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
142
+ run_button = gr.Button(label="Run")
143
+ with gr.Column():
144
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
145
+ run_button.click(fn=openpose, inputs=[input_image, resolution, hand], outputs=[gallery])
146
+
147
+
148
+ with gr.Row():
149
+ gr.Markdown("## Uniformer Segmentation")
150
+ with gr.Row():
151
+ with gr.Column():
152
+ input_image = gr.Image(source='upload', type="numpy")
153
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
154
+ run_button = gr.Button(label="Run")
155
+ with gr.Column():
156
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
157
+ run_button.click(fn=uniformer, inputs=[input_image, resolution], outputs=[gallery])
158
+
159
+
160
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_canny2image.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.canny import CannyDetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_canny = CannyDetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
27
+ with torch.no_grad():
28
+ img = resize_image(HWC3(input_image), image_resolution)
29
+ H, W, C = img.shape
30
+
31
+ detected_map = apply_canny(img, low_threshold, high_threshold)
32
+ detected_map = HWC3(detected_map)
33
+
34
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
35
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
36
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
37
+
38
+ if seed == -1:
39
+ seed = random.randint(0, 65535)
40
+ seed_everything(seed)
41
+
42
+ if config.save_memory:
43
+ model.low_vram_shift(is_diffusing=False)
44
+
45
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
46
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
47
+ shape = (4, H // 8, W // 8)
48
+
49
+ if config.save_memory:
50
+ model.low_vram_shift(is_diffusing=True)
51
+
52
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
53
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
54
+ shape, cond, verbose=False, eta=eta,
55
+ unconditional_guidance_scale=scale,
56
+ unconditional_conditioning=un_cond)
57
+
58
+ if config.save_memory:
59
+ model.low_vram_shift(is_diffusing=False)
60
+
61
+ x_samples = model.decode_first_stage(samples)
62
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
63
+
64
+ results = [x_samples[i] for i in range(num_samples)]
65
+ return [255 - detected_map] + results
66
+
67
+
68
+ block = gr.Blocks().queue()
69
+ with block:
70
+ with gr.Row():
71
+ gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
72
+ with gr.Row():
73
+ with gr.Column():
74
+ input_image = gr.Image(source='upload', type="numpy")
75
+ prompt = gr.Textbox(label="Prompt")
76
+ run_button = gr.Button(label="Run")
77
+ with gr.Accordion("Advanced options", open=False):
78
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
79
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
80
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
81
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
82
+ low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
83
+ high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
84
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
85
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
86
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
87
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
88
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
89
+ n_prompt = gr.Textbox(label="Negative Prompt",
90
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
91
+ with gr.Column():
92
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
93
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold]
94
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
95
+
96
+
97
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_depth2image.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.midas import MidasDetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_midas = MidasDetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
30
+ detected_map = HWC3(detected_map)
31
+ img = resize_image(input_image, image_resolution)
32
+ H, W, C = img.shape
33
+
34
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
35
+
36
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
37
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
38
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
39
+
40
+ if seed == -1:
41
+ seed = random.randint(0, 65535)
42
+ seed_everything(seed)
43
+
44
+ if config.save_memory:
45
+ model.low_vram_shift(is_diffusing=False)
46
+
47
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
48
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
49
+ shape = (4, H // 8, W // 8)
50
+
51
+ if config.save_memory:
52
+ model.low_vram_shift(is_diffusing=True)
53
+
54
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
55
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
56
+ shape, cond, verbose=False, eta=eta,
57
+ unconditional_guidance_scale=scale,
58
+ unconditional_conditioning=un_cond)
59
+
60
+ if config.save_memory:
61
+ model.low_vram_shift(is_diffusing=False)
62
+
63
+ x_samples = model.decode_first_stage(samples)
64
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
65
+
66
+ results = [x_samples[i] for i in range(num_samples)]
67
+ return [detected_map] + results
68
+
69
+
70
+ block = gr.Blocks().queue()
71
+ with block:
72
+ with gr.Row():
73
+ gr.Markdown("## Control Stable Diffusion with Depth Maps")
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image = gr.Image(source='upload', type="numpy")
77
+ prompt = gr.Textbox(label="Prompt")
78
+ run_button = gr.Button(label="Run")
79
+ with gr.Accordion("Advanced options", open=False):
80
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
81
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
82
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
83
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
84
+ detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
85
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
86
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
87
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
88
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
89
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
90
+ n_prompt = gr.Textbox(label="Negative Prompt",
91
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
92
+ with gr.Column():
93
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
94
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
95
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
96
+
97
+
98
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_fake_scribble2image.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.hed import HEDdetector, nms
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_hed = HEDdetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
30
+ detected_map = HWC3(detected_map)
31
+ img = resize_image(input_image, image_resolution)
32
+ H, W, C = img.shape
33
+
34
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
35
+ detected_map = nms(detected_map, 127, 3.0)
36
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
37
+ detected_map[detected_map > 4] = 255
38
+ detected_map[detected_map < 255] = 0
39
+
40
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
41
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
42
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
43
+
44
+ if seed == -1:
45
+ seed = random.randint(0, 65535)
46
+ seed_everything(seed)
47
+
48
+ if config.save_memory:
49
+ model.low_vram_shift(is_diffusing=False)
50
+
51
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
52
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
53
+ shape = (4, H // 8, W // 8)
54
+
55
+ if config.save_memory:
56
+ model.low_vram_shift(is_diffusing=True)
57
+
58
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
59
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
60
+ shape, cond, verbose=False, eta=eta,
61
+ unconditional_guidance_scale=scale,
62
+ unconditional_conditioning=un_cond)
63
+
64
+ if config.save_memory:
65
+ model.low_vram_shift(is_diffusing=False)
66
+
67
+ x_samples = model.decode_first_stage(samples)
68
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
69
+
70
+ results = [x_samples[i] for i in range(num_samples)]
71
+ return [255 - detected_map] + results
72
+
73
+
74
+ block = gr.Blocks().queue()
75
+ with block:
76
+ with gr.Row():
77
+ gr.Markdown("## Control Stable Diffusion with Fake Scribble Maps")
78
+ with gr.Row():
79
+ with gr.Column():
80
+ input_image = gr.Image(source='upload', type="numpy")
81
+ prompt = gr.Textbox(label="Prompt")
82
+ run_button = gr.Button(label="Run")
83
+ with gr.Accordion("Advanced options", open=False):
84
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
85
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
86
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
87
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
88
+ detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
89
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
90
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
91
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
92
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
93
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
94
+ n_prompt = gr.Textbox(label="Negative Prompt",
95
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
96
+ with gr.Column():
97
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
98
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
99
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
100
+
101
+
102
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_hed2image.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.hed import HEDdetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_hed = HEDdetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
30
+ detected_map = HWC3(detected_map)
31
+ img = resize_image(input_image, image_resolution)
32
+ H, W, C = img.shape
33
+
34
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
35
+
36
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
37
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
38
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
39
+
40
+ if seed == -1:
41
+ seed = random.randint(0, 65535)
42
+ seed_everything(seed)
43
+
44
+ if config.save_memory:
45
+ model.low_vram_shift(is_diffusing=False)
46
+
47
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
48
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
49
+ shape = (4, H // 8, W // 8)
50
+
51
+ if config.save_memory:
52
+ model.low_vram_shift(is_diffusing=True)
53
+
54
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
55
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
56
+ shape, cond, verbose=False, eta=eta,
57
+ unconditional_guidance_scale=scale,
58
+ unconditional_conditioning=un_cond)
59
+
60
+ if config.save_memory:
61
+ model.low_vram_shift(is_diffusing=False)
62
+
63
+ x_samples = model.decode_first_stage(samples)
64
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
65
+
66
+ results = [x_samples[i] for i in range(num_samples)]
67
+ return [detected_map] + results
68
+
69
+
70
+ block = gr.Blocks().queue()
71
+ with block:
72
+ with gr.Row():
73
+ gr.Markdown("## Control Stable Diffusion with HED Maps")
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image = gr.Image(source='upload', type="numpy")
77
+ prompt = gr.Textbox(label="Prompt")
78
+ run_button = gr.Button(label="Run")
79
+ with gr.Accordion("Advanced options", open=False):
80
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
81
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
82
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
83
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
84
+ detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
85
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
86
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
87
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
88
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
89
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
90
+ n_prompt = gr.Textbox(label="Negative Prompt",
91
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
92
+ with gr.Column():
93
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
94
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
95
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
96
+
97
+
98
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_hough2image.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.mlsd import MLSDdetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_mlsd = MLSDdetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold)
30
+ detected_map = HWC3(detected_map)
31
+ img = resize_image(input_image, image_resolution)
32
+ H, W, C = img.shape
33
+
34
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
35
+
36
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
37
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
38
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
39
+
40
+ if seed == -1:
41
+ seed = random.randint(0, 65535)
42
+ seed_everything(seed)
43
+
44
+ if config.save_memory:
45
+ model.low_vram_shift(is_diffusing=False)
46
+
47
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
48
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
49
+ shape = (4, H // 8, W // 8)
50
+
51
+ if config.save_memory:
52
+ model.low_vram_shift(is_diffusing=True)
53
+
54
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
55
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
56
+ shape, cond, verbose=False, eta=eta,
57
+ unconditional_guidance_scale=scale,
58
+ unconditional_conditioning=un_cond)
59
+
60
+ if config.save_memory:
61
+ model.low_vram_shift(is_diffusing=False)
62
+
63
+ x_samples = model.decode_first_stage(samples)
64
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
65
+
66
+ results = [x_samples[i] for i in range(num_samples)]
67
+ return [255 - cv2.dilate(detected_map, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)] + results
68
+
69
+
70
+ block = gr.Blocks().queue()
71
+ with block:
72
+ with gr.Row():
73
+ gr.Markdown("## Control Stable Diffusion with Hough Line Maps")
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image = gr.Image(source='upload', type="numpy")
77
+ prompt = gr.Textbox(label="Prompt")
78
+ run_button = gr.Button(label="Run")
79
+ with gr.Accordion("Advanced options", open=False):
80
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
81
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
82
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
83
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
84
+ detect_resolution = gr.Slider(label="Hough Resolution", minimum=128, maximum=1024, value=512, step=1)
85
+ value_threshold = gr.Slider(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
86
+ distance_threshold = gr.Slider(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
87
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
88
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
89
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
90
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
91
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
92
+ n_prompt = gr.Textbox(label="Negative Prompt",
93
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
94
+ with gr.Column():
95
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
96
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold]
97
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
98
+
99
+
100
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_normal2image.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.midas import MidasDetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_midas = MidasDetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ _, detected_map = apply_midas(resize_image(input_image, detect_resolution), bg_th=bg_threshold)
30
+ detected_map = HWC3(detected_map)
31
+ img = resize_image(input_image, image_resolution)
32
+ H, W, C = img.shape
33
+
34
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
35
+
36
+ control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0
37
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
38
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
39
+
40
+ if seed == -1:
41
+ seed = random.randint(0, 65535)
42
+ seed_everything(seed)
43
+
44
+ if config.save_memory:
45
+ model.low_vram_shift(is_diffusing=False)
46
+
47
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
48
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
49
+ shape = (4, H // 8, W // 8)
50
+
51
+ if config.save_memory:
52
+ model.low_vram_shift(is_diffusing=True)
53
+
54
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
55
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
56
+ shape, cond, verbose=False, eta=eta,
57
+ unconditional_guidance_scale=scale,
58
+ unconditional_conditioning=un_cond)
59
+
60
+ if config.save_memory:
61
+ model.low_vram_shift(is_diffusing=False)
62
+
63
+ x_samples = model.decode_first_stage(samples)
64
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
65
+
66
+ results = [x_samples[i] for i in range(num_samples)]
67
+ return [detected_map] + results
68
+
69
+
70
+ block = gr.Blocks().queue()
71
+ with block:
72
+ with gr.Row():
73
+ gr.Markdown("## Control Stable Diffusion with Normal Maps")
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image = gr.Image(source='upload', type="numpy")
77
+ prompt = gr.Textbox(label="Prompt")
78
+ run_button = gr.Button(label="Run")
79
+ with gr.Accordion("Advanced options", open=False):
80
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
81
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
82
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
83
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
84
+ detect_resolution = gr.Slider(label="Normal Resolution", minimum=128, maximum=1024, value=384, step=1)
85
+ bg_threshold = gr.Slider(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01)
86
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
87
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
88
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
89
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
90
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
91
+ n_prompt = gr.Textbox(label="Negative Prompt",
92
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
93
+ with gr.Column():
94
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
95
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold]
96
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
97
+
98
+
99
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_pose2image.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.openpose import OpenposeDetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_openpose = OpenposeDetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
30
+ detected_map = HWC3(detected_map)
31
+ img = resize_image(input_image, image_resolution)
32
+ H, W, C = img.shape
33
+
34
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
35
+
36
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
37
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
38
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
39
+
40
+ if seed == -1:
41
+ seed = random.randint(0, 65535)
42
+ seed_everything(seed)
43
+
44
+ if config.save_memory:
45
+ model.low_vram_shift(is_diffusing=False)
46
+
47
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
48
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
49
+ shape = (4, H // 8, W // 8)
50
+
51
+ if config.save_memory:
52
+ model.low_vram_shift(is_diffusing=True)
53
+
54
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
55
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
56
+ shape, cond, verbose=False, eta=eta,
57
+ unconditional_guidance_scale=scale,
58
+ unconditional_conditioning=un_cond)
59
+
60
+ if config.save_memory:
61
+ model.low_vram_shift(is_diffusing=False)
62
+
63
+ x_samples = model.decode_first_stage(samples)
64
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
65
+
66
+ results = [x_samples[i] for i in range(num_samples)]
67
+ return [detected_map] + results
68
+
69
+
70
+ block = gr.Blocks().queue()
71
+ with block:
72
+ with gr.Row():
73
+ gr.Markdown("## Control Stable Diffusion with Human Pose")
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image = gr.Image(source='upload', type="numpy")
77
+ prompt = gr.Textbox(label="Prompt")
78
+ run_button = gr.Button(label="Run")
79
+ with gr.Accordion("Advanced options", open=False):
80
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
81
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
82
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
83
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
84
+ detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1)
85
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
86
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
87
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
88
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
89
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
90
+ n_prompt = gr.Textbox(label="Negative Prompt",
91
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
92
+ with gr.Column():
93
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
94
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
95
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
96
+
97
+
98
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_scribble2image.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from cldm.model import create_model, load_state_dict
14
+ from cldm.ddim_hacked import DDIMSampler
15
+
16
+
17
+ model = create_model('./models/cldm_v15.yaml').cpu()
18
+ model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
19
+ model = model.cuda()
20
+ ddim_sampler = DDIMSampler(model)
21
+
22
+
23
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
24
+ with torch.no_grad():
25
+ img = resize_image(HWC3(input_image), image_resolution)
26
+ H, W, C = img.shape
27
+
28
+ detected_map = np.zeros_like(img, dtype=np.uint8)
29
+ detected_map[np.min(img, axis=2) < 127] = 255
30
+
31
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
32
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
33
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
34
+
35
+ if seed == -1:
36
+ seed = random.randint(0, 65535)
37
+ seed_everything(seed)
38
+
39
+ if config.save_memory:
40
+ model.low_vram_shift(is_diffusing=False)
41
+
42
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
43
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
44
+ shape = (4, H // 8, W // 8)
45
+
46
+ if config.save_memory:
47
+ model.low_vram_shift(is_diffusing=True)
48
+
49
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
50
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
51
+ shape, cond, verbose=False, eta=eta,
52
+ unconditional_guidance_scale=scale,
53
+ unconditional_conditioning=un_cond)
54
+
55
+ if config.save_memory:
56
+ model.low_vram_shift(is_diffusing=False)
57
+
58
+ x_samples = model.decode_first_stage(samples)
59
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
60
+
61
+ results = [x_samples[i] for i in range(num_samples)]
62
+ return [255 - detected_map] + results
63
+
64
+
65
+ block = gr.Blocks().queue()
66
+ with block:
67
+ with gr.Row():
68
+ gr.Markdown("## Control Stable Diffusion with Scribble Maps")
69
+ with gr.Row():
70
+ with gr.Column():
71
+ input_image = gr.Image(source='upload', type="numpy")
72
+ prompt = gr.Textbox(label="Prompt")
73
+ run_button = gr.Button(label="Run")
74
+ with gr.Accordion("Advanced options", open=False):
75
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
76
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
77
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
78
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
79
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
80
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
81
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
82
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
83
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
84
+ n_prompt = gr.Textbox(label="Negative Prompt",
85
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
86
+ with gr.Column():
87
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
88
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
89
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
90
+
91
+
92
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_scribble2image_interactive.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from cldm.model import create_model, load_state_dict
14
+ from cldm.ddim_hacked import DDIMSampler
15
+
16
+
17
+ model = create_model('./models/cldm_v15.yaml').cpu()
18
+ model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
19
+ model = model.cuda()
20
+ ddim_sampler = DDIMSampler(model)
21
+
22
+
23
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
24
+ with torch.no_grad():
25
+ img = resize_image(HWC3(input_image['mask'][:, :, 0]), image_resolution)
26
+ H, W, C = img.shape
27
+
28
+ detected_map = np.zeros_like(img, dtype=np.uint8)
29
+ detected_map[np.min(img, axis=2) > 127] = 255
30
+
31
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
32
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
33
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
34
+
35
+ if seed == -1:
36
+ seed = random.randint(0, 65535)
37
+ seed_everything(seed)
38
+
39
+ if config.save_memory:
40
+ model.low_vram_shift(is_diffusing=False)
41
+
42
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
43
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
44
+ shape = (4, H // 8, W // 8)
45
+
46
+ if config.save_memory:
47
+ model.low_vram_shift(is_diffusing=True)
48
+
49
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
50
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
51
+ shape, cond, verbose=False, eta=eta,
52
+ unconditional_guidance_scale=scale,
53
+ unconditional_conditioning=un_cond)
54
+
55
+ if config.save_memory:
56
+ model.low_vram_shift(is_diffusing=False)
57
+
58
+ x_samples = model.decode_first_stage(samples)
59
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
60
+
61
+ results = [x_samples[i] for i in range(num_samples)]
62
+ return [255 - detected_map] + results
63
+
64
+
65
+ def create_canvas(w, h):
66
+ return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
67
+
68
+
69
+ block = gr.Blocks().queue()
70
+ with block:
71
+ with gr.Row():
72
+ gr.Markdown("## Control Stable Diffusion with Interactive Scribbles")
73
+ with gr.Row():
74
+ with gr.Column():
75
+ canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=1)
76
+ canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=1)
77
+ create_button = gr.Button(label="Start", value='Open drawing canvas!')
78
+ input_image = gr.Image(source='upload', type='numpy', tool='sketch')
79
+ gr.Markdown(value='Do not forget to change your brush width to make it thinner. (Gradio do not allow developers to set brush width so you need to do it manually.) '
80
+ 'Just click on the small pencil icon in the upper right corner of the above block.')
81
+ create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
82
+ prompt = gr.Textbox(label="Prompt")
83
+ run_button = gr.Button(label="Run")
84
+ with gr.Accordion("Advanced options", open=False):
85
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
86
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
87
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
88
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
89
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
90
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
91
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
92
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
93
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
94
+ n_prompt = gr.Textbox(label="Negative Prompt",
95
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
96
+ with gr.Column():
97
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
98
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
99
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
100
+
101
+
102
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/gradio_seg2image.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.uniformer import UniformerDetector
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_hacked import DDIMSampler
16
+
17
+
18
+ apply_uniformer = UniformerDetector()
19
+
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict('./models/control_sd15_seg.pth', location='cuda'))
22
+ model = model.cuda()
23
+ ddim_sampler = DDIMSampler(model)
24
+
25
+
26
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
27
+ with torch.no_grad():
28
+ input_image = HWC3(input_image)
29
+ detected_map = apply_uniformer(resize_image(input_image, detect_resolution))
30
+ img = resize_image(input_image, image_resolution)
31
+ H, W, C = img.shape
32
+
33
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
34
+
35
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
36
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
37
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
38
+
39
+ if seed == -1:
40
+ seed = random.randint(0, 65535)
41
+ seed_everything(seed)
42
+
43
+ if config.save_memory:
44
+ model.low_vram_shift(is_diffusing=False)
45
+
46
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
47
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
48
+ shape = (4, H // 8, W // 8)
49
+
50
+ if config.save_memory:
51
+ model.low_vram_shift(is_diffusing=True)
52
+
53
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
54
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
55
+ shape, cond, verbose=False, eta=eta,
56
+ unconditional_guidance_scale=scale,
57
+ unconditional_conditioning=un_cond)
58
+
59
+ if config.save_memory:
60
+ model.low_vram_shift(is_diffusing=False)
61
+
62
+ x_samples = model.decode_first_stage(samples)
63
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
64
+
65
+ results = [x_samples[i] for i in range(num_samples)]
66
+ return [detected_map] + results
67
+
68
+
69
+ block = gr.Blocks().queue()
70
+ with block:
71
+ with gr.Row():
72
+ gr.Markdown("## Control Stable Diffusion with Segmentation Maps")
73
+ with gr.Row():
74
+ with gr.Column():
75
+ input_image = gr.Image(source='upload', type="numpy")
76
+ prompt = gr.Textbox(label="Prompt")
77
+ run_button = gr.Button(label="Run")
78
+ with gr.Accordion("Advanced options", open=False):
79
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
80
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
81
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
82
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
83
+ detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1)
84
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
85
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
86
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
87
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
88
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
89
+ n_prompt = gr.Textbox(label="Negative Prompt",
90
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
91
+ with gr.Column():
92
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
93
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
94
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
95
+
96
+
97
+ block.launch(server_name='0.0.0.0')
generation/control/ControlNet/ldm/data/__init__.py ADDED
File without changes
generation/control/ControlNet/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
7
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from ldm.util import instantiate_from_config
10
+ from ldm.modules.ema import LitEma
11
+
12
+
13
+ class AutoencoderKL(pl.LightningModule):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels)==int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0. < ema_decay < 1.
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116
+ last_layer=self.get_last_layer(), split="train")
117
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119
+ return aeloss
120
+
121
+ if optimizer_idx == 1:
122
+ # train the discriminator
123
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124
+ last_layer=self.get_last_layer(), split="train")
125
+
126
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128
+ return discloss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+ log_dict = self._validation_step(batch, batch_idx)
132
+ with self.ema_scope():
133
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134
+ return log_dict
135
+
136
+ def _validation_step(self, batch, batch_idx, postfix=""):
137
+ inputs = self.get_input(batch, self.image_key)
138
+ reconstructions, posterior = self(inputs)
139
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140
+ last_layer=self.get_last_layer(), split="val"+postfix)
141
+
142
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143
+ last_layer=self.get_last_layer(), split="val"+postfix)
144
+
145
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146
+ self.log_dict(log_dict_ae)
147
+ self.log_dict(log_dict_disc)
148
+ return self.log_dict
149
+
150
+ def configure_optimizers(self):
151
+ lr = self.learning_rate
152
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154
+ if self.learn_logvar:
155
+ print(f"{self.__class__.__name__}: Learning logvar")
156
+ ae_params_list.append(self.loss.logvar)
157
+ opt_ae = torch.optim.Adam(ae_params_list,
158
+ lr=lr, betas=(0.5, 0.9))
159
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160
+ lr=lr, betas=(0.5, 0.9))
161
+ return [opt_ae, opt_disc], []
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.conv_out.weight
165
+
166
+ @torch.no_grad()
167
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168
+ log = dict()
169
+ x = self.get_input(batch, self.image_key)
170
+ x = x.to(self.device)
171
+ if not only_inputs:
172
+ xrec, posterior = self(x)
173
+ if x.shape[1] > 3:
174
+ # colorize with random projection
175
+ assert xrec.shape[1] > 3
176
+ x = self.to_rgb(x)
177
+ xrec = self.to_rgb(xrec)
178
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179
+ log["reconstructions"] = xrec
180
+ if log_ema or self.use_ema:
181
+ with self.ema_scope():
182
+ xrec_ema, posterior_ema = self(x)
183
+ if x.shape[1] > 3:
184
+ # colorize with random projection
185
+ assert xrec_ema.shape[1] > 3
186
+ xrec_ema = self.to_rgb(xrec_ema)
187
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188
+ log["reconstructions_ema"] = xrec_ema
189
+ log["inputs"] = x
190
+ return log
191
+
192
+ def to_rgb(self, x):
193
+ assert self.image_key == "segmentation"
194
+ if not hasattr(self, "colorize"):
195
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196
+ x = F.conv2d(x, weight=self.colorize)
197
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198
+ return x
199
+
200
+
201
+ class IdentityFirstStage(torch.nn.Module):
202
+ def __init__(self, *args, vq_interface=False, **kwargs):
203
+ self.vq_interface = vq_interface
204
+ super().__init__()
205
+
206
+ def encode(self, x, *args, **kwargs):
207
+ return x
208
+
209
+ def decode(self, x, *args, **kwargs):
210
+ return x
211
+
212
+ def quantize(self, x, *args, **kwargs):
213
+ if self.vq_interface:
214
+ return x, None, [None, None, None]
215
+ return x
216
+
217
+ def forward(self, x, *args, **kwargs):
218
+ return x
219
+
generation/control/ControlNet/ldm/models/diffusion/__init__.py ADDED
File without changes
generation/control/ControlNet/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ x_in = torch.cat([x] * 2)
191
+ t_in = torch.cat([t] * 2)
192
+ if isinstance(c, dict):
193
+ assert isinstance(unconditional_conditioning, dict)
194
+ c_in = dict()
195
+ for k in c:
196
+ if isinstance(c[k], list):
197
+ c_in[k] = [torch.cat([
198
+ unconditional_conditioning[k][i],
199
+ c[k][i]]) for i in range(len(c[k]))]
200
+ else:
201
+ c_in[k] = torch.cat([
202
+ unconditional_conditioning[k],
203
+ c[k]])
204
+ elif isinstance(c, list):
205
+ c_in = list()
206
+ assert isinstance(unconditional_conditioning, list)
207
+ for i in range(len(c)):
208
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
209
+ else:
210
+ c_in = torch.cat([unconditional_conditioning, c])
211
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
212
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
213
+
214
+ if self.model.parameterization == "v":
215
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
216
+ else:
217
+ e_t = model_output
218
+
219
+ if score_corrector is not None:
220
+ assert self.model.parameterization == "eps", 'not implemented'
221
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
222
+
223
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
224
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
225
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
226
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
227
+ # select parameters corresponding to the currently considered timestep
228
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
229
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
230
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
231
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
232
+
233
+ # current prediction for x_0
234
+ if self.model.parameterization != "v":
235
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
236
+ else:
237
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
238
+
239
+ if quantize_denoised:
240
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
241
+
242
+ if dynamic_threshold is not None:
243
+ raise NotImplementedError()
244
+
245
+ # direction pointing to x_t
246
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
247
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
248
+ if noise_dropout > 0.:
249
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
250
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
251
+ return x_prev, pred_x0
252
+
253
+ @torch.no_grad()
254
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
255
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
256
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
257
+
258
+ assert t_enc <= num_reference_steps
259
+ num_steps = t_enc
260
+
261
+ if use_original_steps:
262
+ alphas_next = self.alphas_cumprod[:num_steps]
263
+ alphas = self.alphas_cumprod_prev[:num_steps]
264
+ else:
265
+ alphas_next = self.ddim_alphas[:num_steps]
266
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
267
+
268
+ x_next = x0
269
+ intermediates = []
270
+ inter_steps = []
271
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
272
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
273
+ if unconditional_guidance_scale == 1.:
274
+ noise_pred = self.model.apply_model(x_next, t, c)
275
+ else:
276
+ assert unconditional_conditioning is not None
277
+ e_t_uncond, noise_pred = torch.chunk(
278
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
279
+ torch.cat((unconditional_conditioning, c))), 2)
280
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
281
+
282
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
283
+ weighted_noise_pred = alphas_next[i].sqrt() * (
284
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
285
+ x_next = xt_weighted + weighted_noise_pred
286
+ if return_intermediates and i % (
287
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
288
+ intermediates.append(x_next)
289
+ inter_steps.append(i)
290
+ elif return_intermediates and i >= num_steps - 2:
291
+ intermediates.append(x_next)
292
+ inter_steps.append(i)
293
+ if callback: callback(i)
294
+
295
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
296
+ if return_intermediates:
297
+ out.update({'intermediates': intermediates})
298
+ return x_next, out
299
+
300
+ @torch.no_grad()
301
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
302
+ # fast, but does not allow for exact reconstruction
303
+ # t serves as an index to gather the correct alphas
304
+ if use_original_steps:
305
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
306
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
307
+ else:
308
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
309
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
310
+
311
+ if noise is None:
312
+ noise = torch.randn_like(x0)
313
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
314
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
315
+
316
+ @torch.no_grad()
317
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
318
+ use_original_steps=False, callback=None):
319
+
320
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
321
+ timesteps = timesteps[:t_start]
322
+
323
+ time_range = np.flip(timesteps)
324
+ total_steps = timesteps.shape[0]
325
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
326
+
327
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
328
+ x_dec = x_latent
329
+ for i, step in enumerate(iterator):
330
+ index = total_steps - i - 1
331
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
332
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
333
+ unconditional_guidance_scale=unconditional_guidance_scale,
334
+ unconditional_conditioning=unconditional_conditioning)
335
+ if callback: callback(i)
336
+ return x_dec
generation/control/ControlNet/ldm/util.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+
7
+ from inspect import isfunction
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ def log_txt_as_img(wh, xc, size=10):
12
+ # wh a tuple of (width, height)
13
+ # xc a list of captions to plot
14
+ b = len(xc)
15
+ txts = list()
16
+ for bi in range(b):
17
+ txt = Image.new("RGB", wh, color="white")
18
+ draw = ImageDraw.Draw(txt)
19
+ font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
20
+ nc = int(40 * (wh[0] / 256))
21
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22
+
23
+ try:
24
+ draw.text((0, 0), lines, fill="black", font=font)
25
+ except UnicodeEncodeError:
26
+ print("Cant encode string for logging. Skipping.")
27
+
28
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29
+ txts.append(txt)
30
+ txts = np.stack(txts)
31
+ txts = torch.tensor(txts)
32
+ return txts
33
+
34
+
35
+ def ismap(x):
36
+ if not isinstance(x, torch.Tensor):
37
+ return False
38
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
39
+
40
+
41
+ def isimage(x):
42
+ if not isinstance(x,torch.Tensor):
43
+ return False
44
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45
+
46
+
47
+ def exists(x):
48
+ return x is not None
49
+
50
+
51
+ def default(val, d):
52
+ if exists(val):
53
+ return val
54
+ return d() if isfunction(d) else d
55
+
56
+
57
+ def mean_flat(tensor):
58
+ """
59
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60
+ Take the mean over all non-batch dimensions.
61
+ """
62
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
63
+
64
+
65
+ def count_params(model, verbose=False):
66
+ total_params = sum(p.numel() for p in model.parameters())
67
+ if verbose:
68
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69
+ return total_params
70
+
71
+
72
+ def instantiate_from_config(config):
73
+ if not "target" in config:
74
+ if config == '__is_first_stage__':
75
+ return None
76
+ elif config == "__is_unconditional__":
77
+ return None
78
+ raise KeyError("Expected key `target` to instantiate.")
79
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
80
+
81
+
82
+ def get_obj_from_str(string, reload=False):
83
+ module, cls = string.rsplit(".", 1)
84
+ if reload:
85
+ module_imp = importlib.import_module(module)
86
+ importlib.reload(module_imp)
87
+ return getattr(importlib.import_module(module, package=None), cls)
88
+
89
+
90
+ class AdamWwithEMAandWings(optim.Optimizer):
91
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94
+ ema_power=1., param_names=()):
95
+ """AdamW that saves EMA versions of the parameters."""
96
+ if not 0.0 <= lr:
97
+ raise ValueError("Invalid learning rate: {}".format(lr))
98
+ if not 0.0 <= eps:
99
+ raise ValueError("Invalid epsilon value: {}".format(eps))
100
+ if not 0.0 <= betas[0] < 1.0:
101
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102
+ if not 0.0 <= betas[1] < 1.0:
103
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104
+ if not 0.0 <= weight_decay:
105
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106
+ if not 0.0 <= ema_decay <= 1.0:
107
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108
+ defaults = dict(lr=lr, betas=betas, eps=eps,
109
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110
+ ema_power=ema_power, param_names=param_names)
111
+ super().__init__(params, defaults)
112
+
113
+ def __setstate__(self, state):
114
+ super().__setstate__(state)
115
+ for group in self.param_groups:
116
+ group.setdefault('amsgrad', False)
117
+
118
+ @torch.no_grad()
119
+ def step(self, closure=None):
120
+ """Performs a single optimization step.
121
+ Args:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ with torch.enable_grad():
128
+ loss = closure()
129
+
130
+ for group in self.param_groups:
131
+ params_with_grad = []
132
+ grads = []
133
+ exp_avgs = []
134
+ exp_avg_sqs = []
135
+ ema_params_with_grad = []
136
+ state_sums = []
137
+ max_exp_avg_sqs = []
138
+ state_steps = []
139
+ amsgrad = group['amsgrad']
140
+ beta1, beta2 = group['betas']
141
+ ema_decay = group['ema_decay']
142
+ ema_power = group['ema_power']
143
+
144
+ for p in group['params']:
145
+ if p.grad is None:
146
+ continue
147
+ params_with_grad.append(p)
148
+ if p.grad.is_sparse:
149
+ raise RuntimeError('AdamW does not support sparse gradients')
150
+ grads.append(p.grad)
151
+
152
+ state = self.state[p]
153
+
154
+ # State initialization
155
+ if len(state) == 0:
156
+ state['step'] = 0
157
+ # Exponential moving average of gradient values
158
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159
+ # Exponential moving average of squared gradient values
160
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161
+ if amsgrad:
162
+ # Maintains max of all exp. moving avg. of sq. grad. values
163
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164
+ # Exponential moving average of parameter values
165
+ state['param_exp_avg'] = p.detach().float().clone()
166
+
167
+ exp_avgs.append(state['exp_avg'])
168
+ exp_avg_sqs.append(state['exp_avg_sq'])
169
+ ema_params_with_grad.append(state['param_exp_avg'])
170
+
171
+ if amsgrad:
172
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173
+
174
+ # update the steps for each param group update
175
+ state['step'] += 1
176
+ # record the step after step update
177
+ state_steps.append(state['step'])
178
+
179
+ optim._functional.adamw(params_with_grad,
180
+ grads,
181
+ exp_avgs,
182
+ exp_avg_sqs,
183
+ max_exp_avg_sqs,
184
+ state_steps,
185
+ amsgrad=amsgrad,
186
+ beta1=beta1,
187
+ beta2=beta2,
188
+ lr=group['lr'],
189
+ weight_decay=group['weight_decay'],
190
+ eps=group['eps'],
191
+ maximize=False)
192
+
193
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196
+
197
+ return loss
generation/control/ControlNet/share.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ from cldm.hack import disable_verbosity, enable_sliced_attention
3
+
4
+
5
+ disable_verbosity()
6
+
7
+ if config.save_memory:
8
+ enable_sliced_attention()
generation/control/ControlNet/tool_add_control.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ assert len(sys.argv) == 3, 'Args are wrong.'
5
+
6
+ input_path = sys.argv[1]
7
+ output_path = sys.argv[2]
8
+
9
+ assert os.path.exists(input_path), 'Input model does not exist.'
10
+ assert not os.path.exists(output_path), 'Output filename already exists.'
11
+ assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
12
+
13
+ import torch
14
+ from share import *
15
+ from cldm.model import create_model
16
+
17
+
18
+ def get_node_name(name, parent_name):
19
+ if len(name) <= len(parent_name):
20
+ return False, ''
21
+ p = name[:len(parent_name)]
22
+ if p != parent_name:
23
+ return False, ''
24
+ return True, name[len(parent_name):]
25
+
26
+
27
+ model = create_model(config_path='./models/cldm_v15.yaml')
28
+
29
+ pretrained_weights = torch.load(input_path)
30
+ if 'state_dict' in pretrained_weights:
31
+ pretrained_weights = pretrained_weights['state_dict']
32
+
33
+ scratch_dict = model.state_dict()
34
+
35
+ target_dict = {}
36
+ for k in scratch_dict.keys():
37
+ is_control, name = get_node_name(k, 'control_')
38
+ if is_control:
39
+ copy_k = 'model.diffusion_' + name
40
+ else:
41
+ copy_k = k
42
+ if copy_k in pretrained_weights:
43
+ target_dict[k] = pretrained_weights[copy_k].clone()
44
+ else:
45
+ target_dict[k] = scratch_dict[k].clone()
46
+ print(f'These weights are newly added: {k}')
47
+
48
+ model.load_state_dict(target_dict, strict=True)
49
+ torch.save(model.state_dict(), output_path)
50
+ print('Done.')
generation/control/ControlNet/tool_add_control_sd21.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ assert len(sys.argv) == 3, 'Args are wrong.'
5
+
6
+ input_path = sys.argv[1]
7
+ output_path = sys.argv[2]
8
+
9
+ assert os.path.exists(input_path), 'Input model does not exist.'
10
+ assert not os.path.exists(output_path), 'Output filename already exists.'
11
+ assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
12
+
13
+ import torch
14
+ from share import *
15
+ from cldm.model import create_model
16
+
17
+
18
+ def get_node_name(name, parent_name):
19
+ if len(name) <= len(parent_name):
20
+ return False, ''
21
+ p = name[:len(parent_name)]
22
+ if p != parent_name:
23
+ return False, ''
24
+ return True, name[len(parent_name):]
25
+
26
+
27
+ model = create_model(config_path='./models/cldm_v21.yaml')
28
+
29
+ pretrained_weights = torch.load(input_path)
30
+ if 'state_dict' in pretrained_weights:
31
+ pretrained_weights = pretrained_weights['state_dict']
32
+
33
+ scratch_dict = model.state_dict()
34
+
35
+ target_dict = {}
36
+ for k in scratch_dict.keys():
37
+ is_control, name = get_node_name(k, 'control_')
38
+ if is_control:
39
+ copy_k = 'model.diffusion_' + name
40
+ else:
41
+ copy_k = k
42
+ if copy_k in pretrained_weights:
43
+ target_dict[k] = pretrained_weights[copy_k].clone()
44
+ else:
45
+ target_dict[k] = scratch_dict[k].clone()
46
+ print(f'These weights are newly added: {k}')
47
+
48
+ model.load_state_dict(target_dict, strict=True)
49
+ torch.save(model.state_dict(), output_path)
50
+ print('Done.')
generation/control/ControlNet/tool_transfer_control.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path_sd15 = './models/v1-5-pruned.ckpt'
2
+ path_sd15_with_control = './models/control_sd15_openpose.pth'
3
+ path_input = './models/anything-v3-full.safetensors'
4
+ path_output = './models/control_any3_openpose.pth'
5
+
6
+
7
+ import os
8
+
9
+
10
+ assert os.path.exists(path_sd15), 'Input path_sd15 does not exists!'
11
+ assert os.path.exists(path_sd15_with_control), 'Input path_sd15_with_control does not exists!'
12
+ assert os.path.exists(path_input), 'Input path_input does not exists!'
13
+ assert os.path.exists(os.path.dirname(path_output)), 'Output folder not exists!'
14
+
15
+
16
+ import torch
17
+ from share import *
18
+ from cldm.model import load_state_dict
19
+
20
+
21
+ sd15_state_dict = load_state_dict(path_sd15)
22
+ sd15_with_control_state_dict = load_state_dict(path_sd15_with_control)
23
+ input_state_dict = load_state_dict(path_input)
24
+
25
+
26
+ def get_node_name(name, parent_name):
27
+ if len(name) <= len(parent_name):
28
+ return False, ''
29
+ p = name[:len(parent_name)]
30
+ if p != parent_name:
31
+ return False, ''
32
+ return True, name[len(parent_name):]
33
+
34
+
35
+ keys = sd15_with_control_state_dict.keys()
36
+
37
+ final_state_dict = {}
38
+ for key in keys:
39
+ is_first_stage, _ = get_node_name(key, 'first_stage_model')
40
+ is_cond_stage, _ = get_node_name(key, 'cond_stage_model')
41
+ if is_first_stage or is_cond_stage:
42
+ final_state_dict[key] = input_state_dict[key]
43
+ continue
44
+ p = sd15_with_control_state_dict[key]
45
+ is_control, node_name = get_node_name(key, 'control_')
46
+ if is_control:
47
+ sd15_key_name = 'model.diffusion_' + node_name
48
+ else:
49
+ sd15_key_name = key
50
+ if sd15_key_name in input_state_dict:
51
+ p_new = p + input_state_dict[sd15_key_name] - sd15_state_dict[sd15_key_name]
52
+ # print(f'Offset clone from [{sd15_key_name}] to [{key}]')
53
+ else:
54
+ p_new = p
55
+ # print(f'Direct clone to [{key}]')
56
+ final_state_dict[key] = p_new
57
+
58
+ torch.save(final_state_dict, path_output)
59
+ print('Transferred model saved at ' + path_output)
generation/control/ControlNet/tutorial_dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class MyDataset(Dataset):
9
+ def __init__(self):
10
+ self.data = []
11
+ with open('./training/fill50k/prompt.json', 'rt') as f:
12
+ for line in f:
13
+ self.data.append(json.loads(line))
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ item = self.data[idx]
20
+
21
+ source_filename = item['source']
22
+ target_filename = item['target']
23
+ prompt = item['prompt']
24
+
25
+ source = cv2.imread('./training/fill50k/' + source_filename)
26
+ target = cv2.imread('./training/fill50k/' + target_filename)
27
+
28
+ # Do not forget that OpenCV read images in BGR order.
29
+ source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
30
+ target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
31
+
32
+ # Normalize source images to [0, 1].
33
+ source = source.astype(np.float32) / 255.0
34
+
35
+ # Normalize target images to [-1, 1].
36
+ target = (target.astype(np.float32) / 127.5) - 1.0
37
+
38
+ return dict(jpg=target, txt=prompt, hint=source)
39
+
generation/control/ControlNet/tutorial_dataset_test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tutorial_dataset import MyDataset
2
+
3
+ dataset = MyDataset()
4
+ print(len(dataset))
5
+
6
+ item = dataset[1234]
7
+ jpg = item['jpg']
8
+ txt = item['txt']
9
+ hint = item['hint']
10
+ print(txt)
11
+ print(jpg.shape)
12
+ print(hint.shape)
generation/control/ControlNet/tutorial_train.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+ from tutorial_dataset import MyDataset
6
+ from cldm.logger import ImageLogger
7
+ from cldm.model import create_model, load_state_dict
8
+
9
+
10
+ # Configs
11
+ resume_path = './models/control_sd15_ini.ckpt'
12
+ batch_size = 4
13
+ logger_freq = 300
14
+ learning_rate = 1e-5
15
+ sd_locked = True
16
+ only_mid_control = False
17
+
18
+
19
+ # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
20
+ model = create_model('./models/cldm_v15.yaml').cpu()
21
+ model.load_state_dict(load_state_dict(resume_path, location='cpu'))
22
+ model.learning_rate = learning_rate
23
+ model.sd_locked = sd_locked
24
+ model.only_mid_control = only_mid_control
25
+
26
+
27
+ # Misc
28
+ dataset = MyDataset()
29
+ dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
30
+ logger = ImageLogger(batch_frequency=logger_freq)
31
+ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
32
+
33
+
34
+ # Train!
35
+ trainer.fit(model, dataloader)
generation/control/ControlNet/tutorial_train_sd21.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+ from tutorial_dataset import MyDataset
6
+ from cldm.logger import ImageLogger
7
+ from cldm.model import create_model, load_state_dict
8
+
9
+
10
+ # Configs
11
+ resume_path = './models/control_sd21_ini.ckpt'
12
+ batch_size = 4
13
+ logger_freq = 300
14
+ learning_rate = 1e-5
15
+ sd_locked = True
16
+ only_mid_control = False
17
+
18
+
19
+ # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
20
+ model = create_model('./models/cldm_v21.yaml').cpu()
21
+ model.load_state_dict(load_state_dict(resume_path, location='cpu'))
22
+ model.learning_rate = learning_rate
23
+ model.sd_locked = sd_locked
24
+ model.only_mid_control = only_mid_control
25
+
26
+
27
+ # Misc
28
+ dataset = MyDataset()
29
+ dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
30
+ logger = ImageLogger(batch_frequency=logger_freq)
31
+ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
32
+
33
+
34
+ # Train!
35
+ trainer.fit(model, dataloader)
generation/control/download_ade20k.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ ! -d "data" ]; then
4
+ mkdir data
5
+ fi
6
+ cd data
7
+
8
+ wget -O ade20k.zip https://keeper.mpdl.mpg.de/f/80b2fc97ffc3430c98de/?dl=1
9
+ unzip ade20k.zip
10
+ rm ade20k.zip
generation/control/download_celebhq.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ ! -d "data" ]; then
4
+ mkdir data
5
+ fi
6
+ cd data
7
+
8
+ wget -O celechq-text.zip https://keeper.mpdl.mpg.de/f/72c34a6017cb40b896e9/?dl=1
9
+ unzip celechq-text.zip
10
+ rm celechq-text.zip
generation/control/eval_canny.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from tqdm import tqdm
3
+ import sys
4
+ import os
5
+ import json
6
+ import cv2
7
+ import torch
8
+ from torch.autograd import Variable
9
+ import torchvision
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchvision.datasets import CocoCaptions
12
+ from torchvision.transforms import ToTensor
13
+ import torchvision.transforms as transforms
14
+ from torchvision.utils import save_image
15
+ import numpy as np
16
+
17
+ # from net_canny import Net
18
+ from ControlNet.annotator.canny import CannyDetector
19
+ from ControlNet.annotator.util import resize_image, HWC3
20
+
21
+ # from pytorch_msssim import ssim, ms_ssim, SSIM
22
+
23
+ class ResultFolderDataset(Dataset):
24
+ def __init__(self, data_dir, results_dir, n, transform=None):
25
+ self.data_dir = data_dir
26
+ self.results_dir = results_dir
27
+ self.n = n
28
+ self.image_paths = sorted([f for f in os.listdir(data_dir) if f.lower().endswith(('.png'))])
29
+ # self.image_paths = sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith('_{}.png'.format(n))])
30
+ self.transform = transform
31
+
32
+ def __len__(self):
33
+ return len(self.image_paths)
34
+
35
+ def __getitem__(self, idx):
36
+ image_name = self.image_paths[idx]
37
+ source_path = os.path.join(self.data_dir, image_name)
38
+
39
+ base_name = image_name.split('_')[1].split('.')[0] # Extract 'x' from 'image_x.png'
40
+ image_name2 = f'result_{base_name}_{self.n}.png'
41
+ result_path = os.path.join(self.results_dir, image_name2)
42
+
43
+ source_image = Image.open(source_path) #.convert('RGB')
44
+ result_image = Image.open(result_path) #.convert('RGB')
45
+ if self.transform:
46
+ source_image = self.transform(source_image)
47
+ result_image = self.transform(result_image)
48
+ return source_image, result_image
49
+
50
+
51
+ def calculate_metrics(pred, target):
52
+ intersection = np.logical_and(pred, target)
53
+ union = np.logical_or(pred, target)
54
+ iou_score = np.sum(intersection) / np.sum(union)
55
+
56
+ accuracy = np.sum(pred == target) / target.size
57
+
58
+ tp = np.sum(intersection) # True positive
59
+ fp = np.sum(pred) - tp # False positive
60
+ fn = np.sum(target) - tp # False negative
61
+
62
+ f1_score = (2 * tp) / (2 * tp + fp + fn)
63
+
64
+ return iou_score, accuracy, f1_score
65
+
66
+ if __name__ == '__main__':
67
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
68
+
69
+ low_threshold = 100
70
+ high_threshold = 200
71
+
72
+ n = 0
73
+ # epoch = 10
74
+ # experiment = './log/image_log_oft_COCO_canny_eps_1-3_r_4_cayley_4gpu'
75
+ experiment = 'log/image_log_householder_gramschmidt_COCO_canny_eps_7e-06_pe_diff_mlp_l_8_8gpu_2024-05-19-21-22-24-466032'
76
+
77
+ if 'train_with_norm' in experiment:
78
+ epoch = 4
79
+ else:
80
+ if 'COCO' in experiment:
81
+ epoch = 10
82
+ else:
83
+ epoch = 19
84
+
85
+ data_dir = os.path.join(experiment, 'source', str(epoch))
86
+ result_dir = os.path.join(experiment, 'results', str(epoch))
87
+ json_file = os.path.join(experiment, 'results.json')
88
+
89
+ # Define the transforms to apply to the images
90
+ transform = transforms.Compose([
91
+ # transforms.Resize((512, 512)),
92
+ # transforms.CenterCrop(512),
93
+ transforms.ToTensor()
94
+ ])
95
+
96
+ dataset = ResultFolderDataset(data_dir, result_dir, n=n, transform=transform)
97
+ data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
98
+
99
+ apply_canny = CannyDetector()
100
+
101
+ loss = 0
102
+ iou_score_mean = 0
103
+ accuracy_mean = 0
104
+ f1_score_mean = 0
105
+ ssim_mean = 0
106
+ for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
107
+ source_image, result_image = data
108
+ # Convert the tensor to a numpy array and transpose it to have the channels last (H, W, C)
109
+ source_image_np = source_image.squeeze(0).permute(1, 2, 0).numpy()
110
+ result_image_np = result_image.squeeze(0).permute(1, 2, 0).numpy()
111
+
112
+ # Convert the image to 8-bit unsigned integers (0-255)
113
+ source_image_np = (source_image_np * 255).astype(np.uint8)
114
+ result_image_np = (result_image_np * 255).astype(np.uint8)
115
+
116
+ source_detected_map = apply_canny(source_image_np, low_threshold, high_threshold) / 255
117
+ result_detected_map = apply_canny(result_image_np, low_threshold, high_threshold) / 255
118
+
119
+ iou_score, accuracy, f1_score = calculate_metrics(result_detected_map, source_detected_map)
120
+
121
+ iou_score_mean = iou_score_mean + iou_score
122
+ accuracy_mean = accuracy_mean + accuracy
123
+ f1_score_mean = f1_score_mean + f1_score
124
+
125
+ iou_score_mean = iou_score_mean / len(dataset)
126
+ accuracy_mean = accuracy_mean / len(dataset)
127
+ f1_score_mean = f1_score_mean / len(dataset)
128
+
129
+ print(experiment)
130
+ print('[Canny]', '[IOU]:', iou_score_mean, '[F1 Score]:', f1_score_mean)
generation/control/eval_landmark.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from skimage.io import imread, imsave
7
+ import cv2
8
+ import json
9
+
10
+ end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1
11
+ def plot_kpts(image, kpts, color = 'g'):
12
+ ''' Draw 68 key points
13
+ Args:
14
+ image: the input image
15
+ kpt: (68, 3).
16
+ '''
17
+ if color == 'r':
18
+ c = (255, 0, 0)
19
+ elif color == 'g':
20
+ c = (0, 255, 0)
21
+ elif color == 'b':
22
+ c = (255, 0, 0)
23
+ image = image.copy()
24
+ kpts = kpts.copy()
25
+ radius = max(int(min(image.shape[0], image.shape[1])/200), 1)
26
+ for i in range(kpts.shape[0]):
27
+ st = kpts[i, :2]
28
+ if kpts.shape[1]==4:
29
+ if kpts[i, 3] > 0.5:
30
+ c = (0, 255, 0)
31
+ else:
32
+ c = (0, 0, 255)
33
+ image = cv2.circle(image,(int(st[0]), int(st[1])), radius, c, radius*2)
34
+ if i in end_list:
35
+ continue
36
+ ed = kpts[i + 1, :2]
37
+ image = cv2.line(image, (int(st[0]), int(st[1])), (int(ed[0]), int(ed[1])), (255, 255, 255), radius)
38
+ return image
39
+
40
+ def plot_points(image, kpts, color = 'w'):
41
+ ''' Draw 68 key points
42
+ Args:
43
+ image: the input image
44
+ kpt: (n, 3).
45
+ '''
46
+ if color == 'r':
47
+ c = (255, 0, 0)
48
+ elif color == 'g':
49
+ c = (0, 255, 0)
50
+ elif color == 'b':
51
+ c = (0, 0, 255)
52
+ elif color == 'y':
53
+ c = (0, 255, 255)
54
+ elif color == 'w':
55
+ c = (255, 255, 255)
56
+ image = image.copy()
57
+ kpts = kpts.copy()
58
+ kpts = kpts.astype(np.int32)
59
+ radius = max(int(min(image.shape[0], image.shape[1])/200), 1)
60
+ for i in range(kpts.shape[0]):
61
+ st = kpts[i, :2]
62
+ image = cv2.circle(image,(int(st[0]), int(st[1])), radius, c, radius*2)
63
+ return image
64
+
65
+ def generate_landmark2d(inputpath, savepath, n=0, device='cuda:0', vis=False):
66
+ print(f'generate 2d landmarks')
67
+ os.makedirs(savepath, exist_ok=True)
68
+ import face_alignment
69
+ detect_model = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device, flip_input=False)
70
+
71
+ imagepath_list = glob(os.path.join(inputpath, '*_{}.png'.format(n)))
72
+ imagepath_list = sorted(imagepath_list)
73
+ for imagepath in tqdm(imagepath_list):
74
+ name = Path(imagepath).stem
75
+
76
+ image = imread(imagepath)[:,:,:3]
77
+ out = detect_model.get_landmarks(image)
78
+ if out is None:
79
+ continue
80
+ kpt = out[0].squeeze()
81
+ np.savetxt(os.path.join(savepath, f'{name}.txt'), kpt)
82
+ if vis:
83
+ image = cv2.imread(imagepath)
84
+ image_point = plot_kpts(image, kpt)
85
+ # check
86
+ cv2.imwrite(os.path.join(savepath, f'{name}_overlay.jpg'), image_point)
87
+ # background = np.zeros_like(image)
88
+ # cv2.imwrite(os.path.join(savepath, f'{name}_line.jpg'), plot_kpts(background, kpt))
89
+ # cv2.imwrite(os.path.join(savepath, f'{name}_point.jpg'), plot_points(background, kpt))
90
+ # exit()
91
+
92
+ def landmark_comparison(lmk_folder, gt_lmk_folder, n=0):
93
+ print(f'calculate reprojection error')
94
+ lmk_err = []
95
+ gt_lmk_folder = './data/celebhq-text/celeba-hq-landmark2d'
96
+ with open('./data/celebhq-text/prompt_val_blip_full.json', 'rt') as f: # fill50k, COCO
97
+ for line in f:
98
+ val_data = json.loads(line)
99
+ # for i in tqdm(range(2000)):
100
+ for i in tqdm(range(len(val_data))):
101
+ # import ipdb; ipdb.set_trace()
102
+ # line = val_data[n]
103
+ line = val_data[i]
104
+
105
+ img_name = line["image"][:-4]
106
+ lmk1_path = os.path.join(gt_lmk_folder, f'{img_name}.txt')
107
+
108
+ lmk1 = np.loadtxt(lmk1_path) / 2
109
+ lmk2_path = os.path.join(lmk_folder, f'result_{i}_{n}.txt')
110
+ if not os.path.exists(lmk2_path):
111
+ print(f'{lmk2_path} not exist')
112
+ continue
113
+ lmk2 = np.loadtxt(lmk2_path)
114
+ lmk_err.append(np.mean(np.linalg.norm(lmk1 - lmk2, axis=1)))
115
+ print(np.mean(lmk_err))
116
+ np.save(os.path.join(lmk_folder, 'lmk_err.npy'), lmk_err)
117
+
118
+ n = 0
119
+ epoch = 19
120
+ gt_lmk_folder = './data/celebhq-text/celeba-hq-landmark2d'
121
+ # input_folder = os.path.join('./data/image_log_opt_lora_CelebA_landmark_lr_5-6_pe_diff_mlp_r_4_cayley_4gpu/results', str(epoch))
122
+ input_folder = os.path.join('log/image_log_householder_none_ADE20K_segm_eps_7e-06_pe_diff_mlp_l_8_8gpu_2024-05-15-19-33-41-650524/results', str(epoch))
123
+ # input_folder = os.path.join('log/image_log_oft_CelebA_landmark_eps_0.001_pe_diff_mlp_r_4_8gpu_2024-03-21-19-07-34-175825/train_with_norm/results', str(epoch))
124
+ save_folder = os.path.join(input_folder, 'landmark')
125
+
126
+ generate_landmark2d(input_folder, save_folder, n, device='cuda:0', vis=False)
127
+ landmark_comparison(save_folder, gt_lmk_folder, n)
generation/control/generation.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from oldm.hack import disable_verbosity
2
+ disable_verbosity()
3
+
4
+ import sys
5
+ import os
6
+ import cv2
7
+ import einops
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ import random
12
+ import json
13
+ import argparse
14
+
15
+ file_path = os.path.abspath(__file__)
16
+ parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..')
17
+ if parent_dir not in sys.path:
18
+ sys.path.append(parent_dir)
19
+
20
+ from PIL import Image
21
+ from pytorch_lightning import seed_everything
22
+ from oldm.model import create_model, load_state_dict
23
+ from oldm.ddim_hacked import DDIMSampler
24
+ from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm
25
+ from hra import inject_trainable_hra
26
+ from lora import inject_trainable_lora
27
+
28
+ from dataset.utils import return_dataset
29
+
30
+
31
+ def process(input_image, prompt, hint_image, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
32
+ with torch.no_grad():
33
+ # img = resize_image(HWC3(input_image), image_resolution)
34
+ H, W, C = input_image.shape
35
+
36
+ #detected_map = apply_canny(input_image, low_threshold, high_threshold)
37
+ #detected_map = HWC3(detected_map)
38
+
39
+ # control = torch.from_numpy(hint_image.copy()).float().cuda() / 255.0
40
+ control = torch.from_numpy(hint_image.copy()).float().cuda()
41
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
42
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
43
+
44
+ if seed == -1:
45
+ seed = random.randint(0, 65535)
46
+ seed_everything(seed)
47
+
48
+ # if config.save_memory:
49
+ # model.low_vram_shift(is_diffusing=False)
50
+
51
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
52
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
53
+ shape = (4, H // 8, W // 8)
54
+
55
+ # if config.save_memory:
56
+ # model.low_vram_shift(is_diffusing=True)
57
+
58
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
59
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
60
+ shape, cond, verbose=False, eta=eta,
61
+ unconditional_guidance_scale=scale,
62
+ unconditional_conditioning=un_cond)
63
+
64
+ # if config.save_memory:
65
+ # model.low_vram_shift(is_diffusing=False)
66
+
67
+ x_samples = model.decode_first_stage(samples)
68
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
69
+
70
+ results = [x_samples[i] for i in range(num_samples)]
71
+ # return [255 - hint_image] + results
72
+ return [input_image] + [hint_image] + results
73
+
74
+ parser = argparse.ArgumentParser()
75
+
76
+ parser.add_argument('--d', type=int, help='the index of GPU', default=0)
77
+
78
+ # HRA
79
+ parser.add_argument('--hra_r', type=int, default=8)
80
+ parser.add_argument('--hra_apply_GS', action="store_true", default=False)
81
+
82
+ # OFT
83
+ parser.add_argument('--oft_r', type=int, default=4)
84
+ parser.add_argument('--oft_eps',
85
+ type=float,
86
+ choices=[1e-3, 2e-5, 7e-6],
87
+ default=7e-6,
88
+ )
89
+ parser.add_argument('--oft_coft', action="store_true", default=True)
90
+ parser.add_argument('--oft_block_share', action="store_true", default=False)
91
+ parser.add_argument('--img_ID', type=int, default=1)
92
+ parser.add_argument('--num_samples', type=int, default=1)
93
+ parser.add_argument('--batch', type=int, default=20)
94
+ parser.add_argument('--sd_locked', action="store_true", default=True)
95
+ parser.add_argument('--only_mid_control', action="store_true", default=False)
96
+ parser.add_argument('--num_gpus', type=int, default=8)
97
+ # parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
98
+ parser.add_argument('--time_str', type=str, default='2024-03-18-10-55-21-089985')
99
+ parser.add_argument('--epoch', type=int, default=19)
100
+ parser.add_argument('--control',
101
+ type=str,
102
+ help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]',
103
+ default="segm")
104
+
105
+ args = parser.parse_args()
106
+
107
+ if __name__ == '__main__':
108
+ # Configs
109
+ epoch = args.epoch
110
+ control = args.control
111
+ _, dataset, data_name, logger_freq, max_epochs = return_dataset(control, full=True)
112
+
113
+ # specify the experiment name
114
+ # experiment = './log/image_log_oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu'.format(data_name, control, args.eps, args.r, args.num_gpus)
115
+
116
+ num_gpus = args.num_gpus
117
+ time_str = args.time_str
118
+ # experiment = 'log/image_log_oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.eps, args.r, num_gpus, time_str)
119
+ experiment = './log/image_log_hra_0.0_ADE20K_segm_pe_diff_mlp_r_8_8gpu_2024-06-27-19-57-34-979197'
120
+ # experiment = './log/image_log_oft_ADE20K_segm_eps_0.001_pe_diff_mlp_r_4_8gpu_2024-03-25-21-04-17-549433/train_with_norm'
121
+
122
+ assert args.control in experiment
123
+
124
+ if 'train_with_norm' in experiment:
125
+ epoch = 4
126
+ else:
127
+ if 'COCO' in experiment:
128
+ epoch = 9
129
+ else:
130
+ epoch = 19
131
+
132
+ resume_path = os.path.join(experiment, f'model-epoch={epoch:02d}.ckpt')
133
+ sd_locked = args.sd_locked
134
+ only_mid_control = args.only_mid_control
135
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
136
+
137
+ # Result directory
138
+ result_dir = os.path.join(experiment, 'results', str(epoch))
139
+ os.makedirs(result_dir, exist_ok=True)
140
+ source_dir = os.path.join(experiment, 'source', str(epoch))
141
+ os.makedirs(source_dir, exist_ok=True)
142
+ hint_dir = os.path.join(experiment, 'hints', str(epoch))
143
+ os.makedirs(hint_dir, exist_ok=True)
144
+
145
+ model = create_model('./configs/oft_ldm_v15.yaml').cpu()
146
+ model.model.requires_grad_(False)
147
+
148
+ if 'hra' in experiment:
149
+ unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS)
150
+ elif 'lora' in experiment:
151
+ unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None)
152
+ else:
153
+ if 'train_with_norm' in experiment:
154
+ unet_opt_params, train_names = inject_trainable_oft_with_norm(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
155
+ else:
156
+ unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
157
+ # unet_lora_params, train_names = inject_trainable_oft_conv(model.model, r=args.r, eps=args.eps, is_coft=args.coft, block_share=args.block_share)
158
+ # unet_lora_params, train_names = inject_trainable_oft_extended(model.model, r=args.r, eps=args.eps, is_coft=args.coft, block_share=args.block_share)
159
+
160
+
161
+ model.load_state_dict(load_state_dict(resume_path, location='cuda'))
162
+ model = model.cuda()
163
+ ddim_sampler = DDIMSampler(model)
164
+
165
+ # pack = range(0, len(dataset), args.batch)
166
+ # formatted_data = {}
167
+ # for index in range(args.batch):
168
+ # # import ipdb; ipdb.set_trace()
169
+ # start_point = pack[args.img_ID]
170
+ # idx = start_point + index
171
+
172
+ # canny
173
+ # img_list = [378, 441, 0, 31, 115, 182, 59, 60, 66, 269, ]
174
+ # landmark
175
+ # img_list = [139, 179, 197, 144, 54, 71, 76, 98, 100, 277, ]
176
+ # segm
177
+ # img_list = [14, 667, 576, 1387, 1603, 1697, 987, 1830, 1232, 1881, ]
178
+
179
+ # for idx in img_list:
180
+
181
+ num_pack = len(dataset) // args.num_gpus
182
+ start_idx = args.d * num_pack
183
+ end_idx = (args.d + 1) * num_pack if args.d < args.num_gpus - 1 else len(dataset)
184
+
185
+ for idx in range(start_idx, end_idx):
186
+
187
+ data = dataset[idx]
188
+ input_image, prompt, hint = data['jpg'], data['txt'], data['hint']
189
+ # input_image, hint = input_image.to(device), hint.to(device)
190
+
191
+ if not os.path.exists(os.path.join(result_dir, f'result_{idx}_0.png')):
192
+ result_images = process(
193
+ input_image=input_image,
194
+ prompt=prompt,
195
+ hint_image=hint,
196
+ a_prompt="",
197
+ n_prompt="",
198
+ num_samples=args.num_samples,
199
+ image_resolution=512,
200
+ ddim_steps=50,
201
+ guess_mode=False,
202
+ strength=1,
203
+ scale=9.0,
204
+ seed=-1,
205
+ eta=0.0,
206
+ low_threshold=100,
207
+ high_threshold=200,
208
+ )
209
+ for i, image in enumerate(result_images):
210
+ if i == 0:
211
+ image = ((image + 1) * 127.5).clip(0, 255).astype(np.uint8)
212
+ pil_image = Image.fromarray(image)
213
+ output_path = os.path.join(source_dir, f'image_{idx}.png')
214
+ pil_image.save(output_path)
215
+ elif i == 1:
216
+ image = (image * 255).clip(0, 255).astype(np.uint8)
217
+ # Convert numpy array to PIL Image
218
+ pil_image = Image.fromarray(image)
219
+ # Save PIL Image to file
220
+ output_path = os.path.join(hint_dir, f'hint_{idx}.png')
221
+ pil_image.save(output_path)
222
+ else:
223
+ n = i - 2
224
+ # Convert numpy array to PIL Image
225
+ pil_image = Image.fromarray(image)
226
+ # Save PIL Image to file
227
+ output_path = os.path.join(result_dir, f'result_{idx}_{n}.png')
228
+ pil_image.save(output_path)
229
+
230
+ # formatted_data[f"item{idx}"] = {
231
+ # "image_name": f'result_{idx}.png',
232
+ # "prompt": prompt
233
+ # }
234
+
235
+ # with open(os.path.join(experiment, 'results_{}.json'.format(img_ID)), 'w') as f:
236
+ # json.dump(formatted_data, f)
237
+
238
+
generation/control/hra.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script utilizes code from lora available at:
3
+ https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
4
+
5
+ Original Author: Simo Ryu
6
+ License: Apache License 2.0
7
+ """
8
+
9
+
10
+ import json
11
+ import math
12
+ from itertools import groupby
13
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
14
+
15
+ import pickle
16
+
17
+ import numpy as np
18
+ import PIL
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ import matplotlib.pyplot as plt
24
+
25
+ try:
26
+ from safetensors.torch import safe_open
27
+ from safetensors.torch import save_file as safe_save
28
+
29
+ safetensors_available = True
30
+ except ImportError:
31
+ from .safe_open import safe_open
32
+
33
+ def safe_save(
34
+ tensors: Dict[str, torch.Tensor],
35
+ filename: str,
36
+ metadata: Optional[Dict[str, str]] = None,
37
+ ) -> None:
38
+ raise EnvironmentError(
39
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
40
+ )
41
+
42
+ safetensors_available = False
43
+
44
+
45
+ def project(R, eps):
46
+ I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
47
+ diff = R - I
48
+ norm_diff = torch.norm(diff)
49
+ if norm_diff <= eps:
50
+ return R
51
+ else:
52
+ return I + eps * (diff / norm_diff)
53
+
54
+ def project_batch(R, eps=1e-5):
55
+ # scaling factor for each of the smaller block matrix
56
+ eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
57
+ I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
58
+ diff = R - I
59
+ norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
60
+ mask = (norm_diff <= eps).bool()
61
+ out = torch.where(mask, R, I + eps * (diff / norm_diff))
62
+ return out
63
+
64
+
65
+ class HRAInjectedLinear(nn.Module):
66
+ def __init__(
67
+ self, in_features, out_features, bias=False, r=8, apply_GS=False,
68
+ ):
69
+ super().__init__()
70
+ self.in_features=in_features
71
+ self.out_features=out_features
72
+
73
+ self.r = r
74
+ self.apply_GS = apply_GS
75
+
76
+ half_u = torch.zeros(in_features, r // 2)
77
+ nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
78
+ self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True)
79
+
80
+ self.fixed_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
81
+
82
+ def forward(self, x):
83
+ orig_weight = self.fixed_linear.weight.data
84
+ if self.apply_GS:
85
+ weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)]
86
+ for i in range(1, self.r):
87
+ ui = self.hra_u[:, i].view(-1, 1)
88
+ for j in range(i):
89
+ ui = ui - (weight[j].t() @ ui) * weight[j]
90
+ weight.append((ui / ui.norm()).view(-1, 1))
91
+ weight = torch.cat(weight, dim=1)
92
+ new_weight = orig_weight @ (torch.eye(self.in_features, device=x.device) - 2 * weight @ weight.t())
93
+
94
+ else:
95
+ new_weight = orig_weight
96
+ hra_u_norm = self.hra_u / self.hra_u.norm(dim=0)
97
+ for i in range(self.r):
98
+ ui = hra_u_norm[:, i].view(-1, 1)
99
+ new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device) - 2 * ui @ ui.t())
100
+
101
+ out = nn.functional.linear(input=x, weight=new_weight, bias=self.fixed_linear.bias)
102
+ return out
103
+
104
+
105
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
106
+
107
+ UNET_CONV_TARGET_REPLACE = {"ResBlock"}
108
+
109
+ UNET_EXTENDED_TARGET_REPLACE = {"ResBlock", "CrossAttention", "Attention", "GEGLU"}
110
+
111
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
112
+
113
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
114
+
115
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
116
+
117
+ EMBED_FLAG = "<embed>"
118
+
119
+
120
+ def _find_children(
121
+ model,
122
+ search_class: List[Type[nn.Module]] = [nn.Linear],
123
+ ):
124
+ """
125
+ Find all modules of a certain class (or union of classes).
126
+ Returns all matching modules, along with the parent of those moduless and the
127
+ names they are referenced by.
128
+ """
129
+ result = []
130
+ for parent in model.modules():
131
+ for name, module in parent.named_children():
132
+ if any([isinstance(module, _class) for _class in search_class]):
133
+ result.append((parent, name, module)) # Append the result to the list
134
+
135
+ return result # Return the list instead of using 'yield'
136
+
137
+
138
+ def _find_modules_v2(
139
+ model,
140
+ ancestor_class: Optional[Set[str]] = None,
141
+ search_class: List[Type[nn.Module]] = [nn.Linear],
142
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
143
+ HRAInjectedLinear,
144
+ ],
145
+ ):
146
+ """
147
+ Find all modules of a certain class (or union of classes) that are direct or
148
+ indirect descendants of other modules of a certain class (or union of classes).
149
+ Returns all matching modules, along with the parent of those moduless and the
150
+ names they are referenced by.
151
+ """
152
+
153
+ # Get the targets we should replace all linears under
154
+ if ancestor_class is not None:
155
+ ancestors = (
156
+ module
157
+ for module in model.modules()
158
+ if module.__class__.__name__ in ancestor_class
159
+ )
160
+ else:
161
+ # the first modules is the most senior father class.
162
+ # this, incase you want to naively iterate over all modules.
163
+ for module in model.modules():
164
+ ancestor_class = module.__class__.__name__
165
+ break
166
+ ancestors = (
167
+ module
168
+ for module in model.modules()
169
+ if module.__class__.__name__ in ancestor_class
170
+ )
171
+
172
+ results = []
173
+ # For each target find every linear_class module that isn't a child of a HRAInjectedLinear
174
+ for ancestor in ancestors:
175
+ for fullname, module in ancestor.named_modules():
176
+ if any([isinstance(module, _class) for _class in search_class]):
177
+ # Find the direct parent if this is a descendant, not a child, of target
178
+ *path, name = fullname.split(".")
179
+ parent = ancestor
180
+ while path:
181
+ parent = parent.get_submodule(path.pop(0))
182
+ # Skip this linear if it's a child of a HRAInjectedLinear
183
+ if exclude_children_of and any(
184
+ [isinstance(parent, _class) for _class in exclude_children_of]
185
+ ):
186
+ continue
187
+ results.append((parent, name, module)) # Append the result to the list
188
+
189
+ return results # Return the list instead of using 'yield'
190
+
191
+ def _find_modules_old(
192
+ model,
193
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
194
+ search_class: List[Type[nn.Module]] = [nn.Linear],
195
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [HRAInjectedLinear],
196
+ ):
197
+ ret = []
198
+ for _module in model.modules():
199
+ if _module.__class__.__name__ in ancestor_class:
200
+
201
+ for name, _child_module in _module.named_modules():
202
+ if _child_module.__class__ in search_class:
203
+ ret.append((_module, name, _child_module))
204
+
205
+ return ret
206
+
207
+
208
+ _find_modules = _find_modules_v2
209
+ # _find_modules = _find_modules_old
210
+
211
+ def inject_trainable_hra(
212
+ model: nn.Module,
213
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
214
+ verbose: bool = False,
215
+ r: int = 8,
216
+ apply_GS: str = False,
217
+ ):
218
+ """
219
+ inject hra into model, and returns hra parameter groups.
220
+ """
221
+
222
+ require_grad_params = []
223
+ names = []
224
+
225
+ for _module, name, _child_module in _find_modules(
226
+ model, target_replace_module, search_class=[nn.Linear]
227
+ ):
228
+
229
+ weight = _child_module.weight
230
+ bias = _child_module.bias
231
+ if verbose:
232
+ print("HRA Injection : injecting hra into ", name)
233
+ print("HRA Injection : weight shape", weight.shape)
234
+ _tmp = HRAInjectedLinear(
235
+ _child_module.in_features,
236
+ _child_module.out_features,
237
+ _child_module.bias is not None,
238
+ r=r,
239
+ apply_GS=apply_GS,
240
+ )
241
+ _tmp.fixed_linear.weight = weight
242
+ if bias is not None:
243
+ _tmp.fixed_linear.bias = bias
244
+
245
+ # switch the module
246
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
247
+ _module._modules[name] = _tmp
248
+
249
+ require_grad_params.append(_module._modules[name].hra_u)
250
+ _module._modules[name].hra_u.requires_grad = True
251
+
252
+ names.append(name)
253
+
254
+ return require_grad_params, names
generation/control/tool_add_hra.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script utilizes code from ControlNet available at:
3
+ https://github.com/lllyasviel/ControlNet/blob/main/tool_add_control.py
4
+
5
+ Original Author: Lvmin Zhang
6
+ License: Apache License 2.0
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ os.environ['HF_HOME'] = '/tmp'
12
+
13
+ # assert len(sys.argv) == 3, 'Args are wrong.'
14
+
15
+ # input_path = sys.argv[1]
16
+ # output_path = sys.argv[2]
17
+
18
+ import torch
19
+ from oldm.hack import disable_verbosity
20
+ disable_verbosity()
21
+ from oldm.model import create_model
22
+ from hra import inject_trainable_hra
23
+
24
+ import argparse
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--input_path', type=str, default='./models/v1-5-pruned.ckpt')
28
+ parser.add_argument('--output_path', type=str, default='./models/hra_half_init_l_8.ckpt')
29
+ parser.add_argument('--r', type=int, default=8)
30
+ parser.add_argument('--apply_GS', action='store_true', default=False)
31
+ args = parser.parse_args()
32
+
33
+ # args.output_path = f'./models/hra_none_l_8.ckpt'
34
+
35
+ assert os.path.exists(args.input_path), 'Input model does not exist.'
36
+ # assert not os.path.exists(output_path), 'Output filename already exists.'
37
+ assert os.path.exists(os.path.dirname(args.output_path)), 'Output path is not valid.'
38
+
39
+ def get_node_name(name, parent_name):
40
+ if len(name) <= len(parent_name):
41
+ return False, ''
42
+ p = name[:len(parent_name)]
43
+ if p != parent_name:
44
+ return False, ''
45
+ return True, name[len(parent_name):]
46
+
47
+
48
+ model = create_model(config_path='./configs/oft_ldm_v15.yaml')
49
+ model.model.requires_grad_(False)
50
+
51
+ unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.r, apply_GS=args.apply_GS)
52
+
53
+ pretrained_weights = torch.load(args.input_path)
54
+ if 'state_dict' in pretrained_weights:
55
+ pretrained_weights = pretrained_weights['state_dict']
56
+
57
+ scratch_dict = model.state_dict()
58
+
59
+ target_dict = {}
60
+ names = []
61
+ for k in scratch_dict.keys():
62
+ names.append(k)
63
+
64
+ if k in pretrained_weights:
65
+ target_dict[k] = pretrained_weights[k].clone()
66
+ else:
67
+ if 'fixed_linear.' in k:
68
+ copy_k = k.replace('fixed_linear.', '')
69
+ target_dict[k] = pretrained_weights[copy_k].clone()
70
+ else:
71
+ target_dict[k] = scratch_dict[k].clone()
72
+ print(f'These weights are newly added: {k}')
73
+
74
+ with open('HRA_model_names.txt', 'w') as file:
75
+ for element in names:
76
+ file.write(element + '\n')
77
+
78
+ model.load_state_dict(target_dict, strict=True)
79
+ torch.save(model.state_dict(), args.output_path)
80
+ # print('没有保存模型')
81
+ print('Done.')
generation/control/train.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from oldm.hack import disable_verbosity
2
+ disable_verbosity()
3
+
4
+ import os
5
+ import sys
6
+ import torch
7
+ from datetime import datetime
8
+
9
+ file_path = os.path.abspath(__file__)
10
+ parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..')
11
+ if parent_dir not in sys.path:
12
+ sys.path.append(parent_dir)
13
+
14
+ import pytorch_lightning as pl
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+ from torch.utils.data import DataLoader
17
+ from oldm.logger import ImageLogger
18
+ from oldm.model import create_model, load_state_dict
19
+ from dataset.utils import return_dataset
20
+
21
+ from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm
22
+ from hra import inject_trainable_hra
23
+ from lora import inject_trainable_lora
24
+
25
+ import argparse
26
+
27
+
28
+ parser = argparse.ArgumentParser()
29
+
30
+ # HRA
31
+ parser.add_argument('--hra_r', type=int, default=8)
32
+ parser.add_argument('--hra_apply_GS', action='store_true', default=False)
33
+ parser.add_argument('--hra_lamb', type=float, default=0.0)
34
+
35
+ # OFT
36
+ parser.add_argument('--oft_r', type=int, default=4)
37
+ parser.add_argument('--oft_eps', type=float, default=7e-6)
38
+ parser.add_argument('--oft_coft', action="store_true", default=True)
39
+ parser.add_argument('--oft_block_share', action="store_true", default=False)
40
+
41
+ parser.add_argument('--batch_size', type=int, default=8)
42
+ parser.add_argument('--num_samples', type=int, default=1)
43
+ parser.add_argument('--plot_frequency', type=int, default=100)
44
+ parser.add_argument('--learning_rate', type=float, default=9e-4)
45
+ parser.add_argument('--sd_locked', action="store_true", default=True)
46
+ parser.add_argument('--only_mid_control', action="store_true", default=False)
47
+ parser.add_argument('--num_gpus', type=int, default=torch.cuda.device_count())
48
+ parser.add_argument('--resume_path',
49
+ type=str,
50
+ default='./models/hra_half_init_l_8.ckpt',
51
+ )
52
+ parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
53
+ parser.add_argument('--num_workers', type=int, default=8)
54
+ parser.add_argument('--control',
55
+ type=str,
56
+ help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]',
57
+ default="segm")
58
+
59
+ args = parser.parse_args()
60
+
61
+
62
+ if __name__ == "__main__":
63
+ # specify the control signal and dataset
64
+ control = args.control
65
+
66
+ # create dataset
67
+ train_dataset, val_dataset, data_name, logger_freq, max_epochs = return_dataset(control) # , n_samples=n_samples)
68
+
69
+ # Configs
70
+ resume_path = args.resume_path
71
+
72
+ batch_size = args.batch_size
73
+ num_samples = args.num_samples
74
+ plot_frequency = args.plot_frequency
75
+ learning_rate = args.learning_rate
76
+ sd_locked = args.sd_locked
77
+ only_mid_control = args.only_mid_control
78
+ num_gpus = args.num_gpus
79
+ time_str = args.time_str
80
+ num_workers = args.num_workers
81
+
82
+ for arg in vars(args):
83
+ print(f'{arg}: {getattr(args, arg)}')
84
+ print(f'data_name: {data_name}\nlogger_freq: {logger_freq}\nmax_epochs: {max_epochs}')
85
+
86
+ if 'oft' in args.resume_path:
87
+ experiment = 'oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.oft_eps, args.oft_r, num_gpus, time_str)
88
+ elif 'hra' in args.resume_path:
89
+ if args.hra_apply_GS:
90
+ experiment = 'hra_apply_GS_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.hra_r, num_gpus, time_str)
91
+ else:
92
+ experiment = 'hra_{}_{}_pe_diff_mlp_r_{}_lambda_{}_lr_{}_{}gpu_{}'.format(data_name, control, args.hra_r, args.hra_lamb, args.learning_rate, num_gpus, time_str)
93
+ elif 'lora' in args.resume_path:
94
+ experiment = 'lora_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.r, num_gpus, time_str)
95
+
96
+ # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
97
+ model = create_model('./configs/oft_ldm_v15.yaml').cpu()
98
+ model.model.requires_grad_(False)
99
+ print(f'Total parameters not requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == False])}')
100
+
101
+ # inject trainable oft parameters
102
+ if 'oft' in args.resume_path:
103
+ unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
104
+ elif 'hra' in args.resume_path:
105
+ unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS)
106
+ elif 'lora' in args.resume_path:
107
+ unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None)
108
+
109
+ print(f'Total parameters requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == True])}')
110
+
111
+ model.load_state_dict(load_state_dict(resume_path, location='cpu'))
112
+ model.learning_rate = learning_rate
113
+ model.sd_locked = sd_locked
114
+ model.only_mid_control = only_mid_control
115
+
116
+ checkpoint_callback = ModelCheckpoint(
117
+ dirpath='log/image_log_' + experiment,
118
+ filename='model-{epoch:02d}',
119
+ save_top_k=-1,
120
+ save_last=True,
121
+ every_n_epochs=1,
122
+ monitor=None, # No specific metric to monitor for saving
123
+ )
124
+
125
+ # Misc
126
+ train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
127
+ val_dataloader = DataLoader(val_dataset, num_workers=num_workers, batch_size=1, shuffle=False)
128
+
129
+ logger = ImageLogger(
130
+ val_dataloader=val_dataloader,
131
+ batch_frequency=logger_freq,
132
+ experiment=experiment,
133
+ plot_frequency=plot_frequency,
134
+ num_samples=num_samples,
135
+ )
136
+
137
+ trainer = pl.Trainer(
138
+ max_epochs=max_epochs,
139
+ gpus=num_gpus,
140
+ precision=32,
141
+ callbacks=[logger, checkpoint_callback],
142
+ )
143
+
144
+ # Train!
145
+ last_model_path = 'log/image_log_' + experiment + '/last.ckpt'
146
+ if os.path.exists(last_model_path):
147
+ trainer.fit(model, train_dataloader, ckpt_path=last_model_path)
148
+ else:
149
+ trainer.fit(model, train_dataloader)
generation/env.yml ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: generation
2
+ dependencies:
3
+ - _libgcc_mutex=0.1=main
4
+ - _openmp_mutex=5.1=1_gnu
5
+ - ca-certificates=2023.12.12=h06a4308_0
6
+ - ld_impl_linux-64=2.38=h1181459_1
7
+ - libffi=3.4.4=h6a678d5_0
8
+ - libgcc-ng=11.2.0=h1234567_1
9
+ - libgomp=11.2.0=h1234567_1
10
+ - libstdcxx-ng=11.2.0=h1234567_1
11
+ - ncurses=6.4=h6a678d5_0
12
+ - openssl=3.0.13=h7f8727e_0
13
+ - pip=23.3.1=py39h06a4308_0
14
+ - python=3.9.18=h955ad1f_0
15
+ - readline=8.2=h5eee18b_0
16
+ - setuptools=68.2.2=py39h06a4308_0
17
+ - sqlite=3.41.2=h5eee18b_0
18
+ - tk=8.6.12=h1ccaba5_0
19
+ - wheel=0.41.2=py39h06a4308_0
20
+ - xz=5.4.6=h5eee18b_0
21
+ - zlib=1.2.13=h5eee18b_0
22
+ - pip:
23
+ - absl-py==2.1.0
24
+ - accelerate==0.21.0
25
+ - addict==2.4.0
26
+ - aiofiles==23.2.1
27
+ - aiohttp==3.9.3
28
+ - aiosignal==1.3.1
29
+ - altair==5.2.0
30
+ - annotated-types==0.6.0
31
+ - antlr4-python3-runtime==4.8
32
+ - anyio==4.3.0
33
+ - appdirs==1.4.4
34
+ - async-timeout==4.0.3
35
+ - attrs==23.2.0
36
+ - basicsr==1.4.2
37
+ - certifi==2024.2.2
38
+ - charset-normalizer==3.3.2
39
+ - click==8.1.7
40
+ - conda-pack==0.7.1
41
+ - contourpy==1.2.0
42
+ - cycler==0.12.1
43
+ - datasets==2.14.2
44
+ - diffusers==0.17.1
45
+ - dill==0.3.7
46
+ - docker-pycreds==0.4.0
47
+ - einops==0.7.0
48
+ - exceptiongroup==1.2.0
49
+ - face-alignment==1.3.4
50
+ - fastapi==0.110.0
51
+ - ffmpy==0.3.2
52
+ - filelock==3.13.1
53
+ - fire==0.6.0
54
+ - fonttools==4.49.0
55
+ - frozenlist==1.4.1
56
+ - fsspec==2024.2.0
57
+ - ftfy==6.1.3
58
+ - future==1.0.0
59
+ - gitdb==4.0.11
60
+ - gitpython==3.1.42
61
+ - gradio==3.16.2
62
+ - grpcio==1.62.1
63
+ - h11==0.14.0
64
+ - httpcore==1.0.4
65
+ - httpx==0.27.0
66
+ - huggingface-hub==0.21.4
67
+ - idna==3.6
68
+ - imageio==2.34.0
69
+ - importlib-metadata==7.0.2
70
+ - importlib-resources==6.3.0
71
+ - jinja2==3.1.3
72
+ - jsonschema==4.21.1
73
+ - jsonschema-specifications==2023.12.1
74
+ - kiwisolver==1.4.5
75
+ - lazy-loader==0.3
76
+ - lightning-utilities==0.10.1
77
+ - linkify-it-py==2.0.3
78
+ - llvmlite==0.42.0
79
+ - lmdb==1.4.1
80
+ - lpips==0.1.4
81
+ - markdown==3.5.2
82
+ - markdown-it-py==3.0.0
83
+ - markupsafe==2.1.5
84
+ - matplotlib==3.8.3
85
+ - mdit-py-plugins==0.4.0
86
+ - mdurl==0.1.2
87
+ - mpmath==1.3.0
88
+ - multidict==6.0.5
89
+ - multiprocess==0.70.15
90
+ - networkx==3.2.1
91
+ - numba==0.59.0
92
+ - numpy==1.26.4
93
+ - nvidia-cublas-cu12==12.1.3.1
94
+ - nvidia-cuda-cupti-cu12==12.1.105
95
+ - nvidia-cuda-nvrtc-cu12==12.1.105
96
+ - nvidia-cuda-runtime-cu12==12.1.105
97
+ - nvidia-cudnn-cu12==8.9.2.26
98
+ - nvidia-cufft-cu12==11.0.2.54
99
+ - nvidia-curand-cu12==10.3.2.106
100
+ - nvidia-cusolver-cu12==11.4.5.107
101
+ - nvidia-cusparse-cu12==12.1.0.106
102
+ - nvidia-nccl-cu12==2.19.3
103
+ - nvidia-nvjitlink-cu12==12.4.99
104
+ - nvidia-nvtx-cu12==12.1.105
105
+ - omegaconf==2.1.1
106
+ - open-clip-torch==2.0.2
107
+ - opencv-python==4.9.0.80
108
+ - orjson==3.9.15
109
+ - packaging==24.0
110
+ - pandas==2.2.1
111
+ - pillow==10.2.0
112
+ - platformdirs==4.2.0
113
+ - protobuf==4.25.3
114
+ - psutil==5.9.8
115
+ - pyarrow==15.0.1
116
+ - pycocotools==2.0.7
117
+ - pycryptodome==3.20.0
118
+ - pydantic==2.6.4
119
+ - pydantic-core==2.16.3
120
+ - pydeprecate==0.3.1
121
+ - pydub==0.25.1
122
+ - pyparsing==3.1.2
123
+ - python-dateutil==2.9.0.post0
124
+ - python-multipart==0.0.9
125
+ - pytorch-fid==0.3.0
126
+ - pytorch-lightning==1.5.0
127
+ - pytz==2024.1
128
+ - pyyaml==6.0.1
129
+ - referencing==0.33.0
130
+ - regex==2023.12.25
131
+ - requests==2.31.0
132
+ - rpds-py==0.18.0
133
+ - safetensors==0.4.2
134
+ - scikit-image==0.22.0
135
+ - scipy==1.12.0
136
+ - sentry-sdk==1.42.0
137
+ - setproctitle==1.3.3
138
+ - six==1.16.0
139
+ - smmap==5.0.1
140
+ - sniffio==1.3.1
141
+ - starlette==0.36.3
142
+ - sympy==1.12
143
+ - tb-nightly==2.17.0a20240313
144
+ - tensorboard==2.16.2
145
+ - tensorboard-data-server==0.7.2
146
+ - termcolor==2.4.0
147
+ - tifffile==2024.2.12
148
+ - timm==0.9.16
149
+ - tokenizers==0.13.3
150
+ - tomli==2.0.1
151
+ - toolz==0.12.1
152
+ - torch==2.2.1
153
+ - torch-fidelity==0.3.0
154
+ - torchaudio==2.2.1
155
+ - torchmetrics==1.3.1
156
+ - torchvision==0.17.1
157
+ - tqdm==4.66.2
158
+ - transformers==4.25.1
159
+ - triton==2.2.0
160
+ - typing-extensions==4.10.0
161
+ - tzdata==2024.1
162
+ - uc-micro-py==1.0.3
163
+ - urllib3==2.2.1
164
+ - uvicorn==0.28.0
165
+ - wandb==0.16.4
166
+ - wcwidth==0.2.13
167
+ - websockets==12.0
168
+ - werkzeug==3.0.1
169
+ - xxhash==3.4.1
170
+ - yapf==0.40.2
171
+ - yarl==1.9.4
172
+ - zipp==3.18.0