Add files using upload-large-folder tool
Browse files- README.md +283 -0
- assets/OHRFT_scheme.png +0 -0
- assets/figure_nlp.png +0 -0
- generation/control/ControlNet/.gitignore +143 -0
- generation/control/ControlNet/LICENSE +201 -0
- generation/control/ControlNet/README.md +348 -0
- generation/control/ControlNet/cldm/cldm.py +435 -0
- generation/control/ControlNet/cldm/ddim_hacked.py +317 -0
- generation/control/ControlNet/cldm/hack.py +111 -0
- generation/control/ControlNet/cldm/logger.py +76 -0
- generation/control/ControlNet/cldm/model.py +28 -0
- generation/control/ControlNet/config.py +1 -0
- generation/control/ControlNet/docs/annotator.md +49 -0
- generation/control/ControlNet/docs/faq.md +21 -0
- generation/control/ControlNet/docs/low_vram.md +15 -0
- generation/control/ControlNet/docs/train.md +276 -0
- generation/control/ControlNet/environment.yaml +35 -0
- generation/control/ControlNet/gradio_annotator.py +160 -0
- generation/control/ControlNet/gradio_canny2image.py +97 -0
- generation/control/ControlNet/gradio_depth2image.py +98 -0
- generation/control/ControlNet/gradio_fake_scribble2image.py +102 -0
- generation/control/ControlNet/gradio_hed2image.py +98 -0
- generation/control/ControlNet/gradio_hough2image.py +100 -0
- generation/control/ControlNet/gradio_normal2image.py +99 -0
- generation/control/ControlNet/gradio_pose2image.py +98 -0
- generation/control/ControlNet/gradio_scribble2image.py +92 -0
- generation/control/ControlNet/gradio_scribble2image_interactive.py +102 -0
- generation/control/ControlNet/gradio_seg2image.py +97 -0
- generation/control/ControlNet/ldm/data/__init__.py +0 -0
- generation/control/ControlNet/ldm/models/autoencoder.py +219 -0
- generation/control/ControlNet/ldm/models/diffusion/__init__.py +0 -0
- generation/control/ControlNet/ldm/models/diffusion/ddim.py +336 -0
- generation/control/ControlNet/ldm/util.py +197 -0
- generation/control/ControlNet/share.py +8 -0
- generation/control/ControlNet/tool_add_control.py +50 -0
- generation/control/ControlNet/tool_add_control_sd21.py +50 -0
- generation/control/ControlNet/tool_transfer_control.py +59 -0
- generation/control/ControlNet/tutorial_dataset.py +39 -0
- generation/control/ControlNet/tutorial_dataset_test.py +12 -0
- generation/control/ControlNet/tutorial_train.py +35 -0
- generation/control/ControlNet/tutorial_train_sd21.py +35 -0
- generation/control/download_ade20k.sh +10 -0
- generation/control/download_celebhq.sh +10 -0
- generation/control/eval_canny.py +130 -0
- generation/control/eval_landmark.py +127 -0
- generation/control/generation.py +238 -0
- generation/control/hra.py +254 -0
- generation/control/tool_add_hra.py +81 -0
- generation/control/train.py +149 -0
- generation/env.yml +172 -0
README.md
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align=center>
|
| 2 |
+
|
| 3 |
+
# [NeurIPS 2024 Spotlight] Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation
|
| 4 |
+
|
| 5 |
+
[](https://arxiv.org/pdf/2405.17484)
|
| 6 |
+
[](https://huggingface.co/docs/peft/en/package_reference/hra)
|
| 7 |
+
|
| 8 |
+
</div>
|
| 9 |
+
|
| 10 |
+
<div align="center">
|
| 11 |
+
<img src="assets/OHRFT_scheme.png" width="1100"/>
|
| 12 |
+
</div>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Introduction
|
| 17 |
+
|
| 18 |
+
This repository includes the official implementation of [HRA](https://arxiv.org/pdf/2405.17484).
|
| 19 |
+
We propose a simple yet effective adapter-based orthogonal fine-tuning method, HRA.
|
| 20 |
+
Given a pre-trained model, our method fine-tunes its layers by multiplying each frozen weight matrix with an orthogonal matrix constructed by a chain of learnable Householder reflections (HRs).
|
| 21 |
+
|
| 22 |
+
## Usage
|
| 23 |
+
|
| 24 |
+
### Subject-driven Generation
|
| 25 |
+
|
| 26 |
+
<div align="center">
|
| 27 |
+
<img src="assets/subject.png" width="600"/>
|
| 28 |
+
</div>
|
| 29 |
+
|
| 30 |
+
Given several images of a specific subject and a textual prompt, subject-driven generation aims to generate images of the same subject in a context aligning with the prompt.
|
| 31 |
+
|
| 32 |
+
#### Environment Setup
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
cd generation
|
| 36 |
+
conda env create -f env.yml
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
#### Prepare Dataset
|
| 40 |
+
|
| 41 |
+
Download [dreambooth](https://github.com/google/dreambooth) dataset by running this script.
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
cd subject
|
| 45 |
+
bash download_dreambooth.sh
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
After downloading the data, your directory structure should look like this:
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
dreambooth
|
| 52 |
+
├── dataset
|
| 53 |
+
│ ├── backpack
|
| 54 |
+
│ └── backpack_dog
|
| 55 |
+
│ ...
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
You can also put your custom images into `dreambooth/dataset`.
|
| 59 |
+
|
| 60 |
+
#### Finetune
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
prompt_idx=0
|
| 64 |
+
class_idx=0
|
| 65 |
+
./train_dreambooth.sh $prompt_idx $class_idx
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
where the `$prompt_idx` corresponds to different prompts ranging from 0 to 24 and the `$class_idx` corresponds to different subjects ranging from 0 to 29.
|
| 69 |
+
|
| 70 |
+
Launch the training script with `accelerate` and pass hyperparameters, as well as LoRa-specific arguments to it such as:
|
| 71 |
+
|
| 72 |
+
- `use_hra`: Enables HRA in the training script.
|
| 73 |
+
- `hra_r`: the number of HRs (i.e., r) across different layers, expressed in `int`.
|
| 74 |
+
As r increases, the number of trainable parameters increases, which generally leads to improved performance.
|
| 75 |
+
However, this also results in higher memory consumption and longer computation times.
|
| 76 |
+
Therefore, r is usually set to 8.
|
| 77 |
+
**Note**, please set r to an even number to avoid potential issues during initialization.
|
| 78 |
+
- `hra_apply_GS`: Applys Gram-Schmidt orthogonalization. Default is `false`.
|
| 79 |
+
- `hra_bias`: specify if the `bias` paramteres should be traind. Can be `none`, `all` or `hra_only`.
|
| 80 |
+
|
| 81 |
+
#### Evaluation
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python evaluate.py
|
| 85 |
+
python get_result.py
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### Controllable Generation
|
| 89 |
+
|
| 90 |
+
<div align="center">
|
| 91 |
+
<img src="assets/control.png" width="650"/>
|
| 92 |
+
</div>
|
| 93 |
+
|
| 94 |
+
Controllable generation aims to generate images aligning with a textual prompt and additional control signals (such as facial landmark annotations, canny edges, and segmentation maps).
|
| 95 |
+
|
| 96 |
+
#### Prepare Dataset
|
| 97 |
+
|
| 98 |
+
Download ADE20K and CelebA-HQ datasets by running this script.
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
cd control
|
| 102 |
+
bash download_ade20k.sh
|
| 103 |
+
bash download_celebhq.sh
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
For COCO dataset, we follow [OFT](https://github.com/Zeju1997/oft) to download and preprocess it.
|
| 107 |
+
|
| 108 |
+
After downloading the data, your directory structure should look like this:
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
data
|
| 112 |
+
├── ADE20K
|
| 113 |
+
│ ├── train
|
| 114 |
+
│ │ ├── color
|
| 115 |
+
│ │ ├── segm
|
| 116 |
+
│ │ └── prompt_train_blip.json
|
| 117 |
+
│ └── val
|
| 118 |
+
│ │ ├── color
|
| 119 |
+
│ │ ├── segm
|
| 120 |
+
│ │ └── prompt_val_blip.json
|
| 121 |
+
└── COCO
|
| 122 |
+
│ ├── train
|
| 123 |
+
│ │ ├── color
|
| 124 |
+
│ │ ├── depth
|
| 125 |
+
...
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
#### Prepare pre-trained model
|
| 129 |
+
|
| 130 |
+
Download the pre-trained model weights [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) and save it in the `models` directory.
|
| 131 |
+
|
| 132 |
+
#### Fine-tuning
|
| 133 |
+
|
| 134 |
+
1. Create the model with additional **HRA** parameters:
|
| 135 |
+
```bash
|
| 136 |
+
python tool_add_hra.py \
|
| 137 |
+
--input_path=./models/v1-5-pruned.ckpt \
|
| 138 |
+
--output_path=./models/hra_r_8.ckpt \
|
| 139 |
+
--r=8
|
| 140 |
+
```
|
| 141 |
+
2. Specify the control signal and dataset. Train the model specify the same hyperparameters as above:
|
| 142 |
+
```bash
|
| 143 |
+
python train.py \
|
| 144 |
+
--r=8 \
|
| 145 |
+
--control=segm
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
#### Generation
|
| 149 |
+
1. After finetuning with **HRA**, run inference to generate images based on control signal. Because the inference takes some time, to perform large scale evaluation, we split the dataset into different sub-datasets and run inference on multiple gpus:
|
| 150 |
+
```bash
|
| 151 |
+
python generation.py
|
| 152 |
+
--r=8 \
|
| 153 |
+
--control=segm
|
| 154 |
+
```
|
| 155 |
+
1. To evaluate **HRA** results on the three tasks: canny edge to image (C2I) on the COCO dataset, landmark to face (L2F) on the CelebA-HQ dataset, and segmentation map to image (S2I) on the ADE20K dataset, run the following scripts on the generated images.
|
| 156 |
+
```bash
|
| 157 |
+
python eval_landmark.py
|
| 158 |
+
```
|
| 159 |
+
```bash
|
| 160 |
+
python eval_canny.py
|
| 161 |
+
```
|
| 162 |
+
Note, for evaluating the segmentation map-to-image (S2I) task, please install the [Segformer](https://github.com/NVlabs/SegFormer) repository. Run the following testing command on both the original and generated images.
|
| 163 |
+
```bash
|
| 164 |
+
python tools/test.py local_configs/segformer/B4/segformer.b4.512x512.ade.160k.py ./weights/segformer.b4.512x512.ade.160k.pth
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
### Natural Language Understanding
|
| 168 |
+
|
| 169 |
+
<div align="center">
|
| 170 |
+
<img src="assets/figure_nlp.png" width="300"/>
|
| 171 |
+
</div>
|
| 172 |
+
|
| 173 |
+
We adapt [DeBERTaV3-base](https://arxiv.org/abs/2111.09543) and test the performance of the adapted models on [General Language Understanding Evaluation (GLUE) benchmark](https://gluebenchmark.com/).
|
| 174 |
+
|
| 175 |
+
#### Environment Setup
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
cd nlu
|
| 179 |
+
conda env create -f env.yml
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
Before fine-tuning, you need to install the dependencies.
|
| 183 |
+
|
| 184 |
+
```bash
|
| 185 |
+
python setup.py install
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
#### Prepare Dataset
|
| 189 |
+
|
| 190 |
+
Run this scipt to download glue dataset.
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
cache_dir=/tmp/DeBERTa/
|
| 194 |
+
cd experiments/glue
|
| 195 |
+
./download_data.sh $cache_dir/glue_tasks
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
#### Finetune
|
| 199 |
+
|
| 200 |
+
Run tasks.
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
./mnli.sh
|
| 204 |
+
./cola.sh
|
| 205 |
+
./mrpc.sh
|
| 206 |
+
./qnli.sh
|
| 207 |
+
./qqp.sh
|
| 208 |
+
./rte.sh
|
| 209 |
+
./sst2.sh
|
| 210 |
+
./stsb.sh
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
### Mathematical reasoning
|
| 214 |
+
We have not yet completed the integration of HRA code into PEFT. Before that, if you want to try using the HRA method to fine-tune large models, you can follow the steps below.
|
| 215 |
+
|
| 216 |
+
Go to the llama folder
|
| 217 |
+
```bash
|
| 218 |
+
cd llama
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
#### Environment Setup
|
| 222 |
+
We recommend using Python 3.10 for your environment and use the conda to install it.
|
| 223 |
+
```bash
|
| 224 |
+
conda create -n pytorch python=3.10
|
| 225 |
+
```
|
| 226 |
+
Then install the required packages with the following command:
|
| 227 |
+
```bash
|
| 228 |
+
pip install -r requirements.txt
|
| 229 |
+
```
|
| 230 |
+
Please note that the peft package and transformer package must be downloaded with the versions consistent with those listed in the requirements file.
|
| 231 |
+
|
| 232 |
+
After completing the download, please replace the **oft** folder inside the **peft/tuners** within your running environment's **python/site-packages** with the **oft** folder from the current directory.
|
| 233 |
+
|
| 234 |
+
The path for the oft folder in the environment should be:
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
/your_path/anaconda3/envs/pytorch/lib/python3.10/site-packages/peft/tuners/
|
| 238 |
+
```
|
| 239 |
+
The **layer.py** in the current oft directory is implemented for when λ is not infinity.
|
| 240 |
+
|
| 241 |
+
If you want to simulate when λ is infinity, please replace **layer.py** with **layer_GS_HRA.py**, and set the hyperparameter λ to 0 during training.
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
#### Prepare Dataset
|
| 245 |
+
The dataset we use for fine-tuning is MetaMathQA-40K, which can be downloaded through this [link](https://huggingface.co/datasets/meta-math/MetaMathQA-40K).
|
| 246 |
+
#### Prepare model
|
| 247 |
+
The model we use for fine-tuning is llama2. You can choose the model you want to fine-tune.
|
| 248 |
+
#### Finetune
|
| 249 |
+
Run the following code to complete the fine-tuning:
|
| 250 |
+
```bash
|
| 251 |
+
bash tune.sh
|
| 252 |
+
```
|
| 253 |
+
Please note that you need to change the dataset path, the path of the pre-trained model, and you can change the parameters according to your needs in tune.sh. That is:
|
| 254 |
+
```bash
|
| 255 |
+
BASE_MODEL="YOUR_MODEL_PATH"
|
| 256 |
+
DATA_PATH="YOUR_DATA_PATH"
|
| 257 |
+
OUTPUT="YOUR_MODEL_SAVED_PATH"
|
| 258 |
+
```
|
| 259 |
+
#### Evaluation
|
| 260 |
+
After the training is complete, you can run the following command to test:
|
| 261 |
+
```bash
|
| 262 |
+
bash test.sh
|
| 263 |
+
```
|
| 264 |
+
Please note to change the model path in it:
|
| 265 |
+
```bash
|
| 266 |
+
BASE_MODEL="YOUR_MODEL_PATH"
|
| 267 |
+
OUTPUT="YOUR_MODEL_SAVED_PATH"
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
## 📌 Citing our work
|
| 273 |
+
If you find our work useful, please cite it:
|
| 274 |
+
```bibtex
|
| 275 |
+
@inproceedings{yuanbridging,
|
| 276 |
+
title={Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation},
|
| 277 |
+
author={Yuan, Shen and Liu, Haotian and Xu, Hongteng},
|
| 278 |
+
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
|
| 279 |
+
year={2024}
|
| 280 |
+
}
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
|
assets/OHRFT_scheme.png
ADDED
|
assets/figure_nlp.png
ADDED
|
generation/control/ControlNet/.gitignore
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea/
|
| 2 |
+
|
| 3 |
+
training/
|
| 4 |
+
lightning_logs/
|
| 5 |
+
image_log/
|
| 6 |
+
|
| 7 |
+
*.pth
|
| 8 |
+
*.pt
|
| 9 |
+
*.ckpt
|
| 10 |
+
*.safetensors
|
| 11 |
+
|
| 12 |
+
gradio_pose2image_private.py
|
| 13 |
+
gradio_canny2image_private.py
|
| 14 |
+
|
| 15 |
+
# Byte-compiled / optimized / DLL files
|
| 16 |
+
__pycache__/
|
| 17 |
+
*.py[cod]
|
| 18 |
+
*$py.class
|
| 19 |
+
|
| 20 |
+
# C extensions
|
| 21 |
+
*.so
|
| 22 |
+
|
| 23 |
+
# Distribution / packaging
|
| 24 |
+
.Python
|
| 25 |
+
build/
|
| 26 |
+
develop-eggs/
|
| 27 |
+
dist/
|
| 28 |
+
downloads/
|
| 29 |
+
eggs/
|
| 30 |
+
.eggs/
|
| 31 |
+
lib/
|
| 32 |
+
lib64/
|
| 33 |
+
parts/
|
| 34 |
+
sdist/
|
| 35 |
+
var/
|
| 36 |
+
wheels/
|
| 37 |
+
pip-wheel-metadata/
|
| 38 |
+
share/python-wheels/
|
| 39 |
+
*.egg-info/
|
| 40 |
+
.installed.cfg
|
| 41 |
+
*.egg
|
| 42 |
+
MANIFEST
|
| 43 |
+
|
| 44 |
+
# PyInstaller
|
| 45 |
+
# Usually these files are written by a python script from a template
|
| 46 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 47 |
+
*.manifest
|
| 48 |
+
*.spec
|
| 49 |
+
|
| 50 |
+
# Installer logs
|
| 51 |
+
pip-log.txt
|
| 52 |
+
pip-delete-this-directory.txt
|
| 53 |
+
|
| 54 |
+
# Unit test / coverage reports
|
| 55 |
+
htmlcov/
|
| 56 |
+
.tox/
|
| 57 |
+
.nox/
|
| 58 |
+
.coverage
|
| 59 |
+
.coverage.*
|
| 60 |
+
.cache
|
| 61 |
+
nosetests.xml
|
| 62 |
+
coverage.xml
|
| 63 |
+
*.cover
|
| 64 |
+
*.py,cover
|
| 65 |
+
.hypothesis/
|
| 66 |
+
.pytest_cache/
|
| 67 |
+
|
| 68 |
+
# Translations
|
| 69 |
+
*.mo
|
| 70 |
+
*.pot
|
| 71 |
+
|
| 72 |
+
# Django stuff:
|
| 73 |
+
*.log
|
| 74 |
+
local_settings.py
|
| 75 |
+
db.sqlite3
|
| 76 |
+
db.sqlite3-journal
|
| 77 |
+
|
| 78 |
+
# Flask stuff:
|
| 79 |
+
instance/
|
| 80 |
+
.webassets-cache
|
| 81 |
+
|
| 82 |
+
# Scrapy stuff:
|
| 83 |
+
.scrapy
|
| 84 |
+
|
| 85 |
+
# Sphinx documentation
|
| 86 |
+
docs/_build/
|
| 87 |
+
|
| 88 |
+
# PyBuilder
|
| 89 |
+
target/
|
| 90 |
+
|
| 91 |
+
# Jupyter Notebook
|
| 92 |
+
.ipynb_checkpoints
|
| 93 |
+
|
| 94 |
+
# IPython
|
| 95 |
+
profile_default/
|
| 96 |
+
ipython_config.py
|
| 97 |
+
|
| 98 |
+
# pyenv
|
| 99 |
+
.python-version
|
| 100 |
+
|
| 101 |
+
# pipenv
|
| 102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 105 |
+
# install all needed dependencies.
|
| 106 |
+
#Pipfile.lock
|
| 107 |
+
|
| 108 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 109 |
+
__pypackages__/
|
| 110 |
+
|
| 111 |
+
# Celery stuff
|
| 112 |
+
celerybeat-schedule
|
| 113 |
+
celerybeat.pid
|
| 114 |
+
|
| 115 |
+
# SageMath parsed files
|
| 116 |
+
*.sage.py
|
| 117 |
+
|
| 118 |
+
# Environments
|
| 119 |
+
.env
|
| 120 |
+
.venv
|
| 121 |
+
env/
|
| 122 |
+
venv/
|
| 123 |
+
ENV/
|
| 124 |
+
env.bak/
|
| 125 |
+
venv.bak/
|
| 126 |
+
|
| 127 |
+
# Spyder project settings
|
| 128 |
+
.spyderproject
|
| 129 |
+
.spyproject
|
| 130 |
+
|
| 131 |
+
# Rope project settings
|
| 132 |
+
.ropeproject
|
| 133 |
+
|
| 134 |
+
# mkdocs documentation
|
| 135 |
+
/site
|
| 136 |
+
|
| 137 |
+
# mypy
|
| 138 |
+
.mypy_cache/
|
| 139 |
+
.dmypy.json
|
| 140 |
+
dmypy.json
|
| 141 |
+
|
| 142 |
+
# Pyre type checker
|
| 143 |
+
.pyre/
|
generation/control/ControlNet/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
generation/control/ControlNet/README.md
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# News: A nightly version of ControlNet 1.1 is released!
|
| 2 |
+
|
| 3 |
+
[ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) is released. Those new models will be merged to this repo after we make sure that everything is good.
|
| 4 |
+
|
| 5 |
+
# Below is ControlNet 1.0
|
| 6 |
+
|
| 7 |
+
Official implementation of [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543).
|
| 8 |
+
|
| 9 |
+
ControlNet is a neural network structure to control diffusion models by adding extra conditions.
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
It copys the weights of neural network blocks into a "locked" copy and a "trainable" copy.
|
| 14 |
+
|
| 15 |
+
The "trainable" one learns your condition. The "locked" one preserves your model.
|
| 16 |
+
|
| 17 |
+
Thanks to this, training with small dataset of image pairs will not destroy the production-ready diffusion models.
|
| 18 |
+
|
| 19 |
+
The "zero convolution" is 1×1 convolution with both weight and bias initialized as zeros.
|
| 20 |
+
|
| 21 |
+
Before training, all zero convolutions output zeros, and ControlNet will not cause any distortion.
|
| 22 |
+
|
| 23 |
+
No layer is trained from scratch. You are still fine-tuning. Your original model is safe.
|
| 24 |
+
|
| 25 |
+
This allows training on small-scale or even personal devices.
|
| 26 |
+
|
| 27 |
+
This is also friendly to merge/replacement/offsetting of models/weights/blocks/layers.
|
| 28 |
+
|
| 29 |
+
### FAQ
|
| 30 |
+
|
| 31 |
+
**Q:** But wait, if the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
|
| 32 |
+
|
| 33 |
+
**A:** This is not true. [See an explanation here](docs/faq.md).
|
| 34 |
+
|
| 35 |
+
# Stable Diffusion + ControlNet
|
| 36 |
+
|
| 37 |
+
By repeating the above simple structure 14 times, we can control stable diffusion in this way:
|
| 38 |
+
|
| 39 |
+

|
| 40 |
+
|
| 41 |
+
In this way, the ControlNet can **reuse** the SD encoder as a **deep, strong, robust, and powerful backbone** to learn diverse controls. Many evidences (like [this](https://jerryxu.net/ODISE/) and [this](https://vpd.ivg-research.xyz/)) validate that the SD encoder is an excellent backbone.
|
| 42 |
+
|
| 43 |
+
Note that the way we connect layers is computational efficient. The original SD encoder does not need to store gradients (the locked original SD Encoder Block 1234 and Middle). The required GPU memory is not much larger than original SD, although many layers are added. Great!
|
| 44 |
+
|
| 45 |
+
# Features & News
|
| 46 |
+
|
| 47 |
+
2023/0/14 - We released [ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly). Those new models will be merged to this repo after we make sure that everything is good.
|
| 48 |
+
|
| 49 |
+
2023/03/03 - We released a discussion - [Precomputed ControlNet: Speed up ControlNet by 45%, but is it necessary?](https://github.com/lllyasviel/ControlNet/discussions/216)
|
| 50 |
+
|
| 51 |
+
2023/02/26 - We released a blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188)
|
| 52 |
+
|
| 53 |
+
2023/02/20 - Implementation for non-prompt mode released. See also [Guess Mode / Non-Prompt Mode](#guess-anchor).
|
| 54 |
+
|
| 55 |
+
2023/02/12 - Now you can play with any community model by [Transferring the ControlNet](https://github.com/lllyasviel/ControlNet/discussions/12).
|
| 56 |
+
|
| 57 |
+
2023/02/11 - [Low VRAM mode](docs/low_vram.md) is added. Please use this mode if you are using 8GB GPU(s) or if you want larger batch size.
|
| 58 |
+
|
| 59 |
+
# Production-Ready Pretrained Models
|
| 60 |
+
|
| 61 |
+
First create a new conda environment
|
| 62 |
+
|
| 63 |
+
conda env create -f environment.yaml
|
| 64 |
+
conda activate control
|
| 65 |
+
|
| 66 |
+
All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on.
|
| 67 |
+
|
| 68 |
+
We provide 9 Gradio apps with these models.
|
| 69 |
+
|
| 70 |
+
All test images can be found at the folder "test_imgs".
|
| 71 |
+
|
| 72 |
+
## ControlNet with Canny Edge
|
| 73 |
+
|
| 74 |
+
Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection)
|
| 75 |
+
|
| 76 |
+
python gradio_canny2image.py
|
| 77 |
+
|
| 78 |
+
The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details.
|
| 79 |
+
|
| 80 |
+
Prompt: "bird"
|
| 81 |
+

|
| 82 |
+
|
| 83 |
+
Prompt: "cute dog"
|
| 84 |
+

|
| 85 |
+
|
| 86 |
+
## ControlNet with M-LSD Lines
|
| 87 |
+
|
| 88 |
+
Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection)
|
| 89 |
+
|
| 90 |
+
python gradio_hough2image.py
|
| 91 |
+
|
| 92 |
+
The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details.
|
| 93 |
+
|
| 94 |
+
Prompt: "room"
|
| 95 |
+

|
| 96 |
+
|
| 97 |
+
Prompt: "building"
|
| 98 |
+

|
| 99 |
+
|
| 100 |
+
## ControlNet with HED Boundary
|
| 101 |
+
|
| 102 |
+
Stable Diffusion 1.5 + ControlNet (using soft HED Boundary)
|
| 103 |
+
|
| 104 |
+
python gradio_hed2image.py
|
| 105 |
+
|
| 106 |
+
The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details.
|
| 107 |
+
|
| 108 |
+
Prompt: "oil painting of handsome old man, masterpiece"
|
| 109 |
+

|
| 110 |
+
|
| 111 |
+
Prompt: "Cyberpunk robot"
|
| 112 |
+

|
| 113 |
+
|
| 114 |
+
## ControlNet with User Scribbles
|
| 115 |
+
|
| 116 |
+
Stable Diffusion 1.5 + ControlNet (using Scribbles)
|
| 117 |
+
|
| 118 |
+
python gradio_scribble2image.py
|
| 119 |
+
|
| 120 |
+
Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio.
|
| 121 |
+
|
| 122 |
+
Prompt: "turtle"
|
| 123 |
+

|
| 124 |
+
|
| 125 |
+
Prompt: "hot air balloon"
|
| 126 |
+

|
| 127 |
+
|
| 128 |
+
### Interactive Interface
|
| 129 |
+
|
| 130 |
+
We actually provide an interactive interface
|
| 131 |
+
|
| 132 |
+
python gradio_scribble2image_interactive.py
|
| 133 |
+
|
| 134 |
+
~~However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.~~ (Now fixed, will update asap)
|
| 135 |
+
|
| 136 |
+
The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase.
|
| 137 |
+
|
| 138 |
+
Prompt: "dog in a room"
|
| 139 |
+

|
| 140 |
+
|
| 141 |
+
## ControlNet with Fake Scribbles
|
| 142 |
+
|
| 143 |
+
Stable Diffusion 1.5 + ControlNet (using fake scribbles)
|
| 144 |
+
|
| 145 |
+
python gradio_fake_scribble2image.py
|
| 146 |
+
|
| 147 |
+
Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images.
|
| 148 |
+
|
| 149 |
+
Prompt: "bag"
|
| 150 |
+

|
| 151 |
+
|
| 152 |
+
Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still seems to work.)
|
| 153 |
+

|
| 154 |
+
|
| 155 |
+
## ControlNet with Human Pose
|
| 156 |
+
|
| 157 |
+
Stable Diffusion 1.5 + ControlNet (using human pose)
|
| 158 |
+
|
| 159 |
+
python gradio_pose2image.py
|
| 160 |
+
|
| 161 |
+
Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you.
|
| 162 |
+
|
| 163 |
+
Prompt: "Chief in the kitchen"
|
| 164 |
+

|
| 165 |
+
|
| 166 |
+
Prompt: "An astronaut on the moon"
|
| 167 |
+

|
| 168 |
+
|
| 169 |
+
## ControlNet with Semantic Segmentation
|
| 170 |
+
|
| 171 |
+
Stable Diffusion 1.5 + ControlNet (using semantic segmentation)
|
| 172 |
+
|
| 173 |
+
python gradio_seg2image.py
|
| 174 |
+
|
| 175 |
+
This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details.
|
| 176 |
+
|
| 177 |
+
Prompt: "House"
|
| 178 |
+

|
| 179 |
+
|
| 180 |
+
Prompt: "River"
|
| 181 |
+

|
| 182 |
+
|
| 183 |
+
## ControlNet with Depth
|
| 184 |
+
|
| 185 |
+
Stable Diffusion 1.5 + ControlNet (using depth map)
|
| 186 |
+
|
| 187 |
+
python gradio_depth2image.py
|
| 188 |
+
|
| 189 |
+
Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2).
|
| 190 |
+
|
| 191 |
+
Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map.
|
| 192 |
+
|
| 193 |
+
This is always a strength because if users do not want to preserve more details, they can simply use another SD to post-process an i2i. But if they want to preserve more details, ControlNet becomes their only choice. Again, SD2 uses 64×64 depth, we use 512×512.
|
| 194 |
+
|
| 195 |
+
Prompt: "Stormtrooper's lecture"
|
| 196 |
+

|
| 197 |
+
|
| 198 |
+
## ControlNet with Normal Map
|
| 199 |
+
|
| 200 |
+
Stable Diffusion 1.5 + ControlNet (using normal map)
|
| 201 |
+
|
| 202 |
+
python gradio_normal2image.py
|
| 203 |
+
|
| 204 |
+
This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling).
|
| 205 |
+
|
| 206 |
+
Prompt: "Cute toy"
|
| 207 |
+

|
| 208 |
+
|
| 209 |
+
Prompt: "Plaster statue of Abraham Lincoln"
|
| 210 |
+

|
| 211 |
+
|
| 212 |
+
Compared to depth model, this model seems to be a bit better at preserving the geometry. This is intuitive: minor details are not salient in depth maps, but are salient in normal maps. Below is the depth result with same inputs. You can see that the hairstyle of the man in the input image is modified by depth model, but preserved by the normal model.
|
| 213 |
+
|
| 214 |
+
Prompt: "Plaster statue of Abraham Lincoln"
|
| 215 |
+

|
| 216 |
+
|
| 217 |
+
## ControlNet with Anime Line Drawing
|
| 218 |
+
|
| 219 |
+
We also trained a relatively simple ControlNet for anime line drawings. This tool may be useful for artistic creations. (Although the image details in the results is a bit modified, since it still diffuse latent images.)
|
| 220 |
+
|
| 221 |
+
This model is not available right now. We need to evaluate the potential risks before releasing this model. Nevertheless, you may be interested in [transferring the ControlNet to any community model](https://github.com/lllyasviel/ControlNet/discussions/12).
|
| 222 |
+
|
| 223 |
+

|
| 224 |
+
|
| 225 |
+
<a id="guess-anchor"></a>
|
| 226 |
+
|
| 227 |
+
# Guess Mode / Non-Prompt Mode
|
| 228 |
+
|
| 229 |
+
The "guess mode" (or called non-prompt mode) will completely unleash all the power of the very powerful ControlNet encoder.
|
| 230 |
+
|
| 231 |
+
See also the blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188)
|
| 232 |
+
|
| 233 |
+
You need to manually check the "Guess Mode" toggle to enable this mode.
|
| 234 |
+
|
| 235 |
+
In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
|
| 236 |
+
|
| 237 |
+
**Let's have fun with some very challenging experimental settings!**
|
| 238 |
+
|
| 239 |
+
**No prompts. No "positive" prompts. No "negative" prompts. No extra caption detector. One single diffusion loop.**
|
| 240 |
+
|
| 241 |
+
For this mode, we recommend to use 50 steps and guidance scale between 3 and 5.
|
| 242 |
+
|
| 243 |
+

|
| 244 |
+
|
| 245 |
+
No prompts:
|
| 246 |
+
|
| 247 |
+

|
| 248 |
+
|
| 249 |
+
Note that the below example is 768×768. No prompts. No "positive" prompts. No "negative" prompts.
|
| 250 |
+
|
| 251 |
+

|
| 252 |
+
|
| 253 |
+
By tuning the parameters, you can get some very intereting results like below:
|
| 254 |
+
|
| 255 |
+

|
| 256 |
+
|
| 257 |
+
Because no prompt is available, the ControlNet encoder will "guess" what is in the control map. Sometimes the guess result is really interesting. Because diffusion algorithm can essentially give multiple results, the ControlNet seems able to give multiple guesses, like this:
|
| 258 |
+
|
| 259 |
+

|
| 260 |
+
|
| 261 |
+
Without prompt, the HED seems good at generating images look like paintings when the control strength is relatively low:
|
| 262 |
+
|
| 263 |
+

|
| 264 |
+
|
| 265 |
+
The Guess Mode is also supported in [WebUI Plugin](https://github.com/Mikubill/sd-webui-controlnet):
|
| 266 |
+
|
| 267 |
+

|
| 268 |
+
|
| 269 |
+
No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce.
|
| 270 |
+
|
| 271 |
+

|
| 272 |
+
|
| 273 |
+
Below is another challenging example:
|
| 274 |
+
|
| 275 |
+

|
| 276 |
+
|
| 277 |
+
No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce.
|
| 278 |
+
|
| 279 |
+

|
| 280 |
+
|
| 281 |
+
Note that in the guess mode, you will still be able to input prompts. The only difference is that the model will "try harder" to guess what is in the control map even if you do not provide the prompt. Just try it yourself!
|
| 282 |
+
|
| 283 |
+
Besides, if you write some scripts (like BLIP) to generate image captions from the "guess mode" images, and then use the generated captions as prompts to diffuse again, you will get a SOTA pipeline for fully automatic conditional image generating.
|
| 284 |
+
|
| 285 |
+
# Combining Multiple ControlNets
|
| 286 |
+
|
| 287 |
+
ControlNets are composable: more than one ControlNet can be easily composed to multi-condition control.
|
| 288 |
+
|
| 289 |
+
Right now this feature is in experimental stage in the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet):
|
| 290 |
+
|
| 291 |
+

|
| 292 |
+
|
| 293 |
+

|
| 294 |
+
|
| 295 |
+
As long as the models are controlling the same SD, the "boundary" between different research projects does not even exist. This plugin also allows different methods to work together!
|
| 296 |
+
|
| 297 |
+
# Use ControlNet in Any Community Model (SD1.X)
|
| 298 |
+
|
| 299 |
+
This is an experimental feature.
|
| 300 |
+
|
| 301 |
+
[See the steps here](https://github.com/lllyasviel/ControlNet/discussions/12).
|
| 302 |
+
|
| 303 |
+
Or you may want to use the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) which is plug-and-play and does not need manual merging.
|
| 304 |
+
|
| 305 |
+
# Annotate Your Own Data
|
| 306 |
+
|
| 307 |
+
We provide simple python scripts to process images.
|
| 308 |
+
|
| 309 |
+
[See a gradio example here](docs/annotator.md).
|
| 310 |
+
|
| 311 |
+
# Train with Your Own Data
|
| 312 |
+
|
| 313 |
+
Training a ControlNet is as easy as (or even easier than) training a simple pix2pix.
|
| 314 |
+
|
| 315 |
+
[See the steps here](docs/train.md).
|
| 316 |
+
|
| 317 |
+
# Related Resources
|
| 318 |
+
|
| 319 |
+
Special Thank to the great project - [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) !
|
| 320 |
+
|
| 321 |
+
We also thank Hysts for making [Hugging Face Space](https://huggingface.co/spaces/hysts/ControlNet) as well as more than 65 models in that amazing [Colab list](https://github.com/camenduru/controlnet-colab)!
|
| 322 |
+
|
| 323 |
+
Thank haofanwang for making [ControlNet-for-Diffusers](https://github.com/haofanwang/ControlNet-for-Diffusers)!
|
| 324 |
+
|
| 325 |
+
We also thank all authors for making Controlnet DEMOs, including but not limited to [fffiloni](https://huggingface.co/spaces/fffiloni/ControlNet-Video), [other-model](https://huggingface.co/spaces/hysts/ControlNet-with-other-models), [ThereforeGames](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/7784), [RamAnanth1](https://huggingface.co/spaces/RamAnanth1/ControlNet), etc!
|
| 326 |
+
|
| 327 |
+
Besides, you may also want to read these amazing related works:
|
| 328 |
+
|
| 329 |
+
[Composer: Creative and Controllable Image Synthesis with Composable Conditions](https://github.com/damo-vilab/composer): A much bigger model to control diffusion!
|
| 330 |
+
|
| 331 |
+
[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://github.com/TencentARC/T2I-Adapter): A much smaller model to control stable diffusion!
|
| 332 |
+
|
| 333 |
+
[ControlLoRA: A Light Neural Network To Control Stable Diffusion Spatial Information](https://github.com/HighCWu/ControlLoRA): Implement Controlnet using LORA!
|
| 334 |
+
|
| 335 |
+
And these amazing recent projects: [InstructPix2Pix Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix), [Pix2pix-zero: Zero-shot Image-to-Image Translation](https://github.com/pix2pixzero/pix2pix-zero), [Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation](https://github.com/MichalGeyer/plug-and-play), [MaskSketch: Unpaired Structure-guided Masked Image Generation](https://arxiv.org/abs/2302.05496), [SEGA: Instructing Diffusion using Semantic Dimensions](https://arxiv.org/abs/2301.12247), [Universal Guidance for Diffusion Models](https://github.com/arpitbansal297/Universal-Guided-Diffusion), [Region-Aware Diffusion for Zero-shot Text-driven Image Editing](https://github.com/haha-lisa/RDM-Region-Aware-Diffusion-Model), [Domain Expansion of Image Generators](https://arxiv.org/abs/2301.05225), [Image Mixer](https://twitter.com/LambdaAPI/status/1626327289288957956), [MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://multidiffusion.github.io/)
|
| 336 |
+
|
| 337 |
+
# Citation
|
| 338 |
+
|
| 339 |
+
@misc{zhang2023adding,
|
| 340 |
+
title={Adding Conditional Control to Text-to-Image Diffusion Models},
|
| 341 |
+
author={Lvmin Zhang and Anyi Rao and Maneesh Agrawala},
|
| 342 |
+
booktitle={IEEE International Conference on Computer Vision (ICCV)}
|
| 343 |
+
year={2023},
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
[Arxiv Link](https://arxiv.org/abs/2302.05543)
|
| 347 |
+
|
| 348 |
+
[Supplementary Materials](https://lllyasviel.github.io/misc/202309/cnet_supp.pdf)
|
generation/control/ControlNet/cldm/cldm.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
import torch
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from ldm.modules.diffusionmodules.util import (
|
| 7 |
+
conv_nd,
|
| 8 |
+
linear,
|
| 9 |
+
zero_module,
|
| 10 |
+
timestep_embedding,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
from torchvision.utils import make_grid
|
| 15 |
+
from ldm.modules.attention import SpatialTransformer
|
| 16 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
| 17 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 18 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
| 19 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ControlledUnetModel(UNetModel):
|
| 23 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
| 24 |
+
hs = []
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 27 |
+
emb = self.time_embed(t_emb)
|
| 28 |
+
h = x.type(self.dtype)
|
| 29 |
+
for module in self.input_blocks:
|
| 30 |
+
h = module(h, emb, context)
|
| 31 |
+
hs.append(h)
|
| 32 |
+
h = self.middle_block(h, emb, context)
|
| 33 |
+
|
| 34 |
+
if control is not None:
|
| 35 |
+
h += control.pop()
|
| 36 |
+
|
| 37 |
+
for i, module in enumerate(self.output_blocks):
|
| 38 |
+
if only_mid_control or control is None:
|
| 39 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 40 |
+
else:
|
| 41 |
+
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
| 42 |
+
h = module(h, emb, context)
|
| 43 |
+
|
| 44 |
+
h = h.type(x.dtype)
|
| 45 |
+
return self.out(h)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ControlNet(nn.Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
image_size,
|
| 52 |
+
in_channels,
|
| 53 |
+
model_channels,
|
| 54 |
+
hint_channels,
|
| 55 |
+
num_res_blocks,
|
| 56 |
+
attention_resolutions,
|
| 57 |
+
dropout=0,
|
| 58 |
+
channel_mult=(1, 2, 4, 8),
|
| 59 |
+
conv_resample=True,
|
| 60 |
+
dims=2,
|
| 61 |
+
use_checkpoint=False,
|
| 62 |
+
use_fp16=False,
|
| 63 |
+
num_heads=-1,
|
| 64 |
+
num_head_channels=-1,
|
| 65 |
+
num_heads_upsample=-1,
|
| 66 |
+
use_scale_shift_norm=False,
|
| 67 |
+
resblock_updown=False,
|
| 68 |
+
use_new_attention_order=False,
|
| 69 |
+
use_spatial_transformer=False, # custom transformer support
|
| 70 |
+
transformer_depth=1, # custom transformer support
|
| 71 |
+
context_dim=None, # custom transformer support
|
| 72 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 73 |
+
legacy=True,
|
| 74 |
+
disable_self_attentions=None,
|
| 75 |
+
num_attention_blocks=None,
|
| 76 |
+
disable_middle_self_attn=False,
|
| 77 |
+
use_linear_in_transformer=False,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
if use_spatial_transformer:
|
| 81 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 82 |
+
|
| 83 |
+
if context_dim is not None:
|
| 84 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 85 |
+
from omegaconf.listconfig import ListConfig
|
| 86 |
+
if type(context_dim) == ListConfig:
|
| 87 |
+
context_dim = list(context_dim)
|
| 88 |
+
|
| 89 |
+
if num_heads_upsample == -1:
|
| 90 |
+
num_heads_upsample = num_heads
|
| 91 |
+
|
| 92 |
+
if num_heads == -1:
|
| 93 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 94 |
+
|
| 95 |
+
if num_head_channels == -1:
|
| 96 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 97 |
+
|
| 98 |
+
self.dims = dims
|
| 99 |
+
self.image_size = image_size
|
| 100 |
+
self.in_channels = in_channels
|
| 101 |
+
self.model_channels = model_channels
|
| 102 |
+
if isinstance(num_res_blocks, int):
|
| 103 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 104 |
+
else:
|
| 105 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 106 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 107 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 108 |
+
self.num_res_blocks = num_res_blocks
|
| 109 |
+
if disable_self_attentions is not None:
|
| 110 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 111 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 112 |
+
if num_attention_blocks is not None:
|
| 113 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 114 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 115 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 116 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 117 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 118 |
+
f"attention will still not be set.")
|
| 119 |
+
|
| 120 |
+
self.attention_resolutions = attention_resolutions
|
| 121 |
+
self.dropout = dropout
|
| 122 |
+
self.channel_mult = channel_mult
|
| 123 |
+
self.conv_resample = conv_resample
|
| 124 |
+
self.use_checkpoint = use_checkpoint
|
| 125 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 126 |
+
self.num_heads = num_heads
|
| 127 |
+
self.num_head_channels = num_head_channels
|
| 128 |
+
self.num_heads_upsample = num_heads_upsample
|
| 129 |
+
self.predict_codebook_ids = n_embed is not None
|
| 130 |
+
|
| 131 |
+
time_embed_dim = model_channels * 4
|
| 132 |
+
self.time_embed = nn.Sequential(
|
| 133 |
+
linear(model_channels, time_embed_dim),
|
| 134 |
+
nn.SiLU(),
|
| 135 |
+
linear(time_embed_dim, time_embed_dim),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.input_blocks = nn.ModuleList(
|
| 139 |
+
[
|
| 140 |
+
TimestepEmbedSequential(
|
| 141 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 142 |
+
)
|
| 143 |
+
]
|
| 144 |
+
)
|
| 145 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
| 146 |
+
|
| 147 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 148 |
+
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 149 |
+
nn.SiLU(),
|
| 150 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
| 151 |
+
nn.SiLU(),
|
| 152 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 153 |
+
nn.SiLU(),
|
| 154 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
| 155 |
+
nn.SiLU(),
|
| 156 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
| 157 |
+
nn.SiLU(),
|
| 158 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
| 159 |
+
nn.SiLU(),
|
| 160 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
| 161 |
+
nn.SiLU(),
|
| 162 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self._feature_size = model_channels
|
| 166 |
+
input_block_chans = [model_channels]
|
| 167 |
+
ch = model_channels
|
| 168 |
+
ds = 1
|
| 169 |
+
for level, mult in enumerate(channel_mult):
|
| 170 |
+
for nr in range(self.num_res_blocks[level]):
|
| 171 |
+
layers = [
|
| 172 |
+
ResBlock(
|
| 173 |
+
ch,
|
| 174 |
+
time_embed_dim,
|
| 175 |
+
dropout,
|
| 176 |
+
out_channels=mult * model_channels,
|
| 177 |
+
dims=dims,
|
| 178 |
+
use_checkpoint=use_checkpoint,
|
| 179 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 180 |
+
)
|
| 181 |
+
]
|
| 182 |
+
ch = mult * model_channels
|
| 183 |
+
if ds in attention_resolutions:
|
| 184 |
+
if num_head_channels == -1:
|
| 185 |
+
dim_head = ch // num_heads
|
| 186 |
+
else:
|
| 187 |
+
num_heads = ch // num_head_channels
|
| 188 |
+
dim_head = num_head_channels
|
| 189 |
+
if legacy:
|
| 190 |
+
# num_heads = 1
|
| 191 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 192 |
+
if exists(disable_self_attentions):
|
| 193 |
+
disabled_sa = disable_self_attentions[level]
|
| 194 |
+
else:
|
| 195 |
+
disabled_sa = False
|
| 196 |
+
|
| 197 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 198 |
+
layers.append(
|
| 199 |
+
AttentionBlock(
|
| 200 |
+
ch,
|
| 201 |
+
use_checkpoint=use_checkpoint,
|
| 202 |
+
num_heads=num_heads,
|
| 203 |
+
num_head_channels=dim_head,
|
| 204 |
+
use_new_attention_order=use_new_attention_order,
|
| 205 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
| 206 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
| 207 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
| 208 |
+
use_checkpoint=use_checkpoint
|
| 209 |
+
)
|
| 210 |
+
)
|
| 211 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 212 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
| 213 |
+
self._feature_size += ch
|
| 214 |
+
input_block_chans.append(ch)
|
| 215 |
+
if level != len(channel_mult) - 1:
|
| 216 |
+
out_ch = ch
|
| 217 |
+
self.input_blocks.append(
|
| 218 |
+
TimestepEmbedSequential(
|
| 219 |
+
ResBlock(
|
| 220 |
+
ch,
|
| 221 |
+
time_embed_dim,
|
| 222 |
+
dropout,
|
| 223 |
+
out_channels=out_ch,
|
| 224 |
+
dims=dims,
|
| 225 |
+
use_checkpoint=use_checkpoint,
|
| 226 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 227 |
+
down=True,
|
| 228 |
+
)
|
| 229 |
+
if resblock_updown
|
| 230 |
+
else Downsample(
|
| 231 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
ch = out_ch
|
| 236 |
+
input_block_chans.append(ch)
|
| 237 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
| 238 |
+
ds *= 2
|
| 239 |
+
self._feature_size += ch
|
| 240 |
+
|
| 241 |
+
if num_head_channels == -1:
|
| 242 |
+
dim_head = ch // num_heads
|
| 243 |
+
else:
|
| 244 |
+
num_heads = ch // num_head_channels
|
| 245 |
+
dim_head = num_head_channels
|
| 246 |
+
if legacy:
|
| 247 |
+
# num_heads = 1
|
| 248 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 249 |
+
self.middle_block = TimestepEmbedSequential(
|
| 250 |
+
ResBlock(
|
| 251 |
+
ch,
|
| 252 |
+
time_embed_dim,
|
| 253 |
+
dropout,
|
| 254 |
+
dims=dims,
|
| 255 |
+
use_checkpoint=use_checkpoint,
|
| 256 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 257 |
+
),
|
| 258 |
+
AttentionBlock(
|
| 259 |
+
ch,
|
| 260 |
+
use_checkpoint=use_checkpoint,
|
| 261 |
+
num_heads=num_heads,
|
| 262 |
+
num_head_channels=dim_head,
|
| 263 |
+
use_new_attention_order=use_new_attention_order,
|
| 264 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
| 265 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
| 266 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
| 267 |
+
use_checkpoint=use_checkpoint
|
| 268 |
+
),
|
| 269 |
+
ResBlock(
|
| 270 |
+
ch,
|
| 271 |
+
time_embed_dim,
|
| 272 |
+
dropout,
|
| 273 |
+
dims=dims,
|
| 274 |
+
use_checkpoint=use_checkpoint,
|
| 275 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 276 |
+
),
|
| 277 |
+
)
|
| 278 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
| 279 |
+
self._feature_size += ch
|
| 280 |
+
|
| 281 |
+
def make_zero_conv(self, channels):
|
| 282 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
| 283 |
+
|
| 284 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 285 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 286 |
+
emb = self.time_embed(t_emb)
|
| 287 |
+
|
| 288 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
| 289 |
+
|
| 290 |
+
outs = []
|
| 291 |
+
|
| 292 |
+
h = x.type(self.dtype)
|
| 293 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 294 |
+
if guided_hint is not None:
|
| 295 |
+
h = module(h, emb, context)
|
| 296 |
+
h += guided_hint
|
| 297 |
+
guided_hint = None
|
| 298 |
+
else:
|
| 299 |
+
h = module(h, emb, context)
|
| 300 |
+
outs.append(zero_conv(h, emb, context))
|
| 301 |
+
|
| 302 |
+
h = self.middle_block(h, emb, context)
|
| 303 |
+
outs.append(self.middle_block_out(h, emb, context))
|
| 304 |
+
|
| 305 |
+
return outs
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class ControlLDM(LatentDiffusion):
|
| 309 |
+
|
| 310 |
+
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
|
| 311 |
+
super().__init__(*args, **kwargs)
|
| 312 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
| 313 |
+
self.control_key = control_key
|
| 314 |
+
self.only_mid_control = only_mid_control
|
| 315 |
+
self.control_scales = [1.0] * 13
|
| 316 |
+
|
| 317 |
+
@torch.no_grad()
|
| 318 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
| 319 |
+
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
| 320 |
+
control = batch[self.control_key]
|
| 321 |
+
if bs is not None:
|
| 322 |
+
control = control[:bs]
|
| 323 |
+
control = control.to(self.device)
|
| 324 |
+
control = einops.rearrange(control, 'b h w c -> b c h w')
|
| 325 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
| 326 |
+
return x, dict(c_crossattn=[c], c_concat=[control])
|
| 327 |
+
|
| 328 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
| 329 |
+
assert isinstance(cond, dict)
|
| 330 |
+
diffusion_model = self.model.diffusion_model
|
| 331 |
+
|
| 332 |
+
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
| 333 |
+
|
| 334 |
+
if cond['c_concat'] is None:
|
| 335 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 336 |
+
else:
|
| 337 |
+
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
| 338 |
+
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
| 339 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 340 |
+
|
| 341 |
+
return eps
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def get_unconditional_conditioning(self, N):
|
| 345 |
+
return self.get_learned_conditioning([""] * N)
|
| 346 |
+
|
| 347 |
+
@torch.no_grad()
|
| 348 |
+
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 349 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
| 350 |
+
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
| 351 |
+
use_ema_scope=True,
|
| 352 |
+
**kwargs):
|
| 353 |
+
use_ddim = ddim_steps is not None
|
| 354 |
+
|
| 355 |
+
log = dict()
|
| 356 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 357 |
+
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
| 358 |
+
N = min(z.shape[0], N)
|
| 359 |
+
n_row = min(z.shape[0], n_row)
|
| 360 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
| 361 |
+
log["control"] = c_cat * 2.0 - 1.0
|
| 362 |
+
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
| 363 |
+
|
| 364 |
+
if plot_diffusion_rows:
|
| 365 |
+
# get diffusion row
|
| 366 |
+
diffusion_row = list()
|
| 367 |
+
z_start = z[:n_row]
|
| 368 |
+
for t in range(self.num_timesteps):
|
| 369 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 370 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 371 |
+
t = t.to(self.device).long()
|
| 372 |
+
noise = torch.randn_like(z_start)
|
| 373 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 374 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 375 |
+
|
| 376 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 377 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
| 378 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
| 379 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 380 |
+
log["diffusion_row"] = diffusion_grid
|
| 381 |
+
|
| 382 |
+
if sample:
|
| 383 |
+
# get denoise row
|
| 384 |
+
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 385 |
+
batch_size=N, ddim=use_ddim,
|
| 386 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
| 387 |
+
x_samples = self.decode_first_stage(samples)
|
| 388 |
+
log["samples"] = x_samples
|
| 389 |
+
if plot_denoise_rows:
|
| 390 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 391 |
+
log["denoise_row"] = denoise_grid
|
| 392 |
+
|
| 393 |
+
if unconditional_guidance_scale > 1.0:
|
| 394 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
| 395 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 396 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 397 |
+
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 398 |
+
batch_size=N, ddim=use_ddim,
|
| 399 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
| 400 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 401 |
+
unconditional_conditioning=uc_full,
|
| 402 |
+
)
|
| 403 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 404 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 405 |
+
|
| 406 |
+
return log
|
| 407 |
+
|
| 408 |
+
@torch.no_grad()
|
| 409 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 410 |
+
ddim_sampler = DDIMSampler(self)
|
| 411 |
+
b, c, h, w = cond["c_concat"][0].shape
|
| 412 |
+
shape = (self.channels, h // 8, w // 8)
|
| 413 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 414 |
+
return samples, intermediates
|
| 415 |
+
|
| 416 |
+
def configure_optimizers(self):
|
| 417 |
+
lr = self.learning_rate
|
| 418 |
+
params = list(self.control_model.parameters())
|
| 419 |
+
if not self.sd_locked:
|
| 420 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 421 |
+
params += list(self.model.diffusion_model.out.parameters())
|
| 422 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 423 |
+
return opt
|
| 424 |
+
|
| 425 |
+
def low_vram_shift(self, is_diffusing):
|
| 426 |
+
if is_diffusing:
|
| 427 |
+
self.model = self.model.cuda()
|
| 428 |
+
self.control_model = self.control_model.cuda()
|
| 429 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
| 430 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
| 431 |
+
else:
|
| 432 |
+
self.model = self.model.cpu()
|
| 433 |
+
self.control_model = self.control_model.cpu()
|
| 434 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
| 435 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
generation/control/ControlNet/cldm/ddim_hacked.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DDIMSampler(object):
|
| 11 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.model = model
|
| 14 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 15 |
+
self.schedule = schedule
|
| 16 |
+
|
| 17 |
+
def register_buffer(self, name, attr):
|
| 18 |
+
if type(attr) == torch.Tensor:
|
| 19 |
+
if attr.device != torch.device("cuda"):
|
| 20 |
+
attr = attr.to(torch.device("cuda"))
|
| 21 |
+
setattr(self, name, attr)
|
| 22 |
+
|
| 23 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
| 24 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
| 25 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
| 26 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 27 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 28 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 29 |
+
|
| 30 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
| 31 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 32 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
| 33 |
+
|
| 34 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 35 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
| 36 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
| 37 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
| 38 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
| 39 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
| 40 |
+
|
| 41 |
+
# ddim sampling parameters
|
| 42 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
| 43 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 44 |
+
eta=ddim_eta,verbose=verbose)
|
| 45 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
| 46 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
| 47 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
| 48 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
| 49 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 50 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
| 51 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
| 52 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def sample(self,
|
| 56 |
+
S,
|
| 57 |
+
batch_size,
|
| 58 |
+
shape,
|
| 59 |
+
conditioning=None,
|
| 60 |
+
callback=None,
|
| 61 |
+
normals_sequence=None,
|
| 62 |
+
img_callback=None,
|
| 63 |
+
quantize_x0=False,
|
| 64 |
+
eta=0.,
|
| 65 |
+
mask=None,
|
| 66 |
+
x0=None,
|
| 67 |
+
temperature=1.,
|
| 68 |
+
noise_dropout=0.,
|
| 69 |
+
score_corrector=None,
|
| 70 |
+
corrector_kwargs=None,
|
| 71 |
+
verbose=True,
|
| 72 |
+
x_T=None,
|
| 73 |
+
log_every_t=100,
|
| 74 |
+
unconditional_guidance_scale=1.,
|
| 75 |
+
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 76 |
+
dynamic_threshold=None,
|
| 77 |
+
ucg_schedule=None,
|
| 78 |
+
**kwargs
|
| 79 |
+
):
|
| 80 |
+
if conditioning is not None:
|
| 81 |
+
if isinstance(conditioning, dict):
|
| 82 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
| 83 |
+
while isinstance(ctmp, list): ctmp = ctmp[0]
|
| 84 |
+
cbs = ctmp.shape[0]
|
| 85 |
+
if cbs != batch_size:
|
| 86 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 87 |
+
|
| 88 |
+
elif isinstance(conditioning, list):
|
| 89 |
+
for ctmp in conditioning:
|
| 90 |
+
if ctmp.shape[0] != batch_size:
|
| 91 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
if conditioning.shape[0] != batch_size:
|
| 95 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 96 |
+
|
| 97 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 98 |
+
# sampling
|
| 99 |
+
C, H, W = shape
|
| 100 |
+
size = (batch_size, C, H, W)
|
| 101 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
| 102 |
+
|
| 103 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
| 104 |
+
callback=callback,
|
| 105 |
+
img_callback=img_callback,
|
| 106 |
+
quantize_denoised=quantize_x0,
|
| 107 |
+
mask=mask, x0=x0,
|
| 108 |
+
ddim_use_original_steps=False,
|
| 109 |
+
noise_dropout=noise_dropout,
|
| 110 |
+
temperature=temperature,
|
| 111 |
+
score_corrector=score_corrector,
|
| 112 |
+
corrector_kwargs=corrector_kwargs,
|
| 113 |
+
x_T=x_T,
|
| 114 |
+
log_every_t=log_every_t,
|
| 115 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 116 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 117 |
+
dynamic_threshold=dynamic_threshold,
|
| 118 |
+
ucg_schedule=ucg_schedule
|
| 119 |
+
)
|
| 120 |
+
return samples, intermediates
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def ddim_sampling(self, cond, shape,
|
| 124 |
+
x_T=None, ddim_use_original_steps=False,
|
| 125 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
| 126 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
| 127 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 128 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
| 129 |
+
ucg_schedule=None):
|
| 130 |
+
device = self.model.betas.device
|
| 131 |
+
b = shape[0]
|
| 132 |
+
if x_T is None:
|
| 133 |
+
img = torch.randn(shape, device=device)
|
| 134 |
+
else:
|
| 135 |
+
img = x_T
|
| 136 |
+
|
| 137 |
+
if timesteps is None:
|
| 138 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 139 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 140 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 141 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 142 |
+
|
| 143 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 144 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 145 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 146 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 147 |
+
|
| 148 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 149 |
+
|
| 150 |
+
for i, step in enumerate(iterator):
|
| 151 |
+
index = total_steps - i - 1
|
| 152 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 153 |
+
|
| 154 |
+
if mask is not None:
|
| 155 |
+
assert x0 is not None
|
| 156 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
| 157 |
+
img = img_orig * mask + (1. - mask) * img
|
| 158 |
+
|
| 159 |
+
if ucg_schedule is not None:
|
| 160 |
+
assert len(ucg_schedule) == len(time_range)
|
| 161 |
+
unconditional_guidance_scale = ucg_schedule[i]
|
| 162 |
+
|
| 163 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 164 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 165 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 166 |
+
corrector_kwargs=corrector_kwargs,
|
| 167 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 168 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 169 |
+
dynamic_threshold=dynamic_threshold)
|
| 170 |
+
img, pred_x0 = outs
|
| 171 |
+
if callback: callback(i)
|
| 172 |
+
if img_callback: img_callback(pred_x0, i)
|
| 173 |
+
|
| 174 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 175 |
+
intermediates['x_inter'].append(img)
|
| 176 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 177 |
+
|
| 178 |
+
return img, intermediates
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
| 182 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 183 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
| 184 |
+
dynamic_threshold=None):
|
| 185 |
+
b, *_, device = *x.shape, x.device
|
| 186 |
+
|
| 187 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 188 |
+
model_output = self.model.apply_model(x, t, c)
|
| 189 |
+
else:
|
| 190 |
+
model_t = self.model.apply_model(x, t, c)
|
| 191 |
+
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
| 192 |
+
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
| 193 |
+
|
| 194 |
+
if self.model.parameterization == "v":
|
| 195 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
| 196 |
+
else:
|
| 197 |
+
e_t = model_output
|
| 198 |
+
|
| 199 |
+
if score_corrector is not None:
|
| 200 |
+
assert self.model.parameterization == "eps", 'not implemented'
|
| 201 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
| 202 |
+
|
| 203 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 204 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 205 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 206 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 207 |
+
# select parameters corresponding to the currently considered timestep
|
| 208 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 209 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 210 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 211 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 212 |
+
|
| 213 |
+
# current prediction for x_0
|
| 214 |
+
if self.model.parameterization != "v":
|
| 215 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 216 |
+
else:
|
| 217 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
| 218 |
+
|
| 219 |
+
if quantize_denoised:
|
| 220 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 221 |
+
|
| 222 |
+
if dynamic_threshold is not None:
|
| 223 |
+
raise NotImplementedError()
|
| 224 |
+
|
| 225 |
+
# direction pointing to x_t
|
| 226 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 227 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 228 |
+
if noise_dropout > 0.:
|
| 229 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 230 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 231 |
+
return x_prev, pred_x0
|
| 232 |
+
|
| 233 |
+
@torch.no_grad()
|
| 234 |
+
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
| 235 |
+
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
| 236 |
+
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
| 237 |
+
num_reference_steps = timesteps.shape[0]
|
| 238 |
+
|
| 239 |
+
assert t_enc <= num_reference_steps
|
| 240 |
+
num_steps = t_enc
|
| 241 |
+
|
| 242 |
+
if use_original_steps:
|
| 243 |
+
alphas_next = self.alphas_cumprod[:num_steps]
|
| 244 |
+
alphas = self.alphas_cumprod_prev[:num_steps]
|
| 245 |
+
else:
|
| 246 |
+
alphas_next = self.ddim_alphas[:num_steps]
|
| 247 |
+
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
| 248 |
+
|
| 249 |
+
x_next = x0
|
| 250 |
+
intermediates = []
|
| 251 |
+
inter_steps = []
|
| 252 |
+
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
| 253 |
+
t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
|
| 254 |
+
if unconditional_guidance_scale == 1.:
|
| 255 |
+
noise_pred = self.model.apply_model(x_next, t, c)
|
| 256 |
+
else:
|
| 257 |
+
assert unconditional_conditioning is not None
|
| 258 |
+
e_t_uncond, noise_pred = torch.chunk(
|
| 259 |
+
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
| 260 |
+
torch.cat((unconditional_conditioning, c))), 2)
|
| 261 |
+
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
| 262 |
+
|
| 263 |
+
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
| 264 |
+
weighted_noise_pred = alphas_next[i].sqrt() * (
|
| 265 |
+
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
| 266 |
+
x_next = xt_weighted + weighted_noise_pred
|
| 267 |
+
if return_intermediates and i % (
|
| 268 |
+
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
| 269 |
+
intermediates.append(x_next)
|
| 270 |
+
inter_steps.append(i)
|
| 271 |
+
elif return_intermediates and i >= num_steps - 2:
|
| 272 |
+
intermediates.append(x_next)
|
| 273 |
+
inter_steps.append(i)
|
| 274 |
+
if callback: callback(i)
|
| 275 |
+
|
| 276 |
+
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
| 277 |
+
if return_intermediates:
|
| 278 |
+
out.update({'intermediates': intermediates})
|
| 279 |
+
return x_next, out
|
| 280 |
+
|
| 281 |
+
@torch.no_grad()
|
| 282 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
| 283 |
+
# fast, but does not allow for exact reconstruction
|
| 284 |
+
# t serves as an index to gather the correct alphas
|
| 285 |
+
if use_original_steps:
|
| 286 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
| 287 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
| 288 |
+
else:
|
| 289 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
| 290 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
| 291 |
+
|
| 292 |
+
if noise is None:
|
| 293 |
+
noise = torch.randn_like(x0)
|
| 294 |
+
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
| 295 |
+
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
| 296 |
+
|
| 297 |
+
@torch.no_grad()
|
| 298 |
+
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
| 299 |
+
use_original_steps=False, callback=None):
|
| 300 |
+
|
| 301 |
+
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
| 302 |
+
timesteps = timesteps[:t_start]
|
| 303 |
+
|
| 304 |
+
time_range = np.flip(timesteps)
|
| 305 |
+
total_steps = timesteps.shape[0]
|
| 306 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 307 |
+
|
| 308 |
+
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
| 309 |
+
x_dec = x_latent
|
| 310 |
+
for i, step in enumerate(iterator):
|
| 311 |
+
index = total_steps - i - 1
|
| 312 |
+
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
| 313 |
+
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
| 314 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 315 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 316 |
+
if callback: callback(i)
|
| 317 |
+
return x_dec
|
generation/control/ControlNet/cldm/hack.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import einops
|
| 3 |
+
|
| 4 |
+
import ldm.modules.encoders.modules
|
| 5 |
+
import ldm.modules.attention
|
| 6 |
+
|
| 7 |
+
from transformers import logging
|
| 8 |
+
from ldm.modules.attention import default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def disable_verbosity():
|
| 12 |
+
logging.set_verbosity_error()
|
| 13 |
+
print('logging improved.')
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def enable_sliced_attention():
|
| 18 |
+
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
|
| 19 |
+
print('Enabled sliced_attention.')
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def hack_everything(clip_skip=0):
|
| 24 |
+
disable_verbosity()
|
| 25 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
|
| 26 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
|
| 27 |
+
print('Enabled clip hacks.')
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Written by Lvmin
|
| 32 |
+
def _hacked_clip_forward(self, text):
|
| 33 |
+
PAD = self.tokenizer.pad_token_id
|
| 34 |
+
EOS = self.tokenizer.eos_token_id
|
| 35 |
+
BOS = self.tokenizer.bos_token_id
|
| 36 |
+
|
| 37 |
+
def tokenize(t):
|
| 38 |
+
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
|
| 39 |
+
|
| 40 |
+
def transformer_encode(t):
|
| 41 |
+
if self.clip_skip > 1:
|
| 42 |
+
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
| 43 |
+
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
|
| 44 |
+
else:
|
| 45 |
+
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
|
| 46 |
+
|
| 47 |
+
def split(x):
|
| 48 |
+
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
|
| 49 |
+
|
| 50 |
+
def pad(x, p, i):
|
| 51 |
+
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
| 52 |
+
|
| 53 |
+
raw_tokens_list = tokenize(text)
|
| 54 |
+
tokens_list = []
|
| 55 |
+
|
| 56 |
+
for raw_tokens in raw_tokens_list:
|
| 57 |
+
raw_tokens_123 = split(raw_tokens)
|
| 58 |
+
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
|
| 59 |
+
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
| 60 |
+
tokens_list.append(raw_tokens_123)
|
| 61 |
+
|
| 62 |
+
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
| 63 |
+
|
| 64 |
+
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
|
| 65 |
+
y = transformer_encode(feed)
|
| 66 |
+
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
|
| 67 |
+
|
| 68 |
+
return z
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
| 72 |
+
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
| 73 |
+
h = self.heads
|
| 74 |
+
|
| 75 |
+
q = self.to_q(x)
|
| 76 |
+
context = default(context, x)
|
| 77 |
+
k = self.to_k(context)
|
| 78 |
+
v = self.to_v(context)
|
| 79 |
+
del context, x
|
| 80 |
+
|
| 81 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 82 |
+
|
| 83 |
+
limit = k.shape[0]
|
| 84 |
+
att_step = 1
|
| 85 |
+
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
| 86 |
+
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
| 87 |
+
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
| 88 |
+
|
| 89 |
+
q_chunks.reverse()
|
| 90 |
+
k_chunks.reverse()
|
| 91 |
+
v_chunks.reverse()
|
| 92 |
+
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
| 93 |
+
del k, q, v
|
| 94 |
+
for i in range(0, limit, att_step):
|
| 95 |
+
q_buffer = q_chunks.pop()
|
| 96 |
+
k_buffer = k_chunks.pop()
|
| 97 |
+
v_buffer = v_chunks.pop()
|
| 98 |
+
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
| 99 |
+
|
| 100 |
+
del k_buffer, q_buffer
|
| 101 |
+
# attention, what we cannot get enough of, by chunks
|
| 102 |
+
|
| 103 |
+
sim_buffer = sim_buffer.softmax(dim=-1)
|
| 104 |
+
|
| 105 |
+
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
| 106 |
+
del v_buffer
|
| 107 |
+
sim[i:i + att_step, :, :] = sim_buffer
|
| 108 |
+
|
| 109 |
+
del sim_buffer
|
| 110 |
+
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
| 111 |
+
return self.to_out(sim)
|
generation/control/ControlNet/cldm/logger.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from pytorch_lightning.callbacks import Callback
|
| 8 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ImageLogger(Callback):
|
| 12 |
+
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
|
| 13 |
+
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
| 14 |
+
log_images_kwargs=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.rescale = rescale
|
| 17 |
+
self.batch_freq = batch_frequency
|
| 18 |
+
self.max_images = max_images
|
| 19 |
+
if not increase_log_steps:
|
| 20 |
+
self.log_steps = [self.batch_freq]
|
| 21 |
+
self.clamp = clamp
|
| 22 |
+
self.disabled = disabled
|
| 23 |
+
self.log_on_batch_idx = log_on_batch_idx
|
| 24 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
| 25 |
+
self.log_first_step = log_first_step
|
| 26 |
+
|
| 27 |
+
@rank_zero_only
|
| 28 |
+
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
| 29 |
+
root = os.path.join(save_dir, "image_log", split)
|
| 30 |
+
for k in images:
|
| 31 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
| 32 |
+
if self.rescale:
|
| 33 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
| 34 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 35 |
+
grid = grid.numpy()
|
| 36 |
+
grid = (grid * 255).astype(np.uint8)
|
| 37 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
| 38 |
+
path = os.path.join(root, filename)
|
| 39 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
| 40 |
+
Image.fromarray(grid).save(path)
|
| 41 |
+
|
| 42 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
| 43 |
+
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
|
| 44 |
+
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
| 45 |
+
hasattr(pl_module, "log_images") and
|
| 46 |
+
callable(pl_module.log_images) and
|
| 47 |
+
self.max_images > 0):
|
| 48 |
+
logger = type(pl_module.logger)
|
| 49 |
+
|
| 50 |
+
is_train = pl_module.training
|
| 51 |
+
if is_train:
|
| 52 |
+
pl_module.eval()
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
| 56 |
+
|
| 57 |
+
for k in images:
|
| 58 |
+
N = min(images[k].shape[0], self.max_images)
|
| 59 |
+
images[k] = images[k][:N]
|
| 60 |
+
if isinstance(images[k], torch.Tensor):
|
| 61 |
+
images[k] = images[k].detach().cpu()
|
| 62 |
+
if self.clamp:
|
| 63 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
| 64 |
+
|
| 65 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
| 66 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
| 67 |
+
|
| 68 |
+
if is_train:
|
| 69 |
+
pl_module.train()
|
| 70 |
+
|
| 71 |
+
def check_frequency(self, check_idx):
|
| 72 |
+
return check_idx % self.batch_freq == 0
|
| 73 |
+
|
| 74 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
| 75 |
+
if not self.disabled:
|
| 76 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
generation/control/ControlNet/cldm/model.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from ldm.util import instantiate_from_config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_state_dict(d):
|
| 9 |
+
return d.get('state_dict', d)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_state_dict(ckpt_path, location='cpu'):
|
| 13 |
+
_, extension = os.path.splitext(ckpt_path)
|
| 14 |
+
if extension.lower() == ".safetensors":
|
| 15 |
+
import safetensors.torch
|
| 16 |
+
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
| 17 |
+
else:
|
| 18 |
+
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
| 19 |
+
state_dict = get_state_dict(state_dict)
|
| 20 |
+
print(f'Loaded state_dict from [{ckpt_path}]')
|
| 21 |
+
return state_dict
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def create_model(config_path):
|
| 25 |
+
config = OmegaConf.load(config_path)
|
| 26 |
+
model = instantiate_from_config(config.model).cpu()
|
| 27 |
+
print(f'Loaded model config from [{config_path}]')
|
| 28 |
+
return model
|
generation/control/ControlNet/config.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
save_memory = False
|
generation/control/ControlNet/docs/annotator.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatic Annotations
|
| 2 |
+
|
| 3 |
+
We provide gradio examples to obtain annotations that are aligned to our pretrained production-ready models.
|
| 4 |
+
|
| 5 |
+
Just run
|
| 6 |
+
|
| 7 |
+
python gradio_annotator.py
|
| 8 |
+
|
| 9 |
+
Since everyone has different habit to organize their datasets, we do not hard code any scripts for batch processing. But "gradio_annotator.py" is written in a super readable way, and modifying it to annotate your images should be easy.
|
| 10 |
+
|
| 11 |
+
In the gradio UI of "gradio_annotator.py" we have the following interfaces:
|
| 12 |
+
|
| 13 |
+
### Canny Edge
|
| 14 |
+
|
| 15 |
+
Be careful about "black edge and white background" or "white edge and black background".
|
| 16 |
+
|
| 17 |
+

|
| 18 |
+
|
| 19 |
+
### HED Edge
|
| 20 |
+
|
| 21 |
+
Be careful about "black edge and white background" or "white edge and black background".
|
| 22 |
+
|
| 23 |
+

|
| 24 |
+
|
| 25 |
+
### MLSD Edge
|
| 26 |
+
|
| 27 |
+
Be careful about "black edge and white background" or "white edge and black background".
|
| 28 |
+
|
| 29 |
+

|
| 30 |
+
|
| 31 |
+
### MIDAS Depth and Normal
|
| 32 |
+
|
| 33 |
+
Be careful about RGB or BGR in normal maps.
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
### Openpose
|
| 38 |
+
|
| 39 |
+
Be careful about RGB or BGR in pose maps.
|
| 40 |
+
|
| 41 |
+
For our production-ready model, the hand pose option is turned off.
|
| 42 |
+
|
| 43 |
+

|
| 44 |
+
|
| 45 |
+
### Uniformer Segmentation
|
| 46 |
+
|
| 47 |
+
Be careful about RGB or BGR in segmentation maps.
|
| 48 |
+
|
| 49 |
+

|
generation/control/ControlNet/docs/faq.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FAQs
|
| 2 |
+
|
| 3 |
+
**Q:** If the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
|
| 4 |
+
|
| 5 |
+
**A:** This is wrong. Let us consider a very simple
|
| 6 |
+
|
| 7 |
+
$$y=wx+b$$
|
| 8 |
+
|
| 9 |
+
and we have
|
| 10 |
+
|
| 11 |
+
$$\partial y/\partial w=x, \partial y/\partial x=w, \partial y/\partial b=1$$
|
| 12 |
+
|
| 13 |
+
and if $w=0$ and $x \neq 0$, then
|
| 14 |
+
|
| 15 |
+
$$\partial y/\partial w \neq 0, \partial y/\partial x=0, \partial y/\partial b\neq 0$$
|
| 16 |
+
|
| 17 |
+
which means as long as $x \neq 0$, one gradient descent iteration will make $w$ non-zero. Then
|
| 18 |
+
|
| 19 |
+
$$\partial y/\partial x\neq 0$$
|
| 20 |
+
|
| 21 |
+
so that the zero convolutions will progressively become a common conv layer with non-zero weights.
|
generation/control/ControlNet/docs/low_vram.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Enable Low VRAM Mode
|
| 2 |
+
|
| 3 |
+
If you are using 8GB GPU card (or if you want larger batch size), please open "config.py", and then set
|
| 4 |
+
|
| 5 |
+
```python
|
| 6 |
+
save_memory = True
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
This feature is still being tested - not all graphics cards are guaranteed to succeed.
|
| 10 |
+
|
| 11 |
+
But it should be neat as I can diffuse at a batch size of 12 now.
|
| 12 |
+
|
| 13 |
+
(prompt "man")
|
| 14 |
+
|
| 15 |
+

|
generation/control/ControlNet/docs/train.md
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train a ControlNet to Control SD
|
| 2 |
+
|
| 3 |
+
You are here because you want to control SD in your own way, maybe you have an idea for your perfect research project, and you will annotate some data or have already annotated your own dataset automatically or manually. Herein, the control can be anything that can be converted to images, such as edges, keypoints, segments, etc.
|
| 4 |
+
|
| 5 |
+
Before moving on to your own dataset, we highly recommend to first try the toy dataset, Fill50K, as a sanity check. This will help you get a "feeling" for the training. You will know how long it will take for the model to converge and whether your device will be able to complete the training in an acceptable amount of time. And what it "feels" like when the model converges.
|
| 6 |
+
|
| 7 |
+
We hope that after you read this page, you will find that training a ControlNet is as easy as (or easier than) training a pix2pix.
|
| 8 |
+
|
| 9 |
+
## Step 0 - Design your control
|
| 10 |
+
|
| 11 |
+
Let us take a look at a very simple task to control SD to fill color in circles.
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
This is simple: we want to control SD to fill a circle with colors, and the prompt contains some description of our target.
|
| 16 |
+
|
| 17 |
+
Stable diffusion is trained on billions of images, and it already knows what is "cyan", what is "circle", what is "pink", and what is "background".
|
| 18 |
+
|
| 19 |
+
But it does not know the meaning of that "Control Image (Source Image)". Our target is to let it know.
|
| 20 |
+
|
| 21 |
+
## Step 1 - Get a dataset
|
| 22 |
+
|
| 23 |
+
Just download the Fill50K dataset from [our huggingface page](https://huggingface.co/lllyasviel/ControlNet) (training/fill50k.zip, the file is only 200M!). Make sure that the data is decompressed as
|
| 24 |
+
|
| 25 |
+
ControlNet/training/fill50k/prompt.json
|
| 26 |
+
ControlNet/training/fill50k/source/X.png
|
| 27 |
+
ControlNet/training/fill50k/target/X.png
|
| 28 |
+
|
| 29 |
+
In the folder "fill50k/source", you will have 50k images of circle lines.
|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
In the folder "fill50k/target", you will have 50k images of filled circles.
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
In the "fill50k/prompt.json", you will have their filenames and prompts. Each prompt is like "a balabala color circle in some other color background."
|
| 38 |
+
|
| 39 |
+

|
| 40 |
+
|
| 41 |
+
## Step 2 - Load the dataset
|
| 42 |
+
|
| 43 |
+
Then you need to write a simple script to read this dataset for pytorch. (In fact we have written it for you in "tutorial_dataset.py".)
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
import json
|
| 47 |
+
import cv2
|
| 48 |
+
import numpy as np
|
| 49 |
+
|
| 50 |
+
from torch.utils.data import Dataset
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MyDataset(Dataset):
|
| 54 |
+
def __init__(self):
|
| 55 |
+
self.data = []
|
| 56 |
+
with open('./training/fill50k/prompt.json', 'rt') as f:
|
| 57 |
+
for line in f:
|
| 58 |
+
self.data.append(json.loads(line))
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.data)
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, idx):
|
| 64 |
+
item = self.data[idx]
|
| 65 |
+
|
| 66 |
+
source_filename = item['source']
|
| 67 |
+
target_filename = item['target']
|
| 68 |
+
prompt = item['prompt']
|
| 69 |
+
|
| 70 |
+
source = cv2.imread('./training/fill50k/' + source_filename)
|
| 71 |
+
target = cv2.imread('./training/fill50k/' + target_filename)
|
| 72 |
+
|
| 73 |
+
# Do not forget that OpenCV read images in BGR order.
|
| 74 |
+
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
|
| 75 |
+
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
|
| 76 |
+
|
| 77 |
+
# Normalize source images to [0, 1].
|
| 78 |
+
source = source.astype(np.float32) / 255.0
|
| 79 |
+
|
| 80 |
+
# Normalize target images to [-1, 1].
|
| 81 |
+
target = (target.astype(np.float32) / 127.5) - 1.0
|
| 82 |
+
|
| 83 |
+
return dict(jpg=target, txt=prompt, hint=source)
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
This will make your dataset into an array-like object in python. You can test this dataset simply by accessing the array, like this
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
from tutorial_dataset import MyDataset
|
| 91 |
+
|
| 92 |
+
dataset = MyDataset()
|
| 93 |
+
print(len(dataset))
|
| 94 |
+
|
| 95 |
+
item = dataset[1234]
|
| 96 |
+
jpg = item['jpg']
|
| 97 |
+
txt = item['txt']
|
| 98 |
+
hint = item['hint']
|
| 99 |
+
print(txt)
|
| 100 |
+
print(jpg.shape)
|
| 101 |
+
print(hint.shape)
|
| 102 |
+
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
The outputs of this simple test on my machine are
|
| 106 |
+
|
| 107 |
+
50000
|
| 108 |
+
burly wood circle with orange background
|
| 109 |
+
(512, 512, 3)
|
| 110 |
+
(512, 512, 3)
|
| 111 |
+
|
| 112 |
+
And this code is in "tutorial_dataset_test.py".
|
| 113 |
+
|
| 114 |
+
In this way, the dataset is an array-like object with 50000 items. Each item is a dict with three entry "jpg", "txt", and "hint". The "jpg" is the target image, the "hint" is the control image, and the "txt" is the prompt.
|
| 115 |
+
|
| 116 |
+
Do not ask us why we use these three names - this is related to the dark history of a library called LDM.
|
| 117 |
+
|
| 118 |
+
## Step 3 - What SD model do you want to control?
|
| 119 |
+
|
| 120 |
+
Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main).
|
| 121 |
+
|
| 122 |
+
(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2.)
|
| 123 |
+
|
| 124 |
+
Then you need to attach a control net to the SD model. The architecture is
|
| 125 |
+
|
| 126 |
+

|
| 127 |
+
|
| 128 |
+
Note that all weights inside the ControlNet are also copied from SD so that no layer is trained from scratch, and you are still finetuning the entire model.
|
| 129 |
+
|
| 130 |
+
We provide a simple script for you to achieve this easily. If your SD filename is "./models/v1-5-pruned.ckpt" and you want the script to save the processed model (SD+ControlNet) at location "./models/control_sd15_ini.ckpt", you can just run:
|
| 131 |
+
|
| 132 |
+
python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt
|
| 133 |
+
|
| 134 |
+
Or if you are using SD2:
|
| 135 |
+
|
| 136 |
+
python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt
|
| 137 |
+
|
| 138 |
+
You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".
|
| 139 |
+
|
| 140 |
+
This is the correct output from my machine:
|
| 141 |
+
|
| 142 |
+

|
| 143 |
+
|
| 144 |
+
## Step 4 - Train!
|
| 145 |
+
|
| 146 |
+
Happy! We finally come to the most exciting part: training!
|
| 147 |
+
|
| 148 |
+
The training code in "tutorial_train.py" is actually surprisingly simple:
|
| 149 |
+
|
| 150 |
+
```python
|
| 151 |
+
import pytorch_lightning as pl
|
| 152 |
+
from torch.utils.data import DataLoader
|
| 153 |
+
from tutorial_dataset import MyDataset
|
| 154 |
+
from cldm.logger import ImageLogger
|
| 155 |
+
from cldm.model import create_model, load_state_dict
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Configs
|
| 159 |
+
resume_path = './models/control_sd15_ini.ckpt'
|
| 160 |
+
batch_size = 4
|
| 161 |
+
logger_freq = 300
|
| 162 |
+
learning_rate = 1e-5
|
| 163 |
+
sd_locked = True
|
| 164 |
+
only_mid_control = False
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
| 168 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 169 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
| 170 |
+
model.learning_rate = learning_rate
|
| 171 |
+
model.sd_locked = sd_locked
|
| 172 |
+
model.only_mid_control = only_mid_control
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Misc
|
| 176 |
+
dataset = MyDataset()
|
| 177 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
| 178 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
| 179 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Train!
|
| 183 |
+
trainer.fit(model, dataloader)
|
| 184 |
+
|
| 185 |
+
```
|
| 186 |
+
(or "tutorial_train_sd21.py" if you are using SD2)
|
| 187 |
+
|
| 188 |
+
Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.
|
| 189 |
+
|
| 190 |
+
Now, you may take a look at [Pytorch Lightning Official DOC](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#trainer) to find out how to enable many useful features like gradient accumulation, multiple GPU training, accelerated dataset loading, flexible checkpoint saving, etc. All these only need about one line of code. Great!
|
| 191 |
+
|
| 192 |
+
Note that if you find OOM, perhaps you need to enable [Low VRAM mode](low_vram.md), and perhaps you also need to use smaller batch size and gradient accumulation. Or you may also want to use some “advanced” tricks like sliced attention or xformers. For example:
|
| 193 |
+
|
| 194 |
+
```python
|
| 195 |
+
# Configs
|
| 196 |
+
batch_size = 1
|
| 197 |
+
|
| 198 |
+
# Misc
|
| 199 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger], accumulate_grad_batches=4) # But this will be 4x slower
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
Note that training with 8 GB laptop GPU is challenging. We will need some GPU memory optimization at least as good as automatic1111’s UI. This may require expert modifications to the code.
|
| 203 |
+
|
| 204 |
+
### Screenshots
|
| 205 |
+
|
| 206 |
+
The training is fast. After 4000 steps (batch size 4, learning rate 1e-5, about 50 minutes on PCIE 40G), the results on my machine (in an output folder "image_log") is
|
| 207 |
+
|
| 208 |
+
Control:
|
| 209 |
+
|
| 210 |
+

|
| 211 |
+
|
| 212 |
+
Prompt:
|
| 213 |
+
|
| 214 |
+

|
| 215 |
+
|
| 216 |
+
Prediction:
|
| 217 |
+
|
| 218 |
+

|
| 219 |
+
|
| 220 |
+
Ground Truth:
|
| 221 |
+
|
| 222 |
+

|
| 223 |
+
|
| 224 |
+
Note that the SD's capability is preserved. Even training on this super aligned dataset, it still draws some random textures and those snow decorations. (Besides, note that the ground truth looks a bit modified because it is converted from SD's latent image.)
|
| 225 |
+
|
| 226 |
+
Larger batch size and longer training will further improve this. Adequate training will make the filling perfect.
|
| 227 |
+
|
| 228 |
+
Of course, training SD to fill circles is meaningless, but this is a successful beginning of your story.
|
| 229 |
+
|
| 230 |
+
Let us work together to control large models more and more.
|
| 231 |
+
|
| 232 |
+
## Other options
|
| 233 |
+
|
| 234 |
+
Beyond standard things, we also provide two important parameters "sd_locked" and "only_mid_control" that you need to know.
|
| 235 |
+
|
| 236 |
+
### only_mid_control
|
| 237 |
+
|
| 238 |
+
By default, only_mid_control is False. When it is True, you will train the below architecture.
|
| 239 |
+
|
| 240 |
+

|
| 241 |
+
|
| 242 |
+
This can be helpful when your computation power is limited and want to speed up the training, or when you want to facilitate the "global" context learning. Note that sometimes you may pause training, set it to True, resume training, and pause again, and set it again, and resume again.
|
| 243 |
+
|
| 244 |
+
If your computation device is good, perhaps you do not need this. But I also know some artists are willing to train a model on their laptop for a month - in that case, perhaps this option can be useful.
|
| 245 |
+
|
| 246 |
+
### sd_locked
|
| 247 |
+
|
| 248 |
+
By default, sd_locked is True. When it is False, you will train the below architecture.
|
| 249 |
+
|
| 250 |
+

|
| 251 |
+
|
| 252 |
+
This will unlock some layers in SD and you will train them as a whole.
|
| 253 |
+
|
| 254 |
+
This option is DANGEROUS! If your dataset is not good enough, this may downgrade the capability of your SD model.
|
| 255 |
+
|
| 256 |
+
However, this option is also very useful when you are training on images with some specific style, or when you are training with special datasets (like medical dataset with X-ray images or geographic datasets with lots of Google Maps). You can understand this as simultaneously training the ControlNet and something like a DreamBooth.
|
| 257 |
+
|
| 258 |
+
Also, if your dataset is large, you may want to end the training with a few thousands of steps with those layer unlocked. This usually improve the "problem-specific" solutions a little. You may try it yourself to feel the difference.
|
| 259 |
+
|
| 260 |
+
Also, if you unlock some original layers, you may want a lower learning rate, like 2e-6.
|
| 261 |
+
|
| 262 |
+
## More Consideration: Sudden Converge Phenomenon and Gradient Accumulation
|
| 263 |
+
|
| 264 |
+

|
| 265 |
+
|
| 266 |
+
Because we use zero convolutions, the SD should always be able to predict meaningful images. (If it cannot, the training has already failed.)
|
| 267 |
+
|
| 268 |
+
You will always find that at some iterations, the model "suddenly" be able to fit some training conditions. This means that you will get a basically usable model at about 3k to 7k steps (future training will improve it, but that model after the first "sudden converge" should be basically functional).
|
| 269 |
+
|
| 270 |
+
Note that 3k to 7k steps is not very large, and you should consider larger batch size rather than more training steps. If you can observe the "sudden converge" at 3k step using batch size 4, then, rather than train it with 300k further steps, a better idea is to use 100× gradient accumulation to re-train that 3k steps with 100× batch size. Note that perhaps we should not do this *too* extremely (perhaps 100x accumulation is too extreme), but you should consider that, since "sudden converge" will *always* happen at that certain point, getting a better converge is more important.
|
| 271 |
+
|
| 272 |
+
Because that "sudden converge" always happens, lets say "sudden converge" will happen at 3k step and our money can optimize 90k step, then we have two options: (1) train 3k steps, sudden converge, then train 87k steps. (2) 30x gradient accumulation, train 3k steps (90k real computation steps), then sudden converge.
|
| 273 |
+
|
| 274 |
+
In my experiments, (2) is usually better than (1). However, in real cases, perhaps you may need to balance the steps before and after the "sudden converge" on your own to find a balance. The training after "sudden converge" is also important.
|
| 275 |
+
|
| 276 |
+
But usually, if your logic batch size is already bigger than 256, then further extending the batch size is not very meaningful. In that case, perhaps a better idea is to train more steps. I tried some "common" logic batch size at 64 or 96 or 128 (by gradient accumulation), it seems that many complicated conditions can be solved very well already.
|
generation/control/ControlNet/environment.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: control
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.8.5
|
| 7 |
+
- pip=20.3
|
| 8 |
+
- cudatoolkit=11.3
|
| 9 |
+
- pytorch=1.12.1
|
| 10 |
+
- torchvision=0.13.1
|
| 11 |
+
- numpy=1.23.1
|
| 12 |
+
- pip:
|
| 13 |
+
- gradio==3.16.2
|
| 14 |
+
- albumentations==1.3.0
|
| 15 |
+
- opencv-contrib-python==4.3.0.36
|
| 16 |
+
- imageio==2.9.0
|
| 17 |
+
- imageio-ffmpeg==0.4.2
|
| 18 |
+
- pytorch-lightning==1.5.0
|
| 19 |
+
- omegaconf==2.1.1
|
| 20 |
+
- test-tube>=0.7.5
|
| 21 |
+
- streamlit==1.12.1
|
| 22 |
+
- einops==0.3.0
|
| 23 |
+
- transformers==4.19.2
|
| 24 |
+
- webdataset==0.2.5
|
| 25 |
+
- kornia==0.6
|
| 26 |
+
- open_clip_torch==2.0.2
|
| 27 |
+
- invisible-watermark>=0.1.5
|
| 28 |
+
- streamlit-drawable-canvas==0.8.0
|
| 29 |
+
- torchmetrics==0.6.0
|
| 30 |
+
- timm==0.6.12
|
| 31 |
+
- addict==2.4.0
|
| 32 |
+
- yapf==0.32.0
|
| 33 |
+
- prettytable==3.6.0
|
| 34 |
+
- safetensors==0.2.7
|
| 35 |
+
- basicsr==1.4.2
|
generation/control/ControlNet/gradio_annotator.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from annotator.util import resize_image, HWC3
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
model_canny = None
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def canny(img, res, l, h):
|
| 10 |
+
img = resize_image(HWC3(img), res)
|
| 11 |
+
global model_canny
|
| 12 |
+
if model_canny is None:
|
| 13 |
+
from annotator.canny import CannyDetector
|
| 14 |
+
model_canny = CannyDetector()
|
| 15 |
+
result = model_canny(img, l, h)
|
| 16 |
+
return [result]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
model_hed = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def hed(img, res):
|
| 23 |
+
img = resize_image(HWC3(img), res)
|
| 24 |
+
global model_hed
|
| 25 |
+
if model_hed is None:
|
| 26 |
+
from annotator.hed import HEDdetector
|
| 27 |
+
model_hed = HEDdetector()
|
| 28 |
+
result = model_hed(img)
|
| 29 |
+
return [result]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
model_mlsd = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def mlsd(img, res, thr_v, thr_d):
|
| 36 |
+
img = resize_image(HWC3(img), res)
|
| 37 |
+
global model_mlsd
|
| 38 |
+
if model_mlsd is None:
|
| 39 |
+
from annotator.mlsd import MLSDdetector
|
| 40 |
+
model_mlsd = MLSDdetector()
|
| 41 |
+
result = model_mlsd(img, thr_v, thr_d)
|
| 42 |
+
return [result]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
model_midas = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def midas(img, res, a):
|
| 49 |
+
img = resize_image(HWC3(img), res)
|
| 50 |
+
global model_midas
|
| 51 |
+
if model_midas is None:
|
| 52 |
+
from annotator.midas import MidasDetector
|
| 53 |
+
model_midas = MidasDetector()
|
| 54 |
+
results = model_midas(img, a)
|
| 55 |
+
return results
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
model_openpose = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def openpose(img, res, has_hand):
|
| 62 |
+
img = resize_image(HWC3(img), res)
|
| 63 |
+
global model_openpose
|
| 64 |
+
if model_openpose is None:
|
| 65 |
+
from annotator.openpose import OpenposeDetector
|
| 66 |
+
model_openpose = OpenposeDetector()
|
| 67 |
+
result, _ = model_openpose(img, has_hand)
|
| 68 |
+
return [result]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
model_uniformer = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def uniformer(img, res):
|
| 75 |
+
img = resize_image(HWC3(img), res)
|
| 76 |
+
global model_uniformer
|
| 77 |
+
if model_uniformer is None:
|
| 78 |
+
from annotator.uniformer import UniformerDetector
|
| 79 |
+
model_uniformer = UniformerDetector()
|
| 80 |
+
result = model_uniformer(img)
|
| 81 |
+
return [result]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
block = gr.Blocks().queue()
|
| 85 |
+
with block:
|
| 86 |
+
with gr.Row():
|
| 87 |
+
gr.Markdown("## Canny Edge")
|
| 88 |
+
with gr.Row():
|
| 89 |
+
with gr.Column():
|
| 90 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 91 |
+
low_threshold = gr.Slider(label="low_threshold", minimum=1, maximum=255, value=100, step=1)
|
| 92 |
+
high_threshold = gr.Slider(label="high_threshold", minimum=1, maximum=255, value=200, step=1)
|
| 93 |
+
resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
|
| 94 |
+
run_button = gr.Button(label="Run")
|
| 95 |
+
with gr.Column():
|
| 96 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
|
| 97 |
+
run_button.click(fn=canny, inputs=[input_image, resolution, low_threshold, high_threshold], outputs=[gallery])
|
| 98 |
+
|
| 99 |
+
with gr.Row():
|
| 100 |
+
gr.Markdown("## HED Edge")
|
| 101 |
+
with gr.Row():
|
| 102 |
+
with gr.Column():
|
| 103 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 104 |
+
resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
|
| 105 |
+
run_button = gr.Button(label="Run")
|
| 106 |
+
with gr.Column():
|
| 107 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
|
| 108 |
+
run_button.click(fn=hed, inputs=[input_image, resolution], outputs=[gallery])
|
| 109 |
+
|
| 110 |
+
with gr.Row():
|
| 111 |
+
gr.Markdown("## MLSD Edge")
|
| 112 |
+
with gr.Row():
|
| 113 |
+
with gr.Column():
|
| 114 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 115 |
+
value_threshold = gr.Slider(label="value_threshold", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
|
| 116 |
+
distance_threshold = gr.Slider(label="distance_threshold", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
|
| 117 |
+
resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64)
|
| 118 |
+
run_button = gr.Button(label="Run")
|
| 119 |
+
with gr.Column():
|
| 120 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
|
| 121 |
+
run_button.click(fn=mlsd, inputs=[input_image, resolution, value_threshold, distance_threshold], outputs=[gallery])
|
| 122 |
+
|
| 123 |
+
with gr.Row():
|
| 124 |
+
gr.Markdown("## MIDAS Depth and Normal")
|
| 125 |
+
with gr.Row():
|
| 126 |
+
with gr.Column():
|
| 127 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 128 |
+
alpha = gr.Slider(label="alpha", minimum=0.1, maximum=20.0, value=6.2, step=0.01)
|
| 129 |
+
resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64)
|
| 130 |
+
run_button = gr.Button(label="Run")
|
| 131 |
+
with gr.Column():
|
| 132 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
|
| 133 |
+
run_button.click(fn=midas, inputs=[input_image, resolution, alpha], outputs=[gallery])
|
| 134 |
+
|
| 135 |
+
with gr.Row():
|
| 136 |
+
gr.Markdown("## Openpose")
|
| 137 |
+
with gr.Row():
|
| 138 |
+
with gr.Column():
|
| 139 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 140 |
+
hand = gr.Checkbox(label='detect hand', value=False)
|
| 141 |
+
resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
|
| 142 |
+
run_button = gr.Button(label="Run")
|
| 143 |
+
with gr.Column():
|
| 144 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
|
| 145 |
+
run_button.click(fn=openpose, inputs=[input_image, resolution, hand], outputs=[gallery])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
with gr.Row():
|
| 149 |
+
gr.Markdown("## Uniformer Segmentation")
|
| 150 |
+
with gr.Row():
|
| 151 |
+
with gr.Column():
|
| 152 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 153 |
+
resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
|
| 154 |
+
run_button = gr.Button(label="Run")
|
| 155 |
+
with gr.Column():
|
| 156 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
|
| 157 |
+
run_button.click(fn=uniformer, inputs=[input_image, resolution], outputs=[gallery])
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_canny2image.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.canny import CannyDetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_canny = CannyDetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
img = resize_image(HWC3(input_image), image_resolution)
|
| 29 |
+
H, W, C = img.shape
|
| 30 |
+
|
| 31 |
+
detected_map = apply_canny(img, low_threshold, high_threshold)
|
| 32 |
+
detected_map = HWC3(detected_map)
|
| 33 |
+
|
| 34 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 35 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 36 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 37 |
+
|
| 38 |
+
if seed == -1:
|
| 39 |
+
seed = random.randint(0, 65535)
|
| 40 |
+
seed_everything(seed)
|
| 41 |
+
|
| 42 |
+
if config.save_memory:
|
| 43 |
+
model.low_vram_shift(is_diffusing=False)
|
| 44 |
+
|
| 45 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 46 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 47 |
+
shape = (4, H // 8, W // 8)
|
| 48 |
+
|
| 49 |
+
if config.save_memory:
|
| 50 |
+
model.low_vram_shift(is_diffusing=True)
|
| 51 |
+
|
| 52 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 53 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 54 |
+
shape, cond, verbose=False, eta=eta,
|
| 55 |
+
unconditional_guidance_scale=scale,
|
| 56 |
+
unconditional_conditioning=un_cond)
|
| 57 |
+
|
| 58 |
+
if config.save_memory:
|
| 59 |
+
model.low_vram_shift(is_diffusing=False)
|
| 60 |
+
|
| 61 |
+
x_samples = model.decode_first_stage(samples)
|
| 62 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 63 |
+
|
| 64 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 65 |
+
return [255 - detected_map] + results
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
block = gr.Blocks().queue()
|
| 69 |
+
with block:
|
| 70 |
+
with gr.Row():
|
| 71 |
+
gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
|
| 72 |
+
with gr.Row():
|
| 73 |
+
with gr.Column():
|
| 74 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 75 |
+
prompt = gr.Textbox(label="Prompt")
|
| 76 |
+
run_button = gr.Button(label="Run")
|
| 77 |
+
with gr.Accordion("Advanced options", open=False):
|
| 78 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 79 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 80 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 81 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 82 |
+
low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
|
| 83 |
+
high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
|
| 84 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 85 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 86 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 87 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 88 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 89 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 90 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 91 |
+
with gr.Column():
|
| 92 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 93 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold]
|
| 94 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_depth2image.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.midas import MidasDetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_midas = MidasDetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
|
| 30 |
+
detected_map = HWC3(detected_map)
|
| 31 |
+
img = resize_image(input_image, image_resolution)
|
| 32 |
+
H, W, C = img.shape
|
| 33 |
+
|
| 34 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 35 |
+
|
| 36 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 37 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 38 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 39 |
+
|
| 40 |
+
if seed == -1:
|
| 41 |
+
seed = random.randint(0, 65535)
|
| 42 |
+
seed_everything(seed)
|
| 43 |
+
|
| 44 |
+
if config.save_memory:
|
| 45 |
+
model.low_vram_shift(is_diffusing=False)
|
| 46 |
+
|
| 47 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 48 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 49 |
+
shape = (4, H // 8, W // 8)
|
| 50 |
+
|
| 51 |
+
if config.save_memory:
|
| 52 |
+
model.low_vram_shift(is_diffusing=True)
|
| 53 |
+
|
| 54 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 55 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 56 |
+
shape, cond, verbose=False, eta=eta,
|
| 57 |
+
unconditional_guidance_scale=scale,
|
| 58 |
+
unconditional_conditioning=un_cond)
|
| 59 |
+
|
| 60 |
+
if config.save_memory:
|
| 61 |
+
model.low_vram_shift(is_diffusing=False)
|
| 62 |
+
|
| 63 |
+
x_samples = model.decode_first_stage(samples)
|
| 64 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 67 |
+
return [detected_map] + results
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
block = gr.Blocks().queue()
|
| 71 |
+
with block:
|
| 72 |
+
with gr.Row():
|
| 73 |
+
gr.Markdown("## Control Stable Diffusion with Depth Maps")
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column():
|
| 76 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 77 |
+
prompt = gr.Textbox(label="Prompt")
|
| 78 |
+
run_button = gr.Button(label="Run")
|
| 79 |
+
with gr.Accordion("Advanced options", open=False):
|
| 80 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 81 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 82 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 83 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 84 |
+
detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
|
| 85 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 86 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 87 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 88 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 89 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 90 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 91 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 92 |
+
with gr.Column():
|
| 93 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 94 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 95 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_fake_scribble2image.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.hed import HEDdetector, nms
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_hed = HEDdetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
detected_map = apply_hed(resize_image(input_image, detect_resolution))
|
| 30 |
+
detected_map = HWC3(detected_map)
|
| 31 |
+
img = resize_image(input_image, image_resolution)
|
| 32 |
+
H, W, C = img.shape
|
| 33 |
+
|
| 34 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 35 |
+
detected_map = nms(detected_map, 127, 3.0)
|
| 36 |
+
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
| 37 |
+
detected_map[detected_map > 4] = 255
|
| 38 |
+
detected_map[detected_map < 255] = 0
|
| 39 |
+
|
| 40 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 41 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 42 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 43 |
+
|
| 44 |
+
if seed == -1:
|
| 45 |
+
seed = random.randint(0, 65535)
|
| 46 |
+
seed_everything(seed)
|
| 47 |
+
|
| 48 |
+
if config.save_memory:
|
| 49 |
+
model.low_vram_shift(is_diffusing=False)
|
| 50 |
+
|
| 51 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 52 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 53 |
+
shape = (4, H // 8, W // 8)
|
| 54 |
+
|
| 55 |
+
if config.save_memory:
|
| 56 |
+
model.low_vram_shift(is_diffusing=True)
|
| 57 |
+
|
| 58 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 59 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 60 |
+
shape, cond, verbose=False, eta=eta,
|
| 61 |
+
unconditional_guidance_scale=scale,
|
| 62 |
+
unconditional_conditioning=un_cond)
|
| 63 |
+
|
| 64 |
+
if config.save_memory:
|
| 65 |
+
model.low_vram_shift(is_diffusing=False)
|
| 66 |
+
|
| 67 |
+
x_samples = model.decode_first_stage(samples)
|
| 68 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 69 |
+
|
| 70 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 71 |
+
return [255 - detected_map] + results
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
block = gr.Blocks().queue()
|
| 75 |
+
with block:
|
| 76 |
+
with gr.Row():
|
| 77 |
+
gr.Markdown("## Control Stable Diffusion with Fake Scribble Maps")
|
| 78 |
+
with gr.Row():
|
| 79 |
+
with gr.Column():
|
| 80 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 81 |
+
prompt = gr.Textbox(label="Prompt")
|
| 82 |
+
run_button = gr.Button(label="Run")
|
| 83 |
+
with gr.Accordion("Advanced options", open=False):
|
| 84 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 85 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 86 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 87 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 88 |
+
detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
|
| 89 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 90 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 91 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 92 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 93 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 94 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 95 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 96 |
+
with gr.Column():
|
| 97 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 98 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 99 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_hed2image.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.hed import HEDdetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_hed = HEDdetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
detected_map = apply_hed(resize_image(input_image, detect_resolution))
|
| 30 |
+
detected_map = HWC3(detected_map)
|
| 31 |
+
img = resize_image(input_image, image_resolution)
|
| 32 |
+
H, W, C = img.shape
|
| 33 |
+
|
| 34 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 35 |
+
|
| 36 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 37 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 38 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 39 |
+
|
| 40 |
+
if seed == -1:
|
| 41 |
+
seed = random.randint(0, 65535)
|
| 42 |
+
seed_everything(seed)
|
| 43 |
+
|
| 44 |
+
if config.save_memory:
|
| 45 |
+
model.low_vram_shift(is_diffusing=False)
|
| 46 |
+
|
| 47 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 48 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 49 |
+
shape = (4, H // 8, W // 8)
|
| 50 |
+
|
| 51 |
+
if config.save_memory:
|
| 52 |
+
model.low_vram_shift(is_diffusing=True)
|
| 53 |
+
|
| 54 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 55 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 56 |
+
shape, cond, verbose=False, eta=eta,
|
| 57 |
+
unconditional_guidance_scale=scale,
|
| 58 |
+
unconditional_conditioning=un_cond)
|
| 59 |
+
|
| 60 |
+
if config.save_memory:
|
| 61 |
+
model.low_vram_shift(is_diffusing=False)
|
| 62 |
+
|
| 63 |
+
x_samples = model.decode_first_stage(samples)
|
| 64 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 67 |
+
return [detected_map] + results
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
block = gr.Blocks().queue()
|
| 71 |
+
with block:
|
| 72 |
+
with gr.Row():
|
| 73 |
+
gr.Markdown("## Control Stable Diffusion with HED Maps")
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column():
|
| 76 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 77 |
+
prompt = gr.Textbox(label="Prompt")
|
| 78 |
+
run_button = gr.Button(label="Run")
|
| 79 |
+
with gr.Accordion("Advanced options", open=False):
|
| 80 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 81 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 82 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 83 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 84 |
+
detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
|
| 85 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 86 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 87 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 88 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 89 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 90 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 91 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 92 |
+
with gr.Column():
|
| 93 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 94 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 95 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_hough2image.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.mlsd import MLSDdetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_mlsd = MLSDdetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold)
|
| 30 |
+
detected_map = HWC3(detected_map)
|
| 31 |
+
img = resize_image(input_image, image_resolution)
|
| 32 |
+
H, W, C = img.shape
|
| 33 |
+
|
| 34 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
|
| 35 |
+
|
| 36 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 37 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 38 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 39 |
+
|
| 40 |
+
if seed == -1:
|
| 41 |
+
seed = random.randint(0, 65535)
|
| 42 |
+
seed_everything(seed)
|
| 43 |
+
|
| 44 |
+
if config.save_memory:
|
| 45 |
+
model.low_vram_shift(is_diffusing=False)
|
| 46 |
+
|
| 47 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 48 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 49 |
+
shape = (4, H // 8, W // 8)
|
| 50 |
+
|
| 51 |
+
if config.save_memory:
|
| 52 |
+
model.low_vram_shift(is_diffusing=True)
|
| 53 |
+
|
| 54 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 55 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 56 |
+
shape, cond, verbose=False, eta=eta,
|
| 57 |
+
unconditional_guidance_scale=scale,
|
| 58 |
+
unconditional_conditioning=un_cond)
|
| 59 |
+
|
| 60 |
+
if config.save_memory:
|
| 61 |
+
model.low_vram_shift(is_diffusing=False)
|
| 62 |
+
|
| 63 |
+
x_samples = model.decode_first_stage(samples)
|
| 64 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 67 |
+
return [255 - cv2.dilate(detected_map, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)] + results
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
block = gr.Blocks().queue()
|
| 71 |
+
with block:
|
| 72 |
+
with gr.Row():
|
| 73 |
+
gr.Markdown("## Control Stable Diffusion with Hough Line Maps")
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column():
|
| 76 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 77 |
+
prompt = gr.Textbox(label="Prompt")
|
| 78 |
+
run_button = gr.Button(label="Run")
|
| 79 |
+
with gr.Accordion("Advanced options", open=False):
|
| 80 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 81 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 82 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 83 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 84 |
+
detect_resolution = gr.Slider(label="Hough Resolution", minimum=128, maximum=1024, value=512, step=1)
|
| 85 |
+
value_threshold = gr.Slider(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
|
| 86 |
+
distance_threshold = gr.Slider(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
|
| 87 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 88 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 89 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 90 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 91 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 92 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 93 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 94 |
+
with gr.Column():
|
| 95 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 96 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold]
|
| 97 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_normal2image.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.midas import MidasDetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_midas = MidasDetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
_, detected_map = apply_midas(resize_image(input_image, detect_resolution), bg_th=bg_threshold)
|
| 30 |
+
detected_map = HWC3(detected_map)
|
| 31 |
+
img = resize_image(input_image, image_resolution)
|
| 32 |
+
H, W, C = img.shape
|
| 33 |
+
|
| 34 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 35 |
+
|
| 36 |
+
control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0
|
| 37 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 38 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 39 |
+
|
| 40 |
+
if seed == -1:
|
| 41 |
+
seed = random.randint(0, 65535)
|
| 42 |
+
seed_everything(seed)
|
| 43 |
+
|
| 44 |
+
if config.save_memory:
|
| 45 |
+
model.low_vram_shift(is_diffusing=False)
|
| 46 |
+
|
| 47 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 48 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 49 |
+
shape = (4, H // 8, W // 8)
|
| 50 |
+
|
| 51 |
+
if config.save_memory:
|
| 52 |
+
model.low_vram_shift(is_diffusing=True)
|
| 53 |
+
|
| 54 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 55 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 56 |
+
shape, cond, verbose=False, eta=eta,
|
| 57 |
+
unconditional_guidance_scale=scale,
|
| 58 |
+
unconditional_conditioning=un_cond)
|
| 59 |
+
|
| 60 |
+
if config.save_memory:
|
| 61 |
+
model.low_vram_shift(is_diffusing=False)
|
| 62 |
+
|
| 63 |
+
x_samples = model.decode_first_stage(samples)
|
| 64 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 67 |
+
return [detected_map] + results
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
block = gr.Blocks().queue()
|
| 71 |
+
with block:
|
| 72 |
+
with gr.Row():
|
| 73 |
+
gr.Markdown("## Control Stable Diffusion with Normal Maps")
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column():
|
| 76 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 77 |
+
prompt = gr.Textbox(label="Prompt")
|
| 78 |
+
run_button = gr.Button(label="Run")
|
| 79 |
+
with gr.Accordion("Advanced options", open=False):
|
| 80 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 81 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 82 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 83 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 84 |
+
detect_resolution = gr.Slider(label="Normal Resolution", minimum=128, maximum=1024, value=384, step=1)
|
| 85 |
+
bg_threshold = gr.Slider(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01)
|
| 86 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 87 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 88 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 89 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 90 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 91 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 92 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 93 |
+
with gr.Column():
|
| 94 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 95 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold]
|
| 96 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_pose2image.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.openpose import OpenposeDetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_openpose = OpenposeDetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
|
| 30 |
+
detected_map = HWC3(detected_map)
|
| 31 |
+
img = resize_image(input_image, image_resolution)
|
| 32 |
+
H, W, C = img.shape
|
| 33 |
+
|
| 34 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
|
| 35 |
+
|
| 36 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 37 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 38 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 39 |
+
|
| 40 |
+
if seed == -1:
|
| 41 |
+
seed = random.randint(0, 65535)
|
| 42 |
+
seed_everything(seed)
|
| 43 |
+
|
| 44 |
+
if config.save_memory:
|
| 45 |
+
model.low_vram_shift(is_diffusing=False)
|
| 46 |
+
|
| 47 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 48 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 49 |
+
shape = (4, H // 8, W // 8)
|
| 50 |
+
|
| 51 |
+
if config.save_memory:
|
| 52 |
+
model.low_vram_shift(is_diffusing=True)
|
| 53 |
+
|
| 54 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 55 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 56 |
+
shape, cond, verbose=False, eta=eta,
|
| 57 |
+
unconditional_guidance_scale=scale,
|
| 58 |
+
unconditional_conditioning=un_cond)
|
| 59 |
+
|
| 60 |
+
if config.save_memory:
|
| 61 |
+
model.low_vram_shift(is_diffusing=False)
|
| 62 |
+
|
| 63 |
+
x_samples = model.decode_first_stage(samples)
|
| 64 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 67 |
+
return [detected_map] + results
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
block = gr.Blocks().queue()
|
| 71 |
+
with block:
|
| 72 |
+
with gr.Row():
|
| 73 |
+
gr.Markdown("## Control Stable Diffusion with Human Pose")
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column():
|
| 76 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 77 |
+
prompt = gr.Textbox(label="Prompt")
|
| 78 |
+
run_button = gr.Button(label="Run")
|
| 79 |
+
with gr.Accordion("Advanced options", open=False):
|
| 80 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 81 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 82 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 83 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 84 |
+
detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1)
|
| 85 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 86 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 87 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 88 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 89 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 90 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 91 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 92 |
+
with gr.Column():
|
| 93 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 94 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 95 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_scribble2image.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from cldm.model import create_model, load_state_dict
|
| 14 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 18 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
|
| 19 |
+
model = model.cuda()
|
| 20 |
+
ddim_sampler = DDIMSampler(model)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
img = resize_image(HWC3(input_image), image_resolution)
|
| 26 |
+
H, W, C = img.shape
|
| 27 |
+
|
| 28 |
+
detected_map = np.zeros_like(img, dtype=np.uint8)
|
| 29 |
+
detected_map[np.min(img, axis=2) < 127] = 255
|
| 30 |
+
|
| 31 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 32 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 33 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 34 |
+
|
| 35 |
+
if seed == -1:
|
| 36 |
+
seed = random.randint(0, 65535)
|
| 37 |
+
seed_everything(seed)
|
| 38 |
+
|
| 39 |
+
if config.save_memory:
|
| 40 |
+
model.low_vram_shift(is_diffusing=False)
|
| 41 |
+
|
| 42 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 43 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 44 |
+
shape = (4, H // 8, W // 8)
|
| 45 |
+
|
| 46 |
+
if config.save_memory:
|
| 47 |
+
model.low_vram_shift(is_diffusing=True)
|
| 48 |
+
|
| 49 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 50 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 51 |
+
shape, cond, verbose=False, eta=eta,
|
| 52 |
+
unconditional_guidance_scale=scale,
|
| 53 |
+
unconditional_conditioning=un_cond)
|
| 54 |
+
|
| 55 |
+
if config.save_memory:
|
| 56 |
+
model.low_vram_shift(is_diffusing=False)
|
| 57 |
+
|
| 58 |
+
x_samples = model.decode_first_stage(samples)
|
| 59 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 60 |
+
|
| 61 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 62 |
+
return [255 - detected_map] + results
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
block = gr.Blocks().queue()
|
| 66 |
+
with block:
|
| 67 |
+
with gr.Row():
|
| 68 |
+
gr.Markdown("## Control Stable Diffusion with Scribble Maps")
|
| 69 |
+
with gr.Row():
|
| 70 |
+
with gr.Column():
|
| 71 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 72 |
+
prompt = gr.Textbox(label="Prompt")
|
| 73 |
+
run_button = gr.Button(label="Run")
|
| 74 |
+
with gr.Accordion("Advanced options", open=False):
|
| 75 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 76 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 77 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 78 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 79 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 80 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 81 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 82 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 83 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 84 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 85 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 86 |
+
with gr.Column():
|
| 87 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 88 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 89 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_scribble2image_interactive.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from cldm.model import create_model, load_state_dict
|
| 14 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 18 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
|
| 19 |
+
model = model.cuda()
|
| 20 |
+
ddim_sampler = DDIMSampler(model)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
img = resize_image(HWC3(input_image['mask'][:, :, 0]), image_resolution)
|
| 26 |
+
H, W, C = img.shape
|
| 27 |
+
|
| 28 |
+
detected_map = np.zeros_like(img, dtype=np.uint8)
|
| 29 |
+
detected_map[np.min(img, axis=2) > 127] = 255
|
| 30 |
+
|
| 31 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 32 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 33 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 34 |
+
|
| 35 |
+
if seed == -1:
|
| 36 |
+
seed = random.randint(0, 65535)
|
| 37 |
+
seed_everything(seed)
|
| 38 |
+
|
| 39 |
+
if config.save_memory:
|
| 40 |
+
model.low_vram_shift(is_diffusing=False)
|
| 41 |
+
|
| 42 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 43 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 44 |
+
shape = (4, H // 8, W // 8)
|
| 45 |
+
|
| 46 |
+
if config.save_memory:
|
| 47 |
+
model.low_vram_shift(is_diffusing=True)
|
| 48 |
+
|
| 49 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 50 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 51 |
+
shape, cond, verbose=False, eta=eta,
|
| 52 |
+
unconditional_guidance_scale=scale,
|
| 53 |
+
unconditional_conditioning=un_cond)
|
| 54 |
+
|
| 55 |
+
if config.save_memory:
|
| 56 |
+
model.low_vram_shift(is_diffusing=False)
|
| 57 |
+
|
| 58 |
+
x_samples = model.decode_first_stage(samples)
|
| 59 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 60 |
+
|
| 61 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 62 |
+
return [255 - detected_map] + results
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_canvas(w, h):
|
| 66 |
+
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
block = gr.Blocks().queue()
|
| 70 |
+
with block:
|
| 71 |
+
with gr.Row():
|
| 72 |
+
gr.Markdown("## Control Stable Diffusion with Interactive Scribbles")
|
| 73 |
+
with gr.Row():
|
| 74 |
+
with gr.Column():
|
| 75 |
+
canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=1)
|
| 76 |
+
canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=1)
|
| 77 |
+
create_button = gr.Button(label="Start", value='Open drawing canvas!')
|
| 78 |
+
input_image = gr.Image(source='upload', type='numpy', tool='sketch')
|
| 79 |
+
gr.Markdown(value='Do not forget to change your brush width to make it thinner. (Gradio do not allow developers to set brush width so you need to do it manually.) '
|
| 80 |
+
'Just click on the small pencil icon in the upper right corner of the above block.')
|
| 81 |
+
create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
|
| 82 |
+
prompt = gr.Textbox(label="Prompt")
|
| 83 |
+
run_button = gr.Button(label="Run")
|
| 84 |
+
with gr.Accordion("Advanced options", open=False):
|
| 85 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 86 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 87 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 88 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 89 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 90 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 91 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 92 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 93 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 94 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 95 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 96 |
+
with gr.Column():
|
| 97 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 98 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 99 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/gradio_seg2image.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
import config
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import einops
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
from pytorch_lightning import seed_everything
|
| 12 |
+
from annotator.util import resize_image, HWC3
|
| 13 |
+
from annotator.uniformer import UniformerDetector
|
| 14 |
+
from cldm.model import create_model, load_state_dict
|
| 15 |
+
from cldm.ddim_hacked import DDIMSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
apply_uniformer = UniformerDetector()
|
| 19 |
+
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict('./models/control_sd15_seg.pth', location='cuda'))
|
| 22 |
+
model = model.cuda()
|
| 23 |
+
ddim_sampler = DDIMSampler(model)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
input_image = HWC3(input_image)
|
| 29 |
+
detected_map = apply_uniformer(resize_image(input_image, detect_resolution))
|
| 30 |
+
img = resize_image(input_image, image_resolution)
|
| 31 |
+
H, W, C = img.shape
|
| 32 |
+
|
| 33 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
|
| 34 |
+
|
| 35 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
| 36 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 37 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 38 |
+
|
| 39 |
+
if seed == -1:
|
| 40 |
+
seed = random.randint(0, 65535)
|
| 41 |
+
seed_everything(seed)
|
| 42 |
+
|
| 43 |
+
if config.save_memory:
|
| 44 |
+
model.low_vram_shift(is_diffusing=False)
|
| 45 |
+
|
| 46 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 47 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 48 |
+
shape = (4, H // 8, W // 8)
|
| 49 |
+
|
| 50 |
+
if config.save_memory:
|
| 51 |
+
model.low_vram_shift(is_diffusing=True)
|
| 52 |
+
|
| 53 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 54 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 55 |
+
shape, cond, verbose=False, eta=eta,
|
| 56 |
+
unconditional_guidance_scale=scale,
|
| 57 |
+
unconditional_conditioning=un_cond)
|
| 58 |
+
|
| 59 |
+
if config.save_memory:
|
| 60 |
+
model.low_vram_shift(is_diffusing=False)
|
| 61 |
+
|
| 62 |
+
x_samples = model.decode_first_stage(samples)
|
| 63 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 64 |
+
|
| 65 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 66 |
+
return [detected_map] + results
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
block = gr.Blocks().queue()
|
| 70 |
+
with block:
|
| 71 |
+
with gr.Row():
|
| 72 |
+
gr.Markdown("## Control Stable Diffusion with Segmentation Maps")
|
| 73 |
+
with gr.Row():
|
| 74 |
+
with gr.Column():
|
| 75 |
+
input_image = gr.Image(source='upload', type="numpy")
|
| 76 |
+
prompt = gr.Textbox(label="Prompt")
|
| 77 |
+
run_button = gr.Button(label="Run")
|
| 78 |
+
with gr.Accordion("Advanced options", open=False):
|
| 79 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
|
| 80 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
| 81 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 82 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 83 |
+
detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1)
|
| 84 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 85 |
+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
| 86 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 87 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 88 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 89 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 90 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 91 |
+
with gr.Column():
|
| 92 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 93 |
+
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
|
| 94 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
block.launch(server_name='0.0.0.0')
|
generation/control/ControlNet/ldm/data/__init__.py
ADDED
|
File without changes
|
generation/control/ControlNet/ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
|
| 6 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
| 7 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 8 |
+
|
| 9 |
+
from ldm.util import instantiate_from_config
|
| 10 |
+
from ldm.modules.ema import LitEma
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AutoencoderKL(pl.LightningModule):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
ddconfig,
|
| 16 |
+
lossconfig,
|
| 17 |
+
embed_dim,
|
| 18 |
+
ckpt_path=None,
|
| 19 |
+
ignore_keys=[],
|
| 20 |
+
image_key="image",
|
| 21 |
+
colorize_nlabels=None,
|
| 22 |
+
monitor=None,
|
| 23 |
+
ema_decay=None,
|
| 24 |
+
learn_logvar=False
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.learn_logvar = learn_logvar
|
| 28 |
+
self.image_key = image_key
|
| 29 |
+
self.encoder = Encoder(**ddconfig)
|
| 30 |
+
self.decoder = Decoder(**ddconfig)
|
| 31 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 32 |
+
assert ddconfig["double_z"]
|
| 33 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
| 34 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 35 |
+
self.embed_dim = embed_dim
|
| 36 |
+
if colorize_nlabels is not None:
|
| 37 |
+
assert type(colorize_nlabels)==int
|
| 38 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 39 |
+
if monitor is not None:
|
| 40 |
+
self.monitor = monitor
|
| 41 |
+
|
| 42 |
+
self.use_ema = ema_decay is not None
|
| 43 |
+
if self.use_ema:
|
| 44 |
+
self.ema_decay = ema_decay
|
| 45 |
+
assert 0. < ema_decay < 1.
|
| 46 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
| 47 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 48 |
+
|
| 49 |
+
if ckpt_path is not None:
|
| 50 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 51 |
+
|
| 52 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 53 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 54 |
+
keys = list(sd.keys())
|
| 55 |
+
for k in keys:
|
| 56 |
+
for ik in ignore_keys:
|
| 57 |
+
if k.startswith(ik):
|
| 58 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 59 |
+
del sd[k]
|
| 60 |
+
self.load_state_dict(sd, strict=False)
|
| 61 |
+
print(f"Restored from {path}")
|
| 62 |
+
|
| 63 |
+
@contextmanager
|
| 64 |
+
def ema_scope(self, context=None):
|
| 65 |
+
if self.use_ema:
|
| 66 |
+
self.model_ema.store(self.parameters())
|
| 67 |
+
self.model_ema.copy_to(self)
|
| 68 |
+
if context is not None:
|
| 69 |
+
print(f"{context}: Switched to EMA weights")
|
| 70 |
+
try:
|
| 71 |
+
yield None
|
| 72 |
+
finally:
|
| 73 |
+
if self.use_ema:
|
| 74 |
+
self.model_ema.restore(self.parameters())
|
| 75 |
+
if context is not None:
|
| 76 |
+
print(f"{context}: Restored training weights")
|
| 77 |
+
|
| 78 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 79 |
+
if self.use_ema:
|
| 80 |
+
self.model_ema(self)
|
| 81 |
+
|
| 82 |
+
def encode(self, x):
|
| 83 |
+
h = self.encoder(x)
|
| 84 |
+
moments = self.quant_conv(h)
|
| 85 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 86 |
+
return posterior
|
| 87 |
+
|
| 88 |
+
def decode(self, z):
|
| 89 |
+
z = self.post_quant_conv(z)
|
| 90 |
+
dec = self.decoder(z)
|
| 91 |
+
return dec
|
| 92 |
+
|
| 93 |
+
def forward(self, input, sample_posterior=True):
|
| 94 |
+
posterior = self.encode(input)
|
| 95 |
+
if sample_posterior:
|
| 96 |
+
z = posterior.sample()
|
| 97 |
+
else:
|
| 98 |
+
z = posterior.mode()
|
| 99 |
+
dec = self.decode(z)
|
| 100 |
+
return dec, posterior
|
| 101 |
+
|
| 102 |
+
def get_input(self, batch, k):
|
| 103 |
+
x = batch[k]
|
| 104 |
+
if len(x.shape) == 3:
|
| 105 |
+
x = x[..., None]
|
| 106 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 110 |
+
inputs = self.get_input(batch, self.image_key)
|
| 111 |
+
reconstructions, posterior = self(inputs)
|
| 112 |
+
|
| 113 |
+
if optimizer_idx == 0:
|
| 114 |
+
# train encoder+decoder+logvar
|
| 115 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 116 |
+
last_layer=self.get_last_layer(), split="train")
|
| 117 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 118 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 119 |
+
return aeloss
|
| 120 |
+
|
| 121 |
+
if optimizer_idx == 1:
|
| 122 |
+
# train the discriminator
|
| 123 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 124 |
+
last_layer=self.get_last_layer(), split="train")
|
| 125 |
+
|
| 126 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 127 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 128 |
+
return discloss
|
| 129 |
+
|
| 130 |
+
def validation_step(self, batch, batch_idx):
|
| 131 |
+
log_dict = self._validation_step(batch, batch_idx)
|
| 132 |
+
with self.ema_scope():
|
| 133 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
| 134 |
+
return log_dict
|
| 135 |
+
|
| 136 |
+
def _validation_step(self, batch, batch_idx, postfix=""):
|
| 137 |
+
inputs = self.get_input(batch, self.image_key)
|
| 138 |
+
reconstructions, posterior = self(inputs)
|
| 139 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
| 140 |
+
last_layer=self.get_last_layer(), split="val"+postfix)
|
| 141 |
+
|
| 142 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
| 143 |
+
last_layer=self.get_last_layer(), split="val"+postfix)
|
| 144 |
+
|
| 145 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
| 146 |
+
self.log_dict(log_dict_ae)
|
| 147 |
+
self.log_dict(log_dict_disc)
|
| 148 |
+
return self.log_dict
|
| 149 |
+
|
| 150 |
+
def configure_optimizers(self):
|
| 151 |
+
lr = self.learning_rate
|
| 152 |
+
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
| 153 |
+
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
| 154 |
+
if self.learn_logvar:
|
| 155 |
+
print(f"{self.__class__.__name__}: Learning logvar")
|
| 156 |
+
ae_params_list.append(self.loss.logvar)
|
| 157 |
+
opt_ae = torch.optim.Adam(ae_params_list,
|
| 158 |
+
lr=lr, betas=(0.5, 0.9))
|
| 159 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 160 |
+
lr=lr, betas=(0.5, 0.9))
|
| 161 |
+
return [opt_ae, opt_disc], []
|
| 162 |
+
|
| 163 |
+
def get_last_layer(self):
|
| 164 |
+
return self.decoder.conv_out.weight
|
| 165 |
+
|
| 166 |
+
@torch.no_grad()
|
| 167 |
+
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
| 168 |
+
log = dict()
|
| 169 |
+
x = self.get_input(batch, self.image_key)
|
| 170 |
+
x = x.to(self.device)
|
| 171 |
+
if not only_inputs:
|
| 172 |
+
xrec, posterior = self(x)
|
| 173 |
+
if x.shape[1] > 3:
|
| 174 |
+
# colorize with random projection
|
| 175 |
+
assert xrec.shape[1] > 3
|
| 176 |
+
x = self.to_rgb(x)
|
| 177 |
+
xrec = self.to_rgb(xrec)
|
| 178 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
| 179 |
+
log["reconstructions"] = xrec
|
| 180 |
+
if log_ema or self.use_ema:
|
| 181 |
+
with self.ema_scope():
|
| 182 |
+
xrec_ema, posterior_ema = self(x)
|
| 183 |
+
if x.shape[1] > 3:
|
| 184 |
+
# colorize with random projection
|
| 185 |
+
assert xrec_ema.shape[1] > 3
|
| 186 |
+
xrec_ema = self.to_rgb(xrec_ema)
|
| 187 |
+
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
| 188 |
+
log["reconstructions_ema"] = xrec_ema
|
| 189 |
+
log["inputs"] = x
|
| 190 |
+
return log
|
| 191 |
+
|
| 192 |
+
def to_rgb(self, x):
|
| 193 |
+
assert self.image_key == "segmentation"
|
| 194 |
+
if not hasattr(self, "colorize"):
|
| 195 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 196 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 197 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class IdentityFirstStage(torch.nn.Module):
|
| 202 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 203 |
+
self.vq_interface = vq_interface
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
def encode(self, x, *args, **kwargs):
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
def decode(self, x, *args, **kwargs):
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
def quantize(self, x, *args, **kwargs):
|
| 213 |
+
if self.vq_interface:
|
| 214 |
+
return x, None, [None, None, None]
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
def forward(self, x, *args, **kwargs):
|
| 218 |
+
return x
|
| 219 |
+
|
generation/control/ControlNet/ldm/models/diffusion/__init__.py
ADDED
|
File without changes
|
generation/control/ControlNet/ldm/models/diffusion/ddim.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DDIMSampler(object):
|
| 11 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.model = model
|
| 14 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 15 |
+
self.schedule = schedule
|
| 16 |
+
|
| 17 |
+
def register_buffer(self, name, attr):
|
| 18 |
+
if type(attr) == torch.Tensor:
|
| 19 |
+
if attr.device != torch.device("cuda"):
|
| 20 |
+
attr = attr.to(torch.device("cuda"))
|
| 21 |
+
setattr(self, name, attr)
|
| 22 |
+
|
| 23 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
| 24 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
| 25 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
| 26 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 27 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 28 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 29 |
+
|
| 30 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
| 31 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 32 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
| 33 |
+
|
| 34 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 35 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
| 36 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
| 37 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
| 38 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
| 39 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
| 40 |
+
|
| 41 |
+
# ddim sampling parameters
|
| 42 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
| 43 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 44 |
+
eta=ddim_eta,verbose=verbose)
|
| 45 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
| 46 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
| 47 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
| 48 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
| 49 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 50 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
| 51 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
| 52 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def sample(self,
|
| 56 |
+
S,
|
| 57 |
+
batch_size,
|
| 58 |
+
shape,
|
| 59 |
+
conditioning=None,
|
| 60 |
+
callback=None,
|
| 61 |
+
normals_sequence=None,
|
| 62 |
+
img_callback=None,
|
| 63 |
+
quantize_x0=False,
|
| 64 |
+
eta=0.,
|
| 65 |
+
mask=None,
|
| 66 |
+
x0=None,
|
| 67 |
+
temperature=1.,
|
| 68 |
+
noise_dropout=0.,
|
| 69 |
+
score_corrector=None,
|
| 70 |
+
corrector_kwargs=None,
|
| 71 |
+
verbose=True,
|
| 72 |
+
x_T=None,
|
| 73 |
+
log_every_t=100,
|
| 74 |
+
unconditional_guidance_scale=1.,
|
| 75 |
+
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 76 |
+
dynamic_threshold=None,
|
| 77 |
+
ucg_schedule=None,
|
| 78 |
+
**kwargs
|
| 79 |
+
):
|
| 80 |
+
if conditioning is not None:
|
| 81 |
+
if isinstance(conditioning, dict):
|
| 82 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
| 83 |
+
while isinstance(ctmp, list): ctmp = ctmp[0]
|
| 84 |
+
cbs = ctmp.shape[0]
|
| 85 |
+
if cbs != batch_size:
|
| 86 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 87 |
+
|
| 88 |
+
elif isinstance(conditioning, list):
|
| 89 |
+
for ctmp in conditioning:
|
| 90 |
+
if ctmp.shape[0] != batch_size:
|
| 91 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
if conditioning.shape[0] != batch_size:
|
| 95 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 96 |
+
|
| 97 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 98 |
+
# sampling
|
| 99 |
+
C, H, W = shape
|
| 100 |
+
size = (batch_size, C, H, W)
|
| 101 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
| 102 |
+
|
| 103 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
| 104 |
+
callback=callback,
|
| 105 |
+
img_callback=img_callback,
|
| 106 |
+
quantize_denoised=quantize_x0,
|
| 107 |
+
mask=mask, x0=x0,
|
| 108 |
+
ddim_use_original_steps=False,
|
| 109 |
+
noise_dropout=noise_dropout,
|
| 110 |
+
temperature=temperature,
|
| 111 |
+
score_corrector=score_corrector,
|
| 112 |
+
corrector_kwargs=corrector_kwargs,
|
| 113 |
+
x_T=x_T,
|
| 114 |
+
log_every_t=log_every_t,
|
| 115 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 116 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 117 |
+
dynamic_threshold=dynamic_threshold,
|
| 118 |
+
ucg_schedule=ucg_schedule
|
| 119 |
+
)
|
| 120 |
+
return samples, intermediates
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def ddim_sampling(self, cond, shape,
|
| 124 |
+
x_T=None, ddim_use_original_steps=False,
|
| 125 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
| 126 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
| 127 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 128 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
| 129 |
+
ucg_schedule=None):
|
| 130 |
+
device = self.model.betas.device
|
| 131 |
+
b = shape[0]
|
| 132 |
+
if x_T is None:
|
| 133 |
+
img = torch.randn(shape, device=device)
|
| 134 |
+
else:
|
| 135 |
+
img = x_T
|
| 136 |
+
|
| 137 |
+
if timesteps is None:
|
| 138 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 139 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 140 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 141 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 142 |
+
|
| 143 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 144 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 145 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 146 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 147 |
+
|
| 148 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 149 |
+
|
| 150 |
+
for i, step in enumerate(iterator):
|
| 151 |
+
index = total_steps - i - 1
|
| 152 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 153 |
+
|
| 154 |
+
if mask is not None:
|
| 155 |
+
assert x0 is not None
|
| 156 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
| 157 |
+
img = img_orig * mask + (1. - mask) * img
|
| 158 |
+
|
| 159 |
+
if ucg_schedule is not None:
|
| 160 |
+
assert len(ucg_schedule) == len(time_range)
|
| 161 |
+
unconditional_guidance_scale = ucg_schedule[i]
|
| 162 |
+
|
| 163 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 164 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 165 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 166 |
+
corrector_kwargs=corrector_kwargs,
|
| 167 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 168 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 169 |
+
dynamic_threshold=dynamic_threshold)
|
| 170 |
+
img, pred_x0 = outs
|
| 171 |
+
if callback: callback(i)
|
| 172 |
+
if img_callback: img_callback(pred_x0, i)
|
| 173 |
+
|
| 174 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 175 |
+
intermediates['x_inter'].append(img)
|
| 176 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 177 |
+
|
| 178 |
+
return img, intermediates
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
| 182 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 183 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
| 184 |
+
dynamic_threshold=None):
|
| 185 |
+
b, *_, device = *x.shape, x.device
|
| 186 |
+
|
| 187 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 188 |
+
model_output = self.model.apply_model(x, t, c)
|
| 189 |
+
else:
|
| 190 |
+
x_in = torch.cat([x] * 2)
|
| 191 |
+
t_in = torch.cat([t] * 2)
|
| 192 |
+
if isinstance(c, dict):
|
| 193 |
+
assert isinstance(unconditional_conditioning, dict)
|
| 194 |
+
c_in = dict()
|
| 195 |
+
for k in c:
|
| 196 |
+
if isinstance(c[k], list):
|
| 197 |
+
c_in[k] = [torch.cat([
|
| 198 |
+
unconditional_conditioning[k][i],
|
| 199 |
+
c[k][i]]) for i in range(len(c[k]))]
|
| 200 |
+
else:
|
| 201 |
+
c_in[k] = torch.cat([
|
| 202 |
+
unconditional_conditioning[k],
|
| 203 |
+
c[k]])
|
| 204 |
+
elif isinstance(c, list):
|
| 205 |
+
c_in = list()
|
| 206 |
+
assert isinstance(unconditional_conditioning, list)
|
| 207 |
+
for i in range(len(c)):
|
| 208 |
+
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
| 209 |
+
else:
|
| 210 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
| 211 |
+
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
| 212 |
+
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
| 213 |
+
|
| 214 |
+
if self.model.parameterization == "v":
|
| 215 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
| 216 |
+
else:
|
| 217 |
+
e_t = model_output
|
| 218 |
+
|
| 219 |
+
if score_corrector is not None:
|
| 220 |
+
assert self.model.parameterization == "eps", 'not implemented'
|
| 221 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
| 222 |
+
|
| 223 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 224 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 225 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 226 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 227 |
+
# select parameters corresponding to the currently considered timestep
|
| 228 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 229 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 230 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 231 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 232 |
+
|
| 233 |
+
# current prediction for x_0
|
| 234 |
+
if self.model.parameterization != "v":
|
| 235 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 236 |
+
else:
|
| 237 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
| 238 |
+
|
| 239 |
+
if quantize_denoised:
|
| 240 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 241 |
+
|
| 242 |
+
if dynamic_threshold is not None:
|
| 243 |
+
raise NotImplementedError()
|
| 244 |
+
|
| 245 |
+
# direction pointing to x_t
|
| 246 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 247 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 248 |
+
if noise_dropout > 0.:
|
| 249 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 250 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 251 |
+
return x_prev, pred_x0
|
| 252 |
+
|
| 253 |
+
@torch.no_grad()
|
| 254 |
+
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
| 255 |
+
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
| 256 |
+
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
| 257 |
+
|
| 258 |
+
assert t_enc <= num_reference_steps
|
| 259 |
+
num_steps = t_enc
|
| 260 |
+
|
| 261 |
+
if use_original_steps:
|
| 262 |
+
alphas_next = self.alphas_cumprod[:num_steps]
|
| 263 |
+
alphas = self.alphas_cumprod_prev[:num_steps]
|
| 264 |
+
else:
|
| 265 |
+
alphas_next = self.ddim_alphas[:num_steps]
|
| 266 |
+
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
| 267 |
+
|
| 268 |
+
x_next = x0
|
| 269 |
+
intermediates = []
|
| 270 |
+
inter_steps = []
|
| 271 |
+
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
| 272 |
+
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
| 273 |
+
if unconditional_guidance_scale == 1.:
|
| 274 |
+
noise_pred = self.model.apply_model(x_next, t, c)
|
| 275 |
+
else:
|
| 276 |
+
assert unconditional_conditioning is not None
|
| 277 |
+
e_t_uncond, noise_pred = torch.chunk(
|
| 278 |
+
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
| 279 |
+
torch.cat((unconditional_conditioning, c))), 2)
|
| 280 |
+
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
| 281 |
+
|
| 282 |
+
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
| 283 |
+
weighted_noise_pred = alphas_next[i].sqrt() * (
|
| 284 |
+
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
| 285 |
+
x_next = xt_weighted + weighted_noise_pred
|
| 286 |
+
if return_intermediates and i % (
|
| 287 |
+
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
| 288 |
+
intermediates.append(x_next)
|
| 289 |
+
inter_steps.append(i)
|
| 290 |
+
elif return_intermediates and i >= num_steps - 2:
|
| 291 |
+
intermediates.append(x_next)
|
| 292 |
+
inter_steps.append(i)
|
| 293 |
+
if callback: callback(i)
|
| 294 |
+
|
| 295 |
+
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
| 296 |
+
if return_intermediates:
|
| 297 |
+
out.update({'intermediates': intermediates})
|
| 298 |
+
return x_next, out
|
| 299 |
+
|
| 300 |
+
@torch.no_grad()
|
| 301 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
| 302 |
+
# fast, but does not allow for exact reconstruction
|
| 303 |
+
# t serves as an index to gather the correct alphas
|
| 304 |
+
if use_original_steps:
|
| 305 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
| 306 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
| 307 |
+
else:
|
| 308 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
| 309 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
| 310 |
+
|
| 311 |
+
if noise is None:
|
| 312 |
+
noise = torch.randn_like(x0)
|
| 313 |
+
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
| 314 |
+
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
| 315 |
+
|
| 316 |
+
@torch.no_grad()
|
| 317 |
+
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
| 318 |
+
use_original_steps=False, callback=None):
|
| 319 |
+
|
| 320 |
+
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
| 321 |
+
timesteps = timesteps[:t_start]
|
| 322 |
+
|
| 323 |
+
time_range = np.flip(timesteps)
|
| 324 |
+
total_steps = timesteps.shape[0]
|
| 325 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 326 |
+
|
| 327 |
+
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
| 328 |
+
x_dec = x_latent
|
| 329 |
+
for i, step in enumerate(iterator):
|
| 330 |
+
index = total_steps - i - 1
|
| 331 |
+
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
| 332 |
+
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
| 333 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 334 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 335 |
+
if callback: callback(i)
|
| 336 |
+
return x_dec
|
generation/control/ControlNet/ldm/util.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import optim
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inspect import isfunction
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def log_txt_as_img(wh, xc, size=10):
|
| 12 |
+
# wh a tuple of (width, height)
|
| 13 |
+
# xc a list of captions to plot
|
| 14 |
+
b = len(xc)
|
| 15 |
+
txts = list()
|
| 16 |
+
for bi in range(b):
|
| 17 |
+
txt = Image.new("RGB", wh, color="white")
|
| 18 |
+
draw = ImageDraw.Draw(txt)
|
| 19 |
+
font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
|
| 20 |
+
nc = int(40 * (wh[0] / 256))
|
| 21 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
| 25 |
+
except UnicodeEncodeError:
|
| 26 |
+
print("Cant encode string for logging. Skipping.")
|
| 27 |
+
|
| 28 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
| 29 |
+
txts.append(txt)
|
| 30 |
+
txts = np.stack(txts)
|
| 31 |
+
txts = torch.tensor(txts)
|
| 32 |
+
return txts
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def ismap(x):
|
| 36 |
+
if not isinstance(x, torch.Tensor):
|
| 37 |
+
return False
|
| 38 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def isimage(x):
|
| 42 |
+
if not isinstance(x,torch.Tensor):
|
| 43 |
+
return False
|
| 44 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def exists(x):
|
| 48 |
+
return x is not None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def default(val, d):
|
| 52 |
+
if exists(val):
|
| 53 |
+
return val
|
| 54 |
+
return d() if isfunction(d) else d
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def mean_flat(tensor):
|
| 58 |
+
"""
|
| 59 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
| 60 |
+
Take the mean over all non-batch dimensions.
|
| 61 |
+
"""
|
| 62 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def count_params(model, verbose=False):
|
| 66 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 67 |
+
if verbose:
|
| 68 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
| 69 |
+
return total_params
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def instantiate_from_config(config):
|
| 73 |
+
if not "target" in config:
|
| 74 |
+
if config == '__is_first_stage__':
|
| 75 |
+
return None
|
| 76 |
+
elif config == "__is_unconditional__":
|
| 77 |
+
return None
|
| 78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 79 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_obj_from_str(string, reload=False):
|
| 83 |
+
module, cls = string.rsplit(".", 1)
|
| 84 |
+
if reload:
|
| 85 |
+
module_imp = importlib.import_module(module)
|
| 86 |
+
importlib.reload(module_imp)
|
| 87 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
| 91 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
| 92 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
| 93 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
| 94 |
+
ema_power=1., param_names=()):
|
| 95 |
+
"""AdamW that saves EMA versions of the parameters."""
|
| 96 |
+
if not 0.0 <= lr:
|
| 97 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 98 |
+
if not 0.0 <= eps:
|
| 99 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 100 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 101 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 102 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 103 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 104 |
+
if not 0.0 <= weight_decay:
|
| 105 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 106 |
+
if not 0.0 <= ema_decay <= 1.0:
|
| 107 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
| 108 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
| 109 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
| 110 |
+
ema_power=ema_power, param_names=param_names)
|
| 111 |
+
super().__init__(params, defaults)
|
| 112 |
+
|
| 113 |
+
def __setstate__(self, state):
|
| 114 |
+
super().__setstate__(state)
|
| 115 |
+
for group in self.param_groups:
|
| 116 |
+
group.setdefault('amsgrad', False)
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def step(self, closure=None):
|
| 120 |
+
"""Performs a single optimization step.
|
| 121 |
+
Args:
|
| 122 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 123 |
+
and returns the loss.
|
| 124 |
+
"""
|
| 125 |
+
loss = None
|
| 126 |
+
if closure is not None:
|
| 127 |
+
with torch.enable_grad():
|
| 128 |
+
loss = closure()
|
| 129 |
+
|
| 130 |
+
for group in self.param_groups:
|
| 131 |
+
params_with_grad = []
|
| 132 |
+
grads = []
|
| 133 |
+
exp_avgs = []
|
| 134 |
+
exp_avg_sqs = []
|
| 135 |
+
ema_params_with_grad = []
|
| 136 |
+
state_sums = []
|
| 137 |
+
max_exp_avg_sqs = []
|
| 138 |
+
state_steps = []
|
| 139 |
+
amsgrad = group['amsgrad']
|
| 140 |
+
beta1, beta2 = group['betas']
|
| 141 |
+
ema_decay = group['ema_decay']
|
| 142 |
+
ema_power = group['ema_power']
|
| 143 |
+
|
| 144 |
+
for p in group['params']:
|
| 145 |
+
if p.grad is None:
|
| 146 |
+
continue
|
| 147 |
+
params_with_grad.append(p)
|
| 148 |
+
if p.grad.is_sparse:
|
| 149 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
| 150 |
+
grads.append(p.grad)
|
| 151 |
+
|
| 152 |
+
state = self.state[p]
|
| 153 |
+
|
| 154 |
+
# State initialization
|
| 155 |
+
if len(state) == 0:
|
| 156 |
+
state['step'] = 0
|
| 157 |
+
# Exponential moving average of gradient values
|
| 158 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 159 |
+
# Exponential moving average of squared gradient values
|
| 160 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 161 |
+
if amsgrad:
|
| 162 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 163 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 164 |
+
# Exponential moving average of parameter values
|
| 165 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
| 166 |
+
|
| 167 |
+
exp_avgs.append(state['exp_avg'])
|
| 168 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 169 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
| 170 |
+
|
| 171 |
+
if amsgrad:
|
| 172 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
| 173 |
+
|
| 174 |
+
# update the steps for each param group update
|
| 175 |
+
state['step'] += 1
|
| 176 |
+
# record the step after step update
|
| 177 |
+
state_steps.append(state['step'])
|
| 178 |
+
|
| 179 |
+
optim._functional.adamw(params_with_grad,
|
| 180 |
+
grads,
|
| 181 |
+
exp_avgs,
|
| 182 |
+
exp_avg_sqs,
|
| 183 |
+
max_exp_avg_sqs,
|
| 184 |
+
state_steps,
|
| 185 |
+
amsgrad=amsgrad,
|
| 186 |
+
beta1=beta1,
|
| 187 |
+
beta2=beta2,
|
| 188 |
+
lr=group['lr'],
|
| 189 |
+
weight_decay=group['weight_decay'],
|
| 190 |
+
eps=group['eps'],
|
| 191 |
+
maximize=False)
|
| 192 |
+
|
| 193 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
| 194 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
| 195 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
| 196 |
+
|
| 197 |
+
return loss
|
generation/control/ControlNet/share.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import config
|
| 2 |
+
from cldm.hack import disable_verbosity, enable_sliced_attention
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
disable_verbosity()
|
| 6 |
+
|
| 7 |
+
if config.save_memory:
|
| 8 |
+
enable_sliced_attention()
|
generation/control/ControlNet/tool_add_control.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
assert len(sys.argv) == 3, 'Args are wrong.'
|
| 5 |
+
|
| 6 |
+
input_path = sys.argv[1]
|
| 7 |
+
output_path = sys.argv[2]
|
| 8 |
+
|
| 9 |
+
assert os.path.exists(input_path), 'Input model does not exist.'
|
| 10 |
+
assert not os.path.exists(output_path), 'Output filename already exists.'
|
| 11 |
+
assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from share import *
|
| 15 |
+
from cldm.model import create_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_node_name(name, parent_name):
|
| 19 |
+
if len(name) <= len(parent_name):
|
| 20 |
+
return False, ''
|
| 21 |
+
p = name[:len(parent_name)]
|
| 22 |
+
if p != parent_name:
|
| 23 |
+
return False, ''
|
| 24 |
+
return True, name[len(parent_name):]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
model = create_model(config_path='./models/cldm_v15.yaml')
|
| 28 |
+
|
| 29 |
+
pretrained_weights = torch.load(input_path)
|
| 30 |
+
if 'state_dict' in pretrained_weights:
|
| 31 |
+
pretrained_weights = pretrained_weights['state_dict']
|
| 32 |
+
|
| 33 |
+
scratch_dict = model.state_dict()
|
| 34 |
+
|
| 35 |
+
target_dict = {}
|
| 36 |
+
for k in scratch_dict.keys():
|
| 37 |
+
is_control, name = get_node_name(k, 'control_')
|
| 38 |
+
if is_control:
|
| 39 |
+
copy_k = 'model.diffusion_' + name
|
| 40 |
+
else:
|
| 41 |
+
copy_k = k
|
| 42 |
+
if copy_k in pretrained_weights:
|
| 43 |
+
target_dict[k] = pretrained_weights[copy_k].clone()
|
| 44 |
+
else:
|
| 45 |
+
target_dict[k] = scratch_dict[k].clone()
|
| 46 |
+
print(f'These weights are newly added: {k}')
|
| 47 |
+
|
| 48 |
+
model.load_state_dict(target_dict, strict=True)
|
| 49 |
+
torch.save(model.state_dict(), output_path)
|
| 50 |
+
print('Done.')
|
generation/control/ControlNet/tool_add_control_sd21.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
assert len(sys.argv) == 3, 'Args are wrong.'
|
| 5 |
+
|
| 6 |
+
input_path = sys.argv[1]
|
| 7 |
+
output_path = sys.argv[2]
|
| 8 |
+
|
| 9 |
+
assert os.path.exists(input_path), 'Input model does not exist.'
|
| 10 |
+
assert not os.path.exists(output_path), 'Output filename already exists.'
|
| 11 |
+
assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from share import *
|
| 15 |
+
from cldm.model import create_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_node_name(name, parent_name):
|
| 19 |
+
if len(name) <= len(parent_name):
|
| 20 |
+
return False, ''
|
| 21 |
+
p = name[:len(parent_name)]
|
| 22 |
+
if p != parent_name:
|
| 23 |
+
return False, ''
|
| 24 |
+
return True, name[len(parent_name):]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
model = create_model(config_path='./models/cldm_v21.yaml')
|
| 28 |
+
|
| 29 |
+
pretrained_weights = torch.load(input_path)
|
| 30 |
+
if 'state_dict' in pretrained_weights:
|
| 31 |
+
pretrained_weights = pretrained_weights['state_dict']
|
| 32 |
+
|
| 33 |
+
scratch_dict = model.state_dict()
|
| 34 |
+
|
| 35 |
+
target_dict = {}
|
| 36 |
+
for k in scratch_dict.keys():
|
| 37 |
+
is_control, name = get_node_name(k, 'control_')
|
| 38 |
+
if is_control:
|
| 39 |
+
copy_k = 'model.diffusion_' + name
|
| 40 |
+
else:
|
| 41 |
+
copy_k = k
|
| 42 |
+
if copy_k in pretrained_weights:
|
| 43 |
+
target_dict[k] = pretrained_weights[copy_k].clone()
|
| 44 |
+
else:
|
| 45 |
+
target_dict[k] = scratch_dict[k].clone()
|
| 46 |
+
print(f'These weights are newly added: {k}')
|
| 47 |
+
|
| 48 |
+
model.load_state_dict(target_dict, strict=True)
|
| 49 |
+
torch.save(model.state_dict(), output_path)
|
| 50 |
+
print('Done.')
|
generation/control/ControlNet/tool_transfer_control.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path_sd15 = './models/v1-5-pruned.ckpt'
|
| 2 |
+
path_sd15_with_control = './models/control_sd15_openpose.pth'
|
| 3 |
+
path_input = './models/anything-v3-full.safetensors'
|
| 4 |
+
path_output = './models/control_any3_openpose.pth'
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
assert os.path.exists(path_sd15), 'Input path_sd15 does not exists!'
|
| 11 |
+
assert os.path.exists(path_sd15_with_control), 'Input path_sd15_with_control does not exists!'
|
| 12 |
+
assert os.path.exists(path_input), 'Input path_input does not exists!'
|
| 13 |
+
assert os.path.exists(os.path.dirname(path_output)), 'Output folder not exists!'
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from share import *
|
| 18 |
+
from cldm.model import load_state_dict
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
sd15_state_dict = load_state_dict(path_sd15)
|
| 22 |
+
sd15_with_control_state_dict = load_state_dict(path_sd15_with_control)
|
| 23 |
+
input_state_dict = load_state_dict(path_input)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_node_name(name, parent_name):
|
| 27 |
+
if len(name) <= len(parent_name):
|
| 28 |
+
return False, ''
|
| 29 |
+
p = name[:len(parent_name)]
|
| 30 |
+
if p != parent_name:
|
| 31 |
+
return False, ''
|
| 32 |
+
return True, name[len(parent_name):]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
keys = sd15_with_control_state_dict.keys()
|
| 36 |
+
|
| 37 |
+
final_state_dict = {}
|
| 38 |
+
for key in keys:
|
| 39 |
+
is_first_stage, _ = get_node_name(key, 'first_stage_model')
|
| 40 |
+
is_cond_stage, _ = get_node_name(key, 'cond_stage_model')
|
| 41 |
+
if is_first_stage or is_cond_stage:
|
| 42 |
+
final_state_dict[key] = input_state_dict[key]
|
| 43 |
+
continue
|
| 44 |
+
p = sd15_with_control_state_dict[key]
|
| 45 |
+
is_control, node_name = get_node_name(key, 'control_')
|
| 46 |
+
if is_control:
|
| 47 |
+
sd15_key_name = 'model.diffusion_' + node_name
|
| 48 |
+
else:
|
| 49 |
+
sd15_key_name = key
|
| 50 |
+
if sd15_key_name in input_state_dict:
|
| 51 |
+
p_new = p + input_state_dict[sd15_key_name] - sd15_state_dict[sd15_key_name]
|
| 52 |
+
# print(f'Offset clone from [{sd15_key_name}] to [{key}]')
|
| 53 |
+
else:
|
| 54 |
+
p_new = p
|
| 55 |
+
# print(f'Direct clone to [{key}]')
|
| 56 |
+
final_state_dict[key] = p_new
|
| 57 |
+
|
| 58 |
+
torch.save(final_state_dict, path_output)
|
| 59 |
+
print('Transferred model saved at ' + path_output)
|
generation/control/ControlNet/tutorial_dataset.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MyDataset(Dataset):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.data = []
|
| 11 |
+
with open('./training/fill50k/prompt.json', 'rt') as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
self.data.append(json.loads(line))
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return len(self.data)
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, idx):
|
| 19 |
+
item = self.data[idx]
|
| 20 |
+
|
| 21 |
+
source_filename = item['source']
|
| 22 |
+
target_filename = item['target']
|
| 23 |
+
prompt = item['prompt']
|
| 24 |
+
|
| 25 |
+
source = cv2.imread('./training/fill50k/' + source_filename)
|
| 26 |
+
target = cv2.imread('./training/fill50k/' + target_filename)
|
| 27 |
+
|
| 28 |
+
# Do not forget that OpenCV read images in BGR order.
|
| 29 |
+
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
|
| 30 |
+
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
|
| 31 |
+
|
| 32 |
+
# Normalize source images to [0, 1].
|
| 33 |
+
source = source.astype(np.float32) / 255.0
|
| 34 |
+
|
| 35 |
+
# Normalize target images to [-1, 1].
|
| 36 |
+
target = (target.astype(np.float32) / 127.5) - 1.0
|
| 37 |
+
|
| 38 |
+
return dict(jpg=target, txt=prompt, hint=source)
|
| 39 |
+
|
generation/control/ControlNet/tutorial_dataset_test.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tutorial_dataset import MyDataset
|
| 2 |
+
|
| 3 |
+
dataset = MyDataset()
|
| 4 |
+
print(len(dataset))
|
| 5 |
+
|
| 6 |
+
item = dataset[1234]
|
| 7 |
+
jpg = item['jpg']
|
| 8 |
+
txt = item['txt']
|
| 9 |
+
hint = item['hint']
|
| 10 |
+
print(txt)
|
| 11 |
+
print(jpg.shape)
|
| 12 |
+
print(hint.shape)
|
generation/control/ControlNet/tutorial_train.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from tutorial_dataset import MyDataset
|
| 6 |
+
from cldm.logger import ImageLogger
|
| 7 |
+
from cldm.model import create_model, load_state_dict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Configs
|
| 11 |
+
resume_path = './models/control_sd15_ini.ckpt'
|
| 12 |
+
batch_size = 4
|
| 13 |
+
logger_freq = 300
|
| 14 |
+
learning_rate = 1e-5
|
| 15 |
+
sd_locked = True
|
| 16 |
+
only_mid_control = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
| 20 |
+
model = create_model('./models/cldm_v15.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
| 22 |
+
model.learning_rate = learning_rate
|
| 23 |
+
model.sd_locked = sd_locked
|
| 24 |
+
model.only_mid_control = only_mid_control
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Misc
|
| 28 |
+
dataset = MyDataset()
|
| 29 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
| 30 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
| 31 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Train!
|
| 35 |
+
trainer.fit(model, dataloader)
|
generation/control/ControlNet/tutorial_train_sd21.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from share import *
|
| 2 |
+
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from tutorial_dataset import MyDataset
|
| 6 |
+
from cldm.logger import ImageLogger
|
| 7 |
+
from cldm.model import create_model, load_state_dict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Configs
|
| 11 |
+
resume_path = './models/control_sd21_ini.ckpt'
|
| 12 |
+
batch_size = 4
|
| 13 |
+
logger_freq = 300
|
| 14 |
+
learning_rate = 1e-5
|
| 15 |
+
sd_locked = True
|
| 16 |
+
only_mid_control = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
| 20 |
+
model = create_model('./models/cldm_v21.yaml').cpu()
|
| 21 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
| 22 |
+
model.learning_rate = learning_rate
|
| 23 |
+
model.sd_locked = sd_locked
|
| 24 |
+
model.only_mid_control = only_mid_control
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Misc
|
| 28 |
+
dataset = MyDataset()
|
| 29 |
+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
|
| 30 |
+
logger = ImageLogger(batch_frequency=logger_freq)
|
| 31 |
+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Train!
|
| 35 |
+
trainer.fit(model, dataloader)
|
generation/control/download_ade20k.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
if [ ! -d "data" ]; then
|
| 4 |
+
mkdir data
|
| 5 |
+
fi
|
| 6 |
+
cd data
|
| 7 |
+
|
| 8 |
+
wget -O ade20k.zip https://keeper.mpdl.mpg.de/f/80b2fc97ffc3430c98de/?dl=1
|
| 9 |
+
unzip ade20k.zip
|
| 10 |
+
rm ade20k.zip
|
generation/control/download_celebhq.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
if [ ! -d "data" ]; then
|
| 4 |
+
mkdir data
|
| 5 |
+
fi
|
| 6 |
+
cd data
|
| 7 |
+
|
| 8 |
+
wget -O celechq-text.zip https://keeper.mpdl.mpg.de/f/72c34a6017cb40b896e9/?dl=1
|
| 9 |
+
unzip celechq-text.zip
|
| 10 |
+
rm celechq-text.zip
|
generation/control/eval_canny.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
from torch.autograd import Variable
|
| 9 |
+
import torchvision
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
from torchvision.datasets import CocoCaptions
|
| 12 |
+
from torchvision.transforms import ToTensor
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
from torchvision.utils import save_image
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
# from net_canny import Net
|
| 18 |
+
from ControlNet.annotator.canny import CannyDetector
|
| 19 |
+
from ControlNet.annotator.util import resize_image, HWC3
|
| 20 |
+
|
| 21 |
+
# from pytorch_msssim import ssim, ms_ssim, SSIM
|
| 22 |
+
|
| 23 |
+
class ResultFolderDataset(Dataset):
|
| 24 |
+
def __init__(self, data_dir, results_dir, n, transform=None):
|
| 25 |
+
self.data_dir = data_dir
|
| 26 |
+
self.results_dir = results_dir
|
| 27 |
+
self.n = n
|
| 28 |
+
self.image_paths = sorted([f for f in os.listdir(data_dir) if f.lower().endswith(('.png'))])
|
| 29 |
+
# self.image_paths = sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith('_{}.png'.format(n))])
|
| 30 |
+
self.transform = transform
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.image_paths)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
image_name = self.image_paths[idx]
|
| 37 |
+
source_path = os.path.join(self.data_dir, image_name)
|
| 38 |
+
|
| 39 |
+
base_name = image_name.split('_')[1].split('.')[0] # Extract 'x' from 'image_x.png'
|
| 40 |
+
image_name2 = f'result_{base_name}_{self.n}.png'
|
| 41 |
+
result_path = os.path.join(self.results_dir, image_name2)
|
| 42 |
+
|
| 43 |
+
source_image = Image.open(source_path) #.convert('RGB')
|
| 44 |
+
result_image = Image.open(result_path) #.convert('RGB')
|
| 45 |
+
if self.transform:
|
| 46 |
+
source_image = self.transform(source_image)
|
| 47 |
+
result_image = self.transform(result_image)
|
| 48 |
+
return source_image, result_image
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def calculate_metrics(pred, target):
|
| 52 |
+
intersection = np.logical_and(pred, target)
|
| 53 |
+
union = np.logical_or(pred, target)
|
| 54 |
+
iou_score = np.sum(intersection) / np.sum(union)
|
| 55 |
+
|
| 56 |
+
accuracy = np.sum(pred == target) / target.size
|
| 57 |
+
|
| 58 |
+
tp = np.sum(intersection) # True positive
|
| 59 |
+
fp = np.sum(pred) - tp # False positive
|
| 60 |
+
fn = np.sum(target) - tp # False negative
|
| 61 |
+
|
| 62 |
+
f1_score = (2 * tp) / (2 * tp + fp + fn)
|
| 63 |
+
|
| 64 |
+
return iou_score, accuracy, f1_score
|
| 65 |
+
|
| 66 |
+
if __name__ == '__main__':
|
| 67 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 68 |
+
|
| 69 |
+
low_threshold = 100
|
| 70 |
+
high_threshold = 200
|
| 71 |
+
|
| 72 |
+
n = 0
|
| 73 |
+
# epoch = 10
|
| 74 |
+
# experiment = './log/image_log_oft_COCO_canny_eps_1-3_r_4_cayley_4gpu'
|
| 75 |
+
experiment = 'log/image_log_householder_gramschmidt_COCO_canny_eps_7e-06_pe_diff_mlp_l_8_8gpu_2024-05-19-21-22-24-466032'
|
| 76 |
+
|
| 77 |
+
if 'train_with_norm' in experiment:
|
| 78 |
+
epoch = 4
|
| 79 |
+
else:
|
| 80 |
+
if 'COCO' in experiment:
|
| 81 |
+
epoch = 10
|
| 82 |
+
else:
|
| 83 |
+
epoch = 19
|
| 84 |
+
|
| 85 |
+
data_dir = os.path.join(experiment, 'source', str(epoch))
|
| 86 |
+
result_dir = os.path.join(experiment, 'results', str(epoch))
|
| 87 |
+
json_file = os.path.join(experiment, 'results.json')
|
| 88 |
+
|
| 89 |
+
# Define the transforms to apply to the images
|
| 90 |
+
transform = transforms.Compose([
|
| 91 |
+
# transforms.Resize((512, 512)),
|
| 92 |
+
# transforms.CenterCrop(512),
|
| 93 |
+
transforms.ToTensor()
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
dataset = ResultFolderDataset(data_dir, result_dir, n=n, transform=transform)
|
| 97 |
+
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
|
| 98 |
+
|
| 99 |
+
apply_canny = CannyDetector()
|
| 100 |
+
|
| 101 |
+
loss = 0
|
| 102 |
+
iou_score_mean = 0
|
| 103 |
+
accuracy_mean = 0
|
| 104 |
+
f1_score_mean = 0
|
| 105 |
+
ssim_mean = 0
|
| 106 |
+
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
| 107 |
+
source_image, result_image = data
|
| 108 |
+
# Convert the tensor to a numpy array and transpose it to have the channels last (H, W, C)
|
| 109 |
+
source_image_np = source_image.squeeze(0).permute(1, 2, 0).numpy()
|
| 110 |
+
result_image_np = result_image.squeeze(0).permute(1, 2, 0).numpy()
|
| 111 |
+
|
| 112 |
+
# Convert the image to 8-bit unsigned integers (0-255)
|
| 113 |
+
source_image_np = (source_image_np * 255).astype(np.uint8)
|
| 114 |
+
result_image_np = (result_image_np * 255).astype(np.uint8)
|
| 115 |
+
|
| 116 |
+
source_detected_map = apply_canny(source_image_np, low_threshold, high_threshold) / 255
|
| 117 |
+
result_detected_map = apply_canny(result_image_np, low_threshold, high_threshold) / 255
|
| 118 |
+
|
| 119 |
+
iou_score, accuracy, f1_score = calculate_metrics(result_detected_map, source_detected_map)
|
| 120 |
+
|
| 121 |
+
iou_score_mean = iou_score_mean + iou_score
|
| 122 |
+
accuracy_mean = accuracy_mean + accuracy
|
| 123 |
+
f1_score_mean = f1_score_mean + f1_score
|
| 124 |
+
|
| 125 |
+
iou_score_mean = iou_score_mean / len(dataset)
|
| 126 |
+
accuracy_mean = accuracy_mean / len(dataset)
|
| 127 |
+
f1_score_mean = f1_score_mean / len(dataset)
|
| 128 |
+
|
| 129 |
+
print(experiment)
|
| 130 |
+
print('[Canny]', '[IOU]:', iou_score_mean, '[F1 Score]:', f1_score_mean)
|
generation/control/eval_landmark.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from skimage.io import imread, imsave
|
| 7 |
+
import cv2
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1
|
| 11 |
+
def plot_kpts(image, kpts, color = 'g'):
|
| 12 |
+
''' Draw 68 key points
|
| 13 |
+
Args:
|
| 14 |
+
image: the input image
|
| 15 |
+
kpt: (68, 3).
|
| 16 |
+
'''
|
| 17 |
+
if color == 'r':
|
| 18 |
+
c = (255, 0, 0)
|
| 19 |
+
elif color == 'g':
|
| 20 |
+
c = (0, 255, 0)
|
| 21 |
+
elif color == 'b':
|
| 22 |
+
c = (255, 0, 0)
|
| 23 |
+
image = image.copy()
|
| 24 |
+
kpts = kpts.copy()
|
| 25 |
+
radius = max(int(min(image.shape[0], image.shape[1])/200), 1)
|
| 26 |
+
for i in range(kpts.shape[0]):
|
| 27 |
+
st = kpts[i, :2]
|
| 28 |
+
if kpts.shape[1]==4:
|
| 29 |
+
if kpts[i, 3] > 0.5:
|
| 30 |
+
c = (0, 255, 0)
|
| 31 |
+
else:
|
| 32 |
+
c = (0, 0, 255)
|
| 33 |
+
image = cv2.circle(image,(int(st[0]), int(st[1])), radius, c, radius*2)
|
| 34 |
+
if i in end_list:
|
| 35 |
+
continue
|
| 36 |
+
ed = kpts[i + 1, :2]
|
| 37 |
+
image = cv2.line(image, (int(st[0]), int(st[1])), (int(ed[0]), int(ed[1])), (255, 255, 255), radius)
|
| 38 |
+
return image
|
| 39 |
+
|
| 40 |
+
def plot_points(image, kpts, color = 'w'):
|
| 41 |
+
''' Draw 68 key points
|
| 42 |
+
Args:
|
| 43 |
+
image: the input image
|
| 44 |
+
kpt: (n, 3).
|
| 45 |
+
'''
|
| 46 |
+
if color == 'r':
|
| 47 |
+
c = (255, 0, 0)
|
| 48 |
+
elif color == 'g':
|
| 49 |
+
c = (0, 255, 0)
|
| 50 |
+
elif color == 'b':
|
| 51 |
+
c = (0, 0, 255)
|
| 52 |
+
elif color == 'y':
|
| 53 |
+
c = (0, 255, 255)
|
| 54 |
+
elif color == 'w':
|
| 55 |
+
c = (255, 255, 255)
|
| 56 |
+
image = image.copy()
|
| 57 |
+
kpts = kpts.copy()
|
| 58 |
+
kpts = kpts.astype(np.int32)
|
| 59 |
+
radius = max(int(min(image.shape[0], image.shape[1])/200), 1)
|
| 60 |
+
for i in range(kpts.shape[0]):
|
| 61 |
+
st = kpts[i, :2]
|
| 62 |
+
image = cv2.circle(image,(int(st[0]), int(st[1])), radius, c, radius*2)
|
| 63 |
+
return image
|
| 64 |
+
|
| 65 |
+
def generate_landmark2d(inputpath, savepath, n=0, device='cuda:0', vis=False):
|
| 66 |
+
print(f'generate 2d landmarks')
|
| 67 |
+
os.makedirs(savepath, exist_ok=True)
|
| 68 |
+
import face_alignment
|
| 69 |
+
detect_model = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device, flip_input=False)
|
| 70 |
+
|
| 71 |
+
imagepath_list = glob(os.path.join(inputpath, '*_{}.png'.format(n)))
|
| 72 |
+
imagepath_list = sorted(imagepath_list)
|
| 73 |
+
for imagepath in tqdm(imagepath_list):
|
| 74 |
+
name = Path(imagepath).stem
|
| 75 |
+
|
| 76 |
+
image = imread(imagepath)[:,:,:3]
|
| 77 |
+
out = detect_model.get_landmarks(image)
|
| 78 |
+
if out is None:
|
| 79 |
+
continue
|
| 80 |
+
kpt = out[0].squeeze()
|
| 81 |
+
np.savetxt(os.path.join(savepath, f'{name}.txt'), kpt)
|
| 82 |
+
if vis:
|
| 83 |
+
image = cv2.imread(imagepath)
|
| 84 |
+
image_point = plot_kpts(image, kpt)
|
| 85 |
+
# check
|
| 86 |
+
cv2.imwrite(os.path.join(savepath, f'{name}_overlay.jpg'), image_point)
|
| 87 |
+
# background = np.zeros_like(image)
|
| 88 |
+
# cv2.imwrite(os.path.join(savepath, f'{name}_line.jpg'), plot_kpts(background, kpt))
|
| 89 |
+
# cv2.imwrite(os.path.join(savepath, f'{name}_point.jpg'), plot_points(background, kpt))
|
| 90 |
+
# exit()
|
| 91 |
+
|
| 92 |
+
def landmark_comparison(lmk_folder, gt_lmk_folder, n=0):
|
| 93 |
+
print(f'calculate reprojection error')
|
| 94 |
+
lmk_err = []
|
| 95 |
+
gt_lmk_folder = './data/celebhq-text/celeba-hq-landmark2d'
|
| 96 |
+
with open('./data/celebhq-text/prompt_val_blip_full.json', 'rt') as f: # fill50k, COCO
|
| 97 |
+
for line in f:
|
| 98 |
+
val_data = json.loads(line)
|
| 99 |
+
# for i in tqdm(range(2000)):
|
| 100 |
+
for i in tqdm(range(len(val_data))):
|
| 101 |
+
# import ipdb; ipdb.set_trace()
|
| 102 |
+
# line = val_data[n]
|
| 103 |
+
line = val_data[i]
|
| 104 |
+
|
| 105 |
+
img_name = line["image"][:-4]
|
| 106 |
+
lmk1_path = os.path.join(gt_lmk_folder, f'{img_name}.txt')
|
| 107 |
+
|
| 108 |
+
lmk1 = np.loadtxt(lmk1_path) / 2
|
| 109 |
+
lmk2_path = os.path.join(lmk_folder, f'result_{i}_{n}.txt')
|
| 110 |
+
if not os.path.exists(lmk2_path):
|
| 111 |
+
print(f'{lmk2_path} not exist')
|
| 112 |
+
continue
|
| 113 |
+
lmk2 = np.loadtxt(lmk2_path)
|
| 114 |
+
lmk_err.append(np.mean(np.linalg.norm(lmk1 - lmk2, axis=1)))
|
| 115 |
+
print(np.mean(lmk_err))
|
| 116 |
+
np.save(os.path.join(lmk_folder, 'lmk_err.npy'), lmk_err)
|
| 117 |
+
|
| 118 |
+
n = 0
|
| 119 |
+
epoch = 19
|
| 120 |
+
gt_lmk_folder = './data/celebhq-text/celeba-hq-landmark2d'
|
| 121 |
+
# input_folder = os.path.join('./data/image_log_opt_lora_CelebA_landmark_lr_5-6_pe_diff_mlp_r_4_cayley_4gpu/results', str(epoch))
|
| 122 |
+
input_folder = os.path.join('log/image_log_householder_none_ADE20K_segm_eps_7e-06_pe_diff_mlp_l_8_8gpu_2024-05-15-19-33-41-650524/results', str(epoch))
|
| 123 |
+
# input_folder = os.path.join('log/image_log_oft_CelebA_landmark_eps_0.001_pe_diff_mlp_r_4_8gpu_2024-03-21-19-07-34-175825/train_with_norm/results', str(epoch))
|
| 124 |
+
save_folder = os.path.join(input_folder, 'landmark')
|
| 125 |
+
|
| 126 |
+
generate_landmark2d(input_folder, save_folder, n, device='cuda:0', vis=False)
|
| 127 |
+
landmark_comparison(save_folder, gt_lmk_folder, n)
|
generation/control/generation.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from oldm.hack import disable_verbosity
|
| 2 |
+
disable_verbosity()
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import einops
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import random
|
| 12 |
+
import json
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
file_path = os.path.abspath(__file__)
|
| 16 |
+
parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..')
|
| 17 |
+
if parent_dir not in sys.path:
|
| 18 |
+
sys.path.append(parent_dir)
|
| 19 |
+
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from pytorch_lightning import seed_everything
|
| 22 |
+
from oldm.model import create_model, load_state_dict
|
| 23 |
+
from oldm.ddim_hacked import DDIMSampler
|
| 24 |
+
from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm
|
| 25 |
+
from hra import inject_trainable_hra
|
| 26 |
+
from lora import inject_trainable_lora
|
| 27 |
+
|
| 28 |
+
from dataset.utils import return_dataset
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def process(input_image, prompt, hint_image, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
# img = resize_image(HWC3(input_image), image_resolution)
|
| 34 |
+
H, W, C = input_image.shape
|
| 35 |
+
|
| 36 |
+
#detected_map = apply_canny(input_image, low_threshold, high_threshold)
|
| 37 |
+
#detected_map = HWC3(detected_map)
|
| 38 |
+
|
| 39 |
+
# control = torch.from_numpy(hint_image.copy()).float().cuda() / 255.0
|
| 40 |
+
control = torch.from_numpy(hint_image.copy()).float().cuda()
|
| 41 |
+
control = torch.stack([control for _ in range(num_samples)], dim=0)
|
| 42 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 43 |
+
|
| 44 |
+
if seed == -1:
|
| 45 |
+
seed = random.randint(0, 65535)
|
| 46 |
+
seed_everything(seed)
|
| 47 |
+
|
| 48 |
+
# if config.save_memory:
|
| 49 |
+
# model.low_vram_shift(is_diffusing=False)
|
| 50 |
+
|
| 51 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 52 |
+
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
| 53 |
+
shape = (4, H // 8, W // 8)
|
| 54 |
+
|
| 55 |
+
# if config.save_memory:
|
| 56 |
+
# model.low_vram_shift(is_diffusing=True)
|
| 57 |
+
|
| 58 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 59 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
| 60 |
+
shape, cond, verbose=False, eta=eta,
|
| 61 |
+
unconditional_guidance_scale=scale,
|
| 62 |
+
unconditional_conditioning=un_cond)
|
| 63 |
+
|
| 64 |
+
# if config.save_memory:
|
| 65 |
+
# model.low_vram_shift(is_diffusing=False)
|
| 66 |
+
|
| 67 |
+
x_samples = model.decode_first_stage(samples)
|
| 68 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 69 |
+
|
| 70 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 71 |
+
# return [255 - hint_image] + results
|
| 72 |
+
return [input_image] + [hint_image] + results
|
| 73 |
+
|
| 74 |
+
parser = argparse.ArgumentParser()
|
| 75 |
+
|
| 76 |
+
parser.add_argument('--d', type=int, help='the index of GPU', default=0)
|
| 77 |
+
|
| 78 |
+
# HRA
|
| 79 |
+
parser.add_argument('--hra_r', type=int, default=8)
|
| 80 |
+
parser.add_argument('--hra_apply_GS', action="store_true", default=False)
|
| 81 |
+
|
| 82 |
+
# OFT
|
| 83 |
+
parser.add_argument('--oft_r', type=int, default=4)
|
| 84 |
+
parser.add_argument('--oft_eps',
|
| 85 |
+
type=float,
|
| 86 |
+
choices=[1e-3, 2e-5, 7e-6],
|
| 87 |
+
default=7e-6,
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument('--oft_coft', action="store_true", default=True)
|
| 90 |
+
parser.add_argument('--oft_block_share', action="store_true", default=False)
|
| 91 |
+
parser.add_argument('--img_ID', type=int, default=1)
|
| 92 |
+
parser.add_argument('--num_samples', type=int, default=1)
|
| 93 |
+
parser.add_argument('--batch', type=int, default=20)
|
| 94 |
+
parser.add_argument('--sd_locked', action="store_true", default=True)
|
| 95 |
+
parser.add_argument('--only_mid_control', action="store_true", default=False)
|
| 96 |
+
parser.add_argument('--num_gpus', type=int, default=8)
|
| 97 |
+
# parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
|
| 98 |
+
parser.add_argument('--time_str', type=str, default='2024-03-18-10-55-21-089985')
|
| 99 |
+
parser.add_argument('--epoch', type=int, default=19)
|
| 100 |
+
parser.add_argument('--control',
|
| 101 |
+
type=str,
|
| 102 |
+
help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]',
|
| 103 |
+
default="segm")
|
| 104 |
+
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
# Configs
|
| 109 |
+
epoch = args.epoch
|
| 110 |
+
control = args.control
|
| 111 |
+
_, dataset, data_name, logger_freq, max_epochs = return_dataset(control, full=True)
|
| 112 |
+
|
| 113 |
+
# specify the experiment name
|
| 114 |
+
# experiment = './log/image_log_oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu'.format(data_name, control, args.eps, args.r, args.num_gpus)
|
| 115 |
+
|
| 116 |
+
num_gpus = args.num_gpus
|
| 117 |
+
time_str = args.time_str
|
| 118 |
+
# experiment = 'log/image_log_oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.eps, args.r, num_gpus, time_str)
|
| 119 |
+
experiment = './log/image_log_hra_0.0_ADE20K_segm_pe_diff_mlp_r_8_8gpu_2024-06-27-19-57-34-979197'
|
| 120 |
+
# experiment = './log/image_log_oft_ADE20K_segm_eps_0.001_pe_diff_mlp_r_4_8gpu_2024-03-25-21-04-17-549433/train_with_norm'
|
| 121 |
+
|
| 122 |
+
assert args.control in experiment
|
| 123 |
+
|
| 124 |
+
if 'train_with_norm' in experiment:
|
| 125 |
+
epoch = 4
|
| 126 |
+
else:
|
| 127 |
+
if 'COCO' in experiment:
|
| 128 |
+
epoch = 9
|
| 129 |
+
else:
|
| 130 |
+
epoch = 19
|
| 131 |
+
|
| 132 |
+
resume_path = os.path.join(experiment, f'model-epoch={epoch:02d}.ckpt')
|
| 133 |
+
sd_locked = args.sd_locked
|
| 134 |
+
only_mid_control = args.only_mid_control
|
| 135 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 136 |
+
|
| 137 |
+
# Result directory
|
| 138 |
+
result_dir = os.path.join(experiment, 'results', str(epoch))
|
| 139 |
+
os.makedirs(result_dir, exist_ok=True)
|
| 140 |
+
source_dir = os.path.join(experiment, 'source', str(epoch))
|
| 141 |
+
os.makedirs(source_dir, exist_ok=True)
|
| 142 |
+
hint_dir = os.path.join(experiment, 'hints', str(epoch))
|
| 143 |
+
os.makedirs(hint_dir, exist_ok=True)
|
| 144 |
+
|
| 145 |
+
model = create_model('./configs/oft_ldm_v15.yaml').cpu()
|
| 146 |
+
model.model.requires_grad_(False)
|
| 147 |
+
|
| 148 |
+
if 'hra' in experiment:
|
| 149 |
+
unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS)
|
| 150 |
+
elif 'lora' in experiment:
|
| 151 |
+
unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None)
|
| 152 |
+
else:
|
| 153 |
+
if 'train_with_norm' in experiment:
|
| 154 |
+
unet_opt_params, train_names = inject_trainable_oft_with_norm(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
|
| 155 |
+
else:
|
| 156 |
+
unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
|
| 157 |
+
# unet_lora_params, train_names = inject_trainable_oft_conv(model.model, r=args.r, eps=args.eps, is_coft=args.coft, block_share=args.block_share)
|
| 158 |
+
# unet_lora_params, train_names = inject_trainable_oft_extended(model.model, r=args.r, eps=args.eps, is_coft=args.coft, block_share=args.block_share)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
model.load_state_dict(load_state_dict(resume_path, location='cuda'))
|
| 162 |
+
model = model.cuda()
|
| 163 |
+
ddim_sampler = DDIMSampler(model)
|
| 164 |
+
|
| 165 |
+
# pack = range(0, len(dataset), args.batch)
|
| 166 |
+
# formatted_data = {}
|
| 167 |
+
# for index in range(args.batch):
|
| 168 |
+
# # import ipdb; ipdb.set_trace()
|
| 169 |
+
# start_point = pack[args.img_ID]
|
| 170 |
+
# idx = start_point + index
|
| 171 |
+
|
| 172 |
+
# canny
|
| 173 |
+
# img_list = [378, 441, 0, 31, 115, 182, 59, 60, 66, 269, ]
|
| 174 |
+
# landmark
|
| 175 |
+
# img_list = [139, 179, 197, 144, 54, 71, 76, 98, 100, 277, ]
|
| 176 |
+
# segm
|
| 177 |
+
# img_list = [14, 667, 576, 1387, 1603, 1697, 987, 1830, 1232, 1881, ]
|
| 178 |
+
|
| 179 |
+
# for idx in img_list:
|
| 180 |
+
|
| 181 |
+
num_pack = len(dataset) // args.num_gpus
|
| 182 |
+
start_idx = args.d * num_pack
|
| 183 |
+
end_idx = (args.d + 1) * num_pack if args.d < args.num_gpus - 1 else len(dataset)
|
| 184 |
+
|
| 185 |
+
for idx in range(start_idx, end_idx):
|
| 186 |
+
|
| 187 |
+
data = dataset[idx]
|
| 188 |
+
input_image, prompt, hint = data['jpg'], data['txt'], data['hint']
|
| 189 |
+
# input_image, hint = input_image.to(device), hint.to(device)
|
| 190 |
+
|
| 191 |
+
if not os.path.exists(os.path.join(result_dir, f'result_{idx}_0.png')):
|
| 192 |
+
result_images = process(
|
| 193 |
+
input_image=input_image,
|
| 194 |
+
prompt=prompt,
|
| 195 |
+
hint_image=hint,
|
| 196 |
+
a_prompt="",
|
| 197 |
+
n_prompt="",
|
| 198 |
+
num_samples=args.num_samples,
|
| 199 |
+
image_resolution=512,
|
| 200 |
+
ddim_steps=50,
|
| 201 |
+
guess_mode=False,
|
| 202 |
+
strength=1,
|
| 203 |
+
scale=9.0,
|
| 204 |
+
seed=-1,
|
| 205 |
+
eta=0.0,
|
| 206 |
+
low_threshold=100,
|
| 207 |
+
high_threshold=200,
|
| 208 |
+
)
|
| 209 |
+
for i, image in enumerate(result_images):
|
| 210 |
+
if i == 0:
|
| 211 |
+
image = ((image + 1) * 127.5).clip(0, 255).astype(np.uint8)
|
| 212 |
+
pil_image = Image.fromarray(image)
|
| 213 |
+
output_path = os.path.join(source_dir, f'image_{idx}.png')
|
| 214 |
+
pil_image.save(output_path)
|
| 215 |
+
elif i == 1:
|
| 216 |
+
image = (image * 255).clip(0, 255).astype(np.uint8)
|
| 217 |
+
# Convert numpy array to PIL Image
|
| 218 |
+
pil_image = Image.fromarray(image)
|
| 219 |
+
# Save PIL Image to file
|
| 220 |
+
output_path = os.path.join(hint_dir, f'hint_{idx}.png')
|
| 221 |
+
pil_image.save(output_path)
|
| 222 |
+
else:
|
| 223 |
+
n = i - 2
|
| 224 |
+
# Convert numpy array to PIL Image
|
| 225 |
+
pil_image = Image.fromarray(image)
|
| 226 |
+
# Save PIL Image to file
|
| 227 |
+
output_path = os.path.join(result_dir, f'result_{idx}_{n}.png')
|
| 228 |
+
pil_image.save(output_path)
|
| 229 |
+
|
| 230 |
+
# formatted_data[f"item{idx}"] = {
|
| 231 |
+
# "image_name": f'result_{idx}.png',
|
| 232 |
+
# "prompt": prompt
|
| 233 |
+
# }
|
| 234 |
+
|
| 235 |
+
# with open(os.path.join(experiment, 'results_{}.json'.format(img_ID)), 'w') as f:
|
| 236 |
+
# json.dump(formatted_data, f)
|
| 237 |
+
|
| 238 |
+
|
generation/control/hra.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script utilizes code from lora available at:
|
| 3 |
+
https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
| 4 |
+
|
| 5 |
+
Original Author: Simo Ryu
|
| 6 |
+
License: Apache License 2.0
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
from itertools import groupby
|
| 13 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
| 14 |
+
|
| 15 |
+
import pickle
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import PIL
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from safetensors.torch import safe_open
|
| 27 |
+
from safetensors.torch import save_file as safe_save
|
| 28 |
+
|
| 29 |
+
safetensors_available = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
from .safe_open import safe_open
|
| 32 |
+
|
| 33 |
+
def safe_save(
|
| 34 |
+
tensors: Dict[str, torch.Tensor],
|
| 35 |
+
filename: str,
|
| 36 |
+
metadata: Optional[Dict[str, str]] = None,
|
| 37 |
+
) -> None:
|
| 38 |
+
raise EnvironmentError(
|
| 39 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
safetensors_available = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def project(R, eps):
|
| 46 |
+
I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
|
| 47 |
+
diff = R - I
|
| 48 |
+
norm_diff = torch.norm(diff)
|
| 49 |
+
if norm_diff <= eps:
|
| 50 |
+
return R
|
| 51 |
+
else:
|
| 52 |
+
return I + eps * (diff / norm_diff)
|
| 53 |
+
|
| 54 |
+
def project_batch(R, eps=1e-5):
|
| 55 |
+
# scaling factor for each of the smaller block matrix
|
| 56 |
+
eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
|
| 57 |
+
I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
|
| 58 |
+
diff = R - I
|
| 59 |
+
norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
|
| 60 |
+
mask = (norm_diff <= eps).bool()
|
| 61 |
+
out = torch.where(mask, R, I + eps * (diff / norm_diff))
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class HRAInjectedLinear(nn.Module):
|
| 66 |
+
def __init__(
|
| 67 |
+
self, in_features, out_features, bias=False, r=8, apply_GS=False,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.in_features=in_features
|
| 71 |
+
self.out_features=out_features
|
| 72 |
+
|
| 73 |
+
self.r = r
|
| 74 |
+
self.apply_GS = apply_GS
|
| 75 |
+
|
| 76 |
+
half_u = torch.zeros(in_features, r // 2)
|
| 77 |
+
nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
|
| 78 |
+
self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True)
|
| 79 |
+
|
| 80 |
+
self.fixed_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
orig_weight = self.fixed_linear.weight.data
|
| 84 |
+
if self.apply_GS:
|
| 85 |
+
weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)]
|
| 86 |
+
for i in range(1, self.r):
|
| 87 |
+
ui = self.hra_u[:, i].view(-1, 1)
|
| 88 |
+
for j in range(i):
|
| 89 |
+
ui = ui - (weight[j].t() @ ui) * weight[j]
|
| 90 |
+
weight.append((ui / ui.norm()).view(-1, 1))
|
| 91 |
+
weight = torch.cat(weight, dim=1)
|
| 92 |
+
new_weight = orig_weight @ (torch.eye(self.in_features, device=x.device) - 2 * weight @ weight.t())
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
new_weight = orig_weight
|
| 96 |
+
hra_u_norm = self.hra_u / self.hra_u.norm(dim=0)
|
| 97 |
+
for i in range(self.r):
|
| 98 |
+
ui = hra_u_norm[:, i].view(-1, 1)
|
| 99 |
+
new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device) - 2 * ui @ ui.t())
|
| 100 |
+
|
| 101 |
+
out = nn.functional.linear(input=x, weight=new_weight, bias=self.fixed_linear.bias)
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
| 106 |
+
|
| 107 |
+
UNET_CONV_TARGET_REPLACE = {"ResBlock"}
|
| 108 |
+
|
| 109 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResBlock", "CrossAttention", "Attention", "GEGLU"}
|
| 110 |
+
|
| 111 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
| 112 |
+
|
| 113 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
| 114 |
+
|
| 115 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
| 116 |
+
|
| 117 |
+
EMBED_FLAG = "<embed>"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _find_children(
|
| 121 |
+
model,
|
| 122 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Find all modules of a certain class (or union of classes).
|
| 126 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 127 |
+
names they are referenced by.
|
| 128 |
+
"""
|
| 129 |
+
result = []
|
| 130 |
+
for parent in model.modules():
|
| 131 |
+
for name, module in parent.named_children():
|
| 132 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 133 |
+
result.append((parent, name, module)) # Append the result to the list
|
| 134 |
+
|
| 135 |
+
return result # Return the list instead of using 'yield'
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _find_modules_v2(
|
| 139 |
+
model,
|
| 140 |
+
ancestor_class: Optional[Set[str]] = None,
|
| 141 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 142 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
| 143 |
+
HRAInjectedLinear,
|
| 144 |
+
],
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
| 148 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
| 149 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 150 |
+
names they are referenced by.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
# Get the targets we should replace all linears under
|
| 154 |
+
if ancestor_class is not None:
|
| 155 |
+
ancestors = (
|
| 156 |
+
module
|
| 157 |
+
for module in model.modules()
|
| 158 |
+
if module.__class__.__name__ in ancestor_class
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
# the first modules is the most senior father class.
|
| 162 |
+
# this, incase you want to naively iterate over all modules.
|
| 163 |
+
for module in model.modules():
|
| 164 |
+
ancestor_class = module.__class__.__name__
|
| 165 |
+
break
|
| 166 |
+
ancestors = (
|
| 167 |
+
module
|
| 168 |
+
for module in model.modules()
|
| 169 |
+
if module.__class__.__name__ in ancestor_class
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
results = []
|
| 173 |
+
# For each target find every linear_class module that isn't a child of a HRAInjectedLinear
|
| 174 |
+
for ancestor in ancestors:
|
| 175 |
+
for fullname, module in ancestor.named_modules():
|
| 176 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 177 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
| 178 |
+
*path, name = fullname.split(".")
|
| 179 |
+
parent = ancestor
|
| 180 |
+
while path:
|
| 181 |
+
parent = parent.get_submodule(path.pop(0))
|
| 182 |
+
# Skip this linear if it's a child of a HRAInjectedLinear
|
| 183 |
+
if exclude_children_of and any(
|
| 184 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
| 185 |
+
):
|
| 186 |
+
continue
|
| 187 |
+
results.append((parent, name, module)) # Append the result to the list
|
| 188 |
+
|
| 189 |
+
return results # Return the list instead of using 'yield'
|
| 190 |
+
|
| 191 |
+
def _find_modules_old(
|
| 192 |
+
model,
|
| 193 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 194 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 195 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [HRAInjectedLinear],
|
| 196 |
+
):
|
| 197 |
+
ret = []
|
| 198 |
+
for _module in model.modules():
|
| 199 |
+
if _module.__class__.__name__ in ancestor_class:
|
| 200 |
+
|
| 201 |
+
for name, _child_module in _module.named_modules():
|
| 202 |
+
if _child_module.__class__ in search_class:
|
| 203 |
+
ret.append((_module, name, _child_module))
|
| 204 |
+
|
| 205 |
+
return ret
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
_find_modules = _find_modules_v2
|
| 209 |
+
# _find_modules = _find_modules_old
|
| 210 |
+
|
| 211 |
+
def inject_trainable_hra(
|
| 212 |
+
model: nn.Module,
|
| 213 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 214 |
+
verbose: bool = False,
|
| 215 |
+
r: int = 8,
|
| 216 |
+
apply_GS: str = False,
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
inject hra into model, and returns hra parameter groups.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
require_grad_params = []
|
| 223 |
+
names = []
|
| 224 |
+
|
| 225 |
+
for _module, name, _child_module in _find_modules(
|
| 226 |
+
model, target_replace_module, search_class=[nn.Linear]
|
| 227 |
+
):
|
| 228 |
+
|
| 229 |
+
weight = _child_module.weight
|
| 230 |
+
bias = _child_module.bias
|
| 231 |
+
if verbose:
|
| 232 |
+
print("HRA Injection : injecting hra into ", name)
|
| 233 |
+
print("HRA Injection : weight shape", weight.shape)
|
| 234 |
+
_tmp = HRAInjectedLinear(
|
| 235 |
+
_child_module.in_features,
|
| 236 |
+
_child_module.out_features,
|
| 237 |
+
_child_module.bias is not None,
|
| 238 |
+
r=r,
|
| 239 |
+
apply_GS=apply_GS,
|
| 240 |
+
)
|
| 241 |
+
_tmp.fixed_linear.weight = weight
|
| 242 |
+
if bias is not None:
|
| 243 |
+
_tmp.fixed_linear.bias = bias
|
| 244 |
+
|
| 245 |
+
# switch the module
|
| 246 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 247 |
+
_module._modules[name] = _tmp
|
| 248 |
+
|
| 249 |
+
require_grad_params.append(_module._modules[name].hra_u)
|
| 250 |
+
_module._modules[name].hra_u.requires_grad = True
|
| 251 |
+
|
| 252 |
+
names.append(name)
|
| 253 |
+
|
| 254 |
+
return require_grad_params, names
|
generation/control/tool_add_hra.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script utilizes code from ControlNet available at:
|
| 3 |
+
https://github.com/lllyasviel/ControlNet/blob/main/tool_add_control.py
|
| 4 |
+
|
| 5 |
+
Original Author: Lvmin Zhang
|
| 6 |
+
License: Apache License 2.0
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
os.environ['HF_HOME'] = '/tmp'
|
| 12 |
+
|
| 13 |
+
# assert len(sys.argv) == 3, 'Args are wrong.'
|
| 14 |
+
|
| 15 |
+
# input_path = sys.argv[1]
|
| 16 |
+
# output_path = sys.argv[2]
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from oldm.hack import disable_verbosity
|
| 20 |
+
disable_verbosity()
|
| 21 |
+
from oldm.model import create_model
|
| 22 |
+
from hra import inject_trainable_hra
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
parser.add_argument('--input_path', type=str, default='./models/v1-5-pruned.ckpt')
|
| 28 |
+
parser.add_argument('--output_path', type=str, default='./models/hra_half_init_l_8.ckpt')
|
| 29 |
+
parser.add_argument('--r', type=int, default=8)
|
| 30 |
+
parser.add_argument('--apply_GS', action='store_true', default=False)
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
# args.output_path = f'./models/hra_none_l_8.ckpt'
|
| 34 |
+
|
| 35 |
+
assert os.path.exists(args.input_path), 'Input model does not exist.'
|
| 36 |
+
# assert not os.path.exists(output_path), 'Output filename already exists.'
|
| 37 |
+
assert os.path.exists(os.path.dirname(args.output_path)), 'Output path is not valid.'
|
| 38 |
+
|
| 39 |
+
def get_node_name(name, parent_name):
|
| 40 |
+
if len(name) <= len(parent_name):
|
| 41 |
+
return False, ''
|
| 42 |
+
p = name[:len(parent_name)]
|
| 43 |
+
if p != parent_name:
|
| 44 |
+
return False, ''
|
| 45 |
+
return True, name[len(parent_name):]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
model = create_model(config_path='./configs/oft_ldm_v15.yaml')
|
| 49 |
+
model.model.requires_grad_(False)
|
| 50 |
+
|
| 51 |
+
unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.r, apply_GS=args.apply_GS)
|
| 52 |
+
|
| 53 |
+
pretrained_weights = torch.load(args.input_path)
|
| 54 |
+
if 'state_dict' in pretrained_weights:
|
| 55 |
+
pretrained_weights = pretrained_weights['state_dict']
|
| 56 |
+
|
| 57 |
+
scratch_dict = model.state_dict()
|
| 58 |
+
|
| 59 |
+
target_dict = {}
|
| 60 |
+
names = []
|
| 61 |
+
for k in scratch_dict.keys():
|
| 62 |
+
names.append(k)
|
| 63 |
+
|
| 64 |
+
if k in pretrained_weights:
|
| 65 |
+
target_dict[k] = pretrained_weights[k].clone()
|
| 66 |
+
else:
|
| 67 |
+
if 'fixed_linear.' in k:
|
| 68 |
+
copy_k = k.replace('fixed_linear.', '')
|
| 69 |
+
target_dict[k] = pretrained_weights[copy_k].clone()
|
| 70 |
+
else:
|
| 71 |
+
target_dict[k] = scratch_dict[k].clone()
|
| 72 |
+
print(f'These weights are newly added: {k}')
|
| 73 |
+
|
| 74 |
+
with open('HRA_model_names.txt', 'w') as file:
|
| 75 |
+
for element in names:
|
| 76 |
+
file.write(element + '\n')
|
| 77 |
+
|
| 78 |
+
model.load_state_dict(target_dict, strict=True)
|
| 79 |
+
torch.save(model.state_dict(), args.output_path)
|
| 80 |
+
# print('没有保存模型')
|
| 81 |
+
print('Done.')
|
generation/control/train.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from oldm.hack import disable_verbosity
|
| 2 |
+
disable_verbosity()
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
file_path = os.path.abspath(__file__)
|
| 10 |
+
parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..')
|
| 11 |
+
if parent_dir not in sys.path:
|
| 12 |
+
sys.path.append(parent_dir)
|
| 13 |
+
|
| 14 |
+
import pytorch_lightning as pl
|
| 15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from oldm.logger import ImageLogger
|
| 18 |
+
from oldm.model import create_model, load_state_dict
|
| 19 |
+
from dataset.utils import return_dataset
|
| 20 |
+
|
| 21 |
+
from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm
|
| 22 |
+
from hra import inject_trainable_hra
|
| 23 |
+
from lora import inject_trainable_lora
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
|
| 30 |
+
# HRA
|
| 31 |
+
parser.add_argument('--hra_r', type=int, default=8)
|
| 32 |
+
parser.add_argument('--hra_apply_GS', action='store_true', default=False)
|
| 33 |
+
parser.add_argument('--hra_lamb', type=float, default=0.0)
|
| 34 |
+
|
| 35 |
+
# OFT
|
| 36 |
+
parser.add_argument('--oft_r', type=int, default=4)
|
| 37 |
+
parser.add_argument('--oft_eps', type=float, default=7e-6)
|
| 38 |
+
parser.add_argument('--oft_coft', action="store_true", default=True)
|
| 39 |
+
parser.add_argument('--oft_block_share', action="store_true", default=False)
|
| 40 |
+
|
| 41 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
| 42 |
+
parser.add_argument('--num_samples', type=int, default=1)
|
| 43 |
+
parser.add_argument('--plot_frequency', type=int, default=100)
|
| 44 |
+
parser.add_argument('--learning_rate', type=float, default=9e-4)
|
| 45 |
+
parser.add_argument('--sd_locked', action="store_true", default=True)
|
| 46 |
+
parser.add_argument('--only_mid_control', action="store_true", default=False)
|
| 47 |
+
parser.add_argument('--num_gpus', type=int, default=torch.cuda.device_count())
|
| 48 |
+
parser.add_argument('--resume_path',
|
| 49 |
+
type=str,
|
| 50 |
+
default='./models/hra_half_init_l_8.ckpt',
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
|
| 53 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
| 54 |
+
parser.add_argument('--control',
|
| 55 |
+
type=str,
|
| 56 |
+
help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]',
|
| 57 |
+
default="segm")
|
| 58 |
+
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
# specify the control signal and dataset
|
| 64 |
+
control = args.control
|
| 65 |
+
|
| 66 |
+
# create dataset
|
| 67 |
+
train_dataset, val_dataset, data_name, logger_freq, max_epochs = return_dataset(control) # , n_samples=n_samples)
|
| 68 |
+
|
| 69 |
+
# Configs
|
| 70 |
+
resume_path = args.resume_path
|
| 71 |
+
|
| 72 |
+
batch_size = args.batch_size
|
| 73 |
+
num_samples = args.num_samples
|
| 74 |
+
plot_frequency = args.plot_frequency
|
| 75 |
+
learning_rate = args.learning_rate
|
| 76 |
+
sd_locked = args.sd_locked
|
| 77 |
+
only_mid_control = args.only_mid_control
|
| 78 |
+
num_gpus = args.num_gpus
|
| 79 |
+
time_str = args.time_str
|
| 80 |
+
num_workers = args.num_workers
|
| 81 |
+
|
| 82 |
+
for arg in vars(args):
|
| 83 |
+
print(f'{arg}: {getattr(args, arg)}')
|
| 84 |
+
print(f'data_name: {data_name}\nlogger_freq: {logger_freq}\nmax_epochs: {max_epochs}')
|
| 85 |
+
|
| 86 |
+
if 'oft' in args.resume_path:
|
| 87 |
+
experiment = 'oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.oft_eps, args.oft_r, num_gpus, time_str)
|
| 88 |
+
elif 'hra' in args.resume_path:
|
| 89 |
+
if args.hra_apply_GS:
|
| 90 |
+
experiment = 'hra_apply_GS_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.hra_r, num_gpus, time_str)
|
| 91 |
+
else:
|
| 92 |
+
experiment = 'hra_{}_{}_pe_diff_mlp_r_{}_lambda_{}_lr_{}_{}gpu_{}'.format(data_name, control, args.hra_r, args.hra_lamb, args.learning_rate, num_gpus, time_str)
|
| 93 |
+
elif 'lora' in args.resume_path:
|
| 94 |
+
experiment = 'lora_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.r, num_gpus, time_str)
|
| 95 |
+
|
| 96 |
+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
|
| 97 |
+
model = create_model('./configs/oft_ldm_v15.yaml').cpu()
|
| 98 |
+
model.model.requires_grad_(False)
|
| 99 |
+
print(f'Total parameters not requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == False])}')
|
| 100 |
+
|
| 101 |
+
# inject trainable oft parameters
|
| 102 |
+
if 'oft' in args.resume_path:
|
| 103 |
+
unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
|
| 104 |
+
elif 'hra' in args.resume_path:
|
| 105 |
+
unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS)
|
| 106 |
+
elif 'lora' in args.resume_path:
|
| 107 |
+
unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None)
|
| 108 |
+
|
| 109 |
+
print(f'Total parameters requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == True])}')
|
| 110 |
+
|
| 111 |
+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
|
| 112 |
+
model.learning_rate = learning_rate
|
| 113 |
+
model.sd_locked = sd_locked
|
| 114 |
+
model.only_mid_control = only_mid_control
|
| 115 |
+
|
| 116 |
+
checkpoint_callback = ModelCheckpoint(
|
| 117 |
+
dirpath='log/image_log_' + experiment,
|
| 118 |
+
filename='model-{epoch:02d}',
|
| 119 |
+
save_top_k=-1,
|
| 120 |
+
save_last=True,
|
| 121 |
+
every_n_epochs=1,
|
| 122 |
+
monitor=None, # No specific metric to monitor for saving
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Misc
|
| 126 |
+
train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
|
| 127 |
+
val_dataloader = DataLoader(val_dataset, num_workers=num_workers, batch_size=1, shuffle=False)
|
| 128 |
+
|
| 129 |
+
logger = ImageLogger(
|
| 130 |
+
val_dataloader=val_dataloader,
|
| 131 |
+
batch_frequency=logger_freq,
|
| 132 |
+
experiment=experiment,
|
| 133 |
+
plot_frequency=plot_frequency,
|
| 134 |
+
num_samples=num_samples,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
trainer = pl.Trainer(
|
| 138 |
+
max_epochs=max_epochs,
|
| 139 |
+
gpus=num_gpus,
|
| 140 |
+
precision=32,
|
| 141 |
+
callbacks=[logger, checkpoint_callback],
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Train!
|
| 145 |
+
last_model_path = 'log/image_log_' + experiment + '/last.ckpt'
|
| 146 |
+
if os.path.exists(last_model_path):
|
| 147 |
+
trainer.fit(model, train_dataloader, ckpt_path=last_model_path)
|
| 148 |
+
else:
|
| 149 |
+
trainer.fit(model, train_dataloader)
|
generation/env.yml
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: generation
|
| 2 |
+
dependencies:
|
| 3 |
+
- _libgcc_mutex=0.1=main
|
| 4 |
+
- _openmp_mutex=5.1=1_gnu
|
| 5 |
+
- ca-certificates=2023.12.12=h06a4308_0
|
| 6 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 7 |
+
- libffi=3.4.4=h6a678d5_0
|
| 8 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 9 |
+
- libgomp=11.2.0=h1234567_1
|
| 10 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 11 |
+
- ncurses=6.4=h6a678d5_0
|
| 12 |
+
- openssl=3.0.13=h7f8727e_0
|
| 13 |
+
- pip=23.3.1=py39h06a4308_0
|
| 14 |
+
- python=3.9.18=h955ad1f_0
|
| 15 |
+
- readline=8.2=h5eee18b_0
|
| 16 |
+
- setuptools=68.2.2=py39h06a4308_0
|
| 17 |
+
- sqlite=3.41.2=h5eee18b_0
|
| 18 |
+
- tk=8.6.12=h1ccaba5_0
|
| 19 |
+
- wheel=0.41.2=py39h06a4308_0
|
| 20 |
+
- xz=5.4.6=h5eee18b_0
|
| 21 |
+
- zlib=1.2.13=h5eee18b_0
|
| 22 |
+
- pip:
|
| 23 |
+
- absl-py==2.1.0
|
| 24 |
+
- accelerate==0.21.0
|
| 25 |
+
- addict==2.4.0
|
| 26 |
+
- aiofiles==23.2.1
|
| 27 |
+
- aiohttp==3.9.3
|
| 28 |
+
- aiosignal==1.3.1
|
| 29 |
+
- altair==5.2.0
|
| 30 |
+
- annotated-types==0.6.0
|
| 31 |
+
- antlr4-python3-runtime==4.8
|
| 32 |
+
- anyio==4.3.0
|
| 33 |
+
- appdirs==1.4.4
|
| 34 |
+
- async-timeout==4.0.3
|
| 35 |
+
- attrs==23.2.0
|
| 36 |
+
- basicsr==1.4.2
|
| 37 |
+
- certifi==2024.2.2
|
| 38 |
+
- charset-normalizer==3.3.2
|
| 39 |
+
- click==8.1.7
|
| 40 |
+
- conda-pack==0.7.1
|
| 41 |
+
- contourpy==1.2.0
|
| 42 |
+
- cycler==0.12.1
|
| 43 |
+
- datasets==2.14.2
|
| 44 |
+
- diffusers==0.17.1
|
| 45 |
+
- dill==0.3.7
|
| 46 |
+
- docker-pycreds==0.4.0
|
| 47 |
+
- einops==0.7.0
|
| 48 |
+
- exceptiongroup==1.2.0
|
| 49 |
+
- face-alignment==1.3.4
|
| 50 |
+
- fastapi==0.110.0
|
| 51 |
+
- ffmpy==0.3.2
|
| 52 |
+
- filelock==3.13.1
|
| 53 |
+
- fire==0.6.0
|
| 54 |
+
- fonttools==4.49.0
|
| 55 |
+
- frozenlist==1.4.1
|
| 56 |
+
- fsspec==2024.2.0
|
| 57 |
+
- ftfy==6.1.3
|
| 58 |
+
- future==1.0.0
|
| 59 |
+
- gitdb==4.0.11
|
| 60 |
+
- gitpython==3.1.42
|
| 61 |
+
- gradio==3.16.2
|
| 62 |
+
- grpcio==1.62.1
|
| 63 |
+
- h11==0.14.0
|
| 64 |
+
- httpcore==1.0.4
|
| 65 |
+
- httpx==0.27.0
|
| 66 |
+
- huggingface-hub==0.21.4
|
| 67 |
+
- idna==3.6
|
| 68 |
+
- imageio==2.34.0
|
| 69 |
+
- importlib-metadata==7.0.2
|
| 70 |
+
- importlib-resources==6.3.0
|
| 71 |
+
- jinja2==3.1.3
|
| 72 |
+
- jsonschema==4.21.1
|
| 73 |
+
- jsonschema-specifications==2023.12.1
|
| 74 |
+
- kiwisolver==1.4.5
|
| 75 |
+
- lazy-loader==0.3
|
| 76 |
+
- lightning-utilities==0.10.1
|
| 77 |
+
- linkify-it-py==2.0.3
|
| 78 |
+
- llvmlite==0.42.0
|
| 79 |
+
- lmdb==1.4.1
|
| 80 |
+
- lpips==0.1.4
|
| 81 |
+
- markdown==3.5.2
|
| 82 |
+
- markdown-it-py==3.0.0
|
| 83 |
+
- markupsafe==2.1.5
|
| 84 |
+
- matplotlib==3.8.3
|
| 85 |
+
- mdit-py-plugins==0.4.0
|
| 86 |
+
- mdurl==0.1.2
|
| 87 |
+
- mpmath==1.3.0
|
| 88 |
+
- multidict==6.0.5
|
| 89 |
+
- multiprocess==0.70.15
|
| 90 |
+
- networkx==3.2.1
|
| 91 |
+
- numba==0.59.0
|
| 92 |
+
- numpy==1.26.4
|
| 93 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 94 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 95 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 96 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 97 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
| 98 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 99 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 100 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 101 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 102 |
+
- nvidia-nccl-cu12==2.19.3
|
| 103 |
+
- nvidia-nvjitlink-cu12==12.4.99
|
| 104 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 105 |
+
- omegaconf==2.1.1
|
| 106 |
+
- open-clip-torch==2.0.2
|
| 107 |
+
- opencv-python==4.9.0.80
|
| 108 |
+
- orjson==3.9.15
|
| 109 |
+
- packaging==24.0
|
| 110 |
+
- pandas==2.2.1
|
| 111 |
+
- pillow==10.2.0
|
| 112 |
+
- platformdirs==4.2.0
|
| 113 |
+
- protobuf==4.25.3
|
| 114 |
+
- psutil==5.9.8
|
| 115 |
+
- pyarrow==15.0.1
|
| 116 |
+
- pycocotools==2.0.7
|
| 117 |
+
- pycryptodome==3.20.0
|
| 118 |
+
- pydantic==2.6.4
|
| 119 |
+
- pydantic-core==2.16.3
|
| 120 |
+
- pydeprecate==0.3.1
|
| 121 |
+
- pydub==0.25.1
|
| 122 |
+
- pyparsing==3.1.2
|
| 123 |
+
- python-dateutil==2.9.0.post0
|
| 124 |
+
- python-multipart==0.0.9
|
| 125 |
+
- pytorch-fid==0.3.0
|
| 126 |
+
- pytorch-lightning==1.5.0
|
| 127 |
+
- pytz==2024.1
|
| 128 |
+
- pyyaml==6.0.1
|
| 129 |
+
- referencing==0.33.0
|
| 130 |
+
- regex==2023.12.25
|
| 131 |
+
- requests==2.31.0
|
| 132 |
+
- rpds-py==0.18.0
|
| 133 |
+
- safetensors==0.4.2
|
| 134 |
+
- scikit-image==0.22.0
|
| 135 |
+
- scipy==1.12.0
|
| 136 |
+
- sentry-sdk==1.42.0
|
| 137 |
+
- setproctitle==1.3.3
|
| 138 |
+
- six==1.16.0
|
| 139 |
+
- smmap==5.0.1
|
| 140 |
+
- sniffio==1.3.1
|
| 141 |
+
- starlette==0.36.3
|
| 142 |
+
- sympy==1.12
|
| 143 |
+
- tb-nightly==2.17.0a20240313
|
| 144 |
+
- tensorboard==2.16.2
|
| 145 |
+
- tensorboard-data-server==0.7.2
|
| 146 |
+
- termcolor==2.4.0
|
| 147 |
+
- tifffile==2024.2.12
|
| 148 |
+
- timm==0.9.16
|
| 149 |
+
- tokenizers==0.13.3
|
| 150 |
+
- tomli==2.0.1
|
| 151 |
+
- toolz==0.12.1
|
| 152 |
+
- torch==2.2.1
|
| 153 |
+
- torch-fidelity==0.3.0
|
| 154 |
+
- torchaudio==2.2.1
|
| 155 |
+
- torchmetrics==1.3.1
|
| 156 |
+
- torchvision==0.17.1
|
| 157 |
+
- tqdm==4.66.2
|
| 158 |
+
- transformers==4.25.1
|
| 159 |
+
- triton==2.2.0
|
| 160 |
+
- typing-extensions==4.10.0
|
| 161 |
+
- tzdata==2024.1
|
| 162 |
+
- uc-micro-py==1.0.3
|
| 163 |
+
- urllib3==2.2.1
|
| 164 |
+
- uvicorn==0.28.0
|
| 165 |
+
- wandb==0.16.4
|
| 166 |
+
- wcwidth==0.2.13
|
| 167 |
+
- websockets==12.0
|
| 168 |
+
- werkzeug==3.0.1
|
| 169 |
+
- xxhash==3.4.1
|
| 170 |
+
- yapf==0.40.2
|
| 171 |
+
- yarl==1.9.4
|
| 172 |
+
- zipp==3.18.0
|