Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- diffusers/examples/kandinsky2_2/text_to_image/README.md +317 -0
- diffusers/examples/kandinsky2_2/text_to_image/requirements.txt +7 -0
- diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +929 -0
- diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +812 -0
- diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +844 -0
- diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +958 -0
- diffusers/examples/research_projects/anytext/README.md +40 -0
- diffusers/examples/research_projects/anytext/anytext.py +0 -0
- diffusers/examples/research_projects/anytext/anytext_controlnet.py +463 -0
- diffusers/examples/research_projects/anytext/ocr_recog/RNN.py +209 -0
- diffusers/examples/research_projects/anytext/ocr_recog/RecCTCHead.py +45 -0
- diffusers/examples/research_projects/anytext/ocr_recog/RecModel.py +49 -0
- diffusers/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py +197 -0
- diffusers/examples/research_projects/anytext/ocr_recog/RecSVTR.py +570 -0
- diffusers/examples/research_projects/anytext/ocr_recog/common.py +74 -0
- diffusers/examples/research_projects/anytext/ocr_recog/en_dict.txt +95 -0
- diffusers/examples/research_projects/consistency_training/README.md +24 -0
- diffusers/examples/research_projects/consistency_training/requirements.txt +6 -0
- diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py +1438 -0
- diffusers/examples/research_projects/diffusion_dpo/README.md +94 -0
- diffusers/examples/research_projects/diffusion_dpo/requirements.txt +8 -0
- diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +982 -0
- diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +1140 -0
- diffusers/examples/research_projects/diffusion_orpo/README.md +1 -0
- diffusers/examples/research_projects/diffusion_orpo/requirements.txt +7 -0
- diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py +1092 -0
- diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +1095 -0
- diffusers/examples/research_projects/dreambooth_inpaint/README.md +118 -0
- diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt +7 -0
- diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py +812 -0
- diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +831 -0
- diffusers/examples/research_projects/flux_lora_quantization/README.md +167 -0
- diffusers/examples/research_projects/flux_lora_quantization/accelerate.yaml +17 -0
- diffusers/examples/research_projects/flux_lora_quantization/compute_embeddings.py +107 -0
- diffusers/examples/research_projects/flux_lora_quantization/ds2.yaml +23 -0
- diffusers/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py +1200 -0
- diffusers/examples/research_projects/intel_opts/README.md +36 -0
- diffusers/examples/research_projects/intel_opts/inference_bf16.py +56 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion/README.md +68 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt +7 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +646 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md +93 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt +7 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py +112 -0
- diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py +996 -0
- diffusers/examples/research_projects/ip_adapter/README.md +226 -0
- diffusers/examples/research_projects/ip_adapter/requirements.txt +4 -0
- diffusers/examples/research_projects/ip_adapter/tutorial_train_faceid.py +415 -0
- diffusers/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py +422 -0
- diffusers/examples/research_projects/ip_adapter/tutorial_train_plus.py +445 -0
diffusers/examples/kandinsky2_2/text_to_image/README.md
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kandinsky2.2 text-to-image fine-tuning
|
| 2 |
+
|
| 3 |
+
Kandinsky 2.2 includes a prior pipeline that generates image embeddings from text prompts, and a decoder pipeline that generates the output image based on the image embeddings. We provide `train_text_to_image_prior.py` and `train_text_to_image_decoder.py` scripts to show you how to fine-tune the Kandinsky prior and decoder models separately based on your own dataset. To achieve the best results, you should fine-tune **_both_** your prior and decoder models.
|
| 4 |
+
|
| 5 |
+
___Note___:
|
| 6 |
+
|
| 7 |
+
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Running locally with PyTorch
|
| 11 |
+
|
| 12 |
+
Before running the scripts, make sure to install the library's training dependencies:
|
| 13 |
+
|
| 14 |
+
**Important**
|
| 15 |
+
|
| 16 |
+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
| 17 |
+
```bash
|
| 18 |
+
git clone https://github.com/huggingface/diffusers
|
| 19 |
+
cd diffusers
|
| 20 |
+
pip install .
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Then cd in the example folder and run
|
| 24 |
+
```bash
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
accelerate config
|
| 32 |
+
```
|
| 33 |
+
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag.
|
| 34 |
+
|
| 35 |
+
___
|
| 36 |
+
|
| 37 |
+
### Naruto example
|
| 38 |
+
|
| 39 |
+
For all our examples, we will directly store the trained weights on the Hub, so we need to be logged in and add the `--push_to_hub` flag. In order to do that, you have to be a registered user on the 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to the [User Access Tokens](https://huggingface.co/docs/hub/security-tokens) guide.
|
| 40 |
+
|
| 41 |
+
Run the following command to authenticate your token
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
huggingface-cli login
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
We also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
pip install wandb
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
To disable wandb logging, remove the `--report_to=="wandb"` and `--validation_prompts="A robot naruto, 4k photo"` flags from below examples
|
| 54 |
+
|
| 55 |
+
#### Fine-tune decoder
|
| 56 |
+
<br>
|
| 57 |
+
|
| 58 |
+
<!-- accelerate_snippet_start -->
|
| 59 |
+
```bash
|
| 60 |
+
export DATASET_NAME="lambdalabs/naruto-blip-captions"
|
| 61 |
+
|
| 62 |
+
accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
|
| 63 |
+
--dataset_name=$DATASET_NAME \
|
| 64 |
+
--resolution=768 \
|
| 65 |
+
--train_batch_size=1 \
|
| 66 |
+
--gradient_accumulation_steps=4 \
|
| 67 |
+
--gradient_checkpointing \
|
| 68 |
+
--max_train_steps=15000 \
|
| 69 |
+
--learning_rate=1e-05 \
|
| 70 |
+
--max_grad_norm=1 \
|
| 71 |
+
--checkpoints_total_limit=3 \
|
| 72 |
+
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 73 |
+
--validation_prompts="A robot naruto, 4k photo" \
|
| 74 |
+
--report_to="wandb" \
|
| 75 |
+
--push_to_hub \
|
| 76 |
+
--output_dir="kandi2-decoder-naruto-model"
|
| 77 |
+
```
|
| 78 |
+
<!-- accelerate_snippet_end -->
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
To train on your own training files, prepare the dataset according to the format required by `datasets`. You can find the instructions for how to do that in the [ImageFolder with metadata](https://huggingface.co/docs/datasets/en/image_load#imagefolder-with-metadata) guide.
|
| 82 |
+
If you wish to use custom loading logic, you should modify the script and we have left pointers for that in the training script.
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
export TRAIN_DIR="path_to_your_dataset"
|
| 86 |
+
|
| 87 |
+
accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
|
| 88 |
+
--train_data_dir=$TRAIN_DIR \
|
| 89 |
+
--resolution=768 \
|
| 90 |
+
--train_batch_size=1 \
|
| 91 |
+
--gradient_accumulation_steps=4 \
|
| 92 |
+
--gradient_checkpointing \
|
| 93 |
+
--max_train_steps=15000 \
|
| 94 |
+
--learning_rate=1e-05 \
|
| 95 |
+
--max_grad_norm=1 \
|
| 96 |
+
--checkpoints_total_limit=3 \
|
| 97 |
+
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 98 |
+
--validation_prompts="A robot naruto, 4k photo" \
|
| 99 |
+
--report_to="wandb" \
|
| 100 |
+
--push_to_hub \
|
| 101 |
+
--output_dir="kandi22-decoder-naruto-model"
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `kandi22-decoder-naruto-model`. To load the fine-tuned model for inference just pass that path to `AutoPipelineForText2Image`
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
from diffusers import AutoPipelineForText2Image
|
| 109 |
+
import torch
|
| 110 |
+
|
| 111 |
+
pipe = AutoPipelineForText2Image.from_pretrained(output_dir, torch_dtype=torch.float16)
|
| 112 |
+
pipe.enable_model_cpu_offload()
|
| 113 |
+
|
| 114 |
+
prompt='A robot naruto, 4k photo'
|
| 115 |
+
images = pipe(prompt=prompt).images
|
| 116 |
+
images[0].save("robot-naruto.png")
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
|
| 120 |
+
```python
|
| 121 |
+
from diffusers import AutoPipelineForText2Image, UNet2DConditionModel
|
| 122 |
+
|
| 123 |
+
model_path = "path_to_saved_model"
|
| 124 |
+
|
| 125 |
+
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")
|
| 126 |
+
|
| 127 |
+
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
|
| 128 |
+
pipe.enable_model_cpu_offload()
|
| 129 |
+
|
| 130 |
+
image = pipe(prompt="A robot naruto, 4k photo").images[0]
|
| 131 |
+
image.save("robot-naruto.png")
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
#### Fine-tune prior
|
| 135 |
+
|
| 136 |
+
You can fine-tune the Kandinsky prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning.
|
| 137 |
+
|
| 138 |
+
<br>
|
| 139 |
+
|
| 140 |
+
<!-- accelerate_snippet_start -->
|
| 141 |
+
```bash
|
| 142 |
+
export DATASET_NAME="lambdalabs/naruto-blip-captions"
|
| 143 |
+
|
| 144 |
+
accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
|
| 145 |
+
--dataset_name=$DATASET_NAME \
|
| 146 |
+
--resolution=768 \
|
| 147 |
+
--train_batch_size=1 \
|
| 148 |
+
--gradient_accumulation_steps=4 \
|
| 149 |
+
--max_train_steps=15000 \
|
| 150 |
+
--learning_rate=1e-05 \
|
| 151 |
+
--max_grad_norm=1 \
|
| 152 |
+
--checkpoints_total_limit=3 \
|
| 153 |
+
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 154 |
+
--validation_prompts="A robot naruto, 4k photo" \
|
| 155 |
+
--report_to="wandb" \
|
| 156 |
+
--push_to_hub \
|
| 157 |
+
--output_dir="kandi2-prior-naruto-model"
|
| 158 |
+
```
|
| 159 |
+
<!-- accelerate_snippet_end -->
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
To perform inference with the fine-tuned prior model, you will need to first create a prior pipeline by passing the `output_dir` to `DiffusionPipeline`. Then create a `KandinskyV22CombinedPipeline` from a pretrained or fine-tuned decoder checkpoint along with all the modules of the prior pipeline you just created.
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
from diffusers import AutoPipelineForText2Image, DiffusionPipeline
|
| 166 |
+
import torch
|
| 167 |
+
|
| 168 |
+
pipe_prior = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
|
| 169 |
+
prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()}
|
| 170 |
+
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)
|
| 171 |
+
|
| 172 |
+
pipe.enable_model_cpu_offload()
|
| 173 |
+
prompt='A robot naruto, 4k photo'
|
| 174 |
+
images = pipe(prompt=prompt, negative_prompt=negative_prompt).images
|
| 175 |
+
images[0]
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
If you want to use a fine-tuned decoder checkpoint along with your fine-tuned prior checkpoint, you can simply replace the "kandinsky-community/kandinsky-2-2-decoder" in above code with your custom model repo name. Note that in order to be able to create a `KandinskyV22CombinedPipeline`, your model repository need to have a prior tag. If you have created your model repo using our training script, the prior tag is automatically included.
|
| 179 |
+
|
| 180 |
+
#### Training with multiple GPUs
|
| 181 |
+
|
| 182 |
+
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
| 183 |
+
for running distributed training with `accelerate`. Here is an example command:
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
export DATASET_NAME="lambdalabs/naruto-blip-captions"
|
| 187 |
+
|
| 188 |
+
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image_decoder.py \
|
| 189 |
+
--dataset_name=$DATASET_NAME \
|
| 190 |
+
--resolution=768 \
|
| 191 |
+
--train_batch_size=1 \
|
| 192 |
+
--gradient_accumulation_steps=4 \
|
| 193 |
+
--gradient_checkpointing \
|
| 194 |
+
--max_train_steps=15000 \
|
| 195 |
+
--learning_rate=1e-05 \
|
| 196 |
+
--max_grad_norm=1 \
|
| 197 |
+
--checkpoints_total_limit=3 \
|
| 198 |
+
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 199 |
+
--validation_prompts="A robot naruto, 4k photo" \
|
| 200 |
+
--report_to="wandb" \
|
| 201 |
+
--push_to_hub \
|
| 202 |
+
--output_dir="kandi2-decoder-naruto-model"
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
#### Training with Min-SNR weighting
|
| 207 |
+
|
| 208 |
+
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://huggingface.co/papers/2303.09556) which helps achieve faster convergence
|
| 209 |
+
by rebalancing the loss. Enable the `--snr_gamma` argument and set it to the recommended
|
| 210 |
+
value of 5.0.
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
## Training with LoRA
|
| 214 |
+
|
| 215 |
+
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
| 216 |
+
|
| 217 |
+
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
| 218 |
+
|
| 219 |
+
- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
|
| 220 |
+
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
|
| 221 |
+
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
|
| 222 |
+
|
| 223 |
+
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
| 224 |
+
|
| 225 |
+
With LoRA, it's possible to fine-tune Kandinsky 2.2 on a custom image-caption pair dataset
|
| 226 |
+
on consumer GPUs like Tesla T4, Tesla V100.
|
| 227 |
+
|
| 228 |
+
### Training
|
| 229 |
+
|
| 230 |
+
First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
#### Train decoder
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
export DATASET_NAME="lambdalabs/naruto-blip-captions"
|
| 237 |
+
|
| 238 |
+
accelerate launch --mixed_precision="fp16" train_text_to_image_decoder_lora.py \
|
| 239 |
+
--dataset_name=$DATASET_NAME --caption_column="text" \
|
| 240 |
+
--resolution=768 \
|
| 241 |
+
--train_batch_size=1 \
|
| 242 |
+
--num_train_epochs=100 --checkpointing_steps=5000 \
|
| 243 |
+
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 244 |
+
--seed=42 \
|
| 245 |
+
--rank=4 \
|
| 246 |
+
--gradient_checkpointing \
|
| 247 |
+
--output_dir="kandi22-decoder-naruto-lora" \
|
| 248 |
+
--validation_prompt="cute dragon creature" --report_to="wandb" \
|
| 249 |
+
--push_to_hub \
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
#### Train prior
|
| 253 |
+
|
| 254 |
+
```bash
|
| 255 |
+
export DATASET_NAME="lambdalabs/naruto-blip-captions"
|
| 256 |
+
|
| 257 |
+
accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \
|
| 258 |
+
--dataset_name=$DATASET_NAME --caption_column="text" \
|
| 259 |
+
--resolution=768 \
|
| 260 |
+
--train_batch_size=1 \
|
| 261 |
+
--num_train_epochs=100 --checkpointing_steps=5000 \
|
| 262 |
+
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 263 |
+
--seed=42 \
|
| 264 |
+
--rank=4 \
|
| 265 |
+
--output_dir="kandi22-prior-naruto-lora" \
|
| 266 |
+
--validation_prompt="cute dragon creature" --report_to="wandb" \
|
| 267 |
+
--push_to_hub \
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run above scripts in consumer GPUs like T4 or V100.___**
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
### Inference
|
| 274 |
+
|
| 275 |
+
#### Inference using fine-tuned LoRA checkpoint for decoder
|
| 276 |
+
|
| 277 |
+
Once you have trained a Kandinsky decoder model using the above command, inference can be done with the `AutoPipelineForText2Image` after loading the trained LoRA weights. You need to pass the `output_dir` for loading the LoRA weights, which in this case is `kandi22-decoder-naruto-lora`.
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
```python
|
| 281 |
+
from diffusers import AutoPipelineForText2Image
|
| 282 |
+
import torch
|
| 283 |
+
|
| 284 |
+
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
|
| 285 |
+
pipe.unet.load_attn_procs(output_dir)
|
| 286 |
+
pipe.enable_model_cpu_offload()
|
| 287 |
+
|
| 288 |
+
prompt='A robot naruto, 4k photo'
|
| 289 |
+
image = pipe(prompt=prompt).images[0]
|
| 290 |
+
image.save("robot_naruto.png")
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
#### Inference using fine-tuned LoRA checkpoint for prior
|
| 294 |
+
|
| 295 |
+
```python
|
| 296 |
+
from diffusers import AutoPipelineForText2Image
|
| 297 |
+
import torch
|
| 298 |
+
|
| 299 |
+
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
|
| 300 |
+
pipe.prior_prior.load_attn_procs(output_dir)
|
| 301 |
+
pipe.enable_model_cpu_offload()
|
| 302 |
+
|
| 303 |
+
prompt='A robot naruto, 4k photo'
|
| 304 |
+
image = pipe(prompt=prompt).images[0]
|
| 305 |
+
image.save("robot_naruto.png")
|
| 306 |
+
image
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
### Training with xFormers:
|
| 310 |
+
|
| 311 |
+
You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.
|
| 312 |
+
|
| 313 |
+
xFormers training is not available for fine-tuning the prior model.
|
| 314 |
+
|
| 315 |
+
**Note**:
|
| 316 |
+
|
| 317 |
+
According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
|
diffusers/examples/kandinsky2_2/text_to_image/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=0.16.0
|
| 2 |
+
torchvision
|
| 3 |
+
transformers>=4.25.1
|
| 4 |
+
datasets
|
| 5 |
+
ftfy
|
| 6 |
+
tensorboard
|
| 7 |
+
Jinja2
|
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import logging
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import shutil
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import accelerate
|
| 24 |
+
import datasets
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
import transformers
|
| 30 |
+
from accelerate import Accelerator
|
| 31 |
+
from accelerate.logging import get_logger
|
| 32 |
+
from accelerate.state import AcceleratorState
|
| 33 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 34 |
+
from datasets import load_dataset
|
| 35 |
+
from huggingface_hub import create_repo, upload_folder
|
| 36 |
+
from packaging import version
|
| 37 |
+
from PIL import Image
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 40 |
+
from transformers.utils import ContextManagers
|
| 41 |
+
|
| 42 |
+
import diffusers
|
| 43 |
+
from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
|
| 44 |
+
from diffusers.optimization import get_scheduler
|
| 45 |
+
from diffusers.training_utils import EMAModel, compute_snr
|
| 46 |
+
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
| 47 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_wandb_available():
|
| 51 |
+
import wandb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 55 |
+
check_min_version("0.35.0.dev0")
|
| 56 |
+
|
| 57 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def save_model_card(
|
| 61 |
+
args,
|
| 62 |
+
repo_id: str,
|
| 63 |
+
images=None,
|
| 64 |
+
repo_folder=None,
|
| 65 |
+
):
|
| 66 |
+
img_str = ""
|
| 67 |
+
if len(images) > 0:
|
| 68 |
+
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
|
| 69 |
+
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
| 70 |
+
img_str += "\n"
|
| 71 |
+
|
| 72 |
+
yaml = f"""
|
| 73 |
+
---
|
| 74 |
+
license: creativeml-openrail-m
|
| 75 |
+
base_model: {args.pretrained_decoder_model_name_or_path}
|
| 76 |
+
datasets:
|
| 77 |
+
- {args.dataset_name}
|
| 78 |
+
prior:
|
| 79 |
+
- {args.pretrained_prior_model_name_or_path}
|
| 80 |
+
tags:
|
| 81 |
+
- kandinsky
|
| 82 |
+
- text-to-image
|
| 83 |
+
- diffusers
|
| 84 |
+
- diffusers-training
|
| 85 |
+
inference: true
|
| 86 |
+
---
|
| 87 |
+
"""
|
| 88 |
+
model_card = f"""
|
| 89 |
+
# Finetuning - {repo_id}
|
| 90 |
+
|
| 91 |
+
This pipeline was finetuned from **{args.pretrained_decoder_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
| 92 |
+
{img_str}
|
| 93 |
+
|
| 94 |
+
## Pipeline usage
|
| 95 |
+
|
| 96 |
+
You can use the pipeline like so:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
from diffusers import DiffusionPipeline
|
| 100 |
+
import torch
|
| 101 |
+
|
| 102 |
+
pipeline = AutoPipelineForText2Image.from_pretrained("{repo_id}", torch_dtype=torch.float16)
|
| 103 |
+
prompt = "{args.validation_prompts[0]}"
|
| 104 |
+
image = pipeline(prompt).images[0]
|
| 105 |
+
image.save("my_image.png")
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Training info
|
| 109 |
+
|
| 110 |
+
These are the key hyperparameters used during training:
|
| 111 |
+
|
| 112 |
+
* Epochs: {args.num_train_epochs}
|
| 113 |
+
* Learning rate: {args.learning_rate}
|
| 114 |
+
* Batch size: {args.train_batch_size}
|
| 115 |
+
* Gradient accumulation steps: {args.gradient_accumulation_steps}
|
| 116 |
+
* Image resolution: {args.resolution}
|
| 117 |
+
* Mixed-precision: {args.mixed_precision}
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
wandb_info = ""
|
| 121 |
+
if is_wandb_available():
|
| 122 |
+
wandb_run_url = None
|
| 123 |
+
if wandb.run is not None:
|
| 124 |
+
wandb_run_url = wandb.run.url
|
| 125 |
+
|
| 126 |
+
if wandb_run_url is not None:
|
| 127 |
+
wandb_info = f"""
|
| 128 |
+
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
model_card += wandb_info
|
| 132 |
+
|
| 133 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 134 |
+
f.write(yaml + model_card)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def log_validation(vae, image_encoder, image_processor, unet, args, accelerator, weight_dtype, epoch):
|
| 138 |
+
logger.info("Running validation... ")
|
| 139 |
+
|
| 140 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 141 |
+
args.pretrained_decoder_model_name_or_path,
|
| 142 |
+
vae=accelerator.unwrap_model(vae),
|
| 143 |
+
prior_image_encoder=accelerator.unwrap_model(image_encoder),
|
| 144 |
+
prior_image_processor=image_processor,
|
| 145 |
+
unet=accelerator.unwrap_model(unet),
|
| 146 |
+
torch_dtype=weight_dtype,
|
| 147 |
+
)
|
| 148 |
+
pipeline = pipeline.to(accelerator.device)
|
| 149 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 150 |
+
|
| 151 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 152 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
| 153 |
+
|
| 154 |
+
if args.seed is None:
|
| 155 |
+
generator = None
|
| 156 |
+
else:
|
| 157 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 158 |
+
|
| 159 |
+
images = []
|
| 160 |
+
for i in range(len(args.validation_prompts)):
|
| 161 |
+
with torch.autocast("cuda"):
|
| 162 |
+
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
| 163 |
+
|
| 164 |
+
images.append(image)
|
| 165 |
+
|
| 166 |
+
for tracker in accelerator.trackers:
|
| 167 |
+
if tracker.name == "tensorboard":
|
| 168 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 169 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
| 170 |
+
elif tracker.name == "wandb":
|
| 171 |
+
tracker.log(
|
| 172 |
+
{
|
| 173 |
+
"validation": [
|
| 174 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
| 175 |
+
for i, image in enumerate(images)
|
| 176 |
+
]
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
logger.warning(f"image logging not implemented for {tracker.name}")
|
| 181 |
+
|
| 182 |
+
del pipeline
|
| 183 |
+
torch.cuda.empty_cache()
|
| 184 |
+
|
| 185 |
+
return images
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def parse_args():
|
| 189 |
+
parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--pretrained_decoder_model_name_or_path",
|
| 192 |
+
type=str,
|
| 193 |
+
default="kandinsky-community/kandinsky-2-2-decoder",
|
| 194 |
+
required=False,
|
| 195 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--pretrained_prior_model_name_or_path",
|
| 199 |
+
type=str,
|
| 200 |
+
default="kandinsky-community/kandinsky-2-2-prior",
|
| 201 |
+
required=False,
|
| 202 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--dataset_name",
|
| 206 |
+
type=str,
|
| 207 |
+
default=None,
|
| 208 |
+
help=(
|
| 209 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 210 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 211 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 212 |
+
),
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--dataset_config_name",
|
| 216 |
+
type=str,
|
| 217 |
+
default=None,
|
| 218 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--train_data_dir",
|
| 222 |
+
type=str,
|
| 223 |
+
default=None,
|
| 224 |
+
help=(
|
| 225 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
| 226 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
| 227 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
| 228 |
+
),
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--max_train_samples",
|
| 235 |
+
type=int,
|
| 236 |
+
default=None,
|
| 237 |
+
help=(
|
| 238 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 239 |
+
"value if set."
|
| 240 |
+
),
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--validation_prompts",
|
| 244 |
+
type=str,
|
| 245 |
+
default=None,
|
| 246 |
+
nargs="+",
|
| 247 |
+
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--output_dir",
|
| 251 |
+
type=str,
|
| 252 |
+
default="kandi_2_2-model-finetuned",
|
| 253 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--cache_dir",
|
| 257 |
+
type=str,
|
| 258 |
+
default=None,
|
| 259 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--resolution",
|
| 264 |
+
type=int,
|
| 265 |
+
default=512,
|
| 266 |
+
help=(
|
| 267 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 268 |
+
" resolution"
|
| 269 |
+
),
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--max_train_steps",
|
| 277 |
+
type=int,
|
| 278 |
+
default=None,
|
| 279 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--gradient_accumulation_steps",
|
| 283 |
+
type=int,
|
| 284 |
+
default=1,
|
| 285 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--gradient_checkpointing",
|
| 289 |
+
action="store_true",
|
| 290 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--learning_rate",
|
| 294 |
+
type=float,
|
| 295 |
+
default=1e-4,
|
| 296 |
+
help="learning rate",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--lr_scheduler",
|
| 300 |
+
type=str,
|
| 301 |
+
default="constant",
|
| 302 |
+
help=(
|
| 303 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 304 |
+
' "constant", "constant_with_warmup"]'
|
| 305 |
+
),
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument(
|
| 308 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--snr_gamma",
|
| 312 |
+
type=float,
|
| 313 |
+
default=None,
|
| 314 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 315 |
+
"More details here: https://huggingface.co/papers/2303.09556.",
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 319 |
+
)
|
| 320 |
+
parser.add_argument(
|
| 321 |
+
"--allow_tf32",
|
| 322 |
+
action="store_true",
|
| 323 |
+
help=(
|
| 324 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 325 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 326 |
+
),
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--dataloader_num_workers",
|
| 331 |
+
type=int,
|
| 332 |
+
default=0,
|
| 333 |
+
help=(
|
| 334 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 335 |
+
),
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 338 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--adam_weight_decay",
|
| 341 |
+
type=float,
|
| 342 |
+
default=0.0,
|
| 343 |
+
required=False,
|
| 344 |
+
help="weight decay_to_use",
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 347 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 348 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 349 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--hub_model_id",
|
| 352 |
+
type=str,
|
| 353 |
+
default=None,
|
| 354 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--logging_dir",
|
| 358 |
+
type=str,
|
| 359 |
+
default="logs",
|
| 360 |
+
help=(
|
| 361 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 362 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 363 |
+
),
|
| 364 |
+
)
|
| 365 |
+
parser.add_argument(
|
| 366 |
+
"--mixed_precision",
|
| 367 |
+
type=str,
|
| 368 |
+
default=None,
|
| 369 |
+
choices=["no", "fp16", "bf16"],
|
| 370 |
+
help=(
|
| 371 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 372 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 373 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 374 |
+
),
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--report_to",
|
| 378 |
+
type=str,
|
| 379 |
+
default="tensorboard",
|
| 380 |
+
help=(
|
| 381 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 382 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 383 |
+
),
|
| 384 |
+
)
|
| 385 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 386 |
+
parser.add_argument(
|
| 387 |
+
"--checkpointing_steps",
|
| 388 |
+
type=int,
|
| 389 |
+
default=500,
|
| 390 |
+
help=(
|
| 391 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
| 392 |
+
" training using `--resume_from_checkpoint`."
|
| 393 |
+
),
|
| 394 |
+
)
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--checkpoints_total_limit",
|
| 397 |
+
type=int,
|
| 398 |
+
default=None,
|
| 399 |
+
help=("Max number of checkpoints to store."),
|
| 400 |
+
)
|
| 401 |
+
parser.add_argument(
|
| 402 |
+
"--resume_from_checkpoint",
|
| 403 |
+
type=str,
|
| 404 |
+
default=None,
|
| 405 |
+
help=(
|
| 406 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 407 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 408 |
+
),
|
| 409 |
+
)
|
| 410 |
+
parser.add_argument(
|
| 411 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 412 |
+
)
|
| 413 |
+
parser.add_argument(
|
| 414 |
+
"--validation_epochs",
|
| 415 |
+
type=int,
|
| 416 |
+
default=5,
|
| 417 |
+
help="Run validation every X epochs.",
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--tracker_project_name",
|
| 421 |
+
type=str,
|
| 422 |
+
default="text2image-fine-tune",
|
| 423 |
+
help=(
|
| 424 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
| 425 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
| 426 |
+
),
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
args = parser.parse_args()
|
| 430 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 431 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 432 |
+
args.local_rank = env_local_rank
|
| 433 |
+
|
| 434 |
+
# Sanity checks
|
| 435 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 436 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
| 437 |
+
|
| 438 |
+
return args
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def main():
|
| 442 |
+
args = parse_args()
|
| 443 |
+
|
| 444 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 445 |
+
raise ValueError(
|
| 446 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 447 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
| 451 |
+
accelerator_project_config = ProjectConfiguration(
|
| 452 |
+
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
| 453 |
+
)
|
| 454 |
+
accelerator = Accelerator(
|
| 455 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 456 |
+
mixed_precision=args.mixed_precision,
|
| 457 |
+
log_with=args.report_to,
|
| 458 |
+
project_config=accelerator_project_config,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Disable AMP for MPS.
|
| 462 |
+
if torch.backends.mps.is_available():
|
| 463 |
+
accelerator.native_amp = False
|
| 464 |
+
|
| 465 |
+
# Make one log on every process with the configuration for debugging.
|
| 466 |
+
logging.basicConfig(
|
| 467 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 468 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 469 |
+
level=logging.INFO,
|
| 470 |
+
)
|
| 471 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 472 |
+
if accelerator.is_local_main_process:
|
| 473 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 474 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 475 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 476 |
+
else:
|
| 477 |
+
datasets.utils.logging.set_verbosity_error()
|
| 478 |
+
transformers.utils.logging.set_verbosity_error()
|
| 479 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 480 |
+
|
| 481 |
+
# If passed along, set the training seed now.
|
| 482 |
+
if args.seed is not None:
|
| 483 |
+
set_seed(args.seed)
|
| 484 |
+
|
| 485 |
+
# Handle the repository creation
|
| 486 |
+
if accelerator.is_main_process:
|
| 487 |
+
if args.output_dir is not None:
|
| 488 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 489 |
+
|
| 490 |
+
if args.push_to_hub:
|
| 491 |
+
repo_id = create_repo(
|
| 492 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 493 |
+
).repo_id
|
| 494 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
|
| 495 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 496 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_processor"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def deepspeed_zero_init_disabled_context_manager():
|
| 500 |
+
"""
|
| 501 |
+
returns either a context list that includes one that will disable zero.Init or an empty context list
|
| 502 |
+
"""
|
| 503 |
+
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
|
| 504 |
+
if deepspeed_plugin is None:
|
| 505 |
+
return []
|
| 506 |
+
|
| 507 |
+
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
| 508 |
+
|
| 509 |
+
weight_dtype = torch.float32
|
| 510 |
+
if accelerator.mixed_precision == "fp16":
|
| 511 |
+
weight_dtype = torch.float16
|
| 512 |
+
elif accelerator.mixed_precision == "bf16":
|
| 513 |
+
weight_dtype = torch.bfloat16
|
| 514 |
+
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
| 515 |
+
vae = VQModel.from_pretrained(
|
| 516 |
+
args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
|
| 517 |
+
).eval()
|
| 518 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 519 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
|
| 520 |
+
).eval()
|
| 521 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
|
| 522 |
+
|
| 523 |
+
# Freeze vae and image_encoder
|
| 524 |
+
vae.requires_grad_(False)
|
| 525 |
+
image_encoder.requires_grad_(False)
|
| 526 |
+
|
| 527 |
+
# Set unet to trainable.
|
| 528 |
+
unet.train()
|
| 529 |
+
|
| 530 |
+
# Create EMA for the unet.
|
| 531 |
+
if args.use_ema:
|
| 532 |
+
ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
|
| 533 |
+
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
|
| 534 |
+
ema_unet.to(accelerator.device)
|
| 535 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 536 |
+
if is_xformers_available():
|
| 537 |
+
import xformers
|
| 538 |
+
|
| 539 |
+
xformers_version = version.parse(xformers.__version__)
|
| 540 |
+
if xformers_version == version.parse("0.0.16"):
|
| 541 |
+
logger.warning(
|
| 542 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 543 |
+
)
|
| 544 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 545 |
+
else:
|
| 546 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 547 |
+
|
| 548 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
| 549 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
| 550 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 551 |
+
def save_model_hook(models, weights, output_dir):
|
| 552 |
+
if args.use_ema:
|
| 553 |
+
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
| 554 |
+
|
| 555 |
+
for i, model in enumerate(models):
|
| 556 |
+
model.save_pretrained(os.path.join(output_dir, "unet"))
|
| 557 |
+
|
| 558 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 559 |
+
weights.pop()
|
| 560 |
+
|
| 561 |
+
def load_model_hook(models, input_dir):
|
| 562 |
+
if args.use_ema:
|
| 563 |
+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
|
| 564 |
+
ema_unet.load_state_dict(load_model.state_dict())
|
| 565 |
+
ema_unet.to(accelerator.device)
|
| 566 |
+
del load_model
|
| 567 |
+
|
| 568 |
+
for i in range(len(models)):
|
| 569 |
+
# pop models so that they are not loaded again
|
| 570 |
+
model = models.pop()
|
| 571 |
+
|
| 572 |
+
# load diffusers style into model
|
| 573 |
+
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
| 574 |
+
model.register_to_config(**load_model.config)
|
| 575 |
+
|
| 576 |
+
model.load_state_dict(load_model.state_dict())
|
| 577 |
+
del load_model
|
| 578 |
+
|
| 579 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 580 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 581 |
+
|
| 582 |
+
if args.gradient_checkpointing:
|
| 583 |
+
unet.enable_gradient_checkpointing()
|
| 584 |
+
|
| 585 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 586 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 587 |
+
if args.allow_tf32:
|
| 588 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 589 |
+
|
| 590 |
+
if args.use_8bit_adam:
|
| 591 |
+
try:
|
| 592 |
+
import bitsandbytes as bnb
|
| 593 |
+
except ImportError:
|
| 594 |
+
raise ImportError(
|
| 595 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
| 599 |
+
else:
|
| 600 |
+
optimizer_cls = torch.optim.AdamW
|
| 601 |
+
|
| 602 |
+
optimizer = optimizer_cls(
|
| 603 |
+
unet.parameters(),
|
| 604 |
+
lr=args.learning_rate,
|
| 605 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 606 |
+
weight_decay=args.adam_weight_decay,
|
| 607 |
+
eps=args.adam_epsilon,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
| 611 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
| 612 |
+
|
| 613 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 614 |
+
# download the dataset.
|
| 615 |
+
if args.dataset_name is not None:
|
| 616 |
+
# Downloading and loading a dataset from the hub.
|
| 617 |
+
dataset = load_dataset(
|
| 618 |
+
args.dataset_name,
|
| 619 |
+
args.dataset_config_name,
|
| 620 |
+
cache_dir=args.cache_dir,
|
| 621 |
+
)
|
| 622 |
+
else:
|
| 623 |
+
data_files = {}
|
| 624 |
+
if args.train_data_dir is not None:
|
| 625 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
| 626 |
+
dataset = load_dataset(
|
| 627 |
+
"imagefolder",
|
| 628 |
+
data_files=data_files,
|
| 629 |
+
cache_dir=args.cache_dir,
|
| 630 |
+
)
|
| 631 |
+
# See more about loading custom images at
|
| 632 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
| 633 |
+
|
| 634 |
+
# Preprocessing the datasets.
|
| 635 |
+
# We need to tokenize inputs and targets.
|
| 636 |
+
column_names = dataset["train"].column_names
|
| 637 |
+
|
| 638 |
+
image_column = args.image_column
|
| 639 |
+
if image_column not in column_names:
|
| 640 |
+
raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
|
| 641 |
+
|
| 642 |
+
def center_crop(image):
|
| 643 |
+
width, height = image.size
|
| 644 |
+
new_size = min(width, height)
|
| 645 |
+
left = (width - new_size) / 2
|
| 646 |
+
top = (height - new_size) / 2
|
| 647 |
+
right = (width + new_size) / 2
|
| 648 |
+
bottom = (height + new_size) / 2
|
| 649 |
+
return image.crop((left, top, right, bottom))
|
| 650 |
+
|
| 651 |
+
def train_transforms(img):
|
| 652 |
+
img = center_crop(img)
|
| 653 |
+
img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
|
| 654 |
+
img = np.array(img).astype(np.float32) / 127.5 - 1
|
| 655 |
+
img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
|
| 656 |
+
return img
|
| 657 |
+
|
| 658 |
+
def preprocess_train(examples):
|
| 659 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
| 660 |
+
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 661 |
+
examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
|
| 662 |
+
return examples
|
| 663 |
+
|
| 664 |
+
with accelerator.main_process_first():
|
| 665 |
+
if args.max_train_samples is not None:
|
| 666 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 667 |
+
# Set the training transforms
|
| 668 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 669 |
+
|
| 670 |
+
def collate_fn(examples):
|
| 671 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 672 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 673 |
+
clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
|
| 674 |
+
clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 675 |
+
return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
|
| 676 |
+
|
| 677 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 678 |
+
train_dataset,
|
| 679 |
+
shuffle=True,
|
| 680 |
+
collate_fn=collate_fn,
|
| 681 |
+
batch_size=args.train_batch_size,
|
| 682 |
+
num_workers=args.dataloader_num_workers,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# Scheduler and math around the number of training steps.
|
| 686 |
+
overrode_max_train_steps = False
|
| 687 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 688 |
+
if args.max_train_steps is None:
|
| 689 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 690 |
+
overrode_max_train_steps = True
|
| 691 |
+
|
| 692 |
+
lr_scheduler = get_scheduler(
|
| 693 |
+
args.lr_scheduler,
|
| 694 |
+
optimizer=optimizer,
|
| 695 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 696 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 697 |
+
)
|
| 698 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 699 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 700 |
+
)
|
| 701 |
+
# Move image_encode and vae to gpu and cast to weight_dtype
|
| 702 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 703 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 704 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 705 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 706 |
+
if overrode_max_train_steps:
|
| 707 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 708 |
+
# Afterwards we recalculate our number of training epochs
|
| 709 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 710 |
+
|
| 711 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 712 |
+
# The trackers initializes automatically on the main process.
|
| 713 |
+
if accelerator.is_main_process:
|
| 714 |
+
tracker_config = dict(vars(args))
|
| 715 |
+
tracker_config.pop("validation_prompts")
|
| 716 |
+
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
| 717 |
+
|
| 718 |
+
# Train!
|
| 719 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 720 |
+
|
| 721 |
+
logger.info("***** Running training *****")
|
| 722 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 723 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 724 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 725 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 726 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 727 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 728 |
+
global_step = 0
|
| 729 |
+
first_epoch = 0
|
| 730 |
+
if args.resume_from_checkpoint:
|
| 731 |
+
if args.resume_from_checkpoint != "latest":
|
| 732 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 733 |
+
else:
|
| 734 |
+
# Get the most recent checkpoint
|
| 735 |
+
dirs = os.listdir(args.output_dir)
|
| 736 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 737 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 738 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 739 |
+
|
| 740 |
+
if path is None:
|
| 741 |
+
accelerator.print(
|
| 742 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 743 |
+
)
|
| 744 |
+
args.resume_from_checkpoint = None
|
| 745 |
+
initial_global_step = 0
|
| 746 |
+
else:
|
| 747 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 748 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 749 |
+
global_step = int(path.split("-")[1])
|
| 750 |
+
|
| 751 |
+
initial_global_step = global_step
|
| 752 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 753 |
+
else:
|
| 754 |
+
initial_global_step = 0
|
| 755 |
+
|
| 756 |
+
progress_bar = tqdm(
|
| 757 |
+
range(0, args.max_train_steps),
|
| 758 |
+
initial=initial_global_step,
|
| 759 |
+
desc="Steps",
|
| 760 |
+
# Only show the progress bar once on each machine.
|
| 761 |
+
disable=not accelerator.is_local_main_process,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 765 |
+
train_loss = 0.0
|
| 766 |
+
for step, batch in enumerate(train_dataloader):
|
| 767 |
+
with accelerator.accumulate(unet):
|
| 768 |
+
# Convert images to latent space
|
| 769 |
+
images = batch["pixel_values"].to(weight_dtype)
|
| 770 |
+
clip_images = batch["clip_pixel_values"].to(weight_dtype)
|
| 771 |
+
latents = vae.encode(images).latents
|
| 772 |
+
image_embeds = image_encoder(clip_images).image_embeds
|
| 773 |
+
# Sample noise that we'll add to the latents
|
| 774 |
+
noise = torch.randn_like(latents)
|
| 775 |
+
bsz = latents.shape[0]
|
| 776 |
+
# Sample a random timestep for each image
|
| 777 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 778 |
+
timesteps = timesteps.long()
|
| 779 |
+
|
| 780 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 781 |
+
|
| 782 |
+
target = noise
|
| 783 |
+
|
| 784 |
+
# Predict the noise residual and compute loss
|
| 785 |
+
added_cond_kwargs = {"image_embeds": image_embeds}
|
| 786 |
+
|
| 787 |
+
model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
|
| 788 |
+
|
| 789 |
+
if args.snr_gamma is None:
|
| 790 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 791 |
+
else:
|
| 792 |
+
# Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
|
| 793 |
+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
| 794 |
+
# This is discussed in Section 4.2 of the same paper.
|
| 795 |
+
snr = compute_snr(noise_scheduler, timesteps)
|
| 796 |
+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
| 797 |
+
dim=1
|
| 798 |
+
)[0]
|
| 799 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 800 |
+
mse_loss_weights = mse_loss_weights / snr
|
| 801 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 802 |
+
mse_loss_weights = mse_loss_weights / (snr + 1)
|
| 803 |
+
|
| 804 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 805 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
| 806 |
+
loss = loss.mean()
|
| 807 |
+
|
| 808 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 809 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 810 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 811 |
+
|
| 812 |
+
# Backpropagate
|
| 813 |
+
accelerator.backward(loss)
|
| 814 |
+
if accelerator.sync_gradients:
|
| 815 |
+
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
| 816 |
+
optimizer.step()
|
| 817 |
+
lr_scheduler.step()
|
| 818 |
+
optimizer.zero_grad()
|
| 819 |
+
|
| 820 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 821 |
+
if accelerator.sync_gradients:
|
| 822 |
+
if args.use_ema:
|
| 823 |
+
ema_unet.step(unet.parameters())
|
| 824 |
+
progress_bar.update(1)
|
| 825 |
+
global_step += 1
|
| 826 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 827 |
+
train_loss = 0.0
|
| 828 |
+
|
| 829 |
+
if global_step % args.checkpointing_steps == 0:
|
| 830 |
+
if accelerator.is_main_process:
|
| 831 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 832 |
+
if args.checkpoints_total_limit is not None:
|
| 833 |
+
checkpoints = os.listdir(args.output_dir)
|
| 834 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 835 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 836 |
+
|
| 837 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 838 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 839 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 840 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 841 |
+
|
| 842 |
+
logger.info(
|
| 843 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 844 |
+
)
|
| 845 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 846 |
+
|
| 847 |
+
for removing_checkpoint in removing_checkpoints:
|
| 848 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 849 |
+
shutil.rmtree(removing_checkpoint)
|
| 850 |
+
|
| 851 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 852 |
+
accelerator.save_state(save_path)
|
| 853 |
+
logger.info(f"Saved state to {save_path}")
|
| 854 |
+
|
| 855 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 856 |
+
progress_bar.set_postfix(**logs)
|
| 857 |
+
|
| 858 |
+
if global_step >= args.max_train_steps:
|
| 859 |
+
break
|
| 860 |
+
|
| 861 |
+
if accelerator.is_main_process:
|
| 862 |
+
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
| 863 |
+
if args.use_ema:
|
| 864 |
+
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
| 865 |
+
ema_unet.store(unet.parameters())
|
| 866 |
+
ema_unet.copy_to(unet.parameters())
|
| 867 |
+
log_validation(
|
| 868 |
+
vae,
|
| 869 |
+
image_encoder,
|
| 870 |
+
image_processor,
|
| 871 |
+
unet,
|
| 872 |
+
args,
|
| 873 |
+
accelerator,
|
| 874 |
+
weight_dtype,
|
| 875 |
+
global_step,
|
| 876 |
+
)
|
| 877 |
+
if args.use_ema:
|
| 878 |
+
# Switch back to the original UNet parameters.
|
| 879 |
+
ema_unet.restore(unet.parameters())
|
| 880 |
+
|
| 881 |
+
# Create the pipeline using the trained modules and save it.
|
| 882 |
+
accelerator.wait_for_everyone()
|
| 883 |
+
if accelerator.is_main_process:
|
| 884 |
+
unet = accelerator.unwrap_model(unet)
|
| 885 |
+
if args.use_ema:
|
| 886 |
+
ema_unet.copy_to(unet.parameters())
|
| 887 |
+
|
| 888 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 889 |
+
args.pretrained_decoder_model_name_or_path,
|
| 890 |
+
vae=vae,
|
| 891 |
+
unet=unet,
|
| 892 |
+
)
|
| 893 |
+
pipeline.decoder_pipe.save_pretrained(args.output_dir)
|
| 894 |
+
|
| 895 |
+
# Run a final round of inference.
|
| 896 |
+
images = []
|
| 897 |
+
if args.validation_prompts is not None:
|
| 898 |
+
logger.info("Running inference for collecting generated images...")
|
| 899 |
+
pipeline.torch_dtype = weight_dtype
|
| 900 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 901 |
+
pipeline.enable_model_cpu_offload()
|
| 902 |
+
|
| 903 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 904 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
| 905 |
+
|
| 906 |
+
if args.seed is None:
|
| 907 |
+
generator = None
|
| 908 |
+
else:
|
| 909 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 910 |
+
|
| 911 |
+
for i in range(len(args.validation_prompts)):
|
| 912 |
+
with torch.autocast("cuda"):
|
| 913 |
+
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
| 914 |
+
images.append(image)
|
| 915 |
+
|
| 916 |
+
if args.push_to_hub:
|
| 917 |
+
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
|
| 918 |
+
upload_folder(
|
| 919 |
+
repo_id=repo_id,
|
| 920 |
+
folder_path=args.output_dir,
|
| 921 |
+
commit_message="End of training",
|
| 922 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
accelerator.end_training()
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
if __name__ == "__main__":
|
| 929 |
+
main()
|
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
ADDED
|
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fine-tuning script for Kandinsky with support for LoRA."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import shutil
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import datasets
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
import transformers
|
| 30 |
+
from accelerate import Accelerator
|
| 31 |
+
from accelerate.logging import get_logger
|
| 32 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from huggingface_hub import create_repo, upload_folder
|
| 35 |
+
from PIL import Image
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 38 |
+
|
| 39 |
+
import diffusers
|
| 40 |
+
from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
|
| 41 |
+
from diffusers.loaders import AttnProcsLayers
|
| 42 |
+
from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
|
| 43 |
+
from diffusers.optimization import get_scheduler
|
| 44 |
+
from diffusers.training_utils import compute_snr
|
| 45 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 49 |
+
check_min_version("0.35.0.dev0")
|
| 50 |
+
|
| 51 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
| 55 |
+
img_str = ""
|
| 56 |
+
for i, image in enumerate(images):
|
| 57 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 58 |
+
img_str += f"\n"
|
| 59 |
+
|
| 60 |
+
yaml = f"""
|
| 61 |
+
---
|
| 62 |
+
license: creativeml-openrail-m
|
| 63 |
+
base_model: {base_model}
|
| 64 |
+
tags:
|
| 65 |
+
- kandinsky
|
| 66 |
+
- text-to-image
|
| 67 |
+
- diffusers
|
| 68 |
+
- diffusers-training
|
| 69 |
+
- lora
|
| 70 |
+
inference: true
|
| 71 |
+
---
|
| 72 |
+
"""
|
| 73 |
+
model_card = f"""
|
| 74 |
+
# LoRA text2image fine-tuning - {repo_id}
|
| 75 |
+
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
| 76 |
+
{img_str}
|
| 77 |
+
"""
|
| 78 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 79 |
+
f.write(yaml + model_card)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def parse_args():
|
| 83 |
+
parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2 with LoRA.")
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--pretrained_decoder_model_name_or_path",
|
| 86 |
+
type=str,
|
| 87 |
+
default="kandinsky-community/kandinsky-2-2-decoder",
|
| 88 |
+
required=False,
|
| 89 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--pretrained_prior_model_name_or_path",
|
| 93 |
+
type=str,
|
| 94 |
+
default="kandinsky-community/kandinsky-2-2-prior",
|
| 95 |
+
required=False,
|
| 96 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--dataset_name",
|
| 100 |
+
type=str,
|
| 101 |
+
default=None,
|
| 102 |
+
help=(
|
| 103 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 104 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 105 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 106 |
+
),
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--dataset_config_name",
|
| 110 |
+
type=str,
|
| 111 |
+
default=None,
|
| 112 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--train_data_dir",
|
| 116 |
+
type=str,
|
| 117 |
+
default=None,
|
| 118 |
+
help=(
|
| 119 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
| 120 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
| 121 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--num_validation_images",
|
| 132 |
+
type=int,
|
| 133 |
+
default=4,
|
| 134 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--validation_epochs",
|
| 138 |
+
type=int,
|
| 139 |
+
default=1,
|
| 140 |
+
help=(
|
| 141 |
+
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
|
| 142 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
| 143 |
+
),
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--max_train_samples",
|
| 147 |
+
type=int,
|
| 148 |
+
default=None,
|
| 149 |
+
help=(
|
| 150 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 151 |
+
"value if set."
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--output_dir",
|
| 156 |
+
type=str,
|
| 157 |
+
default="kandi_2_2-model-finetuned-lora",
|
| 158 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--cache_dir",
|
| 162 |
+
type=str,
|
| 163 |
+
default=None,
|
| 164 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--resolution",
|
| 169 |
+
type=int,
|
| 170 |
+
default=512,
|
| 171 |
+
help=(
|
| 172 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 173 |
+
" resolution"
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--max_train_steps",
|
| 182 |
+
type=int,
|
| 183 |
+
default=None,
|
| 184 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--gradient_accumulation_steps",
|
| 188 |
+
type=int,
|
| 189 |
+
default=1,
|
| 190 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--gradient_checkpointing",
|
| 194 |
+
action="store_true",
|
| 195 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--learning_rate",
|
| 199 |
+
type=float,
|
| 200 |
+
default=1e-4,
|
| 201 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--lr_scheduler",
|
| 205 |
+
type=str,
|
| 206 |
+
default="constant",
|
| 207 |
+
help=(
|
| 208 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 209 |
+
' "constant", "constant_with_warmup"]'
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--snr_gamma",
|
| 217 |
+
type=float,
|
| 218 |
+
default=None,
|
| 219 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 220 |
+
"More details here: https://huggingface.co/papers/2303.09556.",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--allow_tf32",
|
| 227 |
+
action="store_true",
|
| 228 |
+
help=(
|
| 229 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 230 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 231 |
+
),
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--dataloader_num_workers",
|
| 235 |
+
type=int,
|
| 236 |
+
default=0,
|
| 237 |
+
help=(
|
| 238 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 242 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 243 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
| 244 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 245 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 246 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 247 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--hub_model_id",
|
| 250 |
+
type=str,
|
| 251 |
+
default=None,
|
| 252 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 253 |
+
)
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--logging_dir",
|
| 256 |
+
type=str,
|
| 257 |
+
default="logs",
|
| 258 |
+
help=(
|
| 259 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 260 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 261 |
+
),
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--mixed_precision",
|
| 265 |
+
type=str,
|
| 266 |
+
default=None,
|
| 267 |
+
choices=["no", "fp16", "bf16"],
|
| 268 |
+
help=(
|
| 269 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 270 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 271 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--report_to",
|
| 276 |
+
type=str,
|
| 277 |
+
default="tensorboard",
|
| 278 |
+
help=(
|
| 279 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 280 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 281 |
+
),
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--checkpointing_steps",
|
| 286 |
+
type=int,
|
| 287 |
+
default=500,
|
| 288 |
+
help=(
|
| 289 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
| 290 |
+
" training using `--resume_from_checkpoint`."
|
| 291 |
+
),
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--checkpoints_total_limit",
|
| 295 |
+
type=int,
|
| 296 |
+
default=None,
|
| 297 |
+
help=("Max number of checkpoints to store."),
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--resume_from_checkpoint",
|
| 301 |
+
type=str,
|
| 302 |
+
default=None,
|
| 303 |
+
help=(
|
| 304 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 305 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 306 |
+
),
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--rank",
|
| 310 |
+
type=int,
|
| 311 |
+
default=4,
|
| 312 |
+
help=("The dimension of the LoRA update matrices."),
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
args = parser.parse_args()
|
| 316 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 317 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 318 |
+
args.local_rank = env_local_rank
|
| 319 |
+
|
| 320 |
+
# Sanity checks
|
| 321 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 322 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
| 323 |
+
|
| 324 |
+
return args
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
args = parse_args()
|
| 329 |
+
|
| 330 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 331 |
+
raise ValueError(
|
| 332 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 333 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 337 |
+
accelerator_project_config = ProjectConfiguration(
|
| 338 |
+
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
| 339 |
+
)
|
| 340 |
+
accelerator = Accelerator(
|
| 341 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 342 |
+
mixed_precision=args.mixed_precision,
|
| 343 |
+
log_with=args.report_to,
|
| 344 |
+
project_config=accelerator_project_config,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Disable AMP for MPS.
|
| 348 |
+
if torch.backends.mps.is_available():
|
| 349 |
+
accelerator.native_amp = False
|
| 350 |
+
|
| 351 |
+
if args.report_to == "wandb":
|
| 352 |
+
if not is_wandb_available():
|
| 353 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 354 |
+
import wandb
|
| 355 |
+
|
| 356 |
+
# Make one log on every process with the configuration for debugging.
|
| 357 |
+
logging.basicConfig(
|
| 358 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 359 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 360 |
+
level=logging.INFO,
|
| 361 |
+
)
|
| 362 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 363 |
+
if accelerator.is_local_main_process:
|
| 364 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 365 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 366 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 367 |
+
else:
|
| 368 |
+
datasets.utils.logging.set_verbosity_error()
|
| 369 |
+
transformers.utils.logging.set_verbosity_error()
|
| 370 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 371 |
+
|
| 372 |
+
# If passed along, set the training seed now.
|
| 373 |
+
if args.seed is not None:
|
| 374 |
+
set_seed(args.seed)
|
| 375 |
+
|
| 376 |
+
# Handle the repository creation
|
| 377 |
+
if accelerator.is_main_process:
|
| 378 |
+
if args.output_dir is not None:
|
| 379 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 380 |
+
|
| 381 |
+
if args.push_to_hub:
|
| 382 |
+
repo_id = create_repo(
|
| 383 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 384 |
+
).repo_id
|
| 385 |
+
# Load scheduler, tokenizer and models.
|
| 386 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
|
| 387 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 388 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_processor"
|
| 389 |
+
)
|
| 390 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 391 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
vae = VQModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="movq")
|
| 395 |
+
|
| 396 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
|
| 397 |
+
# freeze parameters of models to save more memory
|
| 398 |
+
unet.requires_grad_(False)
|
| 399 |
+
vae.requires_grad_(False)
|
| 400 |
+
|
| 401 |
+
image_encoder.requires_grad_(False)
|
| 402 |
+
|
| 403 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 404 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 405 |
+
weight_dtype = torch.float32
|
| 406 |
+
if accelerator.mixed_precision == "fp16":
|
| 407 |
+
weight_dtype = torch.float16
|
| 408 |
+
elif accelerator.mixed_precision == "bf16":
|
| 409 |
+
weight_dtype = torch.bfloat16
|
| 410 |
+
|
| 411 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
| 412 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 413 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 414 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 415 |
+
|
| 416 |
+
lora_attn_procs = {}
|
| 417 |
+
for name in unet.attn_processors.keys():
|
| 418 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 419 |
+
if name.startswith("mid_block"):
|
| 420 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 421 |
+
elif name.startswith("up_blocks"):
|
| 422 |
+
block_id = int(name[len("up_blocks.")])
|
| 423 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 424 |
+
elif name.startswith("down_blocks"):
|
| 425 |
+
block_id = int(name[len("down_blocks.")])
|
| 426 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 427 |
+
|
| 428 |
+
lora_attn_procs[name] = LoRAAttnAddedKVProcessor(
|
| 429 |
+
hidden_size=hidden_size,
|
| 430 |
+
cross_attention_dim=cross_attention_dim,
|
| 431 |
+
rank=args.rank,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
unet.set_attn_processor(lora_attn_procs)
|
| 435 |
+
|
| 436 |
+
lora_layers = AttnProcsLayers(unet.attn_processors)
|
| 437 |
+
|
| 438 |
+
if args.allow_tf32:
|
| 439 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 440 |
+
|
| 441 |
+
if args.use_8bit_adam:
|
| 442 |
+
try:
|
| 443 |
+
import bitsandbytes as bnb
|
| 444 |
+
except ImportError:
|
| 445 |
+
raise ImportError(
|
| 446 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
| 450 |
+
else:
|
| 451 |
+
optimizer_cls = torch.optim.AdamW
|
| 452 |
+
|
| 453 |
+
optimizer = optimizer_cls(
|
| 454 |
+
lora_layers.parameters(),
|
| 455 |
+
lr=args.learning_rate,
|
| 456 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 457 |
+
weight_decay=args.adam_weight_decay,
|
| 458 |
+
eps=args.adam_epsilon,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
| 462 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
| 463 |
+
|
| 464 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 465 |
+
# download the dataset.
|
| 466 |
+
if args.dataset_name is not None:
|
| 467 |
+
# Downloading and loading a dataset from the hub.
|
| 468 |
+
dataset = load_dataset(
|
| 469 |
+
args.dataset_name,
|
| 470 |
+
args.dataset_config_name,
|
| 471 |
+
cache_dir=args.cache_dir,
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
data_files = {}
|
| 475 |
+
if args.train_data_dir is not None:
|
| 476 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
| 477 |
+
dataset = load_dataset(
|
| 478 |
+
"imagefolder",
|
| 479 |
+
data_files=data_files,
|
| 480 |
+
cache_dir=args.cache_dir,
|
| 481 |
+
)
|
| 482 |
+
# See more about loading custom images at
|
| 483 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
| 484 |
+
|
| 485 |
+
# Preprocessing the datasets.
|
| 486 |
+
# We need to tokenize inputs and targets.
|
| 487 |
+
column_names = dataset["train"].column_names
|
| 488 |
+
|
| 489 |
+
image_column = args.image_column
|
| 490 |
+
if image_column not in column_names:
|
| 491 |
+
raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
|
| 492 |
+
|
| 493 |
+
def center_crop(image):
|
| 494 |
+
width, height = image.size
|
| 495 |
+
new_size = min(width, height)
|
| 496 |
+
left = (width - new_size) / 2
|
| 497 |
+
top = (height - new_size) / 2
|
| 498 |
+
right = (width + new_size) / 2
|
| 499 |
+
bottom = (height + new_size) / 2
|
| 500 |
+
return image.crop((left, top, right, bottom))
|
| 501 |
+
|
| 502 |
+
def train_transforms(img):
|
| 503 |
+
img = center_crop(img)
|
| 504 |
+
img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
|
| 505 |
+
img = np.array(img).astype(np.float32) / 127.5 - 1
|
| 506 |
+
img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
|
| 507 |
+
return img
|
| 508 |
+
|
| 509 |
+
def preprocess_train(examples):
|
| 510 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
| 511 |
+
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 512 |
+
examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
|
| 513 |
+
return examples
|
| 514 |
+
|
| 515 |
+
with accelerator.main_process_first():
|
| 516 |
+
if args.max_train_samples is not None:
|
| 517 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 518 |
+
# Set the training transforms
|
| 519 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 520 |
+
|
| 521 |
+
def collate_fn(examples):
|
| 522 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 523 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 524 |
+
clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
|
| 525 |
+
clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 526 |
+
return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
|
| 527 |
+
|
| 528 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 529 |
+
train_dataset,
|
| 530 |
+
shuffle=True,
|
| 531 |
+
collate_fn=collate_fn,
|
| 532 |
+
batch_size=args.train_batch_size,
|
| 533 |
+
num_workers=args.dataloader_num_workers,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Scheduler and math around the number of training steps.
|
| 537 |
+
overrode_max_train_steps = False
|
| 538 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 539 |
+
if args.max_train_steps is None:
|
| 540 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 541 |
+
overrode_max_train_steps = True
|
| 542 |
+
|
| 543 |
+
lr_scheduler = get_scheduler(
|
| 544 |
+
args.lr_scheduler,
|
| 545 |
+
optimizer=optimizer,
|
| 546 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 547 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 548 |
+
)
|
| 549 |
+
# Prepare everything with our `accelerator`.
|
| 550 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 551 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 555 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 556 |
+
if overrode_max_train_steps:
|
| 557 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 558 |
+
# Afterwards we recalculate our number of training epochs
|
| 559 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 560 |
+
|
| 561 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 562 |
+
# The trackers initializes automatically on the main process.
|
| 563 |
+
if accelerator.is_main_process:
|
| 564 |
+
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
| 565 |
+
|
| 566 |
+
# Train!
|
| 567 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 568 |
+
|
| 569 |
+
logger.info("***** Running training *****")
|
| 570 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 571 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 572 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 573 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 574 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 575 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 576 |
+
global_step = 0
|
| 577 |
+
first_epoch = 0
|
| 578 |
+
|
| 579 |
+
# Potentially load in the weights and states from a previous save
|
| 580 |
+
if args.resume_from_checkpoint:
|
| 581 |
+
if args.resume_from_checkpoint != "latest":
|
| 582 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 583 |
+
else:
|
| 584 |
+
# Get the most recent checkpoint
|
| 585 |
+
dirs = os.listdir(args.output_dir)
|
| 586 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 587 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 588 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 589 |
+
|
| 590 |
+
if path is None:
|
| 591 |
+
accelerator.print(
|
| 592 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 593 |
+
)
|
| 594 |
+
args.resume_from_checkpoint = None
|
| 595 |
+
initial_global_step = 0
|
| 596 |
+
else:
|
| 597 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 598 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 599 |
+
global_step = int(path.split("-")[1])
|
| 600 |
+
|
| 601 |
+
initial_global_step = global_step
|
| 602 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 603 |
+
else:
|
| 604 |
+
initial_global_step = 0
|
| 605 |
+
|
| 606 |
+
progress_bar = tqdm(
|
| 607 |
+
range(0, args.max_train_steps),
|
| 608 |
+
initial=initial_global_step,
|
| 609 |
+
desc="Steps",
|
| 610 |
+
# Only show the progress bar once on each machine.
|
| 611 |
+
disable=not accelerator.is_local_main_process,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 615 |
+
unet.train()
|
| 616 |
+
train_loss = 0.0
|
| 617 |
+
for step, batch in enumerate(train_dataloader):
|
| 618 |
+
with accelerator.accumulate(unet):
|
| 619 |
+
# Convert images to latent space
|
| 620 |
+
images = batch["pixel_values"].to(weight_dtype)
|
| 621 |
+
clip_images = batch["clip_pixel_values"].to(weight_dtype)
|
| 622 |
+
latents = vae.encode(images).latents
|
| 623 |
+
image_embeds = image_encoder(clip_images).image_embeds
|
| 624 |
+
# Sample noise that we'll add to the latents
|
| 625 |
+
noise = torch.randn_like(latents)
|
| 626 |
+
bsz = latents.shape[0]
|
| 627 |
+
# Sample a random timestep for each image
|
| 628 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 629 |
+
timesteps = timesteps.long()
|
| 630 |
+
|
| 631 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 632 |
+
|
| 633 |
+
target = noise
|
| 634 |
+
|
| 635 |
+
# Predict the noise residual and compute loss
|
| 636 |
+
added_cond_kwargs = {"image_embeds": image_embeds}
|
| 637 |
+
|
| 638 |
+
model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
|
| 639 |
+
|
| 640 |
+
if args.snr_gamma is None:
|
| 641 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 642 |
+
else:
|
| 643 |
+
# Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
|
| 644 |
+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
| 645 |
+
# This is discussed in Section 4.2 of the same paper.
|
| 646 |
+
snr = compute_snr(noise_scheduler, timesteps)
|
| 647 |
+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
| 648 |
+
dim=1
|
| 649 |
+
)[0]
|
| 650 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 651 |
+
mse_loss_weights = mse_loss_weights / snr
|
| 652 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 653 |
+
mse_loss_weights = mse_loss_weights / (snr + 1)
|
| 654 |
+
|
| 655 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 656 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
| 657 |
+
loss = loss.mean()
|
| 658 |
+
|
| 659 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 660 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 661 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 662 |
+
|
| 663 |
+
# Backpropagate
|
| 664 |
+
accelerator.backward(loss)
|
| 665 |
+
if accelerator.sync_gradients:
|
| 666 |
+
params_to_clip = lora_layers.parameters()
|
| 667 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 668 |
+
optimizer.step()
|
| 669 |
+
lr_scheduler.step()
|
| 670 |
+
optimizer.zero_grad()
|
| 671 |
+
|
| 672 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 673 |
+
if accelerator.sync_gradients:
|
| 674 |
+
progress_bar.update(1)
|
| 675 |
+
global_step += 1
|
| 676 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 677 |
+
train_loss = 0.0
|
| 678 |
+
|
| 679 |
+
if global_step % args.checkpointing_steps == 0:
|
| 680 |
+
if accelerator.is_main_process:
|
| 681 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 682 |
+
if args.checkpoints_total_limit is not None:
|
| 683 |
+
checkpoints = os.listdir(args.output_dir)
|
| 684 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 685 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 686 |
+
|
| 687 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 688 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 689 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 690 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 691 |
+
|
| 692 |
+
logger.info(
|
| 693 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 694 |
+
)
|
| 695 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 696 |
+
|
| 697 |
+
for removing_checkpoint in removing_checkpoints:
|
| 698 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 699 |
+
shutil.rmtree(removing_checkpoint)
|
| 700 |
+
|
| 701 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 702 |
+
accelerator.save_state(save_path)
|
| 703 |
+
logger.info(f"Saved state to {save_path}")
|
| 704 |
+
|
| 705 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 706 |
+
progress_bar.set_postfix(**logs)
|
| 707 |
+
|
| 708 |
+
if global_step >= args.max_train_steps:
|
| 709 |
+
break
|
| 710 |
+
|
| 711 |
+
if accelerator.is_main_process:
|
| 712 |
+
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
| 713 |
+
logger.info(
|
| 714 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
| 715 |
+
f" {args.validation_prompt}."
|
| 716 |
+
)
|
| 717 |
+
# create pipeline
|
| 718 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 719 |
+
args.pretrained_decoder_model_name_or_path,
|
| 720 |
+
unet=accelerator.unwrap_model(unet),
|
| 721 |
+
torch_dtype=weight_dtype,
|
| 722 |
+
)
|
| 723 |
+
pipeline = pipeline.to(accelerator.device)
|
| 724 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 725 |
+
|
| 726 |
+
# run inference
|
| 727 |
+
generator = torch.Generator(device=accelerator.device)
|
| 728 |
+
if args.seed is not None:
|
| 729 |
+
generator = generator.manual_seed(args.seed)
|
| 730 |
+
images = []
|
| 731 |
+
for _ in range(args.num_validation_images):
|
| 732 |
+
images.append(
|
| 733 |
+
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
for tracker in accelerator.trackers:
|
| 737 |
+
if tracker.name == "tensorboard":
|
| 738 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 739 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
| 740 |
+
if tracker.name == "wandb":
|
| 741 |
+
tracker.log(
|
| 742 |
+
{
|
| 743 |
+
"validation": [
|
| 744 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 745 |
+
for i, image in enumerate(images)
|
| 746 |
+
]
|
| 747 |
+
}
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
del pipeline
|
| 751 |
+
torch.cuda.empty_cache()
|
| 752 |
+
|
| 753 |
+
# Save the lora layers
|
| 754 |
+
accelerator.wait_for_everyone()
|
| 755 |
+
if accelerator.is_main_process:
|
| 756 |
+
unet = unet.to(torch.float32)
|
| 757 |
+
unet.save_attn_procs(args.output_dir)
|
| 758 |
+
|
| 759 |
+
if args.push_to_hub:
|
| 760 |
+
save_model_card(
|
| 761 |
+
repo_id,
|
| 762 |
+
images=images,
|
| 763 |
+
base_model=args.pretrained_decoder_model_name_or_path,
|
| 764 |
+
dataset_name=args.dataset_name,
|
| 765 |
+
repo_folder=args.output_dir,
|
| 766 |
+
)
|
| 767 |
+
upload_folder(
|
| 768 |
+
repo_id=repo_id,
|
| 769 |
+
folder_path=args.output_dir,
|
| 770 |
+
commit_message="End of training",
|
| 771 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# Final inference
|
| 775 |
+
# Load previous pipeline
|
| 776 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 777 |
+
args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
|
| 778 |
+
)
|
| 779 |
+
pipeline = pipeline.to(accelerator.device)
|
| 780 |
+
|
| 781 |
+
# load attention processors
|
| 782 |
+
pipeline.unet.load_attn_procs(args.output_dir)
|
| 783 |
+
|
| 784 |
+
# run inference
|
| 785 |
+
generator = torch.Generator(device=accelerator.device)
|
| 786 |
+
if args.seed is not None:
|
| 787 |
+
generator = generator.manual_seed(args.seed)
|
| 788 |
+
images = []
|
| 789 |
+
for _ in range(args.num_validation_images):
|
| 790 |
+
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
| 791 |
+
|
| 792 |
+
if accelerator.is_main_process:
|
| 793 |
+
for tracker in accelerator.trackers:
|
| 794 |
+
if len(images) != 0:
|
| 795 |
+
if tracker.name == "tensorboard":
|
| 796 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 797 |
+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
| 798 |
+
if tracker.name == "wandb":
|
| 799 |
+
tracker.log(
|
| 800 |
+
{
|
| 801 |
+
"test": [
|
| 802 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 803 |
+
for i, image in enumerate(images)
|
| 804 |
+
]
|
| 805 |
+
}
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
accelerator.end_training()
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
if __name__ == "__main__":
|
| 812 |
+
main()
|
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import datasets
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.utils.checkpoint
|
| 30 |
+
import transformers
|
| 31 |
+
from accelerate import Accelerator
|
| 32 |
+
from accelerate.logging import get_logger
|
| 33 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 34 |
+
from datasets import load_dataset
|
| 35 |
+
from huggingface_hub import create_repo, upload_folder
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 38 |
+
|
| 39 |
+
import diffusers
|
| 40 |
+
from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
|
| 41 |
+
from diffusers.loaders import AttnProcsLayers
|
| 42 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor
|
| 43 |
+
from diffusers.optimization import get_scheduler
|
| 44 |
+
from diffusers.training_utils import compute_snr
|
| 45 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 49 |
+
check_min_version("0.35.0.dev0")
|
| 50 |
+
|
| 51 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
| 55 |
+
img_str = ""
|
| 56 |
+
for i, image in enumerate(images):
|
| 57 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 58 |
+
img_str += f"\n"
|
| 59 |
+
|
| 60 |
+
yaml = f"""
|
| 61 |
+
---
|
| 62 |
+
license: creativeml-openrail-m
|
| 63 |
+
base_model: {base_model}
|
| 64 |
+
tags:
|
| 65 |
+
- kandinsky
|
| 66 |
+
- text-to-image
|
| 67 |
+
- diffusers
|
| 68 |
+
- diffusers-training
|
| 69 |
+
- lora
|
| 70 |
+
inference: true
|
| 71 |
+
---
|
| 72 |
+
"""
|
| 73 |
+
model_card = f"""
|
| 74 |
+
# LoRA text2image fine-tuning - {repo_id}
|
| 75 |
+
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
| 76 |
+
{img_str}
|
| 77 |
+
"""
|
| 78 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 79 |
+
f.write(yaml + model_card)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def parse_args():
|
| 83 |
+
parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--pretrained_decoder_model_name_or_path",
|
| 86 |
+
type=str,
|
| 87 |
+
default="kandinsky-community/kandinsky-2-2-decoder",
|
| 88 |
+
required=False,
|
| 89 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--pretrained_prior_model_name_or_path",
|
| 93 |
+
type=str,
|
| 94 |
+
default="kandinsky-community/kandinsky-2-2-prior",
|
| 95 |
+
required=False,
|
| 96 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--dataset_name",
|
| 100 |
+
type=str,
|
| 101 |
+
default=None,
|
| 102 |
+
help=(
|
| 103 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 104 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 105 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 106 |
+
),
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--dataset_config_name",
|
| 110 |
+
type=str,
|
| 111 |
+
default=None,
|
| 112 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--train_data_dir",
|
| 116 |
+
type=str,
|
| 117 |
+
default=None,
|
| 118 |
+
help=(
|
| 119 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
| 120 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
| 121 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--caption_column",
|
| 129 |
+
type=str,
|
| 130 |
+
default="text",
|
| 131 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--num_validation_images",
|
| 138 |
+
type=int,
|
| 139 |
+
default=4,
|
| 140 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--validation_epochs",
|
| 144 |
+
type=int,
|
| 145 |
+
default=1,
|
| 146 |
+
help=(
|
| 147 |
+
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
|
| 148 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--max_train_samples",
|
| 153 |
+
type=int,
|
| 154 |
+
default=None,
|
| 155 |
+
help=(
|
| 156 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 157 |
+
"value if set."
|
| 158 |
+
),
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--output_dir",
|
| 162 |
+
type=str,
|
| 163 |
+
default="kandi_2_2-model-finetuned-lora",
|
| 164 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--cache_dir",
|
| 168 |
+
type=str,
|
| 169 |
+
default=None,
|
| 170 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--resolution",
|
| 175 |
+
type=int,
|
| 176 |
+
default=512,
|
| 177 |
+
help=(
|
| 178 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 179 |
+
" resolution"
|
| 180 |
+
),
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--max_train_steps",
|
| 188 |
+
type=int,
|
| 189 |
+
default=None,
|
| 190 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--gradient_accumulation_steps",
|
| 194 |
+
type=int,
|
| 195 |
+
default=1,
|
| 196 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--learning_rate",
|
| 200 |
+
type=float,
|
| 201 |
+
default=1e-4,
|
| 202 |
+
help="learning rate",
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--lr_scheduler",
|
| 206 |
+
type=str,
|
| 207 |
+
default="constant",
|
| 208 |
+
help=(
|
| 209 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 210 |
+
' "constant", "constant_with_warmup"]'
|
| 211 |
+
),
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--snr_gamma",
|
| 218 |
+
type=float,
|
| 219 |
+
default=None,
|
| 220 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 221 |
+
"More details here: https://huggingface.co/papers/2303.09556.",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--allow_tf32",
|
| 228 |
+
action="store_true",
|
| 229 |
+
help=(
|
| 230 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 231 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--dataloader_num_workers",
|
| 236 |
+
type=int,
|
| 237 |
+
default=0,
|
| 238 |
+
help=(
|
| 239 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 240 |
+
),
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 243 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--adam_weight_decay",
|
| 246 |
+
type=float,
|
| 247 |
+
default=0.0,
|
| 248 |
+
required=False,
|
| 249 |
+
help="weight decay_to_use",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 252 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 253 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 254 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--hub_model_id",
|
| 257 |
+
type=str,
|
| 258 |
+
default=None,
|
| 259 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--logging_dir",
|
| 263 |
+
type=str,
|
| 264 |
+
default="logs",
|
| 265 |
+
help=(
|
| 266 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 267 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 268 |
+
),
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--mixed_precision",
|
| 272 |
+
type=str,
|
| 273 |
+
default=None,
|
| 274 |
+
choices=["no", "fp16", "bf16"],
|
| 275 |
+
help=(
|
| 276 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 277 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 278 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 279 |
+
),
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--report_to",
|
| 283 |
+
type=str,
|
| 284 |
+
default="tensorboard",
|
| 285 |
+
help=(
|
| 286 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 287 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
"--checkpointing_steps",
|
| 293 |
+
type=int,
|
| 294 |
+
default=500,
|
| 295 |
+
help=(
|
| 296 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
| 297 |
+
" training using `--resume_from_checkpoint`."
|
| 298 |
+
),
|
| 299 |
+
)
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
"--checkpoints_total_limit",
|
| 302 |
+
type=int,
|
| 303 |
+
default=None,
|
| 304 |
+
help=("Max number of checkpoints to store."),
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--resume_from_checkpoint",
|
| 308 |
+
type=str,
|
| 309 |
+
default=None,
|
| 310 |
+
help=(
|
| 311 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 312 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 313 |
+
),
|
| 314 |
+
)
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
"--rank",
|
| 317 |
+
type=int,
|
| 318 |
+
default=4,
|
| 319 |
+
help=("The dimension of the LoRA update matrices."),
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
args = parser.parse_args()
|
| 323 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 324 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 325 |
+
args.local_rank = env_local_rank
|
| 326 |
+
|
| 327 |
+
# Sanity checks
|
| 328 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 329 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
| 330 |
+
|
| 331 |
+
return args
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
DATASET_NAME_MAPPING = {
|
| 335 |
+
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def main():
|
| 340 |
+
args = parse_args()
|
| 341 |
+
|
| 342 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 343 |
+
raise ValueError(
|
| 344 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 345 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 349 |
+
|
| 350 |
+
accelerator_project_config = ProjectConfiguration(
|
| 351 |
+
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
| 352 |
+
)
|
| 353 |
+
accelerator = Accelerator(
|
| 354 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 355 |
+
mixed_precision=args.mixed_precision,
|
| 356 |
+
log_with=args.report_to,
|
| 357 |
+
project_config=accelerator_project_config,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Disable AMP for MPS.
|
| 361 |
+
if torch.backends.mps.is_available():
|
| 362 |
+
accelerator.native_amp = False
|
| 363 |
+
|
| 364 |
+
if args.report_to == "wandb":
|
| 365 |
+
if not is_wandb_available():
|
| 366 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 367 |
+
import wandb
|
| 368 |
+
|
| 369 |
+
# Make one log on every process with the configuration for debugging.
|
| 370 |
+
logging.basicConfig(
|
| 371 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 372 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 373 |
+
level=logging.INFO,
|
| 374 |
+
)
|
| 375 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 376 |
+
if accelerator.is_local_main_process:
|
| 377 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 378 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 379 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 380 |
+
else:
|
| 381 |
+
datasets.utils.logging.set_verbosity_error()
|
| 382 |
+
transformers.utils.logging.set_verbosity_error()
|
| 383 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 384 |
+
|
| 385 |
+
# If passed along, set the training seed now.
|
| 386 |
+
if args.seed is not None:
|
| 387 |
+
set_seed(args.seed)
|
| 388 |
+
|
| 389 |
+
# Handle the repository creation
|
| 390 |
+
if accelerator.is_main_process:
|
| 391 |
+
if args.output_dir is not None:
|
| 392 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 393 |
+
|
| 394 |
+
if args.push_to_hub:
|
| 395 |
+
repo_id = create_repo(
|
| 396 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 397 |
+
).repo_id
|
| 398 |
+
# Load scheduler, image_processor, tokenizer and models.
|
| 399 |
+
noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
|
| 400 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 401 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_processor"
|
| 402 |
+
)
|
| 403 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
|
| 404 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 405 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
|
| 406 |
+
)
|
| 407 |
+
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
| 408 |
+
args.pretrained_prior_model_name_or_path, subfolder="text_encoder"
|
| 409 |
+
)
|
| 410 |
+
prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
| 411 |
+
# freeze parameters of models to save more memory
|
| 412 |
+
image_encoder.requires_grad_(False)
|
| 413 |
+
prior.requires_grad_(False)
|
| 414 |
+
text_encoder.requires_grad_(False)
|
| 415 |
+
weight_dtype = torch.float32
|
| 416 |
+
if accelerator.mixed_precision == "fp16":
|
| 417 |
+
weight_dtype = torch.float16
|
| 418 |
+
elif accelerator.mixed_precision == "bf16":
|
| 419 |
+
weight_dtype = torch.bfloat16
|
| 420 |
+
|
| 421 |
+
# Move image_encoder, text_encoder and prior to device and cast to weight_dtype
|
| 422 |
+
prior.to(accelerator.device, dtype=weight_dtype)
|
| 423 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 424 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 425 |
+
lora_attn_procs = {}
|
| 426 |
+
for name in prior.attn_processors.keys():
|
| 427 |
+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)
|
| 428 |
+
|
| 429 |
+
prior.set_attn_processor(lora_attn_procs)
|
| 430 |
+
lora_layers = AttnProcsLayers(prior.attn_processors)
|
| 431 |
+
|
| 432 |
+
if args.allow_tf32:
|
| 433 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 434 |
+
|
| 435 |
+
if args.use_8bit_adam:
|
| 436 |
+
try:
|
| 437 |
+
import bitsandbytes as bnb
|
| 438 |
+
except ImportError:
|
| 439 |
+
raise ImportError(
|
| 440 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
| 444 |
+
else:
|
| 445 |
+
optimizer_cls = torch.optim.AdamW
|
| 446 |
+
|
| 447 |
+
optimizer = optimizer_cls(
|
| 448 |
+
lora_layers.parameters(),
|
| 449 |
+
lr=args.learning_rate,
|
| 450 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 451 |
+
weight_decay=args.adam_weight_decay,
|
| 452 |
+
eps=args.adam_epsilon,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
| 456 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
| 457 |
+
|
| 458 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 459 |
+
# download the dataset.
|
| 460 |
+
if args.dataset_name is not None:
|
| 461 |
+
# Downloading and loading a dataset from the hub.
|
| 462 |
+
dataset = load_dataset(
|
| 463 |
+
args.dataset_name,
|
| 464 |
+
args.dataset_config_name,
|
| 465 |
+
cache_dir=args.cache_dir,
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
data_files = {}
|
| 469 |
+
if args.train_data_dir is not None:
|
| 470 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
| 471 |
+
dataset = load_dataset(
|
| 472 |
+
"imagefolder",
|
| 473 |
+
data_files=data_files,
|
| 474 |
+
cache_dir=args.cache_dir,
|
| 475 |
+
)
|
| 476 |
+
# See more about loading custom images at
|
| 477 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
| 478 |
+
|
| 479 |
+
# Preprocessing the datasets.
|
| 480 |
+
# We need to tokenize inputs and targets.
|
| 481 |
+
column_names = dataset["train"].column_names
|
| 482 |
+
|
| 483 |
+
# 6. Get the column names for input/target.
|
| 484 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
| 485 |
+
if args.image_column is None:
|
| 486 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
| 487 |
+
else:
|
| 488 |
+
image_column = args.image_column
|
| 489 |
+
if image_column not in column_names:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
| 492 |
+
)
|
| 493 |
+
if args.caption_column is None:
|
| 494 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
| 495 |
+
else:
|
| 496 |
+
caption_column = args.caption_column
|
| 497 |
+
if caption_column not in column_names:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# Preprocessing the datasets.
|
| 503 |
+
# We need to tokenize input captions and transform the images.
|
| 504 |
+
def tokenize_captions(examples, is_train=True):
|
| 505 |
+
captions = []
|
| 506 |
+
for caption in examples[caption_column]:
|
| 507 |
+
if isinstance(caption, str):
|
| 508 |
+
captions.append(caption)
|
| 509 |
+
elif isinstance(caption, (list, np.ndarray)):
|
| 510 |
+
# take a random caption if there are multiple
|
| 511 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
| 512 |
+
else:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
| 515 |
+
)
|
| 516 |
+
inputs = tokenizer(
|
| 517 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
| 518 |
+
)
|
| 519 |
+
text_input_ids = inputs.input_ids
|
| 520 |
+
text_mask = inputs.attention_mask.bool()
|
| 521 |
+
return text_input_ids, text_mask
|
| 522 |
+
|
| 523 |
+
def preprocess_train(examples):
|
| 524 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
| 525 |
+
examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
|
| 526 |
+
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
|
| 527 |
+
return examples
|
| 528 |
+
|
| 529 |
+
with accelerator.main_process_first():
|
| 530 |
+
if args.max_train_samples is not None:
|
| 531 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 532 |
+
# Set the training transforms
|
| 533 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 534 |
+
|
| 535 |
+
def collate_fn(examples):
|
| 536 |
+
clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
|
| 537 |
+
clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 538 |
+
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
|
| 539 |
+
text_mask = torch.stack([example["text_mask"] for example in examples])
|
| 540 |
+
return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
|
| 541 |
+
|
| 542 |
+
# DataLoaders creation:
|
| 543 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 544 |
+
train_dataset,
|
| 545 |
+
shuffle=True,
|
| 546 |
+
collate_fn=collate_fn,
|
| 547 |
+
batch_size=args.train_batch_size,
|
| 548 |
+
num_workers=args.dataloader_num_workers,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Scheduler and math around the number of training steps.
|
| 552 |
+
overrode_max_train_steps = False
|
| 553 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 554 |
+
if args.max_train_steps is None:
|
| 555 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 556 |
+
overrode_max_train_steps = True
|
| 557 |
+
|
| 558 |
+
lr_scheduler = get_scheduler(
|
| 559 |
+
args.lr_scheduler,
|
| 560 |
+
optimizer=optimizer,
|
| 561 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 562 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 563 |
+
)
|
| 564 |
+
clip_mean = prior.clip_mean.clone()
|
| 565 |
+
clip_std = prior.clip_std.clone()
|
| 566 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 567 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 571 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 572 |
+
if overrode_max_train_steps:
|
| 573 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 574 |
+
# Afterwards we recalculate our number of training epochs
|
| 575 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 576 |
+
|
| 577 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 578 |
+
# The trackers initializes automatically on the main process.
|
| 579 |
+
if accelerator.is_main_process:
|
| 580 |
+
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
| 581 |
+
|
| 582 |
+
# Train!
|
| 583 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 584 |
+
|
| 585 |
+
logger.info("***** Running training *****")
|
| 586 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 587 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 588 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 589 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 590 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 591 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 592 |
+
global_step = 0
|
| 593 |
+
first_epoch = 0
|
| 594 |
+
|
| 595 |
+
# Potentially load in the weights and states from a previous save
|
| 596 |
+
if args.resume_from_checkpoint:
|
| 597 |
+
if args.resume_from_checkpoint != "latest":
|
| 598 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 599 |
+
else:
|
| 600 |
+
# Get the most recent checkpoint
|
| 601 |
+
dirs = os.listdir(args.output_dir)
|
| 602 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 603 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 604 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 605 |
+
|
| 606 |
+
if path is None:
|
| 607 |
+
accelerator.print(
|
| 608 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 609 |
+
)
|
| 610 |
+
args.resume_from_checkpoint = None
|
| 611 |
+
initial_global_step = 0
|
| 612 |
+
else:
|
| 613 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 614 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 615 |
+
global_step = int(path.split("-")[1])
|
| 616 |
+
|
| 617 |
+
initial_global_step = global_step
|
| 618 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 619 |
+
|
| 620 |
+
else:
|
| 621 |
+
initial_global_step = 0
|
| 622 |
+
|
| 623 |
+
progress_bar = tqdm(
|
| 624 |
+
range(0, args.max_train_steps),
|
| 625 |
+
initial=initial_global_step,
|
| 626 |
+
desc="Steps",
|
| 627 |
+
# Only show the progress bar once on each machine.
|
| 628 |
+
disable=not accelerator.is_local_main_process,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
|
| 632 |
+
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
|
| 633 |
+
|
| 634 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 635 |
+
prior.train()
|
| 636 |
+
train_loss = 0.0
|
| 637 |
+
for step, batch in enumerate(train_dataloader):
|
| 638 |
+
with accelerator.accumulate(prior):
|
| 639 |
+
# Convert images to latent space
|
| 640 |
+
text_input_ids, text_mask, clip_images = (
|
| 641 |
+
batch["text_input_ids"],
|
| 642 |
+
batch["text_mask"],
|
| 643 |
+
batch["clip_pixel_values"].to(weight_dtype),
|
| 644 |
+
)
|
| 645 |
+
with torch.no_grad():
|
| 646 |
+
text_encoder_output = text_encoder(text_input_ids)
|
| 647 |
+
prompt_embeds = text_encoder_output.text_embeds
|
| 648 |
+
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
| 649 |
+
|
| 650 |
+
image_embeds = image_encoder(clip_images).image_embeds
|
| 651 |
+
# Sample noise that we'll add to the image_embeds
|
| 652 |
+
noise = torch.randn_like(image_embeds)
|
| 653 |
+
bsz = image_embeds.shape[0]
|
| 654 |
+
# Sample a random timestep for each image
|
| 655 |
+
timesteps = torch.randint(
|
| 656 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
|
| 657 |
+
)
|
| 658 |
+
timesteps = timesteps.long()
|
| 659 |
+
image_embeds = (image_embeds - clip_mean) / clip_std
|
| 660 |
+
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
| 661 |
+
|
| 662 |
+
target = image_embeds
|
| 663 |
+
|
| 664 |
+
# Predict the noise residual and compute loss
|
| 665 |
+
model_pred = prior(
|
| 666 |
+
noisy_latents,
|
| 667 |
+
timestep=timesteps,
|
| 668 |
+
proj_embedding=prompt_embeds,
|
| 669 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
| 670 |
+
attention_mask=text_mask,
|
| 671 |
+
).predicted_image_embedding
|
| 672 |
+
|
| 673 |
+
if args.snr_gamma is None:
|
| 674 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 675 |
+
else:
|
| 676 |
+
# Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
|
| 677 |
+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
| 678 |
+
# This is discussed in Section 4.2 of the same paper.
|
| 679 |
+
snr = compute_snr(noise_scheduler, timesteps)
|
| 680 |
+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
| 681 |
+
dim=1
|
| 682 |
+
)[0]
|
| 683 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 684 |
+
mse_loss_weights = mse_loss_weights / snr
|
| 685 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 686 |
+
mse_loss_weights = mse_loss_weights / (snr + 1)
|
| 687 |
+
|
| 688 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 689 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
| 690 |
+
loss = loss.mean()
|
| 691 |
+
|
| 692 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 693 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 694 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 695 |
+
|
| 696 |
+
# Backpropagate
|
| 697 |
+
accelerator.backward(loss)
|
| 698 |
+
if accelerator.sync_gradients:
|
| 699 |
+
accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
|
| 700 |
+
optimizer.step()
|
| 701 |
+
lr_scheduler.step()
|
| 702 |
+
optimizer.zero_grad()
|
| 703 |
+
|
| 704 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 705 |
+
if accelerator.sync_gradients:
|
| 706 |
+
progress_bar.update(1)
|
| 707 |
+
global_step += 1
|
| 708 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 709 |
+
train_loss = 0.0
|
| 710 |
+
|
| 711 |
+
if global_step % args.checkpointing_steps == 0:
|
| 712 |
+
if accelerator.is_main_process:
|
| 713 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 714 |
+
if args.checkpoints_total_limit is not None:
|
| 715 |
+
checkpoints = os.listdir(args.output_dir)
|
| 716 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 717 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 718 |
+
|
| 719 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 720 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 721 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 722 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 723 |
+
|
| 724 |
+
logger.info(
|
| 725 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 726 |
+
)
|
| 727 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 728 |
+
|
| 729 |
+
for removing_checkpoint in removing_checkpoints:
|
| 730 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 731 |
+
shutil.rmtree(removing_checkpoint)
|
| 732 |
+
|
| 733 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 734 |
+
accelerator.save_state(save_path)
|
| 735 |
+
logger.info(f"Saved state to {save_path}")
|
| 736 |
+
|
| 737 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 738 |
+
progress_bar.set_postfix(**logs)
|
| 739 |
+
|
| 740 |
+
if global_step >= args.max_train_steps:
|
| 741 |
+
break
|
| 742 |
+
|
| 743 |
+
if accelerator.is_main_process:
|
| 744 |
+
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
| 745 |
+
logger.info(
|
| 746 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
| 747 |
+
f" {args.validation_prompt}."
|
| 748 |
+
)
|
| 749 |
+
# create pipeline
|
| 750 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 751 |
+
args.pretrained_decoder_model_name_or_path,
|
| 752 |
+
prior_prior=accelerator.unwrap_model(prior),
|
| 753 |
+
torch_dtype=weight_dtype,
|
| 754 |
+
)
|
| 755 |
+
pipeline = pipeline.to(accelerator.device)
|
| 756 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 757 |
+
|
| 758 |
+
# run inference
|
| 759 |
+
generator = torch.Generator(device=accelerator.device)
|
| 760 |
+
if args.seed is not None:
|
| 761 |
+
generator = generator.manual_seed(args.seed)
|
| 762 |
+
images = []
|
| 763 |
+
for _ in range(args.num_validation_images):
|
| 764 |
+
images.append(
|
| 765 |
+
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
for tracker in accelerator.trackers:
|
| 769 |
+
if tracker.name == "tensorboard":
|
| 770 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 771 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
| 772 |
+
if tracker.name == "wandb":
|
| 773 |
+
tracker.log(
|
| 774 |
+
{
|
| 775 |
+
"validation": [
|
| 776 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 777 |
+
for i, image in enumerate(images)
|
| 778 |
+
]
|
| 779 |
+
}
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
del pipeline
|
| 783 |
+
torch.cuda.empty_cache()
|
| 784 |
+
|
| 785 |
+
# Save the lora layers
|
| 786 |
+
accelerator.wait_for_everyone()
|
| 787 |
+
if accelerator.is_main_process:
|
| 788 |
+
prior = prior.to(torch.float32)
|
| 789 |
+
prior.save_attn_procs(args.output_dir)
|
| 790 |
+
|
| 791 |
+
if args.push_to_hub:
|
| 792 |
+
save_model_card(
|
| 793 |
+
repo_id,
|
| 794 |
+
images=images,
|
| 795 |
+
base_model=args.pretrained_prior_model_name_or_path,
|
| 796 |
+
dataset_name=args.dataset_name,
|
| 797 |
+
repo_folder=args.output_dir,
|
| 798 |
+
)
|
| 799 |
+
upload_folder(
|
| 800 |
+
repo_id=repo_id,
|
| 801 |
+
folder_path=args.output_dir,
|
| 802 |
+
commit_message="End of training",
|
| 803 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# Final inference
|
| 807 |
+
# Load previous pipeline
|
| 808 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 809 |
+
args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
|
| 810 |
+
)
|
| 811 |
+
pipeline = pipeline.to(accelerator.device)
|
| 812 |
+
|
| 813 |
+
# load attention processors
|
| 814 |
+
pipeline.prior_prior.load_attn_procs(args.output_dir)
|
| 815 |
+
|
| 816 |
+
# run inference
|
| 817 |
+
generator = torch.Generator(device=accelerator.device)
|
| 818 |
+
if args.seed is not None:
|
| 819 |
+
generator = generator.manual_seed(args.seed)
|
| 820 |
+
images = []
|
| 821 |
+
for _ in range(args.num_validation_images):
|
| 822 |
+
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
| 823 |
+
|
| 824 |
+
if accelerator.is_main_process:
|
| 825 |
+
for tracker in accelerator.trackers:
|
| 826 |
+
if len(images) != 0:
|
| 827 |
+
if tracker.name == "tensorboard":
|
| 828 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 829 |
+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
| 830 |
+
if tracker.name == "wandb":
|
| 831 |
+
tracker.log(
|
| 832 |
+
{
|
| 833 |
+
"test": [
|
| 834 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 835 |
+
for i, image in enumerate(images)
|
| 836 |
+
]
|
| 837 |
+
}
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
accelerator.end_training()
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
if __name__ == "__main__":
|
| 844 |
+
main()
|
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
ADDED
|
@@ -0,0 +1,958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import logging
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import shutil
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import accelerate
|
| 25 |
+
import datasets
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.utils.checkpoint
|
| 30 |
+
import transformers
|
| 31 |
+
from accelerate import Accelerator
|
| 32 |
+
from accelerate.logging import get_logger
|
| 33 |
+
from accelerate.state import AcceleratorState
|
| 34 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
from huggingface_hub import create_repo, upload_folder
|
| 37 |
+
from packaging import version
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 40 |
+
from transformers.utils import ContextManagers
|
| 41 |
+
|
| 42 |
+
import diffusers
|
| 43 |
+
from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
|
| 44 |
+
from diffusers.optimization import get_scheduler
|
| 45 |
+
from diffusers.training_utils import EMAModel, compute_snr
|
| 46 |
+
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_wandb_available():
|
| 50 |
+
import wandb
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 54 |
+
check_min_version("0.35.0.dev0")
|
| 55 |
+
|
| 56 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 57 |
+
|
| 58 |
+
DATASET_NAME_MAPPING = {
|
| 59 |
+
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_model_card(
|
| 64 |
+
args,
|
| 65 |
+
repo_id: str,
|
| 66 |
+
images=None,
|
| 67 |
+
repo_folder=None,
|
| 68 |
+
):
|
| 69 |
+
img_str = ""
|
| 70 |
+
if len(images) > 0:
|
| 71 |
+
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
|
| 72 |
+
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
| 73 |
+
img_str += "\n"
|
| 74 |
+
|
| 75 |
+
yaml = f"""
|
| 76 |
+
---
|
| 77 |
+
license: creativeml-openrail-m
|
| 78 |
+
base_model: {args.pretrained_prior_model_name_or_path}
|
| 79 |
+
datasets:
|
| 80 |
+
- {args.dataset_name}
|
| 81 |
+
tags:
|
| 82 |
+
- kandinsky
|
| 83 |
+
- text-to-image
|
| 84 |
+
- diffusers
|
| 85 |
+
- diffusers-training
|
| 86 |
+
inference: true
|
| 87 |
+
---
|
| 88 |
+
"""
|
| 89 |
+
model_card = f"""
|
| 90 |
+
# Finetuning - {repo_id}
|
| 91 |
+
|
| 92 |
+
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
| 93 |
+
{img_str}
|
| 94 |
+
|
| 95 |
+
## Pipeline usage
|
| 96 |
+
|
| 97 |
+
You can use the pipeline like so:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
from diffusers import DiffusionPipeline
|
| 101 |
+
import torch
|
| 102 |
+
|
| 103 |
+
pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
|
| 104 |
+
pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype=torch.float16)
|
| 105 |
+
prompt = "{args.validation_prompts[0]}"
|
| 106 |
+
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
|
| 107 |
+
image = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]
|
| 108 |
+
image.save("my_image.png")
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## Training info
|
| 112 |
+
|
| 113 |
+
These are the key hyperparameters used during training:
|
| 114 |
+
|
| 115 |
+
* Epochs: {args.num_train_epochs}
|
| 116 |
+
* Learning rate: {args.learning_rate}
|
| 117 |
+
* Batch size: {args.train_batch_size}
|
| 118 |
+
* Gradient accumulation steps: {args.gradient_accumulation_steps}
|
| 119 |
+
* Image resolution: {args.resolution}
|
| 120 |
+
* Mixed-precision: {args.mixed_precision}
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
wandb_info = ""
|
| 124 |
+
if is_wandb_available():
|
| 125 |
+
wandb_run_url = None
|
| 126 |
+
if wandb.run is not None:
|
| 127 |
+
wandb_run_url = wandb.run.url
|
| 128 |
+
|
| 129 |
+
if wandb_run_url is not None:
|
| 130 |
+
wandb_info = f"""
|
| 131 |
+
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
model_card += wandb_info
|
| 135 |
+
|
| 136 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 137 |
+
f.write(yaml + model_card)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def log_validation(
|
| 141 |
+
image_encoder, image_processor, text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch
|
| 142 |
+
):
|
| 143 |
+
logger.info("Running validation... ")
|
| 144 |
+
|
| 145 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 146 |
+
args.pretrained_decoder_model_name_or_path,
|
| 147 |
+
prior_image_encoder=accelerator.unwrap_model(image_encoder),
|
| 148 |
+
prior_image_processor=image_processor,
|
| 149 |
+
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
| 150 |
+
prior_tokenizer=tokenizer,
|
| 151 |
+
prior_prior=accelerator.unwrap_model(prior),
|
| 152 |
+
torch_dtype=weight_dtype,
|
| 153 |
+
)
|
| 154 |
+
pipeline = pipeline.to(accelerator.device)
|
| 155 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 156 |
+
|
| 157 |
+
if args.seed is None:
|
| 158 |
+
generator = None
|
| 159 |
+
else:
|
| 160 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 161 |
+
|
| 162 |
+
images = []
|
| 163 |
+
for i in range(len(args.validation_prompts)):
|
| 164 |
+
with torch.autocast("cuda"):
|
| 165 |
+
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
| 166 |
+
|
| 167 |
+
images.append(image)
|
| 168 |
+
|
| 169 |
+
for tracker in accelerator.trackers:
|
| 170 |
+
if tracker.name == "tensorboard":
|
| 171 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 172 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
| 173 |
+
elif tracker.name == "wandb":
|
| 174 |
+
tracker.log(
|
| 175 |
+
{
|
| 176 |
+
"validation": [
|
| 177 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
| 178 |
+
for i, image in enumerate(images)
|
| 179 |
+
]
|
| 180 |
+
}
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
logger.warning(f"image logging not implemented for {tracker.name}")
|
| 184 |
+
|
| 185 |
+
del pipeline
|
| 186 |
+
torch.cuda.empty_cache()
|
| 187 |
+
|
| 188 |
+
return images
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def parse_args():
|
| 192 |
+
parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--pretrained_decoder_model_name_or_path",
|
| 195 |
+
type=str,
|
| 196 |
+
default="kandinsky-community/kandinsky-2-2-decoder",
|
| 197 |
+
required=False,
|
| 198 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--pretrained_prior_model_name_or_path",
|
| 202 |
+
type=str,
|
| 203 |
+
default="kandinsky-community/kandinsky-2-2-prior",
|
| 204 |
+
required=False,
|
| 205 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--dataset_name",
|
| 209 |
+
type=str,
|
| 210 |
+
default=None,
|
| 211 |
+
help=(
|
| 212 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 213 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 214 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 215 |
+
),
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--dataset_config_name",
|
| 219 |
+
type=str,
|
| 220 |
+
default=None,
|
| 221 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--train_data_dir",
|
| 225 |
+
type=str,
|
| 226 |
+
default=None,
|
| 227 |
+
help=(
|
| 228 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
| 229 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
| 230 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
| 231 |
+
),
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--caption_column",
|
| 238 |
+
type=str,
|
| 239 |
+
default="text",
|
| 240 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--max_train_samples",
|
| 244 |
+
type=int,
|
| 245 |
+
default=None,
|
| 246 |
+
help=(
|
| 247 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 248 |
+
"value if set."
|
| 249 |
+
),
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--validation_prompts",
|
| 253 |
+
type=str,
|
| 254 |
+
default=None,
|
| 255 |
+
nargs="+",
|
| 256 |
+
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--output_dir",
|
| 260 |
+
type=str,
|
| 261 |
+
default="kandi_2_2-model-finetuned",
|
| 262 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--cache_dir",
|
| 266 |
+
type=str,
|
| 267 |
+
default=None,
|
| 268 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--resolution",
|
| 273 |
+
type=int,
|
| 274 |
+
default=512,
|
| 275 |
+
help=(
|
| 276 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 277 |
+
" resolution"
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
"--max_train_steps",
|
| 286 |
+
type=int,
|
| 287 |
+
default=None,
|
| 288 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--gradient_accumulation_steps",
|
| 292 |
+
type=int,
|
| 293 |
+
default=1,
|
| 294 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--learning_rate",
|
| 298 |
+
type=float,
|
| 299 |
+
default=1e-4,
|
| 300 |
+
help="learning rate",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--lr_scheduler",
|
| 304 |
+
type=str,
|
| 305 |
+
default="constant",
|
| 306 |
+
help=(
|
| 307 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 308 |
+
' "constant", "constant_with_warmup"]'
|
| 309 |
+
),
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--snr_gamma",
|
| 316 |
+
type=float,
|
| 317 |
+
default=None,
|
| 318 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 319 |
+
"More details here: https://huggingface.co/papers/2303.09556.",
|
| 320 |
+
)
|
| 321 |
+
parser.add_argument(
|
| 322 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--allow_tf32",
|
| 326 |
+
action="store_true",
|
| 327 |
+
help=(
|
| 328 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 329 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 330 |
+
),
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--dataloader_num_workers",
|
| 335 |
+
type=int,
|
| 336 |
+
default=0,
|
| 337 |
+
help=(
|
| 338 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 339 |
+
),
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 342 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--adam_weight_decay",
|
| 345 |
+
type=float,
|
| 346 |
+
default=0.0,
|
| 347 |
+
required=False,
|
| 348 |
+
help="weight decay_to_use",
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 351 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 352 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 353 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--hub_model_id",
|
| 356 |
+
type=str,
|
| 357 |
+
default=None,
|
| 358 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--logging_dir",
|
| 362 |
+
type=str,
|
| 363 |
+
default="logs",
|
| 364 |
+
help=(
|
| 365 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 366 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 367 |
+
),
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--mixed_precision",
|
| 371 |
+
type=str,
|
| 372 |
+
default=None,
|
| 373 |
+
choices=["no", "fp16", "bf16"],
|
| 374 |
+
help=(
|
| 375 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 376 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 377 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 378 |
+
),
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--report_to",
|
| 382 |
+
type=str,
|
| 383 |
+
default="tensorboard",
|
| 384 |
+
help=(
|
| 385 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 386 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 390 |
+
parser.add_argument(
|
| 391 |
+
"--checkpointing_steps",
|
| 392 |
+
type=int,
|
| 393 |
+
default=500,
|
| 394 |
+
help=(
|
| 395 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
| 396 |
+
" training using `--resume_from_checkpoint`."
|
| 397 |
+
),
|
| 398 |
+
)
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
"--checkpoints_total_limit",
|
| 401 |
+
type=int,
|
| 402 |
+
default=None,
|
| 403 |
+
help=("Max number of checkpoints to store."),
|
| 404 |
+
)
|
| 405 |
+
parser.add_argument(
|
| 406 |
+
"--resume_from_checkpoint",
|
| 407 |
+
type=str,
|
| 408 |
+
default=None,
|
| 409 |
+
help=(
|
| 410 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 411 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 412 |
+
),
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--validation_epochs",
|
| 416 |
+
type=int,
|
| 417 |
+
default=5,
|
| 418 |
+
help="Run validation every X epochs.",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--tracker_project_name",
|
| 422 |
+
type=str,
|
| 423 |
+
default="text2image-fine-tune",
|
| 424 |
+
help=(
|
| 425 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
| 426 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
| 427 |
+
),
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
args = parser.parse_args()
|
| 431 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 432 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 433 |
+
args.local_rank = env_local_rank
|
| 434 |
+
|
| 435 |
+
# Sanity checks
|
| 436 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 437 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
| 438 |
+
|
| 439 |
+
return args
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def main():
|
| 443 |
+
args = parse_args()
|
| 444 |
+
|
| 445 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 446 |
+
raise ValueError(
|
| 447 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 448 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
| 452 |
+
accelerator_project_config = ProjectConfiguration(
|
| 453 |
+
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
| 454 |
+
)
|
| 455 |
+
accelerator = Accelerator(
|
| 456 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 457 |
+
mixed_precision=args.mixed_precision,
|
| 458 |
+
log_with=args.report_to,
|
| 459 |
+
project_config=accelerator_project_config,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Disable AMP for MPS.
|
| 463 |
+
if torch.backends.mps.is_available():
|
| 464 |
+
accelerator.native_amp = False
|
| 465 |
+
|
| 466 |
+
# Make one log on every process with the configuration for debugging.
|
| 467 |
+
logging.basicConfig(
|
| 468 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 469 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 470 |
+
level=logging.INFO,
|
| 471 |
+
)
|
| 472 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 473 |
+
if accelerator.is_local_main_process:
|
| 474 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 475 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 476 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 477 |
+
else:
|
| 478 |
+
datasets.utils.logging.set_verbosity_error()
|
| 479 |
+
transformers.utils.logging.set_verbosity_error()
|
| 480 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 481 |
+
|
| 482 |
+
# If passed along, set the training seed now.
|
| 483 |
+
if args.seed is not None:
|
| 484 |
+
set_seed(args.seed)
|
| 485 |
+
|
| 486 |
+
# Handle the repository creation
|
| 487 |
+
if accelerator.is_main_process:
|
| 488 |
+
if args.output_dir is not None:
|
| 489 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 490 |
+
|
| 491 |
+
if args.push_to_hub:
|
| 492 |
+
repo_id = create_repo(
|
| 493 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 494 |
+
).repo_id
|
| 495 |
+
|
| 496 |
+
# Load scheduler, image_processor, tokenizer and models.
|
| 497 |
+
noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
|
| 498 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 499 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_processor"
|
| 500 |
+
)
|
| 501 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
|
| 502 |
+
|
| 503 |
+
def deepspeed_zero_init_disabled_context_manager():
|
| 504 |
+
"""
|
| 505 |
+
returns either a context list that includes one that will disable zero.Init or an empty context list
|
| 506 |
+
"""
|
| 507 |
+
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
|
| 508 |
+
if deepspeed_plugin is None:
|
| 509 |
+
return []
|
| 510 |
+
|
| 511 |
+
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
| 512 |
+
|
| 513 |
+
weight_dtype = torch.float32
|
| 514 |
+
if accelerator.mixed_precision == "fp16":
|
| 515 |
+
weight_dtype = torch.float16
|
| 516 |
+
elif accelerator.mixed_precision == "bf16":
|
| 517 |
+
weight_dtype = torch.bfloat16
|
| 518 |
+
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
| 519 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 520 |
+
args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
|
| 521 |
+
).eval()
|
| 522 |
+
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
| 523 |
+
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
|
| 524 |
+
).eval()
|
| 525 |
+
|
| 526 |
+
prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
| 527 |
+
|
| 528 |
+
# Freeze text_encoder and image_encoder
|
| 529 |
+
text_encoder.requires_grad_(False)
|
| 530 |
+
image_encoder.requires_grad_(False)
|
| 531 |
+
|
| 532 |
+
# Set prior to trainable.
|
| 533 |
+
prior.train()
|
| 534 |
+
|
| 535 |
+
# Create EMA for the prior.
|
| 536 |
+
if args.use_ema:
|
| 537 |
+
ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
| 538 |
+
ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
|
| 539 |
+
ema_prior.to(accelerator.device)
|
| 540 |
+
|
| 541 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
| 542 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
| 543 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 544 |
+
def save_model_hook(models, weights, output_dir):
|
| 545 |
+
if args.use_ema:
|
| 546 |
+
ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
|
| 547 |
+
|
| 548 |
+
for i, model in enumerate(models):
|
| 549 |
+
model.save_pretrained(os.path.join(output_dir, "prior"))
|
| 550 |
+
|
| 551 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 552 |
+
weights.pop()
|
| 553 |
+
|
| 554 |
+
def load_model_hook(models, input_dir):
|
| 555 |
+
if args.use_ema:
|
| 556 |
+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), PriorTransformer)
|
| 557 |
+
ema_prior.load_state_dict(load_model.state_dict())
|
| 558 |
+
ema_prior.to(accelerator.device)
|
| 559 |
+
del load_model
|
| 560 |
+
|
| 561 |
+
for i in range(len(models)):
|
| 562 |
+
# pop models so that they are not loaded again
|
| 563 |
+
model = models.pop()
|
| 564 |
+
|
| 565 |
+
# load diffusers style into model
|
| 566 |
+
load_model = PriorTransformer.from_pretrained(input_dir, subfolder="prior")
|
| 567 |
+
model.register_to_config(**load_model.config)
|
| 568 |
+
|
| 569 |
+
model.load_state_dict(load_model.state_dict())
|
| 570 |
+
del load_model
|
| 571 |
+
|
| 572 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 573 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 574 |
+
|
| 575 |
+
if args.allow_tf32:
|
| 576 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 577 |
+
|
| 578 |
+
if args.use_8bit_adam:
|
| 579 |
+
try:
|
| 580 |
+
import bitsandbytes as bnb
|
| 581 |
+
except ImportError:
|
| 582 |
+
raise ImportError(
|
| 583 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
| 587 |
+
else:
|
| 588 |
+
optimizer_cls = torch.optim.AdamW
|
| 589 |
+
optimizer = optimizer_cls(
|
| 590 |
+
prior.parameters(),
|
| 591 |
+
lr=args.learning_rate,
|
| 592 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 593 |
+
weight_decay=args.adam_weight_decay,
|
| 594 |
+
eps=args.adam_epsilon,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
| 598 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
| 599 |
+
|
| 600 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 601 |
+
# download the dataset.
|
| 602 |
+
if args.dataset_name is not None:
|
| 603 |
+
# Downloading and loading a dataset from the hub.
|
| 604 |
+
dataset = load_dataset(
|
| 605 |
+
args.dataset_name,
|
| 606 |
+
args.dataset_config_name,
|
| 607 |
+
cache_dir=args.cache_dir,
|
| 608 |
+
)
|
| 609 |
+
else:
|
| 610 |
+
data_files = {}
|
| 611 |
+
if args.train_data_dir is not None:
|
| 612 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
| 613 |
+
dataset = load_dataset(
|
| 614 |
+
"imagefolder",
|
| 615 |
+
data_files=data_files,
|
| 616 |
+
cache_dir=args.cache_dir,
|
| 617 |
+
)
|
| 618 |
+
# See more about loading custom images at
|
| 619 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
| 620 |
+
|
| 621 |
+
# Preprocessing the datasets.
|
| 622 |
+
# We need to tokenize inputs and targets.
|
| 623 |
+
column_names = dataset["train"].column_names
|
| 624 |
+
|
| 625 |
+
# 6. Get the column names for input/target.
|
| 626 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
| 627 |
+
if args.image_column is None:
|
| 628 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
| 629 |
+
else:
|
| 630 |
+
image_column = args.image_column
|
| 631 |
+
if image_column not in column_names:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
| 634 |
+
)
|
| 635 |
+
if args.caption_column is None:
|
| 636 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
| 637 |
+
else:
|
| 638 |
+
caption_column = args.caption_column
|
| 639 |
+
if caption_column not in column_names:
|
| 640 |
+
raise ValueError(
|
| 641 |
+
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Preprocessing the datasets.
|
| 645 |
+
# We need to tokenize input captions and transform the images.
|
| 646 |
+
def tokenize_captions(examples, is_train=True):
|
| 647 |
+
captions = []
|
| 648 |
+
for caption in examples[caption_column]:
|
| 649 |
+
if isinstance(caption, str):
|
| 650 |
+
captions.append(caption)
|
| 651 |
+
elif isinstance(caption, (list, np.ndarray)):
|
| 652 |
+
# take a random caption if there are multiple
|
| 653 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
| 654 |
+
else:
|
| 655 |
+
raise ValueError(
|
| 656 |
+
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
| 657 |
+
)
|
| 658 |
+
inputs = tokenizer(
|
| 659 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
| 660 |
+
)
|
| 661 |
+
text_input_ids = inputs.input_ids
|
| 662 |
+
text_mask = inputs.attention_mask.bool()
|
| 663 |
+
return text_input_ids, text_mask
|
| 664 |
+
|
| 665 |
+
def preprocess_train(examples):
|
| 666 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
| 667 |
+
examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
|
| 668 |
+
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
|
| 669 |
+
return examples
|
| 670 |
+
|
| 671 |
+
with accelerator.main_process_first():
|
| 672 |
+
if args.max_train_samples is not None:
|
| 673 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 674 |
+
# Set the training transforms
|
| 675 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 676 |
+
|
| 677 |
+
def collate_fn(examples):
|
| 678 |
+
clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
|
| 679 |
+
clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 680 |
+
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
|
| 681 |
+
text_mask = torch.stack([example["text_mask"] for example in examples])
|
| 682 |
+
return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
|
| 683 |
+
|
| 684 |
+
# DataLoaders creation:
|
| 685 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 686 |
+
train_dataset,
|
| 687 |
+
shuffle=True,
|
| 688 |
+
collate_fn=collate_fn,
|
| 689 |
+
batch_size=args.train_batch_size,
|
| 690 |
+
num_workers=args.dataloader_num_workers,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# Scheduler and math around the number of training steps.
|
| 694 |
+
overrode_max_train_steps = False
|
| 695 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 696 |
+
if args.max_train_steps is None:
|
| 697 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 698 |
+
overrode_max_train_steps = True
|
| 699 |
+
|
| 700 |
+
lr_scheduler = get_scheduler(
|
| 701 |
+
args.lr_scheduler,
|
| 702 |
+
optimizer=optimizer,
|
| 703 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 704 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
clip_mean = prior.clip_mean.clone()
|
| 708 |
+
clip_std = prior.clip_std.clone()
|
| 709 |
+
|
| 710 |
+
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 711 |
+
prior, optimizer, train_dataloader, lr_scheduler
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 715 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 716 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 717 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 718 |
+
if overrode_max_train_steps:
|
| 719 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 720 |
+
# Afterwards we recalculate our number of training epochs
|
| 721 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 722 |
+
|
| 723 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 724 |
+
# The trackers initializes automatically on the main process.
|
| 725 |
+
if accelerator.is_main_process:
|
| 726 |
+
tracker_config = dict(vars(args))
|
| 727 |
+
tracker_config.pop("validation_prompts")
|
| 728 |
+
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
| 729 |
+
|
| 730 |
+
# Train!
|
| 731 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 732 |
+
|
| 733 |
+
logger.info("***** Running training *****")
|
| 734 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 735 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 736 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 737 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 738 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 739 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 740 |
+
global_step = 0
|
| 741 |
+
first_epoch = 0
|
| 742 |
+
|
| 743 |
+
# Potentially load in the weights and states from a previous save
|
| 744 |
+
if args.resume_from_checkpoint:
|
| 745 |
+
if args.resume_from_checkpoint != "latest":
|
| 746 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 747 |
+
else:
|
| 748 |
+
# Get the most recent checkpoint
|
| 749 |
+
dirs = os.listdir(args.output_dir)
|
| 750 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 751 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 752 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 753 |
+
|
| 754 |
+
if path is None:
|
| 755 |
+
accelerator.print(
|
| 756 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 757 |
+
)
|
| 758 |
+
args.resume_from_checkpoint = None
|
| 759 |
+
initial_global_step = 0
|
| 760 |
+
else:
|
| 761 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 762 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 763 |
+
global_step = int(path.split("-")[1])
|
| 764 |
+
|
| 765 |
+
initial_global_step = global_step
|
| 766 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 767 |
+
else:
|
| 768 |
+
initial_global_step = 0
|
| 769 |
+
|
| 770 |
+
progress_bar = tqdm(
|
| 771 |
+
range(0, args.max_train_steps),
|
| 772 |
+
initial=initial_global_step,
|
| 773 |
+
desc="Steps",
|
| 774 |
+
# Only show the progress bar once on each machine.
|
| 775 |
+
disable=not accelerator.is_local_main_process,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
|
| 779 |
+
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
|
| 780 |
+
|
| 781 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 782 |
+
train_loss = 0.0
|
| 783 |
+
for step, batch in enumerate(train_dataloader):
|
| 784 |
+
with accelerator.accumulate(prior):
|
| 785 |
+
# Convert images to latent space
|
| 786 |
+
text_input_ids, text_mask, clip_images = (
|
| 787 |
+
batch["text_input_ids"],
|
| 788 |
+
batch["text_mask"],
|
| 789 |
+
batch["clip_pixel_values"].to(weight_dtype),
|
| 790 |
+
)
|
| 791 |
+
with torch.no_grad():
|
| 792 |
+
text_encoder_output = text_encoder(text_input_ids)
|
| 793 |
+
prompt_embeds = text_encoder_output.text_embeds
|
| 794 |
+
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
| 795 |
+
|
| 796 |
+
image_embeds = image_encoder(clip_images).image_embeds
|
| 797 |
+
# Sample noise that we'll add to the image_embeds
|
| 798 |
+
noise = torch.randn_like(image_embeds)
|
| 799 |
+
bsz = image_embeds.shape[0]
|
| 800 |
+
# Sample a random timestep for each image
|
| 801 |
+
timesteps = torch.randint(
|
| 802 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
|
| 803 |
+
)
|
| 804 |
+
timesteps = timesteps.long()
|
| 805 |
+
image_embeds = (image_embeds - clip_mean) / clip_std
|
| 806 |
+
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
| 807 |
+
|
| 808 |
+
target = image_embeds
|
| 809 |
+
|
| 810 |
+
# Predict the noise residual and compute loss
|
| 811 |
+
model_pred = prior(
|
| 812 |
+
noisy_latents,
|
| 813 |
+
timestep=timesteps,
|
| 814 |
+
proj_embedding=prompt_embeds,
|
| 815 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
| 816 |
+
attention_mask=text_mask,
|
| 817 |
+
).predicted_image_embedding
|
| 818 |
+
|
| 819 |
+
if args.snr_gamma is None:
|
| 820 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 821 |
+
else:
|
| 822 |
+
# Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
|
| 823 |
+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
| 824 |
+
# This is discussed in Section 4.2 of the same paper.
|
| 825 |
+
snr = compute_snr(noise_scheduler, timesteps)
|
| 826 |
+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
| 827 |
+
dim=1
|
| 828 |
+
)[0]
|
| 829 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 830 |
+
mse_loss_weights = mse_loss_weights / snr
|
| 831 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 832 |
+
mse_loss_weights = mse_loss_weights / (snr + 1)
|
| 833 |
+
|
| 834 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 835 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
| 836 |
+
loss = loss.mean()
|
| 837 |
+
|
| 838 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 839 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 840 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 841 |
+
|
| 842 |
+
# Backpropagate
|
| 843 |
+
accelerator.backward(loss)
|
| 844 |
+
if accelerator.sync_gradients:
|
| 845 |
+
accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
|
| 846 |
+
optimizer.step()
|
| 847 |
+
lr_scheduler.step()
|
| 848 |
+
optimizer.zero_grad()
|
| 849 |
+
|
| 850 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 851 |
+
if accelerator.sync_gradients:
|
| 852 |
+
if args.use_ema:
|
| 853 |
+
ema_prior.step(prior.parameters())
|
| 854 |
+
progress_bar.update(1)
|
| 855 |
+
global_step += 1
|
| 856 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 857 |
+
train_loss = 0.0
|
| 858 |
+
|
| 859 |
+
if global_step % args.checkpointing_steps == 0:
|
| 860 |
+
if accelerator.is_main_process:
|
| 861 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 862 |
+
if args.checkpoints_total_limit is not None:
|
| 863 |
+
checkpoints = os.listdir(args.output_dir)
|
| 864 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 865 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 866 |
+
|
| 867 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 868 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 869 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 870 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 871 |
+
|
| 872 |
+
logger.info(
|
| 873 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 874 |
+
)
|
| 875 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 876 |
+
|
| 877 |
+
for removing_checkpoint in removing_checkpoints:
|
| 878 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 879 |
+
shutil.rmtree(removing_checkpoint)
|
| 880 |
+
|
| 881 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 882 |
+
accelerator.save_state(save_path)
|
| 883 |
+
logger.info(f"Saved state to {save_path}")
|
| 884 |
+
|
| 885 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 886 |
+
progress_bar.set_postfix(**logs)
|
| 887 |
+
|
| 888 |
+
if global_step >= args.max_train_steps:
|
| 889 |
+
break
|
| 890 |
+
|
| 891 |
+
if accelerator.is_main_process:
|
| 892 |
+
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
| 893 |
+
if args.use_ema:
|
| 894 |
+
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
| 895 |
+
ema_prior.store(prior.parameters())
|
| 896 |
+
ema_prior.copy_to(prior.parameters())
|
| 897 |
+
log_validation(
|
| 898 |
+
image_encoder,
|
| 899 |
+
image_processor,
|
| 900 |
+
text_encoder,
|
| 901 |
+
tokenizer,
|
| 902 |
+
prior,
|
| 903 |
+
args,
|
| 904 |
+
accelerator,
|
| 905 |
+
weight_dtype,
|
| 906 |
+
global_step,
|
| 907 |
+
)
|
| 908 |
+
if args.use_ema:
|
| 909 |
+
# Switch back to the original UNet parameters.
|
| 910 |
+
ema_prior.restore(prior.parameters())
|
| 911 |
+
|
| 912 |
+
# Create the pipeline using the trained modules and save it.
|
| 913 |
+
accelerator.wait_for_everyone()
|
| 914 |
+
if accelerator.is_main_process:
|
| 915 |
+
prior = accelerator.unwrap_model(prior)
|
| 916 |
+
if args.use_ema:
|
| 917 |
+
ema_prior.copy_to(prior.parameters())
|
| 918 |
+
|
| 919 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 920 |
+
args.pretrained_decoder_model_name_or_path,
|
| 921 |
+
prior_image_encoder=image_encoder,
|
| 922 |
+
prior_text_encoder=text_encoder,
|
| 923 |
+
prior_prior=prior,
|
| 924 |
+
)
|
| 925 |
+
pipeline.prior_pipe.save_pretrained(args.output_dir)
|
| 926 |
+
|
| 927 |
+
# Run a final round of inference.
|
| 928 |
+
images = []
|
| 929 |
+
if args.validation_prompts is not None:
|
| 930 |
+
logger.info("Running inference for collecting generated images...")
|
| 931 |
+
pipeline = pipeline.to(accelerator.device)
|
| 932 |
+
pipeline.torch_dtype = weight_dtype
|
| 933 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 934 |
+
|
| 935 |
+
if args.seed is None:
|
| 936 |
+
generator = None
|
| 937 |
+
else:
|
| 938 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 939 |
+
|
| 940 |
+
for i in range(len(args.validation_prompts)):
|
| 941 |
+
with torch.autocast("cuda"):
|
| 942 |
+
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
| 943 |
+
images.append(image)
|
| 944 |
+
|
| 945 |
+
if args.push_to_hub:
|
| 946 |
+
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
|
| 947 |
+
upload_folder(
|
| 948 |
+
repo_id=repo_id,
|
| 949 |
+
folder_path=args.output_dir,
|
| 950 |
+
commit_message="End of training",
|
| 951 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
accelerator.end_training()
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
if __name__ == "__main__":
|
| 958 |
+
main()
|
diffusers/examples/research_projects/anytext/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AnyTextPipeline
|
| 2 |
+
|
| 3 |
+
Project page: https://aigcdesigngroup.github.io/homepage_anytext
|
| 4 |
+
|
| 5 |
+
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
|
| 6 |
+
|
| 7 |
+
> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
|
| 8 |
+
|
| 9 |
+
For any usage questions, please refer to the [paper](https://huggingface.co/papers/2311.03054).
|
| 10 |
+
|
| 11 |
+
[](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
|
| 12 |
+
|
| 13 |
+
```py
|
| 14 |
+
# This example requires the `anytext_controlnet.py` file:
|
| 15 |
+
# !git clone --depth 1 https://github.com/huggingface/diffusers.git
|
| 16 |
+
# %cd diffusers/examples/research_projects/anytext
|
| 17 |
+
# Let's choose a font file shared by an HF staff:
|
| 18 |
+
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from diffusers import DiffusionPipeline
|
| 22 |
+
from anytext_controlnet import AnyTextControlNetModel
|
| 23 |
+
from diffusers.utils import load_image
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
|
| 27 |
+
variant="fp16",)
|
| 28 |
+
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
|
| 29 |
+
controlnet=anytext_controlnet, torch_dtype=torch.float16,
|
| 30 |
+
trust_remote_code=False, # One needs to give permission to run this pipeline's code
|
| 31 |
+
).to("cuda")
|
| 32 |
+
|
| 33 |
+
# generate image
|
| 34 |
+
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
|
| 35 |
+
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
|
| 36 |
+
# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
|
| 37 |
+
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
|
| 38 |
+
).images[0]
|
| 39 |
+
image
|
| 40 |
+
```
|
diffusers/examples/research_projects/anytext/anytext.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
diffusers/examples/research_projects/anytext/anytext_controlnet.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
|
| 16 |
+
# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
|
| 17 |
+
# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
|
| 18 |
+
#
|
| 19 |
+
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from diffusers.configuration_utils import register_to_config
|
| 28 |
+
from diffusers.models.controlnets.controlnet import (
|
| 29 |
+
ControlNetModel,
|
| 30 |
+
ControlNetOutput,
|
| 31 |
+
)
|
| 32 |
+
from diffusers.utils import logging
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AnyTextControlNetConditioningEmbedding(nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
| 41 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
| 42 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
| 43 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
| 44 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
| 45 |
+
model) to encode image-space conditions ... into feature maps ..."
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
conditioning_embedding_channels: int,
|
| 51 |
+
glyph_channels=1,
|
| 52 |
+
position_channels=1,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.glyph_block = nn.Sequential(
|
| 57 |
+
nn.Conv2d(glyph_channels, 8, 3, padding=1),
|
| 58 |
+
nn.SiLU(),
|
| 59 |
+
nn.Conv2d(8, 8, 3, padding=1),
|
| 60 |
+
nn.SiLU(),
|
| 61 |
+
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
| 62 |
+
nn.SiLU(),
|
| 63 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
| 64 |
+
nn.SiLU(),
|
| 65 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
| 66 |
+
nn.SiLU(),
|
| 67 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
| 68 |
+
nn.SiLU(),
|
| 69 |
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
| 70 |
+
nn.SiLU(),
|
| 71 |
+
nn.Conv2d(96, 96, 3, padding=1),
|
| 72 |
+
nn.SiLU(),
|
| 73 |
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
| 74 |
+
nn.SiLU(),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.position_block = nn.Sequential(
|
| 78 |
+
nn.Conv2d(position_channels, 8, 3, padding=1),
|
| 79 |
+
nn.SiLU(),
|
| 80 |
+
nn.Conv2d(8, 8, 3, padding=1),
|
| 81 |
+
nn.SiLU(),
|
| 82 |
+
nn.Conv2d(8, 16, 3, padding=1, stride=2),
|
| 83 |
+
nn.SiLU(),
|
| 84 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
| 85 |
+
nn.SiLU(),
|
| 86 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
| 87 |
+
nn.SiLU(),
|
| 88 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
| 89 |
+
nn.SiLU(),
|
| 90 |
+
nn.Conv2d(32, 64, 3, padding=1, stride=2),
|
| 91 |
+
nn.SiLU(),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
|
| 95 |
+
|
| 96 |
+
def forward(self, glyphs, positions, text_info):
|
| 97 |
+
glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
|
| 98 |
+
position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
|
| 99 |
+
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
|
| 100 |
+
|
| 101 |
+
return guided_hint
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class AnyTextControlNetModel(ControlNetModel):
|
| 105 |
+
"""
|
| 106 |
+
A AnyTextControlNetModel model.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
in_channels (`int`, defaults to 4):
|
| 110 |
+
The number of channels in the input sample.
|
| 111 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 112 |
+
Whether to flip the sin to cos in the time embedding.
|
| 113 |
+
freq_shift (`int`, defaults to 0):
|
| 114 |
+
The frequency shift to apply to the time embedding.
|
| 115 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 116 |
+
The tuple of downsample blocks to use.
|
| 117 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
| 118 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
| 119 |
+
The tuple of output channels for each block.
|
| 120 |
+
layers_per_block (`int`, defaults to 2):
|
| 121 |
+
The number of layers per block.
|
| 122 |
+
downsample_padding (`int`, defaults to 1):
|
| 123 |
+
The padding to use for the downsampling convolution.
|
| 124 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
| 125 |
+
The scale factor to use for the mid block.
|
| 126 |
+
act_fn (`str`, defaults to "silu"):
|
| 127 |
+
The activation function to use.
|
| 128 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 129 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
| 130 |
+
in post-processing.
|
| 131 |
+
norm_eps (`float`, defaults to 1e-5):
|
| 132 |
+
The epsilon to use for the normalization.
|
| 133 |
+
cross_attention_dim (`int`, defaults to 1280):
|
| 134 |
+
The dimension of the cross attention features.
|
| 135 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
| 136 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 137 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 138 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 139 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
| 140 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
| 141 |
+
dimension to `cross_attention_dim`.
|
| 142 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
| 143 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
| 144 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
| 145 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
| 146 |
+
The dimension of the attention heads.
|
| 147 |
+
use_linear_projection (`bool`, defaults to `False`):
|
| 148 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 149 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
| 150 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 151 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
| 152 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
| 153 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
| 154 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
| 155 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 156 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 157 |
+
upcast_attention (`bool`, defaults to `False`):
|
| 158 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
| 159 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
| 160 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
| 161 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
| 162 |
+
`class_embed_type="projection"`.
|
| 163 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
| 164 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
| 165 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
| 166 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
| 167 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
| 168 |
+
TODO(Patrick) - unused parameter.
|
| 169 |
+
addition_embed_type_num_heads (`int`, defaults to 64):
|
| 170 |
+
The number of heads to use for the `TextTimeEmbedding` layer.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
_supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
@register_to_config
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
in_channels: int = 4,
|
| 179 |
+
conditioning_channels: int = 1,
|
| 180 |
+
flip_sin_to_cos: bool = True,
|
| 181 |
+
freq_shift: int = 0,
|
| 182 |
+
down_block_types: Tuple[str, ...] = (
|
| 183 |
+
"CrossAttnDownBlock2D",
|
| 184 |
+
"CrossAttnDownBlock2D",
|
| 185 |
+
"CrossAttnDownBlock2D",
|
| 186 |
+
"DownBlock2D",
|
| 187 |
+
),
|
| 188 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| 189 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 190 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
| 191 |
+
layers_per_block: int = 2,
|
| 192 |
+
downsample_padding: int = 1,
|
| 193 |
+
mid_block_scale_factor: float = 1,
|
| 194 |
+
act_fn: str = "silu",
|
| 195 |
+
norm_num_groups: Optional[int] = 32,
|
| 196 |
+
norm_eps: float = 1e-5,
|
| 197 |
+
cross_attention_dim: int = 1280,
|
| 198 |
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
| 199 |
+
encoder_hid_dim: Optional[int] = None,
|
| 200 |
+
encoder_hid_dim_type: Optional[str] = None,
|
| 201 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
| 202 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 203 |
+
use_linear_projection: bool = False,
|
| 204 |
+
class_embed_type: Optional[str] = None,
|
| 205 |
+
addition_embed_type: Optional[str] = None,
|
| 206 |
+
addition_time_embed_dim: Optional[int] = None,
|
| 207 |
+
num_class_embeds: Optional[int] = None,
|
| 208 |
+
upcast_attention: bool = False,
|
| 209 |
+
resnet_time_scale_shift: str = "default",
|
| 210 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 211 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
| 212 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
| 213 |
+
global_pool_conditions: bool = False,
|
| 214 |
+
addition_embed_type_num_heads: int = 64,
|
| 215 |
+
):
|
| 216 |
+
super().__init__(
|
| 217 |
+
in_channels,
|
| 218 |
+
conditioning_channels,
|
| 219 |
+
flip_sin_to_cos,
|
| 220 |
+
freq_shift,
|
| 221 |
+
down_block_types,
|
| 222 |
+
mid_block_type,
|
| 223 |
+
only_cross_attention,
|
| 224 |
+
block_out_channels,
|
| 225 |
+
layers_per_block,
|
| 226 |
+
downsample_padding,
|
| 227 |
+
mid_block_scale_factor,
|
| 228 |
+
act_fn,
|
| 229 |
+
norm_num_groups,
|
| 230 |
+
norm_eps,
|
| 231 |
+
cross_attention_dim,
|
| 232 |
+
transformer_layers_per_block,
|
| 233 |
+
encoder_hid_dim,
|
| 234 |
+
encoder_hid_dim_type,
|
| 235 |
+
attention_head_dim,
|
| 236 |
+
num_attention_heads,
|
| 237 |
+
use_linear_projection,
|
| 238 |
+
class_embed_type,
|
| 239 |
+
addition_embed_type,
|
| 240 |
+
addition_time_embed_dim,
|
| 241 |
+
num_class_embeds,
|
| 242 |
+
upcast_attention,
|
| 243 |
+
resnet_time_scale_shift,
|
| 244 |
+
projection_class_embeddings_input_dim,
|
| 245 |
+
controlnet_conditioning_channel_order,
|
| 246 |
+
conditioning_embedding_out_channels,
|
| 247 |
+
global_pool_conditions,
|
| 248 |
+
addition_embed_type_num_heads,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# control net conditioning embedding
|
| 252 |
+
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
|
| 253 |
+
conditioning_embedding_channels=block_out_channels[0],
|
| 254 |
+
glyph_channels=conditioning_channels,
|
| 255 |
+
position_channels=conditioning_channels,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def forward(
|
| 259 |
+
self,
|
| 260 |
+
sample: torch.Tensor,
|
| 261 |
+
timestep: Union[torch.Tensor, float, int],
|
| 262 |
+
encoder_hidden_states: torch.Tensor,
|
| 263 |
+
controlnet_cond: torch.Tensor,
|
| 264 |
+
conditioning_scale: float = 1.0,
|
| 265 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 266 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 267 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 268 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 269 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 270 |
+
guess_mode: bool = False,
|
| 271 |
+
return_dict: bool = True,
|
| 272 |
+
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
| 273 |
+
"""
|
| 274 |
+
The [`~PromptDiffusionControlNetModel`] forward method.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
sample (`torch.Tensor`):
|
| 278 |
+
The noisy input tensor.
|
| 279 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
| 280 |
+
The number of timesteps to denoise an input.
|
| 281 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 282 |
+
The encoder hidden states.
|
| 283 |
+
#controlnet_cond (`torch.Tensor`):
|
| 284 |
+
# The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
| 285 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
| 286 |
+
The scale factor for ControlNet outputs.
|
| 287 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
| 288 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 289 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
| 290 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
| 291 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
| 292 |
+
embeddings.
|
| 293 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
| 294 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 295 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 296 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 297 |
+
added_cond_kwargs (`dict`):
|
| 298 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
| 299 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
| 300 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
| 301 |
+
guess_mode (`bool`, defaults to `False`):
|
| 302 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
| 303 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
| 304 |
+
return_dict (`bool`, defaults to `True`):
|
| 305 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
| 309 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
| 310 |
+
returned where the first element is the sample tensor.
|
| 311 |
+
"""
|
| 312 |
+
# check channel order
|
| 313 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
| 314 |
+
|
| 315 |
+
if channel_order == "rgb":
|
| 316 |
+
# in rgb order by default
|
| 317 |
+
...
|
| 318 |
+
# elif channel_order == "bgr":
|
| 319 |
+
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
| 320 |
+
else:
|
| 321 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
| 322 |
+
|
| 323 |
+
# prepare attention_mask
|
| 324 |
+
if attention_mask is not None:
|
| 325 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 326 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 327 |
+
|
| 328 |
+
# 1. time
|
| 329 |
+
timesteps = timestep
|
| 330 |
+
if not torch.is_tensor(timesteps):
|
| 331 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 332 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 333 |
+
is_mps = sample.device.type == "mps"
|
| 334 |
+
if isinstance(timestep, float):
|
| 335 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 336 |
+
else:
|
| 337 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 338 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 339 |
+
elif len(timesteps.shape) == 0:
|
| 340 |
+
timesteps = timesteps[None].to(sample.device)
|
| 341 |
+
|
| 342 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 343 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 344 |
+
|
| 345 |
+
t_emb = self.time_proj(timesteps)
|
| 346 |
+
|
| 347 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 348 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 349 |
+
# there might be better ways to encapsulate this.
|
| 350 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 351 |
+
|
| 352 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 353 |
+
aug_emb = None
|
| 354 |
+
|
| 355 |
+
if self.class_embedding is not None:
|
| 356 |
+
if class_labels is None:
|
| 357 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 358 |
+
|
| 359 |
+
if self.config.class_embed_type == "timestep":
|
| 360 |
+
class_labels = self.time_proj(class_labels)
|
| 361 |
+
|
| 362 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 363 |
+
emb = emb + class_emb
|
| 364 |
+
|
| 365 |
+
if self.config.addition_embed_type is not None:
|
| 366 |
+
if self.config.addition_embed_type == "text":
|
| 367 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 368 |
+
|
| 369 |
+
elif self.config.addition_embed_type == "text_time":
|
| 370 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 373 |
+
)
|
| 374 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 375 |
+
if "time_ids" not in added_cond_kwargs:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 378 |
+
)
|
| 379 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 380 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 381 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 382 |
+
|
| 383 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 384 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 385 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 386 |
+
|
| 387 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 388 |
+
|
| 389 |
+
# 2. pre-process
|
| 390 |
+
sample = self.conv_in(sample)
|
| 391 |
+
|
| 392 |
+
controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
|
| 393 |
+
sample = sample + controlnet_cond
|
| 394 |
+
|
| 395 |
+
# 3. down
|
| 396 |
+
down_block_res_samples = (sample,)
|
| 397 |
+
for downsample_block in self.down_blocks:
|
| 398 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 399 |
+
sample, res_samples = downsample_block(
|
| 400 |
+
hidden_states=sample,
|
| 401 |
+
temb=emb,
|
| 402 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 403 |
+
attention_mask=attention_mask,
|
| 404 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 408 |
+
|
| 409 |
+
down_block_res_samples += res_samples
|
| 410 |
+
|
| 411 |
+
# 4. mid
|
| 412 |
+
if self.mid_block is not None:
|
| 413 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 414 |
+
sample = self.mid_block(
|
| 415 |
+
sample,
|
| 416 |
+
emb,
|
| 417 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 418 |
+
attention_mask=attention_mask,
|
| 419 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
sample = self.mid_block(sample, emb)
|
| 423 |
+
|
| 424 |
+
# 5. Control net blocks
|
| 425 |
+
controlnet_down_block_res_samples = ()
|
| 426 |
+
|
| 427 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
| 428 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
| 429 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
| 430 |
+
|
| 431 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
| 432 |
+
|
| 433 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
| 434 |
+
|
| 435 |
+
# 6. scaling
|
| 436 |
+
if guess_mode and not self.config.global_pool_conditions:
|
| 437 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
| 438 |
+
scales = scales * conditioning_scale
|
| 439 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
| 440 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
| 441 |
+
else:
|
| 442 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
| 443 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
| 444 |
+
|
| 445 |
+
if self.config.global_pool_conditions:
|
| 446 |
+
down_block_res_samples = [
|
| 447 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
| 448 |
+
]
|
| 449 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
| 450 |
+
|
| 451 |
+
if not return_dict:
|
| 452 |
+
return (down_block_res_samples, mid_block_res_sample)
|
| 453 |
+
|
| 454 |
+
return ControlNetOutput(
|
| 455 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# Copied from diffusers.models.controlnet.zero_module
|
| 460 |
+
def zero_module(module):
|
| 461 |
+
for p in module.parameters():
|
| 462 |
+
nn.init.zeros_(p)
|
| 463 |
+
return module
|
diffusers/examples/research_projects/anytext/ocr_recog/RNN.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from .RecSVTR import Block
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Swish(nn.Module):
|
| 8 |
+
def __int__(self):
|
| 9 |
+
super(Swish, self).__int__()
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
return x * torch.sigmoid(x)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Im2Im(nn.Module):
|
| 16 |
+
def __init__(self, in_channels, **kwargs):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.out_channels = in_channels
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Im2Seq(nn.Module):
|
| 25 |
+
def __init__(self, in_channels, **kwargs):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.out_channels = in_channels
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
B, C, H, W = x.shape
|
| 31 |
+
# assert H == 1
|
| 32 |
+
x = x.reshape(B, C, H * W)
|
| 33 |
+
x = x.permute((0, 2, 1))
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EncoderWithRNN(nn.Module):
|
| 38 |
+
def __init__(self, in_channels, **kwargs):
|
| 39 |
+
super(EncoderWithRNN, self).__init__()
|
| 40 |
+
hidden_size = kwargs.get("hidden_size", 256)
|
| 41 |
+
self.out_channels = hidden_size * 2
|
| 42 |
+
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
self.lstm.flatten_parameters()
|
| 46 |
+
x, _ = self.lstm(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SequenceEncoder(nn.Module):
|
| 51 |
+
def __init__(self, in_channels, encoder_type="rnn", **kwargs):
|
| 52 |
+
super(SequenceEncoder, self).__init__()
|
| 53 |
+
self.encoder_reshape = Im2Seq(in_channels)
|
| 54 |
+
self.out_channels = self.encoder_reshape.out_channels
|
| 55 |
+
self.encoder_type = encoder_type
|
| 56 |
+
if encoder_type == "reshape":
|
| 57 |
+
self.only_reshape = True
|
| 58 |
+
else:
|
| 59 |
+
support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
|
| 60 |
+
assert encoder_type in support_encoder_dict, "{} must in {}".format(
|
| 61 |
+
encoder_type, support_encoder_dict.keys()
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
|
| 65 |
+
self.out_channels = self.encoder.out_channels
|
| 66 |
+
self.only_reshape = False
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
if self.encoder_type != "svtr":
|
| 70 |
+
x = self.encoder_reshape(x)
|
| 71 |
+
if not self.only_reshape:
|
| 72 |
+
x = self.encoder(x)
|
| 73 |
+
return x
|
| 74 |
+
else:
|
| 75 |
+
x = self.encoder(x)
|
| 76 |
+
x = self.encoder_reshape(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ConvBNLayer(nn.Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.conv = nn.Conv2d(
|
| 86 |
+
in_channels=in_channels,
|
| 87 |
+
out_channels=out_channels,
|
| 88 |
+
kernel_size=kernel_size,
|
| 89 |
+
stride=stride,
|
| 90 |
+
padding=padding,
|
| 91 |
+
groups=groups,
|
| 92 |
+
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
| 93 |
+
bias=bias_attr,
|
| 94 |
+
)
|
| 95 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 96 |
+
self.act = Swish()
|
| 97 |
+
|
| 98 |
+
def forward(self, inputs):
|
| 99 |
+
out = self.conv(inputs)
|
| 100 |
+
out = self.norm(out)
|
| 101 |
+
out = self.act(out)
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class EncoderWithSVTR(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
in_channels,
|
| 109 |
+
dims=64, # XS
|
| 110 |
+
depth=2,
|
| 111 |
+
hidden_dims=120,
|
| 112 |
+
use_guide=False,
|
| 113 |
+
num_heads=8,
|
| 114 |
+
qkv_bias=True,
|
| 115 |
+
mlp_ratio=2.0,
|
| 116 |
+
drop_rate=0.1,
|
| 117 |
+
attn_drop_rate=0.1,
|
| 118 |
+
drop_path=0.0,
|
| 119 |
+
qk_scale=None,
|
| 120 |
+
):
|
| 121 |
+
super(EncoderWithSVTR, self).__init__()
|
| 122 |
+
self.depth = depth
|
| 123 |
+
self.use_guide = use_guide
|
| 124 |
+
self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
|
| 125 |
+
self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
|
| 126 |
+
|
| 127 |
+
self.svtr_block = nn.ModuleList(
|
| 128 |
+
[
|
| 129 |
+
Block(
|
| 130 |
+
dim=hidden_dims,
|
| 131 |
+
num_heads=num_heads,
|
| 132 |
+
mixer="Global",
|
| 133 |
+
HW=None,
|
| 134 |
+
mlp_ratio=mlp_ratio,
|
| 135 |
+
qkv_bias=qkv_bias,
|
| 136 |
+
qk_scale=qk_scale,
|
| 137 |
+
drop=drop_rate,
|
| 138 |
+
act_layer="swish",
|
| 139 |
+
attn_drop=attn_drop_rate,
|
| 140 |
+
drop_path=drop_path,
|
| 141 |
+
norm_layer="nn.LayerNorm",
|
| 142 |
+
epsilon=1e-05,
|
| 143 |
+
prenorm=False,
|
| 144 |
+
)
|
| 145 |
+
for i in range(depth)
|
| 146 |
+
]
|
| 147 |
+
)
|
| 148 |
+
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
|
| 149 |
+
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
|
| 150 |
+
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
|
| 151 |
+
self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
|
| 152 |
+
|
| 153 |
+
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
|
| 154 |
+
self.out_channels = dims
|
| 155 |
+
self.apply(self._init_weights)
|
| 156 |
+
|
| 157 |
+
def _init_weights(self, m):
|
| 158 |
+
# weight initialization
|
| 159 |
+
if isinstance(m, nn.Conv2d):
|
| 160 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
| 161 |
+
if m.bias is not None:
|
| 162 |
+
nn.init.zeros_(m.bias)
|
| 163 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 164 |
+
nn.init.ones_(m.weight)
|
| 165 |
+
nn.init.zeros_(m.bias)
|
| 166 |
+
elif isinstance(m, nn.Linear):
|
| 167 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 168 |
+
if m.bias is not None:
|
| 169 |
+
nn.init.zeros_(m.bias)
|
| 170 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 171 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
| 172 |
+
if m.bias is not None:
|
| 173 |
+
nn.init.zeros_(m.bias)
|
| 174 |
+
elif isinstance(m, nn.LayerNorm):
|
| 175 |
+
nn.init.ones_(m.weight)
|
| 176 |
+
nn.init.zeros_(m.bias)
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
# for use guide
|
| 180 |
+
if self.use_guide:
|
| 181 |
+
z = x.clone()
|
| 182 |
+
z.stop_gradient = True
|
| 183 |
+
else:
|
| 184 |
+
z = x
|
| 185 |
+
# for short cut
|
| 186 |
+
h = z
|
| 187 |
+
# reduce dim
|
| 188 |
+
z = self.conv1(z)
|
| 189 |
+
z = self.conv2(z)
|
| 190 |
+
# SVTR global block
|
| 191 |
+
B, C, H, W = z.shape
|
| 192 |
+
z = z.flatten(2).permute(0, 2, 1)
|
| 193 |
+
|
| 194 |
+
for blk in self.svtr_block:
|
| 195 |
+
z = blk(z)
|
| 196 |
+
|
| 197 |
+
z = self.norm(z)
|
| 198 |
+
# last stage
|
| 199 |
+
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
|
| 200 |
+
z = self.conv3(z)
|
| 201 |
+
z = torch.cat((h, z), dim=1)
|
| 202 |
+
z = self.conv1x1(self.conv4(z))
|
| 203 |
+
|
| 204 |
+
return z
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
svtrRNN = EncoderWithSVTR(56)
|
| 209 |
+
print(svtrRNN)
|
diffusers/examples/research_projects/anytext/ocr_recog/RecCTCHead.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CTCHead(nn.Module):
|
| 5 |
+
def __init__(
|
| 6 |
+
self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
|
| 7 |
+
):
|
| 8 |
+
super(CTCHead, self).__init__()
|
| 9 |
+
if mid_channels is None:
|
| 10 |
+
self.fc = nn.Linear(
|
| 11 |
+
in_channels,
|
| 12 |
+
out_channels,
|
| 13 |
+
bias=True,
|
| 14 |
+
)
|
| 15 |
+
else:
|
| 16 |
+
self.fc1 = nn.Linear(
|
| 17 |
+
in_channels,
|
| 18 |
+
mid_channels,
|
| 19 |
+
bias=True,
|
| 20 |
+
)
|
| 21 |
+
self.fc2 = nn.Linear(
|
| 22 |
+
mid_channels,
|
| 23 |
+
out_channels,
|
| 24 |
+
bias=True,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.out_channels = out_channels
|
| 28 |
+
self.mid_channels = mid_channels
|
| 29 |
+
self.return_feats = return_feats
|
| 30 |
+
|
| 31 |
+
def forward(self, x, labels=None):
|
| 32 |
+
if self.mid_channels is None:
|
| 33 |
+
predicts = self.fc(x)
|
| 34 |
+
else:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
predicts = self.fc2(x)
|
| 37 |
+
|
| 38 |
+
if self.return_feats:
|
| 39 |
+
result = {}
|
| 40 |
+
result["ctc"] = predicts
|
| 41 |
+
result["ctc_neck"] = x
|
| 42 |
+
else:
|
| 43 |
+
result = predicts
|
| 44 |
+
|
| 45 |
+
return result
|
diffusers/examples/research_projects/anytext/ocr_recog/RecModel.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
from .RecCTCHead import CTCHead
|
| 4 |
+
from .RecMv1_enhance import MobileNetV1Enhance
|
| 5 |
+
from .RNN import Im2Im, Im2Seq, SequenceEncoder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
|
| 9 |
+
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
|
| 10 |
+
head_dict = {"CTCHead": CTCHead}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RecModel(nn.Module):
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
super().__init__()
|
| 16 |
+
assert "in_channels" in config, "in_channels must in model config"
|
| 17 |
+
backbone_type = config["backbone"].pop("type")
|
| 18 |
+
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
|
| 19 |
+
self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
|
| 20 |
+
|
| 21 |
+
neck_type = config["neck"].pop("type")
|
| 22 |
+
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
|
| 23 |
+
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
|
| 24 |
+
|
| 25 |
+
head_type = config["head"].pop("type")
|
| 26 |
+
assert head_type in head_dict, f"head.type must in {head_dict}"
|
| 27 |
+
self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
|
| 28 |
+
|
| 29 |
+
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
|
| 30 |
+
|
| 31 |
+
def load_3rd_state_dict(self, _3rd_name, _state):
|
| 32 |
+
self.backbone.load_3rd_state_dict(_3rd_name, _state)
|
| 33 |
+
self.neck.load_3rd_state_dict(_3rd_name, _state)
|
| 34 |
+
self.head.load_3rd_state_dict(_3rd_name, _state)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
x = x.to(torch.float32)
|
| 40 |
+
x = self.backbone(x)
|
| 41 |
+
x = self.neck(x)
|
| 42 |
+
x = self.head(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
def encode(self, x):
|
| 46 |
+
x = self.backbone(x)
|
| 47 |
+
x = self.neck(x)
|
| 48 |
+
x = self.head.ctc_encoder(x)
|
| 49 |
+
return x
|
diffusers/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .common import Activation
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvBNLayer(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
|
| 11 |
+
):
|
| 12 |
+
super(ConvBNLayer, self).__init__()
|
| 13 |
+
self.act = act
|
| 14 |
+
self._conv = nn.Conv2d(
|
| 15 |
+
in_channels=num_channels,
|
| 16 |
+
out_channels=num_filters,
|
| 17 |
+
kernel_size=filter_size,
|
| 18 |
+
stride=stride,
|
| 19 |
+
padding=padding,
|
| 20 |
+
groups=num_groups,
|
| 21 |
+
bias=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self._batch_norm = nn.BatchNorm2d(
|
| 25 |
+
num_filters,
|
| 26 |
+
)
|
| 27 |
+
if self.act is not None:
|
| 28 |
+
self._act = Activation(act_type=act, inplace=True)
|
| 29 |
+
|
| 30 |
+
def forward(self, inputs):
|
| 31 |
+
y = self._conv(inputs)
|
| 32 |
+
y = self._batch_norm(y)
|
| 33 |
+
if self.act is not None:
|
| 34 |
+
y = self._act(y)
|
| 35 |
+
return y
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DepthwiseSeparable(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
|
| 41 |
+
):
|
| 42 |
+
super(DepthwiseSeparable, self).__init__()
|
| 43 |
+
self.use_se = use_se
|
| 44 |
+
self._depthwise_conv = ConvBNLayer(
|
| 45 |
+
num_channels=num_channels,
|
| 46 |
+
num_filters=int(num_filters1 * scale),
|
| 47 |
+
filter_size=dw_size,
|
| 48 |
+
stride=stride,
|
| 49 |
+
padding=padding,
|
| 50 |
+
num_groups=int(num_groups * scale),
|
| 51 |
+
)
|
| 52 |
+
if use_se:
|
| 53 |
+
self._se = SEModule(int(num_filters1 * scale))
|
| 54 |
+
self._pointwise_conv = ConvBNLayer(
|
| 55 |
+
num_channels=int(num_filters1 * scale),
|
| 56 |
+
filter_size=1,
|
| 57 |
+
num_filters=int(num_filters2 * scale),
|
| 58 |
+
stride=1,
|
| 59 |
+
padding=0,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, inputs):
|
| 63 |
+
y = self._depthwise_conv(inputs)
|
| 64 |
+
if self.use_se:
|
| 65 |
+
y = self._se(y)
|
| 66 |
+
y = self._pointwise_conv(y)
|
| 67 |
+
return y
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class MobileNetV1Enhance(nn.Module):
|
| 71 |
+
def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.scale = scale
|
| 74 |
+
self.block_list = []
|
| 75 |
+
|
| 76 |
+
self.conv1 = ConvBNLayer(
|
| 77 |
+
num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
conv2_1 = DepthwiseSeparable(
|
| 81 |
+
num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
|
| 82 |
+
)
|
| 83 |
+
self.block_list.append(conv2_1)
|
| 84 |
+
|
| 85 |
+
conv2_2 = DepthwiseSeparable(
|
| 86 |
+
num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
|
| 87 |
+
)
|
| 88 |
+
self.block_list.append(conv2_2)
|
| 89 |
+
|
| 90 |
+
conv3_1 = DepthwiseSeparable(
|
| 91 |
+
num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
|
| 92 |
+
)
|
| 93 |
+
self.block_list.append(conv3_1)
|
| 94 |
+
|
| 95 |
+
conv3_2 = DepthwiseSeparable(
|
| 96 |
+
num_channels=int(128 * scale),
|
| 97 |
+
num_filters1=128,
|
| 98 |
+
num_filters2=256,
|
| 99 |
+
num_groups=128,
|
| 100 |
+
stride=(2, 1),
|
| 101 |
+
scale=scale,
|
| 102 |
+
)
|
| 103 |
+
self.block_list.append(conv3_2)
|
| 104 |
+
|
| 105 |
+
conv4_1 = DepthwiseSeparable(
|
| 106 |
+
num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
|
| 107 |
+
)
|
| 108 |
+
self.block_list.append(conv4_1)
|
| 109 |
+
|
| 110 |
+
conv4_2 = DepthwiseSeparable(
|
| 111 |
+
num_channels=int(256 * scale),
|
| 112 |
+
num_filters1=256,
|
| 113 |
+
num_filters2=512,
|
| 114 |
+
num_groups=256,
|
| 115 |
+
stride=(2, 1),
|
| 116 |
+
scale=scale,
|
| 117 |
+
)
|
| 118 |
+
self.block_list.append(conv4_2)
|
| 119 |
+
|
| 120 |
+
for _ in range(5):
|
| 121 |
+
conv5 = DepthwiseSeparable(
|
| 122 |
+
num_channels=int(512 * scale),
|
| 123 |
+
num_filters1=512,
|
| 124 |
+
num_filters2=512,
|
| 125 |
+
num_groups=512,
|
| 126 |
+
stride=1,
|
| 127 |
+
dw_size=5,
|
| 128 |
+
padding=2,
|
| 129 |
+
scale=scale,
|
| 130 |
+
use_se=False,
|
| 131 |
+
)
|
| 132 |
+
self.block_list.append(conv5)
|
| 133 |
+
|
| 134 |
+
conv5_6 = DepthwiseSeparable(
|
| 135 |
+
num_channels=int(512 * scale),
|
| 136 |
+
num_filters1=512,
|
| 137 |
+
num_filters2=1024,
|
| 138 |
+
num_groups=512,
|
| 139 |
+
stride=(2, 1),
|
| 140 |
+
dw_size=5,
|
| 141 |
+
padding=2,
|
| 142 |
+
scale=scale,
|
| 143 |
+
use_se=True,
|
| 144 |
+
)
|
| 145 |
+
self.block_list.append(conv5_6)
|
| 146 |
+
|
| 147 |
+
conv6 = DepthwiseSeparable(
|
| 148 |
+
num_channels=int(1024 * scale),
|
| 149 |
+
num_filters1=1024,
|
| 150 |
+
num_filters2=1024,
|
| 151 |
+
num_groups=1024,
|
| 152 |
+
stride=last_conv_stride,
|
| 153 |
+
dw_size=5,
|
| 154 |
+
padding=2,
|
| 155 |
+
use_se=True,
|
| 156 |
+
scale=scale,
|
| 157 |
+
)
|
| 158 |
+
self.block_list.append(conv6)
|
| 159 |
+
|
| 160 |
+
self.block_list = nn.Sequential(*self.block_list)
|
| 161 |
+
if last_pool_type == "avg":
|
| 162 |
+
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
| 163 |
+
else:
|
| 164 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
| 165 |
+
self.out_channels = int(1024 * scale)
|
| 166 |
+
|
| 167 |
+
def forward(self, inputs):
|
| 168 |
+
y = self.conv1(inputs)
|
| 169 |
+
y = self.block_list(y)
|
| 170 |
+
y = self.pool(y)
|
| 171 |
+
return y
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def hardsigmoid(x):
|
| 175 |
+
return F.relu6(x + 3.0, inplace=True) / 6.0
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class SEModule(nn.Module):
|
| 179 |
+
def __init__(self, channel, reduction=4):
|
| 180 |
+
super(SEModule, self).__init__()
|
| 181 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 182 |
+
self.conv1 = nn.Conv2d(
|
| 183 |
+
in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
|
| 184 |
+
)
|
| 185 |
+
self.conv2 = nn.Conv2d(
|
| 186 |
+
in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, inputs):
|
| 190 |
+
outputs = self.avg_pool(inputs)
|
| 191 |
+
outputs = self.conv1(outputs)
|
| 192 |
+
outputs = F.relu(outputs)
|
| 193 |
+
outputs = self.conv2(outputs)
|
| 194 |
+
outputs = hardsigmoid(outputs)
|
| 195 |
+
x = torch.mul(inputs, outputs)
|
| 196 |
+
|
| 197 |
+
return x
|
diffusers/examples/research_projects/anytext/ocr_recog/RecSVTR.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import functional
|
| 5 |
+
from torch.nn.init import ones_, trunc_normal_, zeros_
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def drop_path(x, drop_prob=0.0, training=False):
|
| 9 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 10 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 11 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
| 12 |
+
"""
|
| 13 |
+
if drop_prob == 0.0 or not training:
|
| 14 |
+
return x
|
| 15 |
+
keep_prob = torch.tensor(1 - drop_prob)
|
| 16 |
+
shape = (x.size()[0],) + (1,) * (x.ndim - 1)
|
| 17 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
|
| 18 |
+
random_tensor = torch.floor(random_tensor) # binarize
|
| 19 |
+
output = x.divide(keep_prob) * random_tensor
|
| 20 |
+
return output
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Swish(nn.Module):
|
| 24 |
+
def __int__(self):
|
| 25 |
+
super(Swish, self).__int__()
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return x * torch.sigmoid(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ConvBNLayer(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.conv = nn.Conv2d(
|
| 37 |
+
in_channels=in_channels,
|
| 38 |
+
out_channels=out_channels,
|
| 39 |
+
kernel_size=kernel_size,
|
| 40 |
+
stride=stride,
|
| 41 |
+
padding=padding,
|
| 42 |
+
groups=groups,
|
| 43 |
+
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
| 44 |
+
bias=bias_attr,
|
| 45 |
+
)
|
| 46 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
| 47 |
+
self.act = act()
|
| 48 |
+
|
| 49 |
+
def forward(self, inputs):
|
| 50 |
+
out = self.conv(inputs)
|
| 51 |
+
out = self.norm(out)
|
| 52 |
+
out = self.act(out)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class DropPath(nn.Module):
|
| 57 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 58 |
+
|
| 59 |
+
def __init__(self, drop_prob=None):
|
| 60 |
+
super(DropPath, self).__init__()
|
| 61 |
+
self.drop_prob = drop_prob
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Identity(nn.Module):
|
| 68 |
+
def __init__(self):
|
| 69 |
+
super(Identity, self).__init__()
|
| 70 |
+
|
| 71 |
+
def forward(self, input):
|
| 72 |
+
return input
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Mlp(nn.Module):
|
| 76 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
| 77 |
+
super().__init__()
|
| 78 |
+
out_features = out_features or in_features
|
| 79 |
+
hidden_features = hidden_features or in_features
|
| 80 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 81 |
+
if isinstance(act_layer, str):
|
| 82 |
+
self.act = Swish()
|
| 83 |
+
else:
|
| 84 |
+
self.act = act_layer()
|
| 85 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 86 |
+
self.drop = nn.Dropout(drop)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
x = self.fc1(x)
|
| 90 |
+
x = self.act(x)
|
| 91 |
+
x = self.drop(x)
|
| 92 |
+
x = self.fc2(x)
|
| 93 |
+
x = self.drop(x)
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ConvMixer(nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
dim,
|
| 101 |
+
num_heads=8,
|
| 102 |
+
HW=(8, 25),
|
| 103 |
+
local_k=(3, 3),
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.HW = HW
|
| 107 |
+
self.dim = dim
|
| 108 |
+
self.local_mixer = nn.Conv2d(
|
| 109 |
+
dim,
|
| 110 |
+
dim,
|
| 111 |
+
local_k,
|
| 112 |
+
1,
|
| 113 |
+
(local_k[0] // 2, local_k[1] // 2),
|
| 114 |
+
groups=num_heads,
|
| 115 |
+
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
h = self.HW[0]
|
| 120 |
+
w = self.HW[1]
|
| 121 |
+
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
|
| 122 |
+
x = self.local_mixer(x)
|
| 123 |
+
x = x.flatten(2).transpose([0, 2, 1])
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Attention(nn.Module):
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
dim,
|
| 131 |
+
num_heads=8,
|
| 132 |
+
mixer="Global",
|
| 133 |
+
HW=(8, 25),
|
| 134 |
+
local_k=(7, 11),
|
| 135 |
+
qkv_bias=False,
|
| 136 |
+
qk_scale=None,
|
| 137 |
+
attn_drop=0.0,
|
| 138 |
+
proj_drop=0.0,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.num_heads = num_heads
|
| 142 |
+
head_dim = dim // num_heads
|
| 143 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 144 |
+
|
| 145 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 146 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 147 |
+
self.proj = nn.Linear(dim, dim)
|
| 148 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 149 |
+
self.HW = HW
|
| 150 |
+
if HW is not None:
|
| 151 |
+
H = HW[0]
|
| 152 |
+
W = HW[1]
|
| 153 |
+
self.N = H * W
|
| 154 |
+
self.C = dim
|
| 155 |
+
if mixer == "Local" and HW is not None:
|
| 156 |
+
hk = local_k[0]
|
| 157 |
+
wk = local_k[1]
|
| 158 |
+
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
|
| 159 |
+
for h in range(0, H):
|
| 160 |
+
for w in range(0, W):
|
| 161 |
+
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
|
| 162 |
+
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
|
| 163 |
+
mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
|
| 164 |
+
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
|
| 165 |
+
self.mask = mask[None, None, :]
|
| 166 |
+
# self.mask = mask.unsqueeze([0, 1])
|
| 167 |
+
self.mixer = mixer
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
if self.HW is not None:
|
| 171 |
+
N = self.N
|
| 172 |
+
C = self.C
|
| 173 |
+
else:
|
| 174 |
+
_, N, C = x.shape
|
| 175 |
+
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
|
| 176 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 177 |
+
|
| 178 |
+
attn = q.matmul(k.permute((0, 1, 3, 2)))
|
| 179 |
+
if self.mixer == "Local":
|
| 180 |
+
attn += self.mask
|
| 181 |
+
attn = functional.softmax(attn, dim=-1)
|
| 182 |
+
attn = self.attn_drop(attn)
|
| 183 |
+
|
| 184 |
+
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
|
| 185 |
+
x = self.proj(x)
|
| 186 |
+
x = self.proj_drop(x)
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class Block(nn.Module):
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
dim,
|
| 194 |
+
num_heads,
|
| 195 |
+
mixer="Global",
|
| 196 |
+
local_mixer=(7, 11),
|
| 197 |
+
HW=(8, 25),
|
| 198 |
+
mlp_ratio=4.0,
|
| 199 |
+
qkv_bias=False,
|
| 200 |
+
qk_scale=None,
|
| 201 |
+
drop=0.0,
|
| 202 |
+
attn_drop=0.0,
|
| 203 |
+
drop_path=0.0,
|
| 204 |
+
act_layer=nn.GELU,
|
| 205 |
+
norm_layer="nn.LayerNorm",
|
| 206 |
+
epsilon=1e-6,
|
| 207 |
+
prenorm=True,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
if isinstance(norm_layer, str):
|
| 211 |
+
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
|
| 212 |
+
else:
|
| 213 |
+
self.norm1 = norm_layer(dim)
|
| 214 |
+
if mixer == "Global" or mixer == "Local":
|
| 215 |
+
self.mixer = Attention(
|
| 216 |
+
dim,
|
| 217 |
+
num_heads=num_heads,
|
| 218 |
+
mixer=mixer,
|
| 219 |
+
HW=HW,
|
| 220 |
+
local_k=local_mixer,
|
| 221 |
+
qkv_bias=qkv_bias,
|
| 222 |
+
qk_scale=qk_scale,
|
| 223 |
+
attn_drop=attn_drop,
|
| 224 |
+
proj_drop=drop,
|
| 225 |
+
)
|
| 226 |
+
elif mixer == "Conv":
|
| 227 |
+
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
|
| 228 |
+
else:
|
| 229 |
+
raise TypeError("The mixer must be one of [Global, Local, Conv]")
|
| 230 |
+
|
| 231 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
| 232 |
+
if isinstance(norm_layer, str):
|
| 233 |
+
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
|
| 234 |
+
else:
|
| 235 |
+
self.norm2 = norm_layer(dim)
|
| 236 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 237 |
+
self.mlp_ratio = mlp_ratio
|
| 238 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 239 |
+
self.prenorm = prenorm
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
if self.prenorm:
|
| 243 |
+
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
| 244 |
+
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
| 245 |
+
else:
|
| 246 |
+
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
| 247 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class PatchEmbed(nn.Module):
|
| 252 |
+
"""Image to Patch Embedding"""
|
| 253 |
+
|
| 254 |
+
def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
|
| 255 |
+
super().__init__()
|
| 256 |
+
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
|
| 257 |
+
self.img_size = img_size
|
| 258 |
+
self.num_patches = num_patches
|
| 259 |
+
self.embed_dim = embed_dim
|
| 260 |
+
self.norm = None
|
| 261 |
+
if sub_num == 2:
|
| 262 |
+
self.proj = nn.Sequential(
|
| 263 |
+
ConvBNLayer(
|
| 264 |
+
in_channels=in_channels,
|
| 265 |
+
out_channels=embed_dim // 2,
|
| 266 |
+
kernel_size=3,
|
| 267 |
+
stride=2,
|
| 268 |
+
padding=1,
|
| 269 |
+
act=nn.GELU,
|
| 270 |
+
bias_attr=False,
|
| 271 |
+
),
|
| 272 |
+
ConvBNLayer(
|
| 273 |
+
in_channels=embed_dim // 2,
|
| 274 |
+
out_channels=embed_dim,
|
| 275 |
+
kernel_size=3,
|
| 276 |
+
stride=2,
|
| 277 |
+
padding=1,
|
| 278 |
+
act=nn.GELU,
|
| 279 |
+
bias_attr=False,
|
| 280 |
+
),
|
| 281 |
+
)
|
| 282 |
+
if sub_num == 3:
|
| 283 |
+
self.proj = nn.Sequential(
|
| 284 |
+
ConvBNLayer(
|
| 285 |
+
in_channels=in_channels,
|
| 286 |
+
out_channels=embed_dim // 4,
|
| 287 |
+
kernel_size=3,
|
| 288 |
+
stride=2,
|
| 289 |
+
padding=1,
|
| 290 |
+
act=nn.GELU,
|
| 291 |
+
bias_attr=False,
|
| 292 |
+
),
|
| 293 |
+
ConvBNLayer(
|
| 294 |
+
in_channels=embed_dim // 4,
|
| 295 |
+
out_channels=embed_dim // 2,
|
| 296 |
+
kernel_size=3,
|
| 297 |
+
stride=2,
|
| 298 |
+
padding=1,
|
| 299 |
+
act=nn.GELU,
|
| 300 |
+
bias_attr=False,
|
| 301 |
+
),
|
| 302 |
+
ConvBNLayer(
|
| 303 |
+
in_channels=embed_dim // 2,
|
| 304 |
+
out_channels=embed_dim,
|
| 305 |
+
kernel_size=3,
|
| 306 |
+
stride=2,
|
| 307 |
+
padding=1,
|
| 308 |
+
act=nn.GELU,
|
| 309 |
+
bias_attr=False,
|
| 310 |
+
),
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def forward(self, x):
|
| 314 |
+
B, C, H, W = x.shape
|
| 315 |
+
assert H == self.img_size[0] and W == self.img_size[1], (
|
| 316 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 317 |
+
)
|
| 318 |
+
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
| 319 |
+
return x
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class SubSample(nn.Module):
|
| 323 |
+
def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
|
| 324 |
+
super().__init__()
|
| 325 |
+
self.types = types
|
| 326 |
+
if types == "Pool":
|
| 327 |
+
self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
| 328 |
+
self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
| 329 |
+
self.proj = nn.Linear(in_channels, out_channels)
|
| 330 |
+
else:
|
| 331 |
+
self.conv = nn.Conv2d(
|
| 332 |
+
in_channels,
|
| 333 |
+
out_channels,
|
| 334 |
+
kernel_size=3,
|
| 335 |
+
stride=stride,
|
| 336 |
+
padding=1,
|
| 337 |
+
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
| 338 |
+
)
|
| 339 |
+
self.norm = eval(sub_norm)(out_channels)
|
| 340 |
+
if act is not None:
|
| 341 |
+
self.act = act()
|
| 342 |
+
else:
|
| 343 |
+
self.act = None
|
| 344 |
+
|
| 345 |
+
def forward(self, x):
|
| 346 |
+
if self.types == "Pool":
|
| 347 |
+
x1 = self.avgpool(x)
|
| 348 |
+
x2 = self.maxpool(x)
|
| 349 |
+
x = (x1 + x2) * 0.5
|
| 350 |
+
out = self.proj(x.flatten(2).permute((0, 2, 1)))
|
| 351 |
+
else:
|
| 352 |
+
x = self.conv(x)
|
| 353 |
+
out = x.flatten(2).permute((0, 2, 1))
|
| 354 |
+
out = self.norm(out)
|
| 355 |
+
if self.act is not None:
|
| 356 |
+
out = self.act(out)
|
| 357 |
+
|
| 358 |
+
return out
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class SVTRNet(nn.Module):
|
| 362 |
+
def __init__(
|
| 363 |
+
self,
|
| 364 |
+
img_size=[48, 100],
|
| 365 |
+
in_channels=3,
|
| 366 |
+
embed_dim=[64, 128, 256],
|
| 367 |
+
depth=[3, 6, 3],
|
| 368 |
+
num_heads=[2, 4, 8],
|
| 369 |
+
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
|
| 370 |
+
local_mixer=[[7, 11], [7, 11], [7, 11]],
|
| 371 |
+
patch_merging="Conv", # Conv, Pool, None
|
| 372 |
+
mlp_ratio=4,
|
| 373 |
+
qkv_bias=True,
|
| 374 |
+
qk_scale=None,
|
| 375 |
+
drop_rate=0.0,
|
| 376 |
+
last_drop=0.1,
|
| 377 |
+
attn_drop_rate=0.0,
|
| 378 |
+
drop_path_rate=0.1,
|
| 379 |
+
norm_layer="nn.LayerNorm",
|
| 380 |
+
sub_norm="nn.LayerNorm",
|
| 381 |
+
epsilon=1e-6,
|
| 382 |
+
out_channels=192,
|
| 383 |
+
out_char_num=25,
|
| 384 |
+
block_unit="Block",
|
| 385 |
+
act="nn.GELU",
|
| 386 |
+
last_stage=True,
|
| 387 |
+
sub_num=2,
|
| 388 |
+
prenorm=True,
|
| 389 |
+
use_lenhead=False,
|
| 390 |
+
**kwargs,
|
| 391 |
+
):
|
| 392 |
+
super().__init__()
|
| 393 |
+
self.img_size = img_size
|
| 394 |
+
self.embed_dim = embed_dim
|
| 395 |
+
self.out_channels = out_channels
|
| 396 |
+
self.prenorm = prenorm
|
| 397 |
+
patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
|
| 398 |
+
self.patch_embed = PatchEmbed(
|
| 399 |
+
img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
|
| 400 |
+
)
|
| 401 |
+
num_patches = self.patch_embed.num_patches
|
| 402 |
+
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
|
| 403 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
|
| 404 |
+
# self.pos_embed = self.create_parameter(
|
| 405 |
+
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
|
| 406 |
+
|
| 407 |
+
# self.add_parameter("pos_embed", self.pos_embed)
|
| 408 |
+
|
| 409 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 410 |
+
Block_unit = eval(block_unit)
|
| 411 |
+
|
| 412 |
+
dpr = np.linspace(0, drop_path_rate, sum(depth))
|
| 413 |
+
self.blocks1 = nn.ModuleList(
|
| 414 |
+
[
|
| 415 |
+
Block_unit(
|
| 416 |
+
dim=embed_dim[0],
|
| 417 |
+
num_heads=num_heads[0],
|
| 418 |
+
mixer=mixer[0 : depth[0]][i],
|
| 419 |
+
HW=self.HW,
|
| 420 |
+
local_mixer=local_mixer[0],
|
| 421 |
+
mlp_ratio=mlp_ratio,
|
| 422 |
+
qkv_bias=qkv_bias,
|
| 423 |
+
qk_scale=qk_scale,
|
| 424 |
+
drop=drop_rate,
|
| 425 |
+
act_layer=eval(act),
|
| 426 |
+
attn_drop=attn_drop_rate,
|
| 427 |
+
drop_path=dpr[0 : depth[0]][i],
|
| 428 |
+
norm_layer=norm_layer,
|
| 429 |
+
epsilon=epsilon,
|
| 430 |
+
prenorm=prenorm,
|
| 431 |
+
)
|
| 432 |
+
for i in range(depth[0])
|
| 433 |
+
]
|
| 434 |
+
)
|
| 435 |
+
if patch_merging is not None:
|
| 436 |
+
self.sub_sample1 = SubSample(
|
| 437 |
+
embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
| 438 |
+
)
|
| 439 |
+
HW = [self.HW[0] // 2, self.HW[1]]
|
| 440 |
+
else:
|
| 441 |
+
HW = self.HW
|
| 442 |
+
self.patch_merging = patch_merging
|
| 443 |
+
self.blocks2 = nn.ModuleList(
|
| 444 |
+
[
|
| 445 |
+
Block_unit(
|
| 446 |
+
dim=embed_dim[1],
|
| 447 |
+
num_heads=num_heads[1],
|
| 448 |
+
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
|
| 449 |
+
HW=HW,
|
| 450 |
+
local_mixer=local_mixer[1],
|
| 451 |
+
mlp_ratio=mlp_ratio,
|
| 452 |
+
qkv_bias=qkv_bias,
|
| 453 |
+
qk_scale=qk_scale,
|
| 454 |
+
drop=drop_rate,
|
| 455 |
+
act_layer=eval(act),
|
| 456 |
+
attn_drop=attn_drop_rate,
|
| 457 |
+
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
|
| 458 |
+
norm_layer=norm_layer,
|
| 459 |
+
epsilon=epsilon,
|
| 460 |
+
prenorm=prenorm,
|
| 461 |
+
)
|
| 462 |
+
for i in range(depth[1])
|
| 463 |
+
]
|
| 464 |
+
)
|
| 465 |
+
if patch_merging is not None:
|
| 466 |
+
self.sub_sample2 = SubSample(
|
| 467 |
+
embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
|
| 468 |
+
)
|
| 469 |
+
HW = [self.HW[0] // 4, self.HW[1]]
|
| 470 |
+
else:
|
| 471 |
+
HW = self.HW
|
| 472 |
+
self.blocks3 = nn.ModuleList(
|
| 473 |
+
[
|
| 474 |
+
Block_unit(
|
| 475 |
+
dim=embed_dim[2],
|
| 476 |
+
num_heads=num_heads[2],
|
| 477 |
+
mixer=mixer[depth[0] + depth[1] :][i],
|
| 478 |
+
HW=HW,
|
| 479 |
+
local_mixer=local_mixer[2],
|
| 480 |
+
mlp_ratio=mlp_ratio,
|
| 481 |
+
qkv_bias=qkv_bias,
|
| 482 |
+
qk_scale=qk_scale,
|
| 483 |
+
drop=drop_rate,
|
| 484 |
+
act_layer=eval(act),
|
| 485 |
+
attn_drop=attn_drop_rate,
|
| 486 |
+
drop_path=dpr[depth[0] + depth[1] :][i],
|
| 487 |
+
norm_layer=norm_layer,
|
| 488 |
+
epsilon=epsilon,
|
| 489 |
+
prenorm=prenorm,
|
| 490 |
+
)
|
| 491 |
+
for i in range(depth[2])
|
| 492 |
+
]
|
| 493 |
+
)
|
| 494 |
+
self.last_stage = last_stage
|
| 495 |
+
if last_stage:
|
| 496 |
+
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
|
| 497 |
+
self.last_conv = nn.Conv2d(
|
| 498 |
+
in_channels=embed_dim[2],
|
| 499 |
+
out_channels=self.out_channels,
|
| 500 |
+
kernel_size=1,
|
| 501 |
+
stride=1,
|
| 502 |
+
padding=0,
|
| 503 |
+
bias=False,
|
| 504 |
+
)
|
| 505 |
+
self.hardswish = nn.Hardswish()
|
| 506 |
+
self.dropout = nn.Dropout(p=last_drop)
|
| 507 |
+
if not prenorm:
|
| 508 |
+
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
|
| 509 |
+
self.use_lenhead = use_lenhead
|
| 510 |
+
if use_lenhead:
|
| 511 |
+
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
|
| 512 |
+
self.hardswish_len = nn.Hardswish()
|
| 513 |
+
self.dropout_len = nn.Dropout(p=last_drop)
|
| 514 |
+
|
| 515 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 516 |
+
self.apply(self._init_weights)
|
| 517 |
+
|
| 518 |
+
def _init_weights(self, m):
|
| 519 |
+
if isinstance(m, nn.Linear):
|
| 520 |
+
trunc_normal_(m.weight, std=0.02)
|
| 521 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 522 |
+
zeros_(m.bias)
|
| 523 |
+
elif isinstance(m, nn.LayerNorm):
|
| 524 |
+
zeros_(m.bias)
|
| 525 |
+
ones_(m.weight)
|
| 526 |
+
|
| 527 |
+
def forward_features(self, x):
|
| 528 |
+
x = self.patch_embed(x)
|
| 529 |
+
x = x + self.pos_embed
|
| 530 |
+
x = self.pos_drop(x)
|
| 531 |
+
for blk in self.blocks1:
|
| 532 |
+
x = blk(x)
|
| 533 |
+
if self.patch_merging is not None:
|
| 534 |
+
x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
|
| 535 |
+
for blk in self.blocks2:
|
| 536 |
+
x = blk(x)
|
| 537 |
+
if self.patch_merging is not None:
|
| 538 |
+
x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
|
| 539 |
+
for blk in self.blocks3:
|
| 540 |
+
x = blk(x)
|
| 541 |
+
if not self.prenorm:
|
| 542 |
+
x = self.norm(x)
|
| 543 |
+
return x
|
| 544 |
+
|
| 545 |
+
def forward(self, x):
|
| 546 |
+
x = self.forward_features(x)
|
| 547 |
+
if self.use_lenhead:
|
| 548 |
+
len_x = self.len_conv(x.mean(1))
|
| 549 |
+
len_x = self.dropout_len(self.hardswish_len(len_x))
|
| 550 |
+
if self.last_stage:
|
| 551 |
+
if self.patch_merging is not None:
|
| 552 |
+
h = self.HW[0] // 4
|
| 553 |
+
else:
|
| 554 |
+
h = self.HW[0]
|
| 555 |
+
x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
|
| 556 |
+
x = self.last_conv(x)
|
| 557 |
+
x = self.hardswish(x)
|
| 558 |
+
x = self.dropout(x)
|
| 559 |
+
if self.use_lenhead:
|
| 560 |
+
return x, len_x
|
| 561 |
+
return x
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if __name__ == "__main__":
|
| 565 |
+
a = torch.rand(1, 3, 48, 100)
|
| 566 |
+
svtr = SVTRNet()
|
| 567 |
+
|
| 568 |
+
out = svtr(a)
|
| 569 |
+
print(svtr)
|
| 570 |
+
print(out.size())
|
diffusers/examples/research_projects/anytext/ocr_recog/common.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Hswish(nn.Module):
|
| 7 |
+
def __init__(self, inplace=True):
|
| 8 |
+
super(Hswish, self).__init__()
|
| 9 |
+
self.inplace = inplace
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# out = max(0, min(1, slop*x+offset))
|
| 16 |
+
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
|
| 17 |
+
class Hsigmoid(nn.Module):
|
| 18 |
+
def __init__(self, inplace=True):
|
| 19 |
+
super(Hsigmoid, self).__init__()
|
| 20 |
+
self.inplace = inplace
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
|
| 24 |
+
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
|
| 25 |
+
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GELU(nn.Module):
|
| 29 |
+
def __init__(self, inplace=True):
|
| 30 |
+
super(GELU, self).__init__()
|
| 31 |
+
self.inplace = inplace
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return torch.nn.functional.gelu(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Swish(nn.Module):
|
| 38 |
+
def __init__(self, inplace=True):
|
| 39 |
+
super(Swish, self).__init__()
|
| 40 |
+
self.inplace = inplace
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
if self.inplace:
|
| 44 |
+
x.mul_(torch.sigmoid(x))
|
| 45 |
+
return x
|
| 46 |
+
else:
|
| 47 |
+
return x * torch.sigmoid(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Activation(nn.Module):
|
| 51 |
+
def __init__(self, act_type, inplace=True):
|
| 52 |
+
super(Activation, self).__init__()
|
| 53 |
+
act_type = act_type.lower()
|
| 54 |
+
if act_type == "relu":
|
| 55 |
+
self.act = nn.ReLU(inplace=inplace)
|
| 56 |
+
elif act_type == "relu6":
|
| 57 |
+
self.act = nn.ReLU6(inplace=inplace)
|
| 58 |
+
elif act_type == "sigmoid":
|
| 59 |
+
raise NotImplementedError
|
| 60 |
+
elif act_type == "hard_sigmoid":
|
| 61 |
+
self.act = Hsigmoid(inplace)
|
| 62 |
+
elif act_type == "hard_swish":
|
| 63 |
+
self.act = Hswish(inplace=inplace)
|
| 64 |
+
elif act_type == "leakyrelu":
|
| 65 |
+
self.act = nn.LeakyReLU(inplace=inplace)
|
| 66 |
+
elif act_type == "gelu":
|
| 67 |
+
self.act = GELU(inplace=inplace)
|
| 68 |
+
elif act_type == "swish":
|
| 69 |
+
self.act = Swish(inplace=inplace)
|
| 70 |
+
else:
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
def forward(self, inputs):
|
| 74 |
+
return self.act(inputs)
|
diffusers/examples/research_projects/anytext/ocr_recog/en_dict.txt
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0
|
| 2 |
+
1
|
| 3 |
+
2
|
| 4 |
+
3
|
| 5 |
+
4
|
| 6 |
+
5
|
| 7 |
+
6
|
| 8 |
+
7
|
| 9 |
+
8
|
| 10 |
+
9
|
| 11 |
+
:
|
| 12 |
+
;
|
| 13 |
+
<
|
| 14 |
+
=
|
| 15 |
+
>
|
| 16 |
+
?
|
| 17 |
+
@
|
| 18 |
+
A
|
| 19 |
+
B
|
| 20 |
+
C
|
| 21 |
+
D
|
| 22 |
+
E
|
| 23 |
+
F
|
| 24 |
+
G
|
| 25 |
+
H
|
| 26 |
+
I
|
| 27 |
+
J
|
| 28 |
+
K
|
| 29 |
+
L
|
| 30 |
+
M
|
| 31 |
+
N
|
| 32 |
+
O
|
| 33 |
+
P
|
| 34 |
+
Q
|
| 35 |
+
R
|
| 36 |
+
S
|
| 37 |
+
T
|
| 38 |
+
U
|
| 39 |
+
V
|
| 40 |
+
W
|
| 41 |
+
X
|
| 42 |
+
Y
|
| 43 |
+
Z
|
| 44 |
+
[
|
| 45 |
+
\
|
| 46 |
+
]
|
| 47 |
+
^
|
| 48 |
+
_
|
| 49 |
+
`
|
| 50 |
+
a
|
| 51 |
+
b
|
| 52 |
+
c
|
| 53 |
+
d
|
| 54 |
+
e
|
| 55 |
+
f
|
| 56 |
+
g
|
| 57 |
+
h
|
| 58 |
+
i
|
| 59 |
+
j
|
| 60 |
+
k
|
| 61 |
+
l
|
| 62 |
+
m
|
| 63 |
+
n
|
| 64 |
+
o
|
| 65 |
+
p
|
| 66 |
+
q
|
| 67 |
+
r
|
| 68 |
+
s
|
| 69 |
+
t
|
| 70 |
+
u
|
| 71 |
+
v
|
| 72 |
+
w
|
| 73 |
+
x
|
| 74 |
+
y
|
| 75 |
+
z
|
| 76 |
+
{
|
| 77 |
+
|
|
| 78 |
+
}
|
| 79 |
+
~
|
| 80 |
+
!
|
| 81 |
+
"
|
| 82 |
+
#
|
| 83 |
+
$
|
| 84 |
+
%
|
| 85 |
+
&
|
| 86 |
+
'
|
| 87 |
+
(
|
| 88 |
+
)
|
| 89 |
+
*
|
| 90 |
+
+
|
| 91 |
+
,
|
| 92 |
+
-
|
| 93 |
+
.
|
| 94 |
+
/
|
| 95 |
+
|
diffusers/examples/research_projects/consistency_training/README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Consistency Training
|
| 2 |
+
|
| 3 |
+
`train_cm_ct_unconditional.py` trains a consistency model (CM) from scratch following the consistency training (CT) algorithm introduced in [Consistency Models](https://huggingface.co/papers/2303.01469) and refined in [Improved Techniques for Training Consistency Models](https://huggingface.co/papers/2310.14189). Both unconditional and class-conditional training are supported.
|
| 4 |
+
|
| 5 |
+
A usage example is as follows:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
accelerate launch examples/research_projects/consistency_training/train_cm_ct_unconditional.py \
|
| 9 |
+
--dataset_name="cifar10" \
|
| 10 |
+
--dataset_image_column_name="img" \
|
| 11 |
+
--output_dir="/path/to/output/dir" \
|
| 12 |
+
--mixed_precision=fp16 \
|
| 13 |
+
--resolution=32 \
|
| 14 |
+
--max_train_steps=1000 --max_train_samples=10000 \
|
| 15 |
+
--dataloader_num_workers=8 \
|
| 16 |
+
--noise_precond_type="cm" --input_precond_type="cm" \
|
| 17 |
+
--train_batch_size=4 \
|
| 18 |
+
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
| 19 |
+
--use_8bit_adam \
|
| 20 |
+
--use_ema \
|
| 21 |
+
--validation_steps=100 --eval_batch_size=4 \
|
| 22 |
+
--checkpointing_steps=100 --checkpoints_total_limit=10 \
|
| 23 |
+
--class_conditional --num_classes=10 \
|
| 24 |
+
```
|
diffusers/examples/research_projects/consistency_training/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=0.16.0
|
| 2 |
+
torchvision
|
| 3 |
+
transformers>=4.25.1
|
| 4 |
+
ftfy
|
| 5 |
+
tensorboard
|
| 6 |
+
Jinja2
|
diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
ADDED
|
@@ -0,0 +1,1438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
"""Script to train a consistency model from scratch via (improved) consistency training."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import gc
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
from datetime import timedelta
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import accelerate
|
| 27 |
+
import datasets
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 31 |
+
from accelerate.logging import get_logger
|
| 32 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from huggingface_hub import create_repo, upload_folder
|
| 35 |
+
from packaging import version
|
| 36 |
+
from torchvision import transforms
|
| 37 |
+
from tqdm.auto import tqdm
|
| 38 |
+
|
| 39 |
+
import diffusers
|
| 40 |
+
from diffusers import (
|
| 41 |
+
CMStochasticIterativeScheduler,
|
| 42 |
+
ConsistencyModelPipeline,
|
| 43 |
+
UNet2DModel,
|
| 44 |
+
)
|
| 45 |
+
from diffusers.optimization import get_scheduler
|
| 46 |
+
from diffusers.training_utils import EMAModel, resolve_interpolation_mode
|
| 47 |
+
from diffusers.utils import is_tensorboard_available, is_wandb_available
|
| 48 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 49 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if is_wandb_available():
|
| 53 |
+
import wandb
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 60 |
+
"""
|
| 61 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
| 62 |
+
|
| 63 |
+
:param arr: the 1-D numpy array.
|
| 64 |
+
:param timesteps: a tensor of indices into the array to extract.
|
| 65 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
| 66 |
+
dimension equal to the length of timesteps.
|
| 67 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
| 68 |
+
"""
|
| 69 |
+
if not isinstance(arr, torch.Tensor):
|
| 70 |
+
arr = torch.from_numpy(arr)
|
| 71 |
+
res = arr[timesteps].float().to(timesteps.device)
|
| 72 |
+
while len(res.shape) < len(broadcast_shape):
|
| 73 |
+
res = res[..., None]
|
| 74 |
+
return res.expand(broadcast_shape)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def append_dims(x, target_dims):
|
| 78 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 79 |
+
dims_to_append = target_dims - x.ndim
|
| 80 |
+
if dims_to_append < 0:
|
| 81 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
| 82 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def extract_into_tensor(a, t, x_shape):
|
| 86 |
+
b, *_ = t.shape
|
| 87 |
+
out = a.gather(-1, t)
|
| 88 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_discretization_steps(global_step: int, max_train_steps: int, s_0: int = 10, s_1: int = 1280, constant=False):
|
| 92 |
+
"""
|
| 93 |
+
Calculates the current discretization steps at global step k using the discretization curriculum N(k).
|
| 94 |
+
"""
|
| 95 |
+
if constant:
|
| 96 |
+
return s_0 + 1
|
| 97 |
+
|
| 98 |
+
k_prime = math.floor(max_train_steps / (math.log2(math.floor(s_1 / s_0)) + 1))
|
| 99 |
+
num_discretization_steps = min(s_0 * 2 ** math.floor(global_step / k_prime), s_1) + 1
|
| 100 |
+
|
| 101 |
+
return num_discretization_steps
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_skip_steps(global_step, initial_skip: int = 1):
|
| 105 |
+
# Currently only support constant skip curriculum.
|
| 106 |
+
return initial_skip
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_karras_sigmas(
|
| 110 |
+
num_discretization_steps: int,
|
| 111 |
+
sigma_min: float = 0.002,
|
| 112 |
+
sigma_max: float = 80.0,
|
| 113 |
+
rho: float = 7.0,
|
| 114 |
+
dtype=torch.float32,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Calculates the Karras sigmas timestep discretization of [sigma_min, sigma_max].
|
| 118 |
+
"""
|
| 119 |
+
ramp = np.linspace(0, 1, num_discretization_steps)
|
| 120 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 121 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 122 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 123 |
+
# Make sure sigmas are in increasing rather than decreasing order (see section 2 of the iCT paper)
|
| 124 |
+
sigmas = sigmas[::-1].copy()
|
| 125 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=dtype)
|
| 126 |
+
return sigmas
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_discretized_lognormal_weights(noise_levels: torch.Tensor, p_mean: float = -1.1, p_std: float = 2.0):
|
| 130 |
+
"""
|
| 131 |
+
Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal"
|
| 132 |
+
" distribution used in the iCT paper (given in Equation 10).
|
| 133 |
+
"""
|
| 134 |
+
upper_prob = torch.special.erf((torch.log(noise_levels[1:]) - p_mean) / (math.sqrt(2) * p_std))
|
| 135 |
+
lower_prob = torch.special.erf((torch.log(noise_levels[:-1]) - p_mean) / (math.sqrt(2) * p_std))
|
| 136 |
+
weights = upper_prob - lower_prob
|
| 137 |
+
return weights
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_loss_weighting_schedule(noise_levels: torch.Tensor):
|
| 141 |
+
"""
|
| 142 |
+
Calculates the loss weighting schedule lambda given a set of noise levels.
|
| 143 |
+
"""
|
| 144 |
+
return 1.0 / (noise_levels[1:] - noise_levels[:-1])
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
|
| 148 |
+
# Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples
|
| 149 |
+
sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 150 |
+
while len(sigmas.shape) < len(original_samples.shape):
|
| 151 |
+
sigmas = sigmas.unsqueeze(-1)
|
| 152 |
+
|
| 153 |
+
noisy_samples = original_samples + noise * sigmas
|
| 154 |
+
|
| 155 |
+
return noisy_samples
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_noise_preconditioning(sigmas, noise_precond_type: str = "cm"):
|
| 159 |
+
"""
|
| 160 |
+
Calculates the noise preconditioning function c_noise, which is used to transform the raw Karras sigmas into the
|
| 161 |
+
timestep input for the U-Net.
|
| 162 |
+
"""
|
| 163 |
+
if noise_precond_type == "none":
|
| 164 |
+
return sigmas
|
| 165 |
+
elif noise_precond_type == "edm":
|
| 166 |
+
return 0.25 * torch.log(sigmas)
|
| 167 |
+
elif noise_precond_type == "cm":
|
| 168 |
+
return 1000 * 0.25 * torch.log(sigmas + 1e-44)
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"Noise preconditioning type {noise_precond_type} is not current supported. Currently supported noise"
|
| 172 |
+
f" preconditioning types are `none` (which uses the sigmas as is), `edm`, and `cm`."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_input_preconditioning(sigmas, sigma_data=0.5, input_precond_type: str = "cm"):
|
| 177 |
+
"""
|
| 178 |
+
Calculates the input preconditioning factor c_in, which is used to scale the U-Net image input.
|
| 179 |
+
"""
|
| 180 |
+
if input_precond_type == "none":
|
| 181 |
+
return 1
|
| 182 |
+
elif input_precond_type == "cm":
|
| 183 |
+
return 1.0 / (sigmas**2 + sigma_data**2)
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"Input preconditioning type {input_precond_type} is not current supported. Currently supported input"
|
| 187 |
+
f" preconditioning types are `none` (which uses a scaling factor of 1.0) and `cm`."
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=1.0):
|
| 192 |
+
scaled_timestep = timestep_scaling * timestep
|
| 193 |
+
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
| 194 |
+
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
| 195 |
+
return c_skip, c_out
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name="teacher"):
|
| 199 |
+
logger.info("Running validation... ")
|
| 200 |
+
|
| 201 |
+
unet = accelerator.unwrap_model(unet)
|
| 202 |
+
pipeline = ConsistencyModelPipeline(
|
| 203 |
+
unet=unet,
|
| 204 |
+
scheduler=scheduler,
|
| 205 |
+
)
|
| 206 |
+
pipeline = pipeline.to(device=accelerator.device)
|
| 207 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 208 |
+
|
| 209 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 210 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
| 211 |
+
|
| 212 |
+
if args.seed is None:
|
| 213 |
+
generator = None
|
| 214 |
+
else:
|
| 215 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 216 |
+
|
| 217 |
+
class_labels = [None]
|
| 218 |
+
if args.class_conditional:
|
| 219 |
+
if args.num_classes is not None:
|
| 220 |
+
class_labels = list(range(args.num_classes))
|
| 221 |
+
else:
|
| 222 |
+
logger.warning(
|
| 223 |
+
"The model is class-conditional but the number of classes is not set. The generated images will be"
|
| 224 |
+
" unconditional rather than class-conditional."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
image_logs = []
|
| 228 |
+
|
| 229 |
+
for class_label in class_labels:
|
| 230 |
+
images = []
|
| 231 |
+
with torch.autocast("cuda"):
|
| 232 |
+
images = pipeline(
|
| 233 |
+
num_inference_steps=1,
|
| 234 |
+
batch_size=args.eval_batch_size,
|
| 235 |
+
class_labels=[class_label] * args.eval_batch_size,
|
| 236 |
+
generator=generator,
|
| 237 |
+
).images
|
| 238 |
+
log = {"images": images}
|
| 239 |
+
if args.class_conditional and class_label is not None:
|
| 240 |
+
log["class_label"] = str(class_label)
|
| 241 |
+
else:
|
| 242 |
+
log["class_label"] = "images"
|
| 243 |
+
image_logs.append(log)
|
| 244 |
+
|
| 245 |
+
for tracker in accelerator.trackers:
|
| 246 |
+
if tracker.name == "tensorboard":
|
| 247 |
+
for log in image_logs:
|
| 248 |
+
images = log["images"]
|
| 249 |
+
class_label = log["class_label"]
|
| 250 |
+
formatted_images = []
|
| 251 |
+
for image in images:
|
| 252 |
+
formatted_images.append(np.asarray(image))
|
| 253 |
+
|
| 254 |
+
formatted_images = np.stack(formatted_images)
|
| 255 |
+
|
| 256 |
+
tracker.writer.add_images(class_label, formatted_images, step, dataformats="NHWC")
|
| 257 |
+
elif tracker.name == "wandb":
|
| 258 |
+
formatted_images = []
|
| 259 |
+
|
| 260 |
+
for log in image_logs:
|
| 261 |
+
images = log["images"]
|
| 262 |
+
class_label = log["class_label"]
|
| 263 |
+
for image in images:
|
| 264 |
+
image = wandb.Image(image, caption=class_label)
|
| 265 |
+
formatted_images.append(image)
|
| 266 |
+
|
| 267 |
+
tracker.log({f"validation/{name}": formatted_images})
|
| 268 |
+
else:
|
| 269 |
+
logger.warning(f"image logging not implemented for {tracker.name}")
|
| 270 |
+
|
| 271 |
+
del pipeline
|
| 272 |
+
gc.collect()
|
| 273 |
+
torch.cuda.empty_cache()
|
| 274 |
+
|
| 275 |
+
return image_logs
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def parse_args():
|
| 279 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 280 |
+
# ------------Model Arguments-----------
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--model_config_name_or_path",
|
| 283 |
+
type=str,
|
| 284 |
+
default=None,
|
| 285 |
+
help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--pretrained_model_name_or_path",
|
| 289 |
+
type=str,
|
| 290 |
+
default=None,
|
| 291 |
+
help=(
|
| 292 |
+
"If initializing the weights from a pretrained model, the path to the pretrained model or model identifier"
|
| 293 |
+
" from huggingface.co/models."
|
| 294 |
+
),
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--revision",
|
| 298 |
+
type=str,
|
| 299 |
+
default=None,
|
| 300 |
+
required=False,
|
| 301 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--variant",
|
| 305 |
+
type=str,
|
| 306 |
+
default=None,
|
| 307 |
+
help=(
|
| 308 |
+
"Variant of the model files of the pretrained model identifier from huggingface.co/models, e.g. `fp16`,"
|
| 309 |
+
" `non_ema`, etc.",
|
| 310 |
+
),
|
| 311 |
+
)
|
| 312 |
+
# ------------Dataset Arguments-----------
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--train_data_dir",
|
| 315 |
+
type=str,
|
| 316 |
+
default=None,
|
| 317 |
+
help=(
|
| 318 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
| 319 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
| 320 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
| 321 |
+
),
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--dataset_name",
|
| 325 |
+
type=str,
|
| 326 |
+
default=None,
|
| 327 |
+
help=(
|
| 328 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 329 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 330 |
+
" or to a folder containing files that HF Datasets can understand."
|
| 331 |
+
),
|
| 332 |
+
)
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--dataset_config_name",
|
| 335 |
+
type=str,
|
| 336 |
+
default=None,
|
| 337 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--dataset_image_column_name",
|
| 341 |
+
type=str,
|
| 342 |
+
default="image",
|
| 343 |
+
help="The name of the image column in the dataset to use for training.",
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--dataset_class_label_column_name",
|
| 347 |
+
type=str,
|
| 348 |
+
default="label",
|
| 349 |
+
help="If doing class-conditional training, the name of the class label column in the dataset to use.",
|
| 350 |
+
)
|
| 351 |
+
# ------------Image Processing Arguments-----------
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--resolution",
|
| 354 |
+
type=int,
|
| 355 |
+
default=64,
|
| 356 |
+
help=(
|
| 357 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 358 |
+
" resolution"
|
| 359 |
+
),
|
| 360 |
+
)
|
| 361 |
+
parser.add_argument(
|
| 362 |
+
"--interpolation_type",
|
| 363 |
+
type=str,
|
| 364 |
+
default="bilinear",
|
| 365 |
+
help=(
|
| 366 |
+
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
| 367 |
+
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
| 368 |
+
),
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--center_crop",
|
| 372 |
+
default=False,
|
| 373 |
+
action="store_true",
|
| 374 |
+
help=(
|
| 375 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 376 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 377 |
+
),
|
| 378 |
+
)
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--random_flip",
|
| 381 |
+
default=False,
|
| 382 |
+
action="store_true",
|
| 383 |
+
help="whether to randomly flip images horizontally",
|
| 384 |
+
)
|
| 385 |
+
parser.add_argument(
|
| 386 |
+
"--class_conditional",
|
| 387 |
+
action="store_true",
|
| 388 |
+
help=(
|
| 389 |
+
"Whether to train a class-conditional model. If set, the class labels will be taken from the `label`"
|
| 390 |
+
" column of the provided dataset."
|
| 391 |
+
),
|
| 392 |
+
)
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--num_classes",
|
| 395 |
+
type=int,
|
| 396 |
+
default=None,
|
| 397 |
+
help="The number of classes in the training data, if training a class-conditional model.",
|
| 398 |
+
)
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
"--class_embed_type",
|
| 401 |
+
type=str,
|
| 402 |
+
default=None,
|
| 403 |
+
help=(
|
| 404 |
+
"The class embedding type to use. Choose from `None`, `identity`, and `timestep`. If `class_conditional`"
|
| 405 |
+
" and `num_classes` and set, but `class_embed_type` is `None`, a embedding matrix will be used."
|
| 406 |
+
),
|
| 407 |
+
)
|
| 408 |
+
# ------------Dataloader Arguments-----------
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--dataloader_num_workers",
|
| 411 |
+
type=int,
|
| 412 |
+
default=0,
|
| 413 |
+
help=(
|
| 414 |
+
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
|
| 415 |
+
" process."
|
| 416 |
+
),
|
| 417 |
+
)
|
| 418 |
+
# ------------Training Arguments-----------
|
| 419 |
+
# ----General Training Arguments----
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--output_dir",
|
| 422 |
+
type=str,
|
| 423 |
+
default="ddpm-model-64",
|
| 424 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 425 |
+
)
|
| 426 |
+
parser.add_argument("--overwrite_output_dir", action="store_true")
|
| 427 |
+
parser.add_argument(
|
| 428 |
+
"--cache_dir",
|
| 429 |
+
type=str,
|
| 430 |
+
default=None,
|
| 431 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 432 |
+
)
|
| 433 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 434 |
+
# ----Batch Size and Training Length----
|
| 435 |
+
parser.add_argument(
|
| 436 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 439 |
+
parser.add_argument(
|
| 440 |
+
"--max_train_steps",
|
| 441 |
+
type=int,
|
| 442 |
+
default=None,
|
| 443 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 444 |
+
)
|
| 445 |
+
parser.add_argument(
|
| 446 |
+
"--max_train_samples",
|
| 447 |
+
type=int,
|
| 448 |
+
default=None,
|
| 449 |
+
help=(
|
| 450 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 451 |
+
"value if set."
|
| 452 |
+
),
|
| 453 |
+
)
|
| 454 |
+
# ----Learning Rate----
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--learning_rate",
|
| 457 |
+
type=float,
|
| 458 |
+
default=1e-4,
|
| 459 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 460 |
+
)
|
| 461 |
+
parser.add_argument(
|
| 462 |
+
"--scale_lr",
|
| 463 |
+
action="store_true",
|
| 464 |
+
default=False,
|
| 465 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 466 |
+
)
|
| 467 |
+
parser.add_argument(
|
| 468 |
+
"--lr_scheduler",
|
| 469 |
+
type=str,
|
| 470 |
+
default="cosine",
|
| 471 |
+
help=(
|
| 472 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 473 |
+
' "constant", "constant_with_warmup"]'
|
| 474 |
+
),
|
| 475 |
+
)
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 478 |
+
)
|
| 479 |
+
# ----Optimizer (Adam) Arguments----
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--optimizer_type",
|
| 482 |
+
type=str,
|
| 483 |
+
default="adamw",
|
| 484 |
+
help=(
|
| 485 |
+
"The optimizer algorithm to use for training. Choose between `radam` and `adamw`. The iCT paper uses"
|
| 486 |
+
" RAdam."
|
| 487 |
+
),
|
| 488 |
+
)
|
| 489 |
+
parser.add_argument(
|
| 490 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 491 |
+
)
|
| 492 |
+
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
|
| 493 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
|
| 496 |
+
)
|
| 497 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
|
| 498 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 499 |
+
# ----Consistency Training (CT) Specific Arguments----
|
| 500 |
+
parser.add_argument(
|
| 501 |
+
"--prediction_type",
|
| 502 |
+
type=str,
|
| 503 |
+
default="sample",
|
| 504 |
+
choices=["sample"],
|
| 505 |
+
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
| 506 |
+
)
|
| 507 |
+
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
| 508 |
+
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
|
| 509 |
+
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
| 510 |
+
parser.add_argument(
|
| 511 |
+
"--sigma_min",
|
| 512 |
+
type=float,
|
| 513 |
+
default=0.002,
|
| 514 |
+
help=(
|
| 515 |
+
"The lower boundary for the timestep discretization, which should be set to a small positive value close"
|
| 516 |
+
" to zero to avoid numerical issues when solving the PF-ODE backwards in time."
|
| 517 |
+
),
|
| 518 |
+
)
|
| 519 |
+
parser.add_argument(
|
| 520 |
+
"--sigma_max",
|
| 521 |
+
type=float,
|
| 522 |
+
default=80.0,
|
| 523 |
+
help=(
|
| 524 |
+
"The upper boundary for the timestep discretization, which also determines the variance of the Gaussian"
|
| 525 |
+
" prior."
|
| 526 |
+
),
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--rho",
|
| 530 |
+
type=float,
|
| 531 |
+
default=7.0,
|
| 532 |
+
help="The rho parameter for the Karras sigmas timestep dicretization.",
|
| 533 |
+
)
|
| 534 |
+
parser.add_argument(
|
| 535 |
+
"--huber_c",
|
| 536 |
+
type=float,
|
| 537 |
+
default=None,
|
| 538 |
+
help=(
|
| 539 |
+
"The Pseudo-Huber loss parameter c. If not set, this will default to the value recommended in the Improved"
|
| 540 |
+
" Consistency Training (iCT) paper of 0.00054 * sqrt(d), where d is the data dimensionality."
|
| 541 |
+
),
|
| 542 |
+
)
|
| 543 |
+
parser.add_argument(
|
| 544 |
+
"--discretization_s_0",
|
| 545 |
+
type=int,
|
| 546 |
+
default=10,
|
| 547 |
+
help=(
|
| 548 |
+
"The s_0 parameter in the discretization curriculum N(k). This controls the number of training steps after"
|
| 549 |
+
" which the number of discretization steps N will be doubled."
|
| 550 |
+
),
|
| 551 |
+
)
|
| 552 |
+
parser.add_argument(
|
| 553 |
+
"--discretization_s_1",
|
| 554 |
+
type=int,
|
| 555 |
+
default=1280,
|
| 556 |
+
help=(
|
| 557 |
+
"The s_1 parameter in the discretization curriculum N(k). This controls the upper limit to the number of"
|
| 558 |
+
" discretization steps used. Increasing this value will reduce the bias at the cost of higher variance."
|
| 559 |
+
),
|
| 560 |
+
)
|
| 561 |
+
parser.add_argument(
|
| 562 |
+
"--constant_discretization_steps",
|
| 563 |
+
action="store_true",
|
| 564 |
+
help=(
|
| 565 |
+
"Whether to set the discretization curriculum N(k) to be the constant value `discretization_s_0 + 1`. This"
|
| 566 |
+
" is useful for testing when `max_number_steps` is small, when `k_prime` would otherwise be 0, causing"
|
| 567 |
+
" a divide-by-zero error."
|
| 568 |
+
),
|
| 569 |
+
)
|
| 570 |
+
parser.add_argument(
|
| 571 |
+
"--p_mean",
|
| 572 |
+
type=float,
|
| 573 |
+
default=-1.1,
|
| 574 |
+
help=(
|
| 575 |
+
"The mean parameter P_mean for the (discretized) lognormal noise schedule, which controls the probability"
|
| 576 |
+
" of sampling a (discrete) noise level sigma_i."
|
| 577 |
+
),
|
| 578 |
+
)
|
| 579 |
+
parser.add_argument(
|
| 580 |
+
"--p_std",
|
| 581 |
+
type=float,
|
| 582 |
+
default=2.0,
|
| 583 |
+
help=(
|
| 584 |
+
"The standard deviation parameter P_std for the (discretized) noise schedule, which controls the"
|
| 585 |
+
" probability of sampling a (discrete) noise level sigma_i."
|
| 586 |
+
),
|
| 587 |
+
)
|
| 588 |
+
parser.add_argument(
|
| 589 |
+
"--noise_precond_type",
|
| 590 |
+
type=str,
|
| 591 |
+
default="cm",
|
| 592 |
+
help=(
|
| 593 |
+
"The noise preconditioning function to use for transforming the raw Karras sigmas into the timestep"
|
| 594 |
+
" argument of the U-Net. Choose between `none` (the identity function), `edm`, and `cm`."
|
| 595 |
+
),
|
| 596 |
+
)
|
| 597 |
+
parser.add_argument(
|
| 598 |
+
"--input_precond_type",
|
| 599 |
+
type=str,
|
| 600 |
+
default="cm",
|
| 601 |
+
help=(
|
| 602 |
+
"The input preconditioning function to use for scaling the image input of the U-Net. Choose between `none`"
|
| 603 |
+
" (a scaling factor of 1) and `cm`."
|
| 604 |
+
),
|
| 605 |
+
)
|
| 606 |
+
parser.add_argument(
|
| 607 |
+
"--skip_steps",
|
| 608 |
+
type=int,
|
| 609 |
+
default=1,
|
| 610 |
+
help=(
|
| 611 |
+
"The gap in indices between the student and teacher noise levels. In the iCT paper this is always set to"
|
| 612 |
+
" 1, but theoretically this could be greater than 1 and/or altered according to a curriculum throughout"
|
| 613 |
+
" training, much like the number of discretization steps is."
|
| 614 |
+
),
|
| 615 |
+
)
|
| 616 |
+
parser.add_argument(
|
| 617 |
+
"--cast_teacher",
|
| 618 |
+
action="store_true",
|
| 619 |
+
help="Whether to cast the teacher U-Net model to `weight_dtype` or leave it in full precision.",
|
| 620 |
+
)
|
| 621 |
+
# ----Exponential Moving Average (EMA) Arguments----
|
| 622 |
+
parser.add_argument(
|
| 623 |
+
"--use_ema",
|
| 624 |
+
action="store_true",
|
| 625 |
+
help="Whether to use Exponential Moving Average for the final model weights.",
|
| 626 |
+
)
|
| 627 |
+
parser.add_argument(
|
| 628 |
+
"--ema_min_decay",
|
| 629 |
+
type=float,
|
| 630 |
+
default=None,
|
| 631 |
+
help=(
|
| 632 |
+
"The minimum decay magnitude for EMA. If not set, this will default to the value of `ema_max_decay`,"
|
| 633 |
+
" resulting in a constant EMA decay rate."
|
| 634 |
+
),
|
| 635 |
+
)
|
| 636 |
+
parser.add_argument(
|
| 637 |
+
"--ema_max_decay",
|
| 638 |
+
type=float,
|
| 639 |
+
default=0.99993,
|
| 640 |
+
help=(
|
| 641 |
+
"The maximum decay magnitude for EMA. Setting `ema_min_decay` equal to this value will result in a"
|
| 642 |
+
" constant decay rate."
|
| 643 |
+
),
|
| 644 |
+
)
|
| 645 |
+
parser.add_argument(
|
| 646 |
+
"--use_ema_warmup",
|
| 647 |
+
action="store_true",
|
| 648 |
+
help="Whether to use EMA warmup.",
|
| 649 |
+
)
|
| 650 |
+
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
|
| 651 |
+
parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
|
| 652 |
+
# ----Training Optimization Arguments----
|
| 653 |
+
parser.add_argument(
|
| 654 |
+
"--mixed_precision",
|
| 655 |
+
type=str,
|
| 656 |
+
default="no",
|
| 657 |
+
choices=["no", "fp16", "bf16"],
|
| 658 |
+
help=(
|
| 659 |
+
"Whether to use mixed precision. Choose"
|
| 660 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
| 661 |
+
"and an Nvidia Ampere GPU."
|
| 662 |
+
),
|
| 663 |
+
)
|
| 664 |
+
parser.add_argument(
|
| 665 |
+
"--allow_tf32",
|
| 666 |
+
action="store_true",
|
| 667 |
+
help=(
|
| 668 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 669 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 670 |
+
),
|
| 671 |
+
)
|
| 672 |
+
parser.add_argument(
|
| 673 |
+
"--gradient_checkpointing",
|
| 674 |
+
action="store_true",
|
| 675 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 676 |
+
)
|
| 677 |
+
parser.add_argument(
|
| 678 |
+
"--gradient_accumulation_steps",
|
| 679 |
+
type=int,
|
| 680 |
+
default=1,
|
| 681 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 682 |
+
)
|
| 683 |
+
parser.add_argument(
|
| 684 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 685 |
+
)
|
| 686 |
+
# ----Distributed Training Arguments----
|
| 687 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 688 |
+
# ------------Validation Arguments-----------
|
| 689 |
+
parser.add_argument(
|
| 690 |
+
"--validation_steps",
|
| 691 |
+
type=int,
|
| 692 |
+
default=200,
|
| 693 |
+
help="Run validation every X steps.",
|
| 694 |
+
)
|
| 695 |
+
parser.add_argument(
|
| 696 |
+
"--eval_batch_size",
|
| 697 |
+
type=int,
|
| 698 |
+
default=16,
|
| 699 |
+
help=(
|
| 700 |
+
"The number of images to generate for evaluation. Note that if `class_conditional` and `num_classes` is"
|
| 701 |
+
" set the effective number of images generated per evaluation step is `eval_batch_size * num_classes`."
|
| 702 |
+
),
|
| 703 |
+
)
|
| 704 |
+
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
|
| 705 |
+
# ------------Validation Arguments-----------
|
| 706 |
+
parser.add_argument(
|
| 707 |
+
"--checkpointing_steps",
|
| 708 |
+
type=int,
|
| 709 |
+
default=500,
|
| 710 |
+
help=(
|
| 711 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
| 712 |
+
" training using `--resume_from_checkpoint`."
|
| 713 |
+
),
|
| 714 |
+
)
|
| 715 |
+
parser.add_argument(
|
| 716 |
+
"--checkpoints_total_limit",
|
| 717 |
+
type=int,
|
| 718 |
+
default=None,
|
| 719 |
+
help=("Max number of checkpoints to store."),
|
| 720 |
+
)
|
| 721 |
+
parser.add_argument(
|
| 722 |
+
"--resume_from_checkpoint",
|
| 723 |
+
type=str,
|
| 724 |
+
default=None,
|
| 725 |
+
help=(
|
| 726 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 727 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 728 |
+
),
|
| 729 |
+
)
|
| 730 |
+
parser.add_argument(
|
| 731 |
+
"--save_model_epochs", type=int, default=10, help="How often to save the model during training."
|
| 732 |
+
)
|
| 733 |
+
# ------------Logging Arguments-----------
|
| 734 |
+
parser.add_argument(
|
| 735 |
+
"--report_to",
|
| 736 |
+
type=str,
|
| 737 |
+
default="tensorboard",
|
| 738 |
+
help=(
|
| 739 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 740 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 741 |
+
),
|
| 742 |
+
)
|
| 743 |
+
parser.add_argument(
|
| 744 |
+
"--logging_dir",
|
| 745 |
+
type=str,
|
| 746 |
+
default="logs",
|
| 747 |
+
help=(
|
| 748 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 749 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 750 |
+
),
|
| 751 |
+
)
|
| 752 |
+
# ------------HuggingFace Hub Arguments-----------
|
| 753 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 754 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 755 |
+
parser.add_argument(
|
| 756 |
+
"--hub_model_id",
|
| 757 |
+
type=str,
|
| 758 |
+
default=None,
|
| 759 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 760 |
+
)
|
| 761 |
+
parser.add_argument(
|
| 762 |
+
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
|
| 763 |
+
)
|
| 764 |
+
# ------------Accelerate Arguments-----------
|
| 765 |
+
parser.add_argument(
|
| 766 |
+
"--tracker_project_name",
|
| 767 |
+
type=str,
|
| 768 |
+
default="consistency-training",
|
| 769 |
+
help=(
|
| 770 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
| 771 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
| 772 |
+
),
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
args = parser.parse_args()
|
| 776 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 777 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 778 |
+
args.local_rank = env_local_rank
|
| 779 |
+
|
| 780 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
| 781 |
+
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
|
| 782 |
+
|
| 783 |
+
return args
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def main(args):
|
| 787 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
| 788 |
+
|
| 789 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 790 |
+
raise ValueError(
|
| 791 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 792 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 796 |
+
|
| 797 |
+
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) # a big number for high resolution or big dataset
|
| 798 |
+
accelerator = Accelerator(
|
| 799 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 800 |
+
mixed_precision=args.mixed_precision,
|
| 801 |
+
log_with=args.report_to,
|
| 802 |
+
project_config=accelerator_project_config,
|
| 803 |
+
kwargs_handlers=[kwargs],
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
if args.report_to == "tensorboard":
|
| 807 |
+
if not is_tensorboard_available():
|
| 808 |
+
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
|
| 809 |
+
|
| 810 |
+
elif args.report_to == "wandb":
|
| 811 |
+
if not is_wandb_available():
|
| 812 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 813 |
+
|
| 814 |
+
# Make one log on every process with the configuration for debugging.
|
| 815 |
+
logging.basicConfig(
|
| 816 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 817 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 818 |
+
level=logging.INFO,
|
| 819 |
+
)
|
| 820 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 821 |
+
if accelerator.is_local_main_process:
|
| 822 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 823 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 824 |
+
else:
|
| 825 |
+
datasets.utils.logging.set_verbosity_error()
|
| 826 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 827 |
+
|
| 828 |
+
# If passed along, set the training seed now.
|
| 829 |
+
if args.seed is not None:
|
| 830 |
+
set_seed(args.seed)
|
| 831 |
+
|
| 832 |
+
# Handle the repository creation
|
| 833 |
+
if accelerator.is_main_process:
|
| 834 |
+
if args.output_dir is not None:
|
| 835 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 836 |
+
|
| 837 |
+
if args.push_to_hub:
|
| 838 |
+
repo_id = create_repo(
|
| 839 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 840 |
+
).repo_id
|
| 841 |
+
|
| 842 |
+
# 1. Initialize the noise scheduler.
|
| 843 |
+
initial_discretization_steps = get_discretization_steps(
|
| 844 |
+
0,
|
| 845 |
+
args.max_train_steps,
|
| 846 |
+
s_0=args.discretization_s_0,
|
| 847 |
+
s_1=args.discretization_s_1,
|
| 848 |
+
constant=args.constant_discretization_steps,
|
| 849 |
+
)
|
| 850 |
+
noise_scheduler = CMStochasticIterativeScheduler(
|
| 851 |
+
num_train_timesteps=initial_discretization_steps,
|
| 852 |
+
sigma_min=args.sigma_min,
|
| 853 |
+
sigma_max=args.sigma_max,
|
| 854 |
+
rho=args.rho,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# 2. Initialize the student U-Net model.
|
| 858 |
+
if args.pretrained_model_name_or_path is not None:
|
| 859 |
+
logger.info(f"Loading pretrained U-Net weights from {args.pretrained_model_name_or_path}... ")
|
| 860 |
+
unet = UNet2DModel.from_pretrained(
|
| 861 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 862 |
+
)
|
| 863 |
+
elif args.model_config_name_or_path is None:
|
| 864 |
+
# TODO: use default architectures from iCT paper
|
| 865 |
+
if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
|
| 866 |
+
logger.warning(
|
| 867 |
+
f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
|
| 868 |
+
f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
|
| 869 |
+
)
|
| 870 |
+
args.num_classes = None
|
| 871 |
+
args.class_embed_type = None
|
| 872 |
+
elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
|
| 873 |
+
logger.warning(
|
| 874 |
+
"`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
|
| 875 |
+
"`class_conditional` will be overridden to `False`."
|
| 876 |
+
)
|
| 877 |
+
args.class_conditional = False
|
| 878 |
+
unet = UNet2DModel(
|
| 879 |
+
sample_size=args.resolution,
|
| 880 |
+
in_channels=3,
|
| 881 |
+
out_channels=3,
|
| 882 |
+
layers_per_block=2,
|
| 883 |
+
block_out_channels=(128, 128, 256, 256, 512, 512),
|
| 884 |
+
down_block_types=(
|
| 885 |
+
"DownBlock2D",
|
| 886 |
+
"DownBlock2D",
|
| 887 |
+
"DownBlock2D",
|
| 888 |
+
"DownBlock2D",
|
| 889 |
+
"AttnDownBlock2D",
|
| 890 |
+
"DownBlock2D",
|
| 891 |
+
),
|
| 892 |
+
up_block_types=(
|
| 893 |
+
"UpBlock2D",
|
| 894 |
+
"AttnUpBlock2D",
|
| 895 |
+
"UpBlock2D",
|
| 896 |
+
"UpBlock2D",
|
| 897 |
+
"UpBlock2D",
|
| 898 |
+
"UpBlock2D",
|
| 899 |
+
),
|
| 900 |
+
class_embed_type=args.class_embed_type,
|
| 901 |
+
num_class_embeds=args.num_classes,
|
| 902 |
+
)
|
| 903 |
+
else:
|
| 904 |
+
config = UNet2DModel.load_config(args.model_config_name_or_path)
|
| 905 |
+
unet = UNet2DModel.from_config(config)
|
| 906 |
+
unet.train()
|
| 907 |
+
|
| 908 |
+
# Create EMA for the student U-Net model.
|
| 909 |
+
if args.use_ema:
|
| 910 |
+
if args.ema_min_decay is None:
|
| 911 |
+
args.ema_min_decay = args.ema_max_decay
|
| 912 |
+
ema_unet = EMAModel(
|
| 913 |
+
unet.parameters(),
|
| 914 |
+
decay=args.ema_max_decay,
|
| 915 |
+
min_decay=args.ema_min_decay,
|
| 916 |
+
use_ema_warmup=args.use_ema_warmup,
|
| 917 |
+
inv_gamma=args.ema_inv_gamma,
|
| 918 |
+
power=args.ema_power,
|
| 919 |
+
model_cls=UNet2DModel,
|
| 920 |
+
model_config=unet.config,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# 3. Initialize the teacher U-Net model from the student U-Net model.
|
| 924 |
+
# Note that following the improved Consistency Training paper, the teacher U-Net is not updated via EMA (e.g. the
|
| 925 |
+
# EMA decay rate is 0.)
|
| 926 |
+
teacher_unet = UNet2DModel.from_config(unet.config)
|
| 927 |
+
teacher_unet.load_state_dict(unet.state_dict())
|
| 928 |
+
teacher_unet.train()
|
| 929 |
+
teacher_unet.requires_grad_(False)
|
| 930 |
+
|
| 931 |
+
# 4. Handle mixed precision and device placement
|
| 932 |
+
weight_dtype = torch.float32
|
| 933 |
+
if accelerator.mixed_precision == "fp16":
|
| 934 |
+
weight_dtype = torch.float16
|
| 935 |
+
args.mixed_precision = accelerator.mixed_precision
|
| 936 |
+
elif accelerator.mixed_precision == "bf16":
|
| 937 |
+
weight_dtype = torch.bfloat16
|
| 938 |
+
args.mixed_precision = accelerator.mixed_precision
|
| 939 |
+
|
| 940 |
+
# Cast teacher_unet to weight_dtype if cast_teacher is set.
|
| 941 |
+
if args.cast_teacher:
|
| 942 |
+
teacher_dtype = weight_dtype
|
| 943 |
+
else:
|
| 944 |
+
teacher_dtype = torch.float32
|
| 945 |
+
|
| 946 |
+
teacher_unet.to(accelerator.device)
|
| 947 |
+
if args.use_ema:
|
| 948 |
+
ema_unet.to(accelerator.device)
|
| 949 |
+
|
| 950 |
+
# 5. Handle saving and loading of checkpoints.
|
| 951 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
| 952 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
| 953 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 954 |
+
def save_model_hook(models, weights, output_dir):
|
| 955 |
+
if accelerator.is_main_process:
|
| 956 |
+
teacher_unet.save_pretrained(os.path.join(output_dir, "unet_teacher"))
|
| 957 |
+
if args.use_ema:
|
| 958 |
+
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
| 959 |
+
|
| 960 |
+
for i, model in enumerate(models):
|
| 961 |
+
model.save_pretrained(os.path.join(output_dir, "unet"))
|
| 962 |
+
|
| 963 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 964 |
+
weights.pop()
|
| 965 |
+
|
| 966 |
+
def load_model_hook(models, input_dir):
|
| 967 |
+
load_model = UNet2DModel.from_pretrained(os.path.join(input_dir, "unet_teacher"))
|
| 968 |
+
teacher_unet.load_state_dict(load_model.state_dict())
|
| 969 |
+
teacher_unet.to(accelerator.device)
|
| 970 |
+
del load_model
|
| 971 |
+
|
| 972 |
+
if args.use_ema:
|
| 973 |
+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
|
| 974 |
+
ema_unet.load_state_dict(load_model.state_dict())
|
| 975 |
+
ema_unet.to(accelerator.device)
|
| 976 |
+
del load_model
|
| 977 |
+
|
| 978 |
+
for i in range(len(models)):
|
| 979 |
+
# pop models so that they are not loaded again
|
| 980 |
+
model = models.pop()
|
| 981 |
+
|
| 982 |
+
# load diffusers style into model
|
| 983 |
+
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
|
| 984 |
+
model.register_to_config(**load_model.config)
|
| 985 |
+
|
| 986 |
+
model.load_state_dict(load_model.state_dict())
|
| 987 |
+
del load_model
|
| 988 |
+
|
| 989 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 990 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 991 |
+
|
| 992 |
+
# 6. Enable optimizations
|
| 993 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 994 |
+
if is_xformers_available():
|
| 995 |
+
import xformers
|
| 996 |
+
|
| 997 |
+
xformers_version = version.parse(xformers.__version__)
|
| 998 |
+
if xformers_version == version.parse("0.0.16"):
|
| 999 |
+
logger.warning(
|
| 1000 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 1001 |
+
)
|
| 1002 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 1003 |
+
teacher_unet.enable_xformers_memory_efficient_attention()
|
| 1004 |
+
if args.use_ema:
|
| 1005 |
+
ema_unet.enable_xformers_memory_efficient_attention()
|
| 1006 |
+
else:
|
| 1007 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 1008 |
+
|
| 1009 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 1010 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 1011 |
+
if args.allow_tf32:
|
| 1012 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 1013 |
+
|
| 1014 |
+
if args.gradient_checkpointing:
|
| 1015 |
+
unet.enable_gradient_checkpointing()
|
| 1016 |
+
|
| 1017 |
+
if args.optimizer_type == "radam":
|
| 1018 |
+
optimizer_class = torch.optim.RAdam
|
| 1019 |
+
elif args.optimizer_type == "adamw":
|
| 1020 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model for 16GB GPUs
|
| 1021 |
+
if args.use_8bit_adam:
|
| 1022 |
+
try:
|
| 1023 |
+
import bitsandbytes as bnb
|
| 1024 |
+
except ImportError:
|
| 1025 |
+
raise ImportError(
|
| 1026 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 1030 |
+
else:
|
| 1031 |
+
optimizer_class = torch.optim.AdamW
|
| 1032 |
+
else:
|
| 1033 |
+
raise ValueError(
|
| 1034 |
+
f"Optimizer type {args.optimizer_type} is not supported. Currently supported optimizer types are `radam`"
|
| 1035 |
+
f" and `adamw`."
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
# 7. Initialize the optimizer
|
| 1039 |
+
optimizer = optimizer_class(
|
| 1040 |
+
unet.parameters(),
|
| 1041 |
+
lr=args.learning_rate,
|
| 1042 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 1043 |
+
weight_decay=args.adam_weight_decay,
|
| 1044 |
+
eps=args.adam_epsilon,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
# 8. Dataset creation and data preprocessing
|
| 1048 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
| 1049 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
| 1050 |
+
|
| 1051 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 1052 |
+
# download the dataset.
|
| 1053 |
+
if args.dataset_name is not None:
|
| 1054 |
+
dataset = load_dataset(
|
| 1055 |
+
args.dataset_name,
|
| 1056 |
+
args.dataset_config_name,
|
| 1057 |
+
cache_dir=args.cache_dir,
|
| 1058 |
+
split="train",
|
| 1059 |
+
)
|
| 1060 |
+
else:
|
| 1061 |
+
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
|
| 1062 |
+
# See more about loading custom images at
|
| 1063 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
| 1064 |
+
|
| 1065 |
+
# Preprocessing the datasets and DataLoaders creation.
|
| 1066 |
+
interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
|
| 1067 |
+
augmentations = transforms.Compose(
|
| 1068 |
+
[
|
| 1069 |
+
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
| 1070 |
+
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
| 1071 |
+
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
| 1072 |
+
transforms.ToTensor(),
|
| 1073 |
+
transforms.Normalize([0.5], [0.5]),
|
| 1074 |
+
]
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
def transform_images(examples):
|
| 1078 |
+
images = [augmentations(image.convert("RGB")) for image in examples[args.dataset_image_column_name]]
|
| 1079 |
+
batch_dict = {"images": images}
|
| 1080 |
+
if args.class_conditional:
|
| 1081 |
+
batch_dict["class_labels"] = examples[args.dataset_class_label_column_name]
|
| 1082 |
+
return batch_dict
|
| 1083 |
+
|
| 1084 |
+
logger.info(f"Dataset size: {len(dataset)}")
|
| 1085 |
+
|
| 1086 |
+
dataset.set_transform(transform_images)
|
| 1087 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 1088 |
+
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
# 9. Initialize the learning rate scheduler
|
| 1092 |
+
# Scheduler and math around the number of training steps.
|
| 1093 |
+
overrode_max_train_steps = False
|
| 1094 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 1095 |
+
if args.max_train_steps is None:
|
| 1096 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 1097 |
+
overrode_max_train_steps = True
|
| 1098 |
+
|
| 1099 |
+
lr_scheduler = get_scheduler(
|
| 1100 |
+
args.lr_scheduler,
|
| 1101 |
+
optimizer=optimizer,
|
| 1102 |
+
num_warmup_steps=args.lr_warmup_steps,
|
| 1103 |
+
num_training_steps=args.max_train_steps,
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
# 10. Prepare for training
|
| 1107 |
+
# Prepare everything with our `accelerator`.
|
| 1108 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 1109 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
def recalculate_num_discretization_step_values(discretization_steps, skip_steps):
|
| 1113 |
+
"""
|
| 1114 |
+
Recalculates all quantities depending on the number of discretization steps N.
|
| 1115 |
+
"""
|
| 1116 |
+
noise_scheduler = CMStochasticIterativeScheduler(
|
| 1117 |
+
num_train_timesteps=discretization_steps,
|
| 1118 |
+
sigma_min=args.sigma_min,
|
| 1119 |
+
sigma_max=args.sigma_max,
|
| 1120 |
+
rho=args.rho,
|
| 1121 |
+
)
|
| 1122 |
+
current_timesteps = get_karras_sigmas(discretization_steps, args.sigma_min, args.sigma_max, args.rho)
|
| 1123 |
+
valid_teacher_timesteps_plus_one = current_timesteps[: len(current_timesteps) - skip_steps + 1]
|
| 1124 |
+
# timestep_weights are the unnormalized probabilities of sampling the timestep/noise level at each index
|
| 1125 |
+
timestep_weights = get_discretized_lognormal_weights(
|
| 1126 |
+
valid_teacher_timesteps_plus_one, p_mean=args.p_mean, p_std=args.p_std
|
| 1127 |
+
)
|
| 1128 |
+
# timestep_loss_weights is the timestep-dependent loss weighting schedule lambda(sigma_i)
|
| 1129 |
+
timestep_loss_weights = get_loss_weighting_schedule(valid_teacher_timesteps_plus_one)
|
| 1130 |
+
|
| 1131 |
+
current_timesteps = current_timesteps.to(accelerator.device)
|
| 1132 |
+
timestep_weights = timestep_weights.to(accelerator.device)
|
| 1133 |
+
timestep_loss_weights = timestep_loss_weights.to(accelerator.device)
|
| 1134 |
+
|
| 1135 |
+
return noise_scheduler, current_timesteps, timestep_weights, timestep_loss_weights
|
| 1136 |
+
|
| 1137 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 1138 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 1139 |
+
if overrode_max_train_steps:
|
| 1140 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 1141 |
+
# Afterwards we recalculate our number of training epochs
|
| 1142 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 1143 |
+
|
| 1144 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 1145 |
+
# The trackers initializes automatically on the main process.
|
| 1146 |
+
if accelerator.is_main_process:
|
| 1147 |
+
tracker_config = dict(vars(args))
|
| 1148 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
| 1149 |
+
|
| 1150 |
+
# Function for unwrapping if torch.compile() was used in accelerate.
|
| 1151 |
+
def unwrap_model(model):
|
| 1152 |
+
model = accelerator.unwrap_model(model)
|
| 1153 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 1154 |
+
return model
|
| 1155 |
+
|
| 1156 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 1157 |
+
|
| 1158 |
+
logger.info("***** Running training *****")
|
| 1159 |
+
logger.info(f" Num examples = {len(dataset)}")
|
| 1160 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 1161 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 1162 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 1163 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 1164 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 1165 |
+
|
| 1166 |
+
global_step = 0
|
| 1167 |
+
first_epoch = 0
|
| 1168 |
+
|
| 1169 |
+
# Potentially load in the weights and states from a previous save
|
| 1170 |
+
if args.resume_from_checkpoint:
|
| 1171 |
+
if args.resume_from_checkpoint != "latest":
|
| 1172 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 1173 |
+
else:
|
| 1174 |
+
# Get the most recent checkpoint
|
| 1175 |
+
dirs = os.listdir(args.output_dir)
|
| 1176 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 1177 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 1178 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 1179 |
+
|
| 1180 |
+
if path is None:
|
| 1181 |
+
accelerator.print(
|
| 1182 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 1183 |
+
)
|
| 1184 |
+
args.resume_from_checkpoint = None
|
| 1185 |
+
initial_global_step = 0
|
| 1186 |
+
else:
|
| 1187 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 1188 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 1189 |
+
global_step = int(path.split("-")[1])
|
| 1190 |
+
|
| 1191 |
+
initial_global_step = global_step
|
| 1192 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 1193 |
+
else:
|
| 1194 |
+
initial_global_step = 0
|
| 1195 |
+
|
| 1196 |
+
# Resolve the c parameter for the Pseudo-Huber loss
|
| 1197 |
+
if args.huber_c is None:
|
| 1198 |
+
args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
|
| 1199 |
+
|
| 1200 |
+
# Get current number of discretization steps N according to our discretization curriculum
|
| 1201 |
+
current_discretization_steps = get_discretization_steps(
|
| 1202 |
+
initial_global_step,
|
| 1203 |
+
args.max_train_steps,
|
| 1204 |
+
s_0=args.discretization_s_0,
|
| 1205 |
+
s_1=args.discretization_s_1,
|
| 1206 |
+
constant=args.constant_discretization_steps,
|
| 1207 |
+
)
|
| 1208 |
+
current_skip_steps = get_skip_steps(initial_global_step, initial_skip=args.skip_steps)
|
| 1209 |
+
if current_skip_steps >= current_discretization_steps:
|
| 1210 |
+
raise ValueError(
|
| 1211 |
+
f"The current skip steps is {current_skip_steps}, but should be smaller than the current number of"
|
| 1212 |
+
f" discretization steps {current_discretization_steps}"
|
| 1213 |
+
)
|
| 1214 |
+
# Recalculate all quantities depending on the number of discretization steps N
|
| 1215 |
+
(
|
| 1216 |
+
noise_scheduler,
|
| 1217 |
+
current_timesteps,
|
| 1218 |
+
timestep_weights,
|
| 1219 |
+
timestep_loss_weights,
|
| 1220 |
+
) = recalculate_num_discretization_step_values(current_discretization_steps, current_skip_steps)
|
| 1221 |
+
|
| 1222 |
+
progress_bar = tqdm(
|
| 1223 |
+
range(0, args.max_train_steps),
|
| 1224 |
+
initial=initial_global_step,
|
| 1225 |
+
desc="Steps",
|
| 1226 |
+
# Only show the progress bar once on each machine.
|
| 1227 |
+
disable=not accelerator.is_local_main_process,
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
# 11. Train!
|
| 1231 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 1232 |
+
unet.train()
|
| 1233 |
+
for step, batch in enumerate(train_dataloader):
|
| 1234 |
+
# 1. Get batch of images from dataloader (sample x ~ p_data(x))
|
| 1235 |
+
clean_images = batch["images"].to(weight_dtype)
|
| 1236 |
+
if args.class_conditional:
|
| 1237 |
+
class_labels = batch["class_labels"]
|
| 1238 |
+
else:
|
| 1239 |
+
class_labels = None
|
| 1240 |
+
bsz = clean_images.shape[0]
|
| 1241 |
+
|
| 1242 |
+
# 2. Sample a random timestep for each image according to the noise schedule.
|
| 1243 |
+
# Sample random indices i ~ p(i), where p(i) is the dicretized lognormal distribution in the iCT paper
|
| 1244 |
+
# NOTE: timestep_indices should be in the range [0, len(current_timesteps) - k - 1] inclusive
|
| 1245 |
+
timestep_indices = torch.multinomial(timestep_weights, bsz, replacement=True).long()
|
| 1246 |
+
teacher_timesteps = current_timesteps[timestep_indices]
|
| 1247 |
+
student_timesteps = current_timesteps[timestep_indices + current_skip_steps]
|
| 1248 |
+
|
| 1249 |
+
# 3. Sample noise and add it to the clean images for both teacher and student unets
|
| 1250 |
+
# Sample noise z ~ N(0, I) that we'll add to the images
|
| 1251 |
+
noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
|
| 1252 |
+
# Add noise to the clean images according to the noise magnitude at each timestep
|
| 1253 |
+
# (this is the forward diffusion process)
|
| 1254 |
+
teacher_noisy_images = add_noise(clean_images, noise, teacher_timesteps)
|
| 1255 |
+
student_noisy_images = add_noise(clean_images, noise, student_timesteps)
|
| 1256 |
+
|
| 1257 |
+
# 4. Calculate preconditioning and scalings for boundary conditions for the consistency model.
|
| 1258 |
+
teacher_rescaled_timesteps = get_noise_preconditioning(teacher_timesteps, args.noise_precond_type)
|
| 1259 |
+
student_rescaled_timesteps = get_noise_preconditioning(student_timesteps, args.noise_precond_type)
|
| 1260 |
+
|
| 1261 |
+
c_in_teacher = get_input_preconditioning(teacher_timesteps, input_precond_type=args.input_precond_type)
|
| 1262 |
+
c_in_student = get_input_preconditioning(student_timesteps, input_precond_type=args.input_precond_type)
|
| 1263 |
+
|
| 1264 |
+
c_skip_teacher, c_out_teacher = scalings_for_boundary_conditions(teacher_timesteps)
|
| 1265 |
+
c_skip_student, c_out_student = scalings_for_boundary_conditions(student_timesteps)
|
| 1266 |
+
|
| 1267 |
+
c_skip_teacher, c_out_teacher, c_in_teacher = [
|
| 1268 |
+
append_dims(x, clean_images.ndim) for x in [c_skip_teacher, c_out_teacher, c_in_teacher]
|
| 1269 |
+
]
|
| 1270 |
+
c_skip_student, c_out_student, c_in_student = [
|
| 1271 |
+
append_dims(x, clean_images.ndim) for x in [c_skip_student, c_out_student, c_in_student]
|
| 1272 |
+
]
|
| 1273 |
+
|
| 1274 |
+
with accelerator.accumulate(unet):
|
| 1275 |
+
# 5. Get the student unet denoising prediction on the student timesteps
|
| 1276 |
+
# Get rng state now to ensure that dropout is synced between the student and teacher models.
|
| 1277 |
+
dropout_state = torch.get_rng_state()
|
| 1278 |
+
student_model_output = unet(
|
| 1279 |
+
c_in_student * student_noisy_images, student_rescaled_timesteps, class_labels=class_labels
|
| 1280 |
+
).sample
|
| 1281 |
+
# NOTE: currently only support prediction_type == sample, so no need to convert model_output
|
| 1282 |
+
student_denoise_output = c_skip_student * student_noisy_images + c_out_student * student_model_output
|
| 1283 |
+
|
| 1284 |
+
# 6. Get the teacher unet denoising prediction on the teacher timesteps
|
| 1285 |
+
with torch.no_grad(), torch.autocast("cuda", dtype=teacher_dtype):
|
| 1286 |
+
torch.set_rng_state(dropout_state)
|
| 1287 |
+
teacher_model_output = teacher_unet(
|
| 1288 |
+
c_in_teacher * teacher_noisy_images, teacher_rescaled_timesteps, class_labels=class_labels
|
| 1289 |
+
).sample
|
| 1290 |
+
# NOTE: currently only support prediction_type == sample, so no need to convert model_output
|
| 1291 |
+
teacher_denoise_output = (
|
| 1292 |
+
c_skip_teacher * teacher_noisy_images + c_out_teacher * teacher_model_output
|
| 1293 |
+
)
|
| 1294 |
+
|
| 1295 |
+
# 7. Calculate the weighted Pseudo-Huber loss
|
| 1296 |
+
if args.prediction_type == "sample":
|
| 1297 |
+
# Note that the loss weights should be those at the (teacher) timestep indices.
|
| 1298 |
+
lambda_t = _extract_into_tensor(
|
| 1299 |
+
timestep_loss_weights, timestep_indices, (bsz,) + (1,) * (clean_images.ndim - 1)
|
| 1300 |
+
)
|
| 1301 |
+
loss = lambda_t * (
|
| 1302 |
+
torch.sqrt(
|
| 1303 |
+
(student_denoise_output.float() - teacher_denoise_output.float()) ** 2 + args.huber_c**2
|
| 1304 |
+
)
|
| 1305 |
+
- args.huber_c
|
| 1306 |
+
)
|
| 1307 |
+
loss = loss.mean()
|
| 1308 |
+
else:
|
| 1309 |
+
raise ValueError(
|
| 1310 |
+
f"Unsupported prediction type: {args.prediction_type}. Currently, only `sample` is supported."
|
| 1311 |
+
)
|
| 1312 |
+
|
| 1313 |
+
# 8. Backpropagate on the consistency training loss
|
| 1314 |
+
accelerator.backward(loss)
|
| 1315 |
+
if accelerator.sync_gradients:
|
| 1316 |
+
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
| 1317 |
+
optimizer.step()
|
| 1318 |
+
lr_scheduler.step()
|
| 1319 |
+
optimizer.zero_grad()
|
| 1320 |
+
|
| 1321 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1322 |
+
if accelerator.sync_gradients:
|
| 1323 |
+
# 9. Update teacher_unet and ema_unet parameters using unet's parameters.
|
| 1324 |
+
teacher_unet.load_state_dict(unet.state_dict())
|
| 1325 |
+
if args.use_ema:
|
| 1326 |
+
ema_unet.step(unet.parameters())
|
| 1327 |
+
progress_bar.update(1)
|
| 1328 |
+
global_step += 1
|
| 1329 |
+
|
| 1330 |
+
if accelerator.is_main_process:
|
| 1331 |
+
# 10. Recalculate quantities depending on the global step, if necessary.
|
| 1332 |
+
new_discretization_steps = get_discretization_steps(
|
| 1333 |
+
global_step,
|
| 1334 |
+
args.max_train_steps,
|
| 1335 |
+
s_0=args.discretization_s_0,
|
| 1336 |
+
s_1=args.discretization_s_1,
|
| 1337 |
+
constant=args.constant_discretization_steps,
|
| 1338 |
+
)
|
| 1339 |
+
current_skip_steps = get_skip_steps(global_step, initial_skip=args.skip_steps)
|
| 1340 |
+
if current_skip_steps >= new_discretization_steps:
|
| 1341 |
+
raise ValueError(
|
| 1342 |
+
f"The current skip steps is {current_skip_steps}, but should be smaller than the current"
|
| 1343 |
+
f" number of discretization steps {new_discretization_steps}."
|
| 1344 |
+
)
|
| 1345 |
+
if new_discretization_steps != current_discretization_steps:
|
| 1346 |
+
(
|
| 1347 |
+
noise_scheduler,
|
| 1348 |
+
current_timesteps,
|
| 1349 |
+
timestep_weights,
|
| 1350 |
+
timestep_loss_weights,
|
| 1351 |
+
) = recalculate_num_discretization_step_values(new_discretization_steps, current_skip_steps)
|
| 1352 |
+
current_discretization_steps = new_discretization_steps
|
| 1353 |
+
|
| 1354 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1355 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1356 |
+
if args.checkpoints_total_limit is not None:
|
| 1357 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1358 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1359 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1360 |
+
|
| 1361 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 1362 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1363 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1364 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1365 |
+
|
| 1366 |
+
logger.info(
|
| 1367 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 1368 |
+
)
|
| 1369 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 1370 |
+
|
| 1371 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1372 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1373 |
+
shutil.rmtree(removing_checkpoint)
|
| 1374 |
+
|
| 1375 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1376 |
+
accelerator.save_state(save_path)
|
| 1377 |
+
logger.info(f"Saved state to {save_path}")
|
| 1378 |
+
|
| 1379 |
+
if global_step % args.validation_steps == 0:
|
| 1380 |
+
# NOTE: since we do not use EMA for the teacher model, the teacher parameters and student
|
| 1381 |
+
# parameters are the same at this point in time
|
| 1382 |
+
log_validation(unet, noise_scheduler, args, accelerator, weight_dtype, global_step, "teacher")
|
| 1383 |
+
# teacher_unet.to(dtype=teacher_dtype)
|
| 1384 |
+
|
| 1385 |
+
if args.use_ema:
|
| 1386 |
+
# Store the student unet weights and load the EMA weights.
|
| 1387 |
+
ema_unet.store(unet.parameters())
|
| 1388 |
+
ema_unet.copy_to(unet.parameters())
|
| 1389 |
+
|
| 1390 |
+
log_validation(
|
| 1391 |
+
unet,
|
| 1392 |
+
noise_scheduler,
|
| 1393 |
+
args,
|
| 1394 |
+
accelerator,
|
| 1395 |
+
weight_dtype,
|
| 1396 |
+
global_step,
|
| 1397 |
+
"ema_student",
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
# Restore student unet weights
|
| 1401 |
+
ema_unet.restore(unet.parameters())
|
| 1402 |
+
|
| 1403 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
| 1404 |
+
if args.use_ema:
|
| 1405 |
+
logs["ema_decay"] = ema_unet.cur_decay_value
|
| 1406 |
+
progress_bar.set_postfix(**logs)
|
| 1407 |
+
accelerator.log(logs, step=global_step)
|
| 1408 |
+
|
| 1409 |
+
if global_step >= args.max_train_steps:
|
| 1410 |
+
break
|
| 1411 |
+
# progress_bar.close()
|
| 1412 |
+
|
| 1413 |
+
accelerator.wait_for_everyone()
|
| 1414 |
+
if accelerator.is_main_process:
|
| 1415 |
+
unet = unwrap_model(unet)
|
| 1416 |
+
pipeline = ConsistencyModelPipeline(unet=unet, scheduler=noise_scheduler)
|
| 1417 |
+
pipeline.save_pretrained(args.output_dir)
|
| 1418 |
+
|
| 1419 |
+
# If using EMA, save EMA weights as well.
|
| 1420 |
+
if args.use_ema:
|
| 1421 |
+
ema_unet.copy_to(unet.parameters())
|
| 1422 |
+
|
| 1423 |
+
unet.save_pretrained(os.path.join(args.output_dir, "ema_unet"))
|
| 1424 |
+
|
| 1425 |
+
if args.push_to_hub:
|
| 1426 |
+
upload_folder(
|
| 1427 |
+
repo_id=repo_id,
|
| 1428 |
+
folder_path=args.output_dir,
|
| 1429 |
+
commit_message="End of training",
|
| 1430 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
accelerator.end_training()
|
| 1434 |
+
|
| 1435 |
+
|
| 1436 |
+
if __name__ == "__main__":
|
| 1437 |
+
args = parse_args()
|
| 1438 |
+
main(args)
|
diffusers/examples/research_projects/diffusion_dpo/README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusion Model Alignment Using Direct Preference Optimization
|
| 2 |
+
|
| 3 |
+
This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://huggingface.co/papers/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik.
|
| 4 |
+
|
| 5 |
+
We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below:
|
| 6 |
+
|
| 7 |
+
* [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)
|
| 8 |
+
* [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
|
| 9 |
+
|
| 10 |
+
> 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues.
|
| 11 |
+
|
| 12 |
+
## SD training command
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
accelerate launch train_diffusion_dpo.py \
|
| 16 |
+
--pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \
|
| 17 |
+
--output_dir="diffusion-dpo" \
|
| 18 |
+
--mixed_precision="fp16" \
|
| 19 |
+
--dataset_name=kashif/pickascore \
|
| 20 |
+
--resolution=512 \
|
| 21 |
+
--train_batch_size=16 \
|
| 22 |
+
--gradient_accumulation_steps=2 \
|
| 23 |
+
--gradient_checkpointing \
|
| 24 |
+
--use_8bit_adam \
|
| 25 |
+
--rank=8 \
|
| 26 |
+
--learning_rate=1e-5 \
|
| 27 |
+
--report_to="wandb" \
|
| 28 |
+
--lr_scheduler="constant" \
|
| 29 |
+
--lr_warmup_steps=0 \
|
| 30 |
+
--max_train_steps=10000 \
|
| 31 |
+
--checkpointing_steps=2000 \
|
| 32 |
+
--run_validation --validation_steps=200 \
|
| 33 |
+
--seed="0" \
|
| 34 |
+
--report_to="wandb" \
|
| 35 |
+
--push_to_hub
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## SDXL training command
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
accelerate launch train_diffusion_dpo_sdxl.py \
|
| 42 |
+
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
| 43 |
+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
| 44 |
+
--output_dir="diffusion-sdxl-dpo" \
|
| 45 |
+
--mixed_precision="fp16" \
|
| 46 |
+
--dataset_name=kashif/pickascore \
|
| 47 |
+
--train_batch_size=8 \
|
| 48 |
+
--gradient_accumulation_steps=2 \
|
| 49 |
+
--gradient_checkpointing \
|
| 50 |
+
--use_8bit_adam \
|
| 51 |
+
--rank=8 \
|
| 52 |
+
--learning_rate=1e-5 \
|
| 53 |
+
--report_to="wandb" \
|
| 54 |
+
--lr_scheduler="constant" \
|
| 55 |
+
--lr_warmup_steps=0 \
|
| 56 |
+
--max_train_steps=2000 \
|
| 57 |
+
--checkpointing_steps=500 \
|
| 58 |
+
--run_validation --validation_steps=50 \
|
| 59 |
+
--seed="0" \
|
| 60 |
+
--report_to="wandb" \
|
| 61 |
+
--push_to_hub
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## SDXL Turbo training command
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
accelerate launch train_diffusion_dpo_sdxl.py \
|
| 68 |
+
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
|
| 69 |
+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
| 70 |
+
--output_dir="diffusion-sdxl-turbo-dpo" \
|
| 71 |
+
--mixed_precision="fp16" \
|
| 72 |
+
--dataset_name=kashif/pickascore \
|
| 73 |
+
--train_batch_size=8 \
|
| 74 |
+
--gradient_accumulation_steps=2 \
|
| 75 |
+
--gradient_checkpointing \
|
| 76 |
+
--use_8bit_adam \
|
| 77 |
+
--rank=8 \
|
| 78 |
+
--learning_rate=1e-5 \
|
| 79 |
+
--report_to="wandb" \
|
| 80 |
+
--lr_scheduler="constant" \
|
| 81 |
+
--lr_warmup_steps=0 \
|
| 82 |
+
--max_train_steps=2000 \
|
| 83 |
+
--checkpointing_steps=500 \
|
| 84 |
+
--run_validation --validation_steps=50 \
|
| 85 |
+
--seed="0" \
|
| 86 |
+
--report_to="wandb" \
|
| 87 |
+
--is_turbo --resolution 512 \
|
| 88 |
+
--push_to_hub
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
## Acknowledgements
|
| 93 |
+
|
| 94 |
+
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
|
diffusers/examples/research_projects/diffusion_dpo/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=0.16.0
|
| 2 |
+
torchvision
|
| 3 |
+
transformers>=4.25.1
|
| 4 |
+
ftfy
|
| 5 |
+
tensorboard
|
| 6 |
+
Jinja2
|
| 7 |
+
peft
|
| 8 |
+
wandb
|
diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
ADDED
|
@@ -0,0 +1,982 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 bram-w, The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import contextlib
|
| 18 |
+
import io
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
import transformers
|
| 30 |
+
import wandb
|
| 31 |
+
from accelerate import Accelerator
|
| 32 |
+
from accelerate.logging import get_logger
|
| 33 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 34 |
+
from datasets import load_dataset
|
| 35 |
+
from huggingface_hub import create_repo, upload_folder
|
| 36 |
+
from packaging import version
|
| 37 |
+
from peft import LoraConfig
|
| 38 |
+
from peft.utils import get_peft_model_state_dict
|
| 39 |
+
from PIL import Image
|
| 40 |
+
from torchvision import transforms
|
| 41 |
+
from tqdm.auto import tqdm
|
| 42 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 43 |
+
|
| 44 |
+
import diffusers
|
| 45 |
+
from diffusers import (
|
| 46 |
+
AutoencoderKL,
|
| 47 |
+
DDPMScheduler,
|
| 48 |
+
DiffusionPipeline,
|
| 49 |
+
DPMSolverMultistepScheduler,
|
| 50 |
+
UNet2DConditionModel,
|
| 51 |
+
)
|
| 52 |
+
from diffusers.loaders import StableDiffusionLoraLoaderMixin
|
| 53 |
+
from diffusers.optimization import get_scheduler
|
| 54 |
+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
|
| 55 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 59 |
+
check_min_version("0.25.0.dev0")
|
| 60 |
+
|
| 61 |
+
logger = get_logger(__name__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
VALIDATION_PROMPTS = [
|
| 65 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
|
| 66 |
+
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 67 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 68 |
+
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
| 73 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 74 |
+
pretrained_model_name_or_path,
|
| 75 |
+
subfolder="text_encoder",
|
| 76 |
+
revision=revision,
|
| 77 |
+
)
|
| 78 |
+
model_class = text_encoder_config.architectures[0]
|
| 79 |
+
|
| 80 |
+
if model_class == "CLIPTextModel":
|
| 81 |
+
from transformers import CLIPTextModel
|
| 82 |
+
|
| 83 |
+
return CLIPTextModel
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
|
| 89 |
+
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
| 90 |
+
|
| 91 |
+
# create pipeline
|
| 92 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 93 |
+
args.pretrained_model_name_or_path,
|
| 94 |
+
revision=args.revision,
|
| 95 |
+
variant=args.variant,
|
| 96 |
+
torch_dtype=weight_dtype,
|
| 97 |
+
)
|
| 98 |
+
if not is_final_validation:
|
| 99 |
+
pipeline.unet = accelerator.unwrap_model(unet)
|
| 100 |
+
else:
|
| 101 |
+
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
| 102 |
+
|
| 103 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
| 104 |
+
pipeline = pipeline.to(accelerator.device)
|
| 105 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 106 |
+
|
| 107 |
+
# run inference
|
| 108 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 109 |
+
images = []
|
| 110 |
+
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
| 111 |
+
|
| 112 |
+
for prompt in VALIDATION_PROMPTS:
|
| 113 |
+
with context:
|
| 114 |
+
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
|
| 115 |
+
images.append(image)
|
| 116 |
+
|
| 117 |
+
tracker_key = "test" if is_final_validation else "validation"
|
| 118 |
+
for tracker in accelerator.trackers:
|
| 119 |
+
if tracker.name == "tensorboard":
|
| 120 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 121 |
+
tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
|
| 122 |
+
if tracker.name == "wandb":
|
| 123 |
+
tracker.log(
|
| 124 |
+
{
|
| 125 |
+
tracker_key: [
|
| 126 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
|
| 127 |
+
]
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Also log images without the LoRA params for comparison.
|
| 132 |
+
if is_final_validation:
|
| 133 |
+
pipeline.disable_lora()
|
| 134 |
+
no_lora_images = [
|
| 135 |
+
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
for tracker in accelerator.trackers:
|
| 139 |
+
if tracker.name == "tensorboard":
|
| 140 |
+
np_images = np.stack([np.asarray(img) for img in no_lora_images])
|
| 141 |
+
tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
|
| 142 |
+
if tracker.name == "wandb":
|
| 143 |
+
tracker.log(
|
| 144 |
+
{
|
| 145 |
+
"test_without_lora": [
|
| 146 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
|
| 147 |
+
for i, image in enumerate(no_lora_images)
|
| 148 |
+
]
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def parse_args(input_args=None):
|
| 154 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--pretrained_model_name_or_path",
|
| 157 |
+
type=str,
|
| 158 |
+
default=None,
|
| 159 |
+
required=True,
|
| 160 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--revision",
|
| 164 |
+
type=str,
|
| 165 |
+
default=None,
|
| 166 |
+
required=False,
|
| 167 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--dataset_name",
|
| 171 |
+
type=str,
|
| 172 |
+
default=None,
|
| 173 |
+
help=(
|
| 174 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 175 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 176 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--dataset_split_name",
|
| 181 |
+
type=str,
|
| 182 |
+
default="validation",
|
| 183 |
+
help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--variant",
|
| 187 |
+
type=str,
|
| 188 |
+
default=None,
|
| 189 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--run_validation",
|
| 193 |
+
default=False,
|
| 194 |
+
action="store_true",
|
| 195 |
+
help="Whether to run validation inference in between training and also after training. Helps to track progress.",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--validation_steps",
|
| 199 |
+
type=int,
|
| 200 |
+
default=200,
|
| 201 |
+
help="Run validation every X steps.",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--max_train_samples",
|
| 205 |
+
type=int,
|
| 206 |
+
default=None,
|
| 207 |
+
help=(
|
| 208 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 209 |
+
"value if set."
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--output_dir",
|
| 214 |
+
type=str,
|
| 215 |
+
default="diffusion-dpo-lora",
|
| 216 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--cache_dir",
|
| 220 |
+
type=str,
|
| 221 |
+
default=None,
|
| 222 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--resolution",
|
| 227 |
+
type=int,
|
| 228 |
+
default=512,
|
| 229 |
+
help=(
|
| 230 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 231 |
+
" resolution"
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--vae_encode_batch_size",
|
| 236 |
+
type=int,
|
| 237 |
+
default=8,
|
| 238 |
+
help="Batch size to use for VAE encoding of the images for efficient processing.",
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--no_hflip",
|
| 242 |
+
action="store_true",
|
| 243 |
+
help="whether to randomly flip images horizontally",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--random_crop",
|
| 247 |
+
default=False,
|
| 248 |
+
action="store_true",
|
| 249 |
+
help=(
|
| 250 |
+
"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
|
| 251 |
+
),
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--max_train_steps",
|
| 259 |
+
type=int,
|
| 260 |
+
default=None,
|
| 261 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--checkpointing_steps",
|
| 265 |
+
type=int,
|
| 266 |
+
default=500,
|
| 267 |
+
help=(
|
| 268 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 269 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 270 |
+
" training using `--resume_from_checkpoint`."
|
| 271 |
+
),
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--checkpoints_total_limit",
|
| 275 |
+
type=int,
|
| 276 |
+
default=None,
|
| 277 |
+
help=("Max number of checkpoints to store."),
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--resume_from_checkpoint",
|
| 281 |
+
type=str,
|
| 282 |
+
default=None,
|
| 283 |
+
help=(
|
| 284 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 285 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 286 |
+
),
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--gradient_accumulation_steps",
|
| 290 |
+
type=int,
|
| 291 |
+
default=1,
|
| 292 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
"--gradient_checkpointing",
|
| 296 |
+
action="store_true",
|
| 297 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--beta_dpo",
|
| 301 |
+
type=int,
|
| 302 |
+
default=2500,
|
| 303 |
+
help="DPO KL Divergence penalty.",
|
| 304 |
+
)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--loss_type",
|
| 307 |
+
type=str,
|
| 308 |
+
default="sigmoid",
|
| 309 |
+
help="DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'",
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--learning_rate",
|
| 313 |
+
type=float,
|
| 314 |
+
default=5e-4,
|
| 315 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--scale_lr",
|
| 319 |
+
action="store_true",
|
| 320 |
+
default=False,
|
| 321 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--lr_scheduler",
|
| 325 |
+
type=str,
|
| 326 |
+
default="constant",
|
| 327 |
+
help=(
|
| 328 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 329 |
+
' "constant", "constant_with_warmup"]'
|
| 330 |
+
),
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--lr_num_cycles",
|
| 337 |
+
type=int,
|
| 338 |
+
default=1,
|
| 339 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--dataloader_num_workers",
|
| 344 |
+
type=int,
|
| 345 |
+
default=0,
|
| 346 |
+
help=(
|
| 347 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 348 |
+
),
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 354 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 355 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 356 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 357 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 358 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 359 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--hub_model_id",
|
| 362 |
+
type=str,
|
| 363 |
+
default=None,
|
| 364 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 365 |
+
)
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--logging_dir",
|
| 368 |
+
type=str,
|
| 369 |
+
default="logs",
|
| 370 |
+
help=(
|
| 371 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 372 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 373 |
+
),
|
| 374 |
+
)
|
| 375 |
+
parser.add_argument(
|
| 376 |
+
"--allow_tf32",
|
| 377 |
+
action="store_true",
|
| 378 |
+
help=(
|
| 379 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 380 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 381 |
+
),
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--report_to",
|
| 385 |
+
type=str,
|
| 386 |
+
default="tensorboard",
|
| 387 |
+
help=(
|
| 388 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 389 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 390 |
+
),
|
| 391 |
+
)
|
| 392 |
+
parser.add_argument(
|
| 393 |
+
"--mixed_precision",
|
| 394 |
+
type=str,
|
| 395 |
+
default=None,
|
| 396 |
+
choices=["no", "fp16", "bf16"],
|
| 397 |
+
help=(
|
| 398 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 399 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 400 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 401 |
+
),
|
| 402 |
+
)
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--prior_generation_precision",
|
| 405 |
+
type=str,
|
| 406 |
+
default=None,
|
| 407 |
+
choices=["no", "fp32", "fp16", "bf16"],
|
| 408 |
+
help=(
|
| 409 |
+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 410 |
+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
| 411 |
+
),
|
| 412 |
+
)
|
| 413 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 416 |
+
)
|
| 417 |
+
parser.add_argument(
|
| 418 |
+
"--rank",
|
| 419 |
+
type=int,
|
| 420 |
+
default=4,
|
| 421 |
+
help=("The dimension of the LoRA update matrices."),
|
| 422 |
+
)
|
| 423 |
+
parser.add_argument(
|
| 424 |
+
"--tracker_name",
|
| 425 |
+
type=str,
|
| 426 |
+
default="diffusion-dpo-lora",
|
| 427 |
+
help=("The name of the tracker to report results to."),
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
if input_args is not None:
|
| 431 |
+
args = parser.parse_args(input_args)
|
| 432 |
+
else:
|
| 433 |
+
args = parser.parse_args()
|
| 434 |
+
|
| 435 |
+
if args.dataset_name is None:
|
| 436 |
+
raise ValueError("Must provide a `dataset_name`.")
|
| 437 |
+
|
| 438 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 439 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 440 |
+
args.local_rank = env_local_rank
|
| 441 |
+
|
| 442 |
+
return args
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def tokenize_captions(tokenizer, examples):
|
| 446 |
+
max_length = tokenizer.model_max_length
|
| 447 |
+
captions = []
|
| 448 |
+
for caption in examples["caption"]:
|
| 449 |
+
captions.append(caption)
|
| 450 |
+
|
| 451 |
+
text_inputs = tokenizer(
|
| 452 |
+
captions, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
return text_inputs.input_ids
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
@torch.no_grad()
|
| 459 |
+
def encode_prompt(text_encoder, input_ids):
|
| 460 |
+
text_input_ids = input_ids.to(text_encoder.device)
|
| 461 |
+
attention_mask = None
|
| 462 |
+
|
| 463 |
+
prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)
|
| 464 |
+
prompt_embeds = prompt_embeds[0]
|
| 465 |
+
|
| 466 |
+
return prompt_embeds
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def main(args):
|
| 470 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 471 |
+
raise ValueError(
|
| 472 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 473 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 477 |
+
|
| 478 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 479 |
+
|
| 480 |
+
accelerator = Accelerator(
|
| 481 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 482 |
+
mixed_precision=args.mixed_precision,
|
| 483 |
+
log_with=args.report_to,
|
| 484 |
+
project_config=accelerator_project_config,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Disable AMP for MPS.
|
| 488 |
+
if torch.backends.mps.is_available():
|
| 489 |
+
accelerator.native_amp = False
|
| 490 |
+
|
| 491 |
+
# Make one log on every process with the configuration for debugging.
|
| 492 |
+
logging.basicConfig(
|
| 493 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 494 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 495 |
+
level=logging.INFO,
|
| 496 |
+
)
|
| 497 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 498 |
+
if accelerator.is_local_main_process:
|
| 499 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 500 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 501 |
+
else:
|
| 502 |
+
transformers.utils.logging.set_verbosity_error()
|
| 503 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 504 |
+
|
| 505 |
+
# If passed along, set the training seed now.
|
| 506 |
+
if args.seed is not None:
|
| 507 |
+
set_seed(args.seed)
|
| 508 |
+
|
| 509 |
+
# Handle the repository creation
|
| 510 |
+
if accelerator.is_main_process:
|
| 511 |
+
if args.output_dir is not None:
|
| 512 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 513 |
+
|
| 514 |
+
if args.push_to_hub:
|
| 515 |
+
repo_id = create_repo(
|
| 516 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 517 |
+
).repo_id
|
| 518 |
+
|
| 519 |
+
# Load the tokenizer
|
| 520 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 521 |
+
args.pretrained_model_name_or_path,
|
| 522 |
+
subfolder="tokenizer",
|
| 523 |
+
revision=args.revision,
|
| 524 |
+
use_fast=False,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
# import correct text encoder class
|
| 528 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
| 529 |
+
|
| 530 |
+
# Load scheduler and models
|
| 531 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 532 |
+
text_encoder = text_encoder_cls.from_pretrained(
|
| 533 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 534 |
+
)
|
| 535 |
+
vae = AutoencoderKL.from_pretrained(
|
| 536 |
+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 540 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
vae.requires_grad_(False)
|
| 544 |
+
text_encoder.requires_grad_(False)
|
| 545 |
+
unet.requires_grad_(False)
|
| 546 |
+
|
| 547 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 548 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 549 |
+
weight_dtype = torch.float32
|
| 550 |
+
if accelerator.mixed_precision == "fp16":
|
| 551 |
+
weight_dtype = torch.float16
|
| 552 |
+
elif accelerator.mixed_precision == "bf16":
|
| 553 |
+
weight_dtype = torch.bfloat16
|
| 554 |
+
|
| 555 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
| 556 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 557 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 558 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 559 |
+
|
| 560 |
+
# Set up LoRA.
|
| 561 |
+
unet_lora_config = LoraConfig(
|
| 562 |
+
r=args.rank,
|
| 563 |
+
lora_alpha=args.rank,
|
| 564 |
+
init_lora_weights="gaussian",
|
| 565 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
| 566 |
+
)
|
| 567 |
+
# Add adapter and make sure the trainable params are in float32.
|
| 568 |
+
unet.add_adapter(unet_lora_config)
|
| 569 |
+
if args.mixed_precision == "fp16":
|
| 570 |
+
for param in unet.parameters():
|
| 571 |
+
# only upcast trainable parameters (LoRA) into fp32
|
| 572 |
+
if param.requires_grad:
|
| 573 |
+
param.data = param.to(torch.float32)
|
| 574 |
+
|
| 575 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 576 |
+
if is_xformers_available():
|
| 577 |
+
import xformers
|
| 578 |
+
|
| 579 |
+
xformers_version = version.parse(xformers.__version__)
|
| 580 |
+
if xformers_version == version.parse("0.0.16"):
|
| 581 |
+
logger.warning(
|
| 582 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 583 |
+
)
|
| 584 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 585 |
+
else:
|
| 586 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 587 |
+
|
| 588 |
+
if args.gradient_checkpointing:
|
| 589 |
+
unet.enable_gradient_checkpointing()
|
| 590 |
+
|
| 591 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 592 |
+
def save_model_hook(models, weights, output_dir):
|
| 593 |
+
if accelerator.is_main_process:
|
| 594 |
+
# there are only two options here. Either are just the unet attn processor layers
|
| 595 |
+
# or there are the unet and text encoder atten layers
|
| 596 |
+
unet_lora_layers_to_save = None
|
| 597 |
+
|
| 598 |
+
for model in models:
|
| 599 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 600 |
+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
| 601 |
+
else:
|
| 602 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 603 |
+
|
| 604 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 605 |
+
weights.pop()
|
| 606 |
+
|
| 607 |
+
StableDiffusionLoraLoaderMixin.save_lora_weights(
|
| 608 |
+
output_dir,
|
| 609 |
+
unet_lora_layers=unet_lora_layers_to_save,
|
| 610 |
+
text_encoder_lora_layers=None,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
def load_model_hook(models, input_dir):
|
| 614 |
+
unet_ = None
|
| 615 |
+
|
| 616 |
+
while len(models) > 0:
|
| 617 |
+
model = models.pop()
|
| 618 |
+
|
| 619 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 620 |
+
unet_ = model
|
| 621 |
+
else:
|
| 622 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 623 |
+
|
| 624 |
+
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
| 625 |
+
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
| 626 |
+
|
| 627 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 628 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 629 |
+
|
| 630 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 631 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 632 |
+
if args.allow_tf32:
|
| 633 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 634 |
+
|
| 635 |
+
if args.scale_lr:
|
| 636 |
+
args.learning_rate = (
|
| 637 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 641 |
+
if args.use_8bit_adam:
|
| 642 |
+
try:
|
| 643 |
+
import bitsandbytes as bnb
|
| 644 |
+
except ImportError:
|
| 645 |
+
raise ImportError(
|
| 646 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 650 |
+
else:
|
| 651 |
+
optimizer_class = torch.optim.AdamW
|
| 652 |
+
|
| 653 |
+
# Optimizer creation
|
| 654 |
+
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 655 |
+
optimizer = optimizer_class(
|
| 656 |
+
params_to_optimize,
|
| 657 |
+
lr=args.learning_rate,
|
| 658 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 659 |
+
weight_decay=args.adam_weight_decay,
|
| 660 |
+
eps=args.adam_epsilon,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Dataset and DataLoaders creation:
|
| 664 |
+
train_dataset = load_dataset(
|
| 665 |
+
args.dataset_name,
|
| 666 |
+
cache_dir=args.cache_dir,
|
| 667 |
+
split=args.dataset_split_name,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
train_transforms = transforms.Compose(
|
| 671 |
+
[
|
| 672 |
+
transforms.Resize(int(args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 673 |
+
transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution),
|
| 674 |
+
transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(),
|
| 675 |
+
transforms.ToTensor(),
|
| 676 |
+
transforms.Normalize([0.5], [0.5]),
|
| 677 |
+
]
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
def preprocess_train(examples):
|
| 681 |
+
all_pixel_values = []
|
| 682 |
+
for col_name in ["jpg_0", "jpg_1"]:
|
| 683 |
+
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
|
| 684 |
+
pixel_values = [train_transforms(image) for image in images]
|
| 685 |
+
all_pixel_values.append(pixel_values)
|
| 686 |
+
|
| 687 |
+
# Double on channel dim, jpg_y then jpg_w
|
| 688 |
+
im_tup_iterator = zip(*all_pixel_values)
|
| 689 |
+
combined_pixel_values = []
|
| 690 |
+
for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
|
| 691 |
+
if label_0 == 0:
|
| 692 |
+
im_tup = im_tup[::-1]
|
| 693 |
+
combined_im = torch.cat(im_tup, dim=0) # no batch dim
|
| 694 |
+
combined_pixel_values.append(combined_im)
|
| 695 |
+
examples["pixel_values"] = combined_pixel_values
|
| 696 |
+
|
| 697 |
+
examples["input_ids"] = tokenize_captions(tokenizer, examples)
|
| 698 |
+
return examples
|
| 699 |
+
|
| 700 |
+
with accelerator.main_process_first():
|
| 701 |
+
if args.max_train_samples is not None:
|
| 702 |
+
train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 703 |
+
# Set the training transforms
|
| 704 |
+
train_dataset = train_dataset.with_transform(preprocess_train)
|
| 705 |
+
|
| 706 |
+
def collate_fn(examples):
|
| 707 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 708 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 709 |
+
final_dict = {"pixel_values": pixel_values}
|
| 710 |
+
final_dict["input_ids"] = torch.stack([example["input_ids"] for example in examples])
|
| 711 |
+
return final_dict
|
| 712 |
+
|
| 713 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 714 |
+
train_dataset,
|
| 715 |
+
batch_size=args.train_batch_size,
|
| 716 |
+
shuffle=True,
|
| 717 |
+
collate_fn=collate_fn,
|
| 718 |
+
num_workers=args.dataloader_num_workers,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# Scheduler and math around the number of training steps.
|
| 722 |
+
overrode_max_train_steps = False
|
| 723 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 724 |
+
if args.max_train_steps is None:
|
| 725 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 726 |
+
overrode_max_train_steps = True
|
| 727 |
+
|
| 728 |
+
lr_scheduler = get_scheduler(
|
| 729 |
+
args.lr_scheduler,
|
| 730 |
+
optimizer=optimizer,
|
| 731 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 732 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 733 |
+
num_cycles=args.lr_num_cycles,
|
| 734 |
+
power=args.lr_power,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 738 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 742 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 743 |
+
if overrode_max_train_steps:
|
| 744 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 745 |
+
# Afterwards we recalculate our number of training epochs
|
| 746 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 747 |
+
|
| 748 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 749 |
+
# The trackers initializes automatically on the main process.
|
| 750 |
+
if accelerator.is_main_process:
|
| 751 |
+
accelerator.init_trackers(args.tracker_name, config=vars(args))
|
| 752 |
+
|
| 753 |
+
# Train!
|
| 754 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 755 |
+
|
| 756 |
+
logger.info("***** Running training *****")
|
| 757 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 758 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 759 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 760 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 761 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 762 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 763 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 764 |
+
global_step = 0
|
| 765 |
+
first_epoch = 0
|
| 766 |
+
|
| 767 |
+
# Potentially load in the weights and states from a previous save
|
| 768 |
+
if args.resume_from_checkpoint:
|
| 769 |
+
if args.resume_from_checkpoint != "latest":
|
| 770 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 771 |
+
else:
|
| 772 |
+
# Get the mos recent checkpoint
|
| 773 |
+
dirs = os.listdir(args.output_dir)
|
| 774 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 775 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 776 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 777 |
+
|
| 778 |
+
if path is None:
|
| 779 |
+
accelerator.print(
|
| 780 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 781 |
+
)
|
| 782 |
+
args.resume_from_checkpoint = None
|
| 783 |
+
initial_global_step = 0
|
| 784 |
+
else:
|
| 785 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 786 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 787 |
+
global_step = int(path.split("-")[1])
|
| 788 |
+
|
| 789 |
+
initial_global_step = global_step
|
| 790 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 791 |
+
else:
|
| 792 |
+
initial_global_step = 0
|
| 793 |
+
|
| 794 |
+
progress_bar = tqdm(
|
| 795 |
+
range(0, args.max_train_steps),
|
| 796 |
+
initial=initial_global_step,
|
| 797 |
+
desc="Steps",
|
| 798 |
+
# Only show the progress bar once on each machine.
|
| 799 |
+
disable=not accelerator.is_local_main_process,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
unet.train()
|
| 803 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 804 |
+
for step, batch in enumerate(train_dataloader):
|
| 805 |
+
with accelerator.accumulate(unet):
|
| 806 |
+
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
|
| 807 |
+
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
| 808 |
+
feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
|
| 809 |
+
|
| 810 |
+
latents = []
|
| 811 |
+
for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
|
| 812 |
+
latents.append(
|
| 813 |
+
vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
|
| 814 |
+
)
|
| 815 |
+
latents = torch.cat(latents, dim=0)
|
| 816 |
+
latents = latents * vae.config.scaling_factor
|
| 817 |
+
|
| 818 |
+
# Sample noise that we'll add to the latents
|
| 819 |
+
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
|
| 820 |
+
|
| 821 |
+
# Sample a random timestep for each image
|
| 822 |
+
bsz = latents.shape[0] // 2
|
| 823 |
+
timesteps = torch.randint(
|
| 824 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
| 825 |
+
).repeat(2)
|
| 826 |
+
|
| 827 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
| 828 |
+
# (this is the forward diffusion process)
|
| 829 |
+
noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 830 |
+
|
| 831 |
+
# Get the text embedding for conditioning
|
| 832 |
+
encoder_hidden_states = encode_prompt(text_encoder, batch["input_ids"]).repeat(2, 1, 1)
|
| 833 |
+
|
| 834 |
+
# Predict the noise residual
|
| 835 |
+
model_pred = unet(
|
| 836 |
+
noisy_model_input,
|
| 837 |
+
timesteps,
|
| 838 |
+
encoder_hidden_states,
|
| 839 |
+
).sample
|
| 840 |
+
|
| 841 |
+
# Get the target for loss depending on the prediction type
|
| 842 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 843 |
+
target = noise
|
| 844 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 845 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 846 |
+
else:
|
| 847 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 848 |
+
|
| 849 |
+
# Compute losses.
|
| 850 |
+
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 851 |
+
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
| 852 |
+
model_losses_w, model_losses_l = model_losses.chunk(2)
|
| 853 |
+
|
| 854 |
+
# For logging
|
| 855 |
+
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
|
| 856 |
+
model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
|
| 857 |
+
|
| 858 |
+
# Reference model predictions.
|
| 859 |
+
accelerator.unwrap_model(unet).disable_adapters()
|
| 860 |
+
with torch.no_grad():
|
| 861 |
+
ref_preds = unet(
|
| 862 |
+
noisy_model_input,
|
| 863 |
+
timesteps,
|
| 864 |
+
encoder_hidden_states,
|
| 865 |
+
).sample.detach()
|
| 866 |
+
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
|
| 867 |
+
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
| 868 |
+
|
| 869 |
+
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
| 870 |
+
ref_diff = ref_losses_w - ref_losses_l
|
| 871 |
+
raw_ref_loss = ref_loss.mean()
|
| 872 |
+
|
| 873 |
+
# Re-enable adapters.
|
| 874 |
+
accelerator.unwrap_model(unet).enable_adapters()
|
| 875 |
+
|
| 876 |
+
# Final loss.
|
| 877 |
+
logits = ref_diff - model_diff
|
| 878 |
+
if args.loss_type == "sigmoid":
|
| 879 |
+
loss = -1 * F.logsigmoid(args.beta_dpo * logits).mean()
|
| 880 |
+
elif args.loss_type == "hinge":
|
| 881 |
+
loss = torch.relu(1 - args.beta_dpo * logits).mean()
|
| 882 |
+
elif args.loss_type == "ipo":
|
| 883 |
+
losses = (logits - 1 / (2 * args.beta)) ** 2
|
| 884 |
+
loss = losses.mean()
|
| 885 |
+
else:
|
| 886 |
+
raise ValueError(f"Unknown loss type {args.loss_type}")
|
| 887 |
+
|
| 888 |
+
implicit_acc = (logits > 0).sum().float() / logits.size(0)
|
| 889 |
+
implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)
|
| 890 |
+
|
| 891 |
+
accelerator.backward(loss)
|
| 892 |
+
if accelerator.sync_gradients:
|
| 893 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
| 894 |
+
optimizer.step()
|
| 895 |
+
lr_scheduler.step()
|
| 896 |
+
optimizer.zero_grad()
|
| 897 |
+
|
| 898 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 899 |
+
if accelerator.sync_gradients:
|
| 900 |
+
progress_bar.update(1)
|
| 901 |
+
global_step += 1
|
| 902 |
+
|
| 903 |
+
if accelerator.is_main_process:
|
| 904 |
+
if global_step % args.checkpointing_steps == 0:
|
| 905 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 906 |
+
if args.checkpoints_total_limit is not None:
|
| 907 |
+
checkpoints = os.listdir(args.output_dir)
|
| 908 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 909 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 910 |
+
|
| 911 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 912 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 913 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 914 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 915 |
+
|
| 916 |
+
logger.info(
|
| 917 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 918 |
+
)
|
| 919 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 920 |
+
|
| 921 |
+
for removing_checkpoint in removing_checkpoints:
|
| 922 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 923 |
+
shutil.rmtree(removing_checkpoint)
|
| 924 |
+
|
| 925 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 926 |
+
accelerator.save_state(save_path)
|
| 927 |
+
logger.info(f"Saved state to {save_path}")
|
| 928 |
+
|
| 929 |
+
if args.run_validation and global_step % args.validation_steps == 0:
|
| 930 |
+
log_validation(
|
| 931 |
+
args, unet=unet, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
logs = {
|
| 935 |
+
"loss": loss.detach().item(),
|
| 936 |
+
"raw_model_loss": raw_model_loss.detach().item(),
|
| 937 |
+
"ref_loss": raw_ref_loss.detach().item(),
|
| 938 |
+
"implicit_acc": implicit_acc.detach().item(),
|
| 939 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 940 |
+
}
|
| 941 |
+
progress_bar.set_postfix(**logs)
|
| 942 |
+
accelerator.log(logs, step=global_step)
|
| 943 |
+
|
| 944 |
+
if global_step >= args.max_train_steps:
|
| 945 |
+
break
|
| 946 |
+
|
| 947 |
+
# Save the lora layers
|
| 948 |
+
accelerator.wait_for_everyone()
|
| 949 |
+
if accelerator.is_main_process:
|
| 950 |
+
unet = accelerator.unwrap_model(unet)
|
| 951 |
+
unet = unet.to(torch.float32)
|
| 952 |
+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
| 953 |
+
|
| 954 |
+
StableDiffusionLoraLoaderMixin.save_lora_weights(
|
| 955 |
+
save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=None
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
# Final validation?
|
| 959 |
+
if args.run_validation:
|
| 960 |
+
log_validation(
|
| 961 |
+
args,
|
| 962 |
+
unet=None,
|
| 963 |
+
accelerator=accelerator,
|
| 964 |
+
weight_dtype=weight_dtype,
|
| 965 |
+
epoch=epoch,
|
| 966 |
+
is_final_validation=True,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
if args.push_to_hub:
|
| 970 |
+
upload_folder(
|
| 971 |
+
repo_id=repo_id,
|
| 972 |
+
folder_path=args.output_dir,
|
| 973 |
+
commit_message="End of training",
|
| 974 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
accelerator.end_training()
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
if __name__ == "__main__":
|
| 981 |
+
args = parse_args()
|
| 982 |
+
main(args)
|
diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
ADDED
|
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 bram-w, The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import contextlib
|
| 18 |
+
import io
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import random
|
| 23 |
+
import shutil
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.utils.checkpoint
|
| 30 |
+
import transformers
|
| 31 |
+
import wandb
|
| 32 |
+
from accelerate import Accelerator
|
| 33 |
+
from accelerate.logging import get_logger
|
| 34 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
from huggingface_hub import create_repo, upload_folder
|
| 37 |
+
from packaging import version
|
| 38 |
+
from peft import LoraConfig
|
| 39 |
+
from peft.utils import get_peft_model_state_dict
|
| 40 |
+
from PIL import Image
|
| 41 |
+
from torchvision import transforms
|
| 42 |
+
from torchvision.transforms.functional import crop
|
| 43 |
+
from tqdm.auto import tqdm
|
| 44 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 45 |
+
|
| 46 |
+
import diffusers
|
| 47 |
+
from diffusers import (
|
| 48 |
+
AutoencoderKL,
|
| 49 |
+
DDPMScheduler,
|
| 50 |
+
DiffusionPipeline,
|
| 51 |
+
UNet2DConditionModel,
|
| 52 |
+
)
|
| 53 |
+
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
|
| 54 |
+
from diffusers.optimization import get_scheduler
|
| 55 |
+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
|
| 56 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 60 |
+
check_min_version("0.25.0.dev0")
|
| 61 |
+
|
| 62 |
+
logger = get_logger(__name__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
VALIDATION_PROMPTS = [
|
| 66 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
|
| 67 |
+
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 68 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 69 |
+
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def import_model_class_from_model_name_or_path(
|
| 74 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
| 75 |
+
):
|
| 76 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 77 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 78 |
+
)
|
| 79 |
+
model_class = text_encoder_config.architectures[0]
|
| 80 |
+
|
| 81 |
+
if model_class == "CLIPTextModel":
|
| 82 |
+
from transformers import CLIPTextModel
|
| 83 |
+
|
| 84 |
+
return CLIPTextModel
|
| 85 |
+
elif model_class == "CLIPTextModelWithProjection":
|
| 86 |
+
from transformers import CLIPTextModelWithProjection
|
| 87 |
+
|
| 88 |
+
return CLIPTextModelWithProjection
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
| 94 |
+
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
| 95 |
+
|
| 96 |
+
if is_final_validation:
|
| 97 |
+
if args.mixed_precision == "fp16":
|
| 98 |
+
vae.to(weight_dtype)
|
| 99 |
+
|
| 100 |
+
# create pipeline
|
| 101 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 102 |
+
args.pretrained_model_name_or_path,
|
| 103 |
+
vae=vae,
|
| 104 |
+
revision=args.revision,
|
| 105 |
+
variant=args.variant,
|
| 106 |
+
torch_dtype=weight_dtype,
|
| 107 |
+
)
|
| 108 |
+
if not is_final_validation:
|
| 109 |
+
pipeline.unet = accelerator.unwrap_model(unet)
|
| 110 |
+
else:
|
| 111 |
+
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
| 112 |
+
|
| 113 |
+
pipeline = pipeline.to(accelerator.device)
|
| 114 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 115 |
+
|
| 116 |
+
# run inference
|
| 117 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 118 |
+
images = []
|
| 119 |
+
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
| 120 |
+
|
| 121 |
+
guidance_scale = 5.0
|
| 122 |
+
num_inference_steps = 25
|
| 123 |
+
if args.is_turbo:
|
| 124 |
+
guidance_scale = 0.0
|
| 125 |
+
num_inference_steps = 4
|
| 126 |
+
for prompt in VALIDATION_PROMPTS:
|
| 127 |
+
with context:
|
| 128 |
+
image = pipeline(
|
| 129 |
+
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
| 130 |
+
).images[0]
|
| 131 |
+
images.append(image)
|
| 132 |
+
|
| 133 |
+
tracker_key = "test" if is_final_validation else "validation"
|
| 134 |
+
for tracker in accelerator.trackers:
|
| 135 |
+
if tracker.name == "tensorboard":
|
| 136 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 137 |
+
tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
|
| 138 |
+
if tracker.name == "wandb":
|
| 139 |
+
tracker.log(
|
| 140 |
+
{
|
| 141 |
+
tracker_key: [
|
| 142 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
|
| 143 |
+
]
|
| 144 |
+
}
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Also log images without the LoRA params for comparison.
|
| 148 |
+
if is_final_validation:
|
| 149 |
+
pipeline.disable_lora()
|
| 150 |
+
no_lora_images = [
|
| 151 |
+
pipeline(
|
| 152 |
+
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
| 153 |
+
).images[0]
|
| 154 |
+
for prompt in VALIDATION_PROMPTS
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
for tracker in accelerator.trackers:
|
| 158 |
+
if tracker.name == "tensorboard":
|
| 159 |
+
np_images = np.stack([np.asarray(img) for img in no_lora_images])
|
| 160 |
+
tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
|
| 161 |
+
if tracker.name == "wandb":
|
| 162 |
+
tracker.log(
|
| 163 |
+
{
|
| 164 |
+
"test_without_lora": [
|
| 165 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
|
| 166 |
+
for i, image in enumerate(no_lora_images)
|
| 167 |
+
]
|
| 168 |
+
}
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def parse_args(input_args=None):
|
| 173 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--pretrained_model_name_or_path",
|
| 176 |
+
type=str,
|
| 177 |
+
default=None,
|
| 178 |
+
required=True,
|
| 179 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--pretrained_vae_model_name_or_path",
|
| 183 |
+
type=str,
|
| 184 |
+
default=None,
|
| 185 |
+
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--revision",
|
| 189 |
+
type=str,
|
| 190 |
+
default=None,
|
| 191 |
+
required=False,
|
| 192 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--dataset_name",
|
| 196 |
+
type=str,
|
| 197 |
+
default=None,
|
| 198 |
+
help=(
|
| 199 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 200 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 201 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 202 |
+
),
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--dataset_split_name",
|
| 206 |
+
type=str,
|
| 207 |
+
default="validation",
|
| 208 |
+
help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--variant",
|
| 212 |
+
type=str,
|
| 213 |
+
default=None,
|
| 214 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--run_validation",
|
| 218 |
+
default=False,
|
| 219 |
+
action="store_true",
|
| 220 |
+
help="Whether to run validation inference in between training and also after training. Helps to track progress.",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--validation_steps",
|
| 224 |
+
type=int,
|
| 225 |
+
default=200,
|
| 226 |
+
help="Run validation every X steps.",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--max_train_samples",
|
| 230 |
+
type=int,
|
| 231 |
+
default=None,
|
| 232 |
+
help=(
|
| 233 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 234 |
+
"value if set."
|
| 235 |
+
),
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--output_dir",
|
| 239 |
+
type=str,
|
| 240 |
+
default="diffusion-dpo-lora",
|
| 241 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--cache_dir",
|
| 245 |
+
type=str,
|
| 246 |
+
default=None,
|
| 247 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--resolution",
|
| 252 |
+
type=int,
|
| 253 |
+
default=1024,
|
| 254 |
+
help=(
|
| 255 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 256 |
+
" resolution"
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--vae_encode_batch_size",
|
| 261 |
+
type=int,
|
| 262 |
+
default=8,
|
| 263 |
+
help="Batch size to use for VAE encoding of the images for efficient processing.",
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--no_hflip",
|
| 267 |
+
action="store_true",
|
| 268 |
+
help="whether to randomly flip images horizontally",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--random_crop",
|
| 272 |
+
default=False,
|
| 273 |
+
action="store_true",
|
| 274 |
+
help=(
|
| 275 |
+
"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
|
| 276 |
+
),
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--max_train_steps",
|
| 284 |
+
type=int,
|
| 285 |
+
default=None,
|
| 286 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--checkpointing_steps",
|
| 290 |
+
type=int,
|
| 291 |
+
default=500,
|
| 292 |
+
help=(
|
| 293 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 294 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 295 |
+
" training using `--resume_from_checkpoint`."
|
| 296 |
+
),
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--checkpoints_total_limit",
|
| 300 |
+
type=int,
|
| 301 |
+
default=None,
|
| 302 |
+
help=("Max number of checkpoints to store."),
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--resume_from_checkpoint",
|
| 306 |
+
type=str,
|
| 307 |
+
default=None,
|
| 308 |
+
help=(
|
| 309 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 310 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 311 |
+
),
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--gradient_accumulation_steps",
|
| 315 |
+
type=int,
|
| 316 |
+
default=1,
|
| 317 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--gradient_checkpointing",
|
| 321 |
+
action="store_true",
|
| 322 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--beta_dpo",
|
| 326 |
+
type=int,
|
| 327 |
+
default=5000,
|
| 328 |
+
help="DPO KL Divergence penalty.",
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--learning_rate",
|
| 332 |
+
type=float,
|
| 333 |
+
default=5e-4,
|
| 334 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--scale_lr",
|
| 338 |
+
action="store_true",
|
| 339 |
+
default=False,
|
| 340 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--lr_scheduler",
|
| 344 |
+
type=str,
|
| 345 |
+
default="constant",
|
| 346 |
+
help=(
|
| 347 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 348 |
+
' "constant", "constant_with_warmup"]'
|
| 349 |
+
),
|
| 350 |
+
)
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--lr_num_cycles",
|
| 356 |
+
type=int,
|
| 357 |
+
default=1,
|
| 358 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
| 361 |
+
parser.add_argument(
|
| 362 |
+
"--dataloader_num_workers",
|
| 363 |
+
type=int,
|
| 364 |
+
default=0,
|
| 365 |
+
help=(
|
| 366 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 367 |
+
),
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 373 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 374 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 375 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 376 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 377 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 378 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--hub_model_id",
|
| 381 |
+
type=str,
|
| 382 |
+
default=None,
|
| 383 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 384 |
+
)
|
| 385 |
+
parser.add_argument(
|
| 386 |
+
"--logging_dir",
|
| 387 |
+
type=str,
|
| 388 |
+
default="logs",
|
| 389 |
+
help=(
|
| 390 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 391 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 392 |
+
),
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--allow_tf32",
|
| 396 |
+
action="store_true",
|
| 397 |
+
help=(
|
| 398 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 399 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 400 |
+
),
|
| 401 |
+
)
|
| 402 |
+
parser.add_argument(
|
| 403 |
+
"--report_to",
|
| 404 |
+
type=str,
|
| 405 |
+
default="tensorboard",
|
| 406 |
+
help=(
|
| 407 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 408 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 409 |
+
),
|
| 410 |
+
)
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--mixed_precision",
|
| 413 |
+
type=str,
|
| 414 |
+
default=None,
|
| 415 |
+
choices=["no", "fp16", "bf16"],
|
| 416 |
+
help=(
|
| 417 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 418 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 419 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 420 |
+
),
|
| 421 |
+
)
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--prior_generation_precision",
|
| 424 |
+
type=str,
|
| 425 |
+
default=None,
|
| 426 |
+
choices=["no", "fp32", "fp16", "bf16"],
|
| 427 |
+
help=(
|
| 428 |
+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 429 |
+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
| 430 |
+
),
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 433 |
+
parser.add_argument(
|
| 434 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 435 |
+
)
|
| 436 |
+
parser.add_argument(
|
| 437 |
+
"--is_turbo",
|
| 438 |
+
action="store_true",
|
| 439 |
+
help=("Use if tuning SDXL Turbo instead of SDXL"),
|
| 440 |
+
)
|
| 441 |
+
parser.add_argument(
|
| 442 |
+
"--rank",
|
| 443 |
+
type=int,
|
| 444 |
+
default=4,
|
| 445 |
+
help=("The dimension of the LoRA update matrices."),
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--tracker_name",
|
| 449 |
+
type=str,
|
| 450 |
+
default="diffusion-dpo-lora-sdxl",
|
| 451 |
+
help=("The name of the tracker to report results to."),
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if input_args is not None:
|
| 455 |
+
args = parser.parse_args(input_args)
|
| 456 |
+
else:
|
| 457 |
+
args = parser.parse_args()
|
| 458 |
+
|
| 459 |
+
if args.dataset_name is None:
|
| 460 |
+
raise ValueError("Must provide a `dataset_name`.")
|
| 461 |
+
|
| 462 |
+
if args.is_turbo:
|
| 463 |
+
assert "turbo" in args.pretrained_model_name_or_path
|
| 464 |
+
|
| 465 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 466 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 467 |
+
args.local_rank = env_local_rank
|
| 468 |
+
|
| 469 |
+
return args
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def tokenize_captions(tokenizers, examples):
|
| 473 |
+
captions = []
|
| 474 |
+
for caption in examples["caption"]:
|
| 475 |
+
captions.append(caption)
|
| 476 |
+
|
| 477 |
+
tokens_one = tokenizers[0](
|
| 478 |
+
captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt"
|
| 479 |
+
).input_ids
|
| 480 |
+
tokens_two = tokenizers[1](
|
| 481 |
+
captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt"
|
| 482 |
+
).input_ids
|
| 483 |
+
|
| 484 |
+
return tokens_one, tokens_two
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@torch.no_grad()
|
| 488 |
+
def encode_prompt(text_encoders, text_input_ids_list):
|
| 489 |
+
prompt_embeds_list = []
|
| 490 |
+
|
| 491 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 492 |
+
text_input_ids = text_input_ids_list[i]
|
| 493 |
+
|
| 494 |
+
prompt_embeds = text_encoder(
|
| 495 |
+
text_input_ids.to(text_encoder.device),
|
| 496 |
+
output_hidden_states=True,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 500 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 501 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 502 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 503 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 504 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 505 |
+
|
| 506 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 507 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 508 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def main(args):
|
| 512 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 515 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 519 |
+
|
| 520 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 521 |
+
|
| 522 |
+
accelerator = Accelerator(
|
| 523 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 524 |
+
mixed_precision=args.mixed_precision,
|
| 525 |
+
log_with=args.report_to,
|
| 526 |
+
project_config=accelerator_project_config,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# Disable AMP for MPS.
|
| 530 |
+
if torch.backends.mps.is_available():
|
| 531 |
+
accelerator.native_amp = False
|
| 532 |
+
|
| 533 |
+
# Make one log on every process with the configuration for debugging.
|
| 534 |
+
logging.basicConfig(
|
| 535 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 536 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 537 |
+
level=logging.INFO,
|
| 538 |
+
)
|
| 539 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 540 |
+
if accelerator.is_local_main_process:
|
| 541 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 542 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 543 |
+
else:
|
| 544 |
+
transformers.utils.logging.set_verbosity_error()
|
| 545 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 546 |
+
|
| 547 |
+
# If passed along, set the training seed now.
|
| 548 |
+
if args.seed is not None:
|
| 549 |
+
set_seed(args.seed)
|
| 550 |
+
|
| 551 |
+
# Handle the repository creation
|
| 552 |
+
if accelerator.is_main_process:
|
| 553 |
+
if args.output_dir is not None:
|
| 554 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 555 |
+
|
| 556 |
+
if args.push_to_hub:
|
| 557 |
+
repo_id = create_repo(
|
| 558 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 559 |
+
).repo_id
|
| 560 |
+
|
| 561 |
+
# Load the tokenizers
|
| 562 |
+
tokenizer_one = AutoTokenizer.from_pretrained(
|
| 563 |
+
args.pretrained_model_name_or_path,
|
| 564 |
+
subfolder="tokenizer",
|
| 565 |
+
revision=args.revision,
|
| 566 |
+
use_fast=False,
|
| 567 |
+
)
|
| 568 |
+
tokenizer_two = AutoTokenizer.from_pretrained(
|
| 569 |
+
args.pretrained_model_name_or_path,
|
| 570 |
+
subfolder="tokenizer_2",
|
| 571 |
+
revision=args.revision,
|
| 572 |
+
use_fast=False,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# import correct text encoder classes
|
| 576 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
| 577 |
+
args.pretrained_model_name_or_path, args.revision
|
| 578 |
+
)
|
| 579 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
| 580 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Load scheduler and models
|
| 584 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 585 |
+
|
| 586 |
+
def enforce_zero_terminal_snr(scheduler):
|
| 587 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
|
| 588 |
+
# Original implementation https://huggingface.co/papers/2305.08891
|
| 589 |
+
# Turbo needs zero terminal SNR
|
| 590 |
+
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
|
| 591 |
+
# Convert betas to alphas_bar_sqrt
|
| 592 |
+
alphas = 1 - scheduler.betas
|
| 593 |
+
alphas_bar = alphas.cumprod(0)
|
| 594 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
| 595 |
+
|
| 596 |
+
# Store old values.
|
| 597 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 598 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 599 |
+
# Shift so last timestep is zero.
|
| 600 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 601 |
+
# Scale so first timestep is back to old value.
|
| 602 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 603 |
+
|
| 604 |
+
alphas_bar = alphas_bar_sqrt**2
|
| 605 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
| 606 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 607 |
+
|
| 608 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 609 |
+
scheduler.alphas_cumprod = alphas_cumprod
|
| 610 |
+
return
|
| 611 |
+
|
| 612 |
+
if args.is_turbo:
|
| 613 |
+
enforce_zero_terminal_snr(noise_scheduler)
|
| 614 |
+
|
| 615 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
| 616 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 617 |
+
)
|
| 618 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
| 619 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
| 620 |
+
)
|
| 621 |
+
vae_path = (
|
| 622 |
+
args.pretrained_model_name_or_path
|
| 623 |
+
if args.pretrained_vae_model_name_or_path is None
|
| 624 |
+
else args.pretrained_vae_model_name_or_path
|
| 625 |
+
)
|
| 626 |
+
vae = AutoencoderKL.from_pretrained(
|
| 627 |
+
vae_path,
|
| 628 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
| 629 |
+
revision=args.revision,
|
| 630 |
+
variant=args.variant,
|
| 631 |
+
)
|
| 632 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 633 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# We only train the additional adapter LoRA layers
|
| 637 |
+
vae.requires_grad_(False)
|
| 638 |
+
text_encoder_one.requires_grad_(False)
|
| 639 |
+
text_encoder_two.requires_grad_(False)
|
| 640 |
+
unet.requires_grad_(False)
|
| 641 |
+
|
| 642 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 643 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 644 |
+
weight_dtype = torch.float32
|
| 645 |
+
if accelerator.mixed_precision == "fp16":
|
| 646 |
+
weight_dtype = torch.float16
|
| 647 |
+
elif accelerator.mixed_precision == "bf16":
|
| 648 |
+
weight_dtype = torch.bfloat16
|
| 649 |
+
|
| 650 |
+
# Move unet and text_encoders to device and cast to weight_dtype
|
| 651 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 652 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 653 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 654 |
+
|
| 655 |
+
# The VAE is always in float32 to avoid NaN losses.
|
| 656 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
| 657 |
+
|
| 658 |
+
# Set up LoRA.
|
| 659 |
+
unet_lora_config = LoraConfig(
|
| 660 |
+
r=args.rank,
|
| 661 |
+
lora_alpha=args.rank,
|
| 662 |
+
init_lora_weights="gaussian",
|
| 663 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
| 664 |
+
)
|
| 665 |
+
# Add adapter and make sure the trainable params are in float32.
|
| 666 |
+
unet.add_adapter(unet_lora_config)
|
| 667 |
+
if args.mixed_precision == "fp16":
|
| 668 |
+
for param in unet.parameters():
|
| 669 |
+
# only upcast trainable parameters (LoRA) into fp32
|
| 670 |
+
if param.requires_grad:
|
| 671 |
+
param.data = param.to(torch.float32)
|
| 672 |
+
|
| 673 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 674 |
+
if is_xformers_available():
|
| 675 |
+
import xformers
|
| 676 |
+
|
| 677 |
+
xformers_version = version.parse(xformers.__version__)
|
| 678 |
+
if xformers_version == version.parse("0.0.16"):
|
| 679 |
+
logger.warning(
|
| 680 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 681 |
+
)
|
| 682 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 683 |
+
else:
|
| 684 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 685 |
+
|
| 686 |
+
if args.gradient_checkpointing:
|
| 687 |
+
unet.enable_gradient_checkpointing()
|
| 688 |
+
|
| 689 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 690 |
+
def save_model_hook(models, weights, output_dir):
|
| 691 |
+
if accelerator.is_main_process:
|
| 692 |
+
# there are only two options here. Either are just the unet attn processor layers
|
| 693 |
+
# or there are the unet and text encoder atten layers
|
| 694 |
+
unet_lora_layers_to_save = None
|
| 695 |
+
|
| 696 |
+
for model in models:
|
| 697 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 698 |
+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
| 699 |
+
else:
|
| 700 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 701 |
+
|
| 702 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 703 |
+
weights.pop()
|
| 704 |
+
|
| 705 |
+
StableDiffusionXLLoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save)
|
| 706 |
+
|
| 707 |
+
def load_model_hook(models, input_dir):
|
| 708 |
+
unet_ = None
|
| 709 |
+
|
| 710 |
+
while len(models) > 0:
|
| 711 |
+
model = models.pop()
|
| 712 |
+
|
| 713 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 714 |
+
unet_ = model
|
| 715 |
+
else:
|
| 716 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 717 |
+
|
| 718 |
+
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
| 719 |
+
StableDiffusionXLLoraLoaderMixin.load_lora_into_unet(
|
| 720 |
+
lora_state_dict, network_alphas=network_alphas, unet=unet_
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 724 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 725 |
+
|
| 726 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 727 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 728 |
+
if args.allow_tf32:
|
| 729 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 730 |
+
|
| 731 |
+
if args.scale_lr:
|
| 732 |
+
args.learning_rate = (
|
| 733 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 737 |
+
if args.use_8bit_adam:
|
| 738 |
+
try:
|
| 739 |
+
import bitsandbytes as bnb
|
| 740 |
+
except ImportError:
|
| 741 |
+
raise ImportError(
|
| 742 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 746 |
+
else:
|
| 747 |
+
optimizer_class = torch.optim.AdamW
|
| 748 |
+
|
| 749 |
+
# Optimizer creation
|
| 750 |
+
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 751 |
+
optimizer = optimizer_class(
|
| 752 |
+
params_to_optimize,
|
| 753 |
+
lr=args.learning_rate,
|
| 754 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 755 |
+
weight_decay=args.adam_weight_decay,
|
| 756 |
+
eps=args.adam_epsilon,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Dataset and DataLoaders creation:
|
| 760 |
+
train_dataset = load_dataset(
|
| 761 |
+
args.dataset_name,
|
| 762 |
+
cache_dir=args.cache_dir,
|
| 763 |
+
split=args.dataset_split_name,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# Preprocessing the datasets.
|
| 767 |
+
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
| 768 |
+
train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
|
| 769 |
+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 770 |
+
to_tensor = transforms.ToTensor()
|
| 771 |
+
normalize = transforms.Normalize([0.5], [0.5])
|
| 772 |
+
|
| 773 |
+
def preprocess_train(examples):
|
| 774 |
+
all_pixel_values = []
|
| 775 |
+
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]]
|
| 776 |
+
original_sizes = [(image.height, image.width) for image in images]
|
| 777 |
+
crop_top_lefts = []
|
| 778 |
+
|
| 779 |
+
for col_name in ["jpg_0", "jpg_1"]:
|
| 780 |
+
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
|
| 781 |
+
if col_name == "jpg_1":
|
| 782 |
+
# Need to bring down the image to the same resolution.
|
| 783 |
+
# This seems like the simplest reasonable approach.
|
| 784 |
+
# "::-1" because PIL resize takes (width, height).
|
| 785 |
+
images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]
|
| 786 |
+
pixel_values = [to_tensor(image) for image in images]
|
| 787 |
+
all_pixel_values.append(pixel_values)
|
| 788 |
+
|
| 789 |
+
# Double on channel dim, jpg_y then jpg_w
|
| 790 |
+
im_tup_iterator = zip(*all_pixel_values)
|
| 791 |
+
combined_pixel_values = []
|
| 792 |
+
for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
|
| 793 |
+
if label_0 == 0:
|
| 794 |
+
im_tup = im_tup[::-1]
|
| 795 |
+
|
| 796 |
+
combined_im = torch.cat(im_tup, dim=0) # no batch dim
|
| 797 |
+
|
| 798 |
+
# Resize.
|
| 799 |
+
combined_im = train_resize(combined_im)
|
| 800 |
+
|
| 801 |
+
# Flipping.
|
| 802 |
+
if not args.no_hflip and random.random() < 0.5:
|
| 803 |
+
combined_im = train_flip(combined_im)
|
| 804 |
+
|
| 805 |
+
# Cropping.
|
| 806 |
+
if not args.random_crop:
|
| 807 |
+
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
|
| 808 |
+
x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
|
| 809 |
+
combined_im = train_crop(combined_im)
|
| 810 |
+
else:
|
| 811 |
+
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
|
| 812 |
+
combined_im = crop(combined_im, y1, x1, h, w)
|
| 813 |
+
|
| 814 |
+
crop_top_left = (y1, x1)
|
| 815 |
+
crop_top_lefts.append(crop_top_left)
|
| 816 |
+
combined_im = normalize(combined_im)
|
| 817 |
+
combined_pixel_values.append(combined_im)
|
| 818 |
+
|
| 819 |
+
examples["pixel_values"] = combined_pixel_values
|
| 820 |
+
examples["original_sizes"] = original_sizes
|
| 821 |
+
examples["crop_top_lefts"] = crop_top_lefts
|
| 822 |
+
tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)
|
| 823 |
+
examples["input_ids_one"] = tokens_one
|
| 824 |
+
examples["input_ids_two"] = tokens_two
|
| 825 |
+
return examples
|
| 826 |
+
|
| 827 |
+
with accelerator.main_process_first():
|
| 828 |
+
if args.max_train_samples is not None:
|
| 829 |
+
train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 830 |
+
# Set the training transforms
|
| 831 |
+
train_dataset = train_dataset.with_transform(preprocess_train)
|
| 832 |
+
|
| 833 |
+
def collate_fn(examples):
|
| 834 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 835 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 836 |
+
original_sizes = [example["original_sizes"] for example in examples]
|
| 837 |
+
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
| 838 |
+
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
| 839 |
+
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
| 840 |
+
|
| 841 |
+
return {
|
| 842 |
+
"pixel_values": pixel_values,
|
| 843 |
+
"input_ids_one": input_ids_one,
|
| 844 |
+
"input_ids_two": input_ids_two,
|
| 845 |
+
"original_sizes": original_sizes,
|
| 846 |
+
"crop_top_lefts": crop_top_lefts,
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 850 |
+
train_dataset,
|
| 851 |
+
batch_size=args.train_batch_size,
|
| 852 |
+
shuffle=True,
|
| 853 |
+
collate_fn=collate_fn,
|
| 854 |
+
num_workers=args.dataloader_num_workers,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Scheduler and math around the number of training steps.
|
| 858 |
+
overrode_max_train_steps = False
|
| 859 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 860 |
+
if args.max_train_steps is None:
|
| 861 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 862 |
+
overrode_max_train_steps = True
|
| 863 |
+
|
| 864 |
+
lr_scheduler = get_scheduler(
|
| 865 |
+
args.lr_scheduler,
|
| 866 |
+
optimizer=optimizer,
|
| 867 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 868 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 869 |
+
num_cycles=args.lr_num_cycles,
|
| 870 |
+
power=args.lr_power,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 874 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 878 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 879 |
+
if overrode_max_train_steps:
|
| 880 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 881 |
+
# Afterwards we recalculate our number of training epochs
|
| 882 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 883 |
+
|
| 884 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 885 |
+
# The trackers initializes automatically on the main process.
|
| 886 |
+
if accelerator.is_main_process:
|
| 887 |
+
accelerator.init_trackers(args.tracker_name, config=vars(args))
|
| 888 |
+
|
| 889 |
+
# Train!
|
| 890 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 891 |
+
|
| 892 |
+
logger.info("***** Running training *****")
|
| 893 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 894 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 895 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 896 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 897 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 898 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 899 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 900 |
+
global_step = 0
|
| 901 |
+
first_epoch = 0
|
| 902 |
+
|
| 903 |
+
# Potentially load in the weights and states from a previous save
|
| 904 |
+
if args.resume_from_checkpoint:
|
| 905 |
+
if args.resume_from_checkpoint != "latest":
|
| 906 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 907 |
+
else:
|
| 908 |
+
# Get the mos recent checkpoint
|
| 909 |
+
dirs = os.listdir(args.output_dir)
|
| 910 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 911 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 912 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 913 |
+
|
| 914 |
+
if path is None:
|
| 915 |
+
accelerator.print(
|
| 916 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 917 |
+
)
|
| 918 |
+
args.resume_from_checkpoint = None
|
| 919 |
+
initial_global_step = 0
|
| 920 |
+
else:
|
| 921 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 922 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 923 |
+
global_step = int(path.split("-")[1])
|
| 924 |
+
|
| 925 |
+
initial_global_step = global_step
|
| 926 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 927 |
+
else:
|
| 928 |
+
initial_global_step = 0
|
| 929 |
+
|
| 930 |
+
progress_bar = tqdm(
|
| 931 |
+
range(0, args.max_train_steps),
|
| 932 |
+
initial=initial_global_step,
|
| 933 |
+
desc="Steps",
|
| 934 |
+
# Only show the progress bar once on each machine.
|
| 935 |
+
disable=not accelerator.is_local_main_process,
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
unet.train()
|
| 939 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 940 |
+
for step, batch in enumerate(train_dataloader):
|
| 941 |
+
with accelerator.accumulate(unet):
|
| 942 |
+
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
|
| 943 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 944 |
+
feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
|
| 945 |
+
|
| 946 |
+
latents = []
|
| 947 |
+
for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
|
| 948 |
+
latents.append(
|
| 949 |
+
vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
|
| 950 |
+
)
|
| 951 |
+
latents = torch.cat(latents, dim=0)
|
| 952 |
+
latents = latents * vae.config.scaling_factor
|
| 953 |
+
if args.pretrained_vae_model_name_or_path is None:
|
| 954 |
+
latents = latents.to(weight_dtype)
|
| 955 |
+
|
| 956 |
+
# Sample noise that we'll add to the latents
|
| 957 |
+
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
|
| 958 |
+
|
| 959 |
+
# Sample a random timestep for each image
|
| 960 |
+
bsz = latents.shape[0] // 2
|
| 961 |
+
timesteps = torch.randint(
|
| 962 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
| 963 |
+
).repeat(2)
|
| 964 |
+
if args.is_turbo:
|
| 965 |
+
# Learn a 4 timestep schedule
|
| 966 |
+
timesteps_0_to_3 = timesteps % 4
|
| 967 |
+
timesteps = 250 * timesteps_0_to_3 + 249
|
| 968 |
+
|
| 969 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
| 970 |
+
# (this is the forward diffusion process)
|
| 971 |
+
noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 972 |
+
|
| 973 |
+
# time ids
|
| 974 |
+
def compute_time_ids(original_size, crops_coords_top_left):
|
| 975 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
| 976 |
+
target_size = (args.resolution, args.resolution)
|
| 977 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 978 |
+
add_time_ids = torch.tensor([add_time_ids])
|
| 979 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
| 980 |
+
return add_time_ids
|
| 981 |
+
|
| 982 |
+
add_time_ids = torch.cat(
|
| 983 |
+
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
| 984 |
+
).repeat(2, 1)
|
| 985 |
+
|
| 986 |
+
# Get the text embedding for conditioning
|
| 987 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
| 988 |
+
[text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
|
| 989 |
+
)
|
| 990 |
+
prompt_embeds = prompt_embeds.repeat(2, 1, 1)
|
| 991 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
|
| 992 |
+
|
| 993 |
+
# Predict the noise residual
|
| 994 |
+
model_pred = unet(
|
| 995 |
+
noisy_model_input,
|
| 996 |
+
timesteps,
|
| 997 |
+
prompt_embeds,
|
| 998 |
+
added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
|
| 999 |
+
).sample
|
| 1000 |
+
|
| 1001 |
+
# Get the target for loss depending on the prediction type
|
| 1002 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 1003 |
+
target = noise
|
| 1004 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 1005 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 1006 |
+
else:
|
| 1007 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 1008 |
+
|
| 1009 |
+
# Compute losses.
|
| 1010 |
+
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 1011 |
+
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
| 1012 |
+
model_losses_w, model_losses_l = model_losses.chunk(2)
|
| 1013 |
+
|
| 1014 |
+
# For logging
|
| 1015 |
+
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
|
| 1016 |
+
model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
|
| 1017 |
+
|
| 1018 |
+
# Reference model predictions.
|
| 1019 |
+
accelerator.unwrap_model(unet).disable_adapters()
|
| 1020 |
+
with torch.no_grad():
|
| 1021 |
+
ref_preds = unet(
|
| 1022 |
+
noisy_model_input,
|
| 1023 |
+
timesteps,
|
| 1024 |
+
prompt_embeds,
|
| 1025 |
+
added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
|
| 1026 |
+
).sample
|
| 1027 |
+
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
|
| 1028 |
+
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
| 1029 |
+
|
| 1030 |
+
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
| 1031 |
+
ref_diff = ref_losses_w - ref_losses_l
|
| 1032 |
+
raw_ref_loss = ref_loss.mean()
|
| 1033 |
+
|
| 1034 |
+
# Re-enable adapters.
|
| 1035 |
+
accelerator.unwrap_model(unet).enable_adapters()
|
| 1036 |
+
|
| 1037 |
+
# Final loss.
|
| 1038 |
+
scale_term = -0.5 * args.beta_dpo
|
| 1039 |
+
inside_term = scale_term * (model_diff - ref_diff)
|
| 1040 |
+
loss = -1 * F.logsigmoid(inside_term).mean()
|
| 1041 |
+
|
| 1042 |
+
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
| 1043 |
+
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
| 1044 |
+
|
| 1045 |
+
accelerator.backward(loss)
|
| 1046 |
+
if accelerator.sync_gradients:
|
| 1047 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
| 1048 |
+
optimizer.step()
|
| 1049 |
+
lr_scheduler.step()
|
| 1050 |
+
optimizer.zero_grad()
|
| 1051 |
+
|
| 1052 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1053 |
+
if accelerator.sync_gradients:
|
| 1054 |
+
progress_bar.update(1)
|
| 1055 |
+
global_step += 1
|
| 1056 |
+
|
| 1057 |
+
if accelerator.is_main_process:
|
| 1058 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1059 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1060 |
+
if args.checkpoints_total_limit is not None:
|
| 1061 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1062 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1063 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1064 |
+
|
| 1065 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 1066 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1067 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1068 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1069 |
+
|
| 1070 |
+
logger.info(
|
| 1071 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 1072 |
+
)
|
| 1073 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 1074 |
+
|
| 1075 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1076 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1077 |
+
shutil.rmtree(removing_checkpoint)
|
| 1078 |
+
|
| 1079 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1080 |
+
accelerator.save_state(save_path)
|
| 1081 |
+
logger.info(f"Saved state to {save_path}")
|
| 1082 |
+
|
| 1083 |
+
if args.run_validation and global_step % args.validation_steps == 0:
|
| 1084 |
+
log_validation(
|
| 1085 |
+
args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
|
| 1086 |
+
)
|
| 1087 |
+
|
| 1088 |
+
logs = {
|
| 1089 |
+
"loss": loss.detach().item(),
|
| 1090 |
+
"raw_model_loss": raw_model_loss.detach().item(),
|
| 1091 |
+
"ref_loss": raw_ref_loss.detach().item(),
|
| 1092 |
+
"implicit_acc": implicit_acc.detach().item(),
|
| 1093 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 1094 |
+
}
|
| 1095 |
+
progress_bar.set_postfix(**logs)
|
| 1096 |
+
accelerator.log(logs, step=global_step)
|
| 1097 |
+
|
| 1098 |
+
if global_step >= args.max_train_steps:
|
| 1099 |
+
break
|
| 1100 |
+
|
| 1101 |
+
# Save the lora layers
|
| 1102 |
+
accelerator.wait_for_everyone()
|
| 1103 |
+
if accelerator.is_main_process:
|
| 1104 |
+
unet = accelerator.unwrap_model(unet)
|
| 1105 |
+
unet = unet.to(torch.float32)
|
| 1106 |
+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
| 1107 |
+
|
| 1108 |
+
StableDiffusionXLLoraLoaderMixin.save_lora_weights(
|
| 1109 |
+
save_directory=args.output_dir,
|
| 1110 |
+
unet_lora_layers=unet_lora_state_dict,
|
| 1111 |
+
text_encoder_lora_layers=None,
|
| 1112 |
+
text_encoder_2_lora_layers=None,
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
# Final validation?
|
| 1116 |
+
if args.run_validation:
|
| 1117 |
+
log_validation(
|
| 1118 |
+
args,
|
| 1119 |
+
unet=None,
|
| 1120 |
+
vae=vae,
|
| 1121 |
+
accelerator=accelerator,
|
| 1122 |
+
weight_dtype=weight_dtype,
|
| 1123 |
+
epoch=epoch,
|
| 1124 |
+
is_final_validation=True,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
if args.push_to_hub:
|
| 1128 |
+
upload_folder(
|
| 1129 |
+
repo_id=repo_id,
|
| 1130 |
+
folder_path=args.output_dir,
|
| 1131 |
+
commit_message="End of training",
|
| 1132 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
accelerator.end_training()
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
if __name__ == "__main__":
|
| 1139 |
+
args = parse_args()
|
| 1140 |
+
main(args)
|
diffusers/examples/research_projects/diffusion_orpo/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too!
|
diffusers/examples/research_projects/diffusion_orpo/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets
|
| 2 |
+
accelerate
|
| 3 |
+
transformers
|
| 4 |
+
torchvision
|
| 5 |
+
wandb
|
| 6 |
+
peft
|
| 7 |
+
webdataset
|
diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
ADDED
|
@@ -0,0 +1,1092 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import contextlib
|
| 18 |
+
import io
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import random
|
| 23 |
+
import shutil
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.utils.checkpoint
|
| 30 |
+
import transformers
|
| 31 |
+
import wandb
|
| 32 |
+
from accelerate import Accelerator
|
| 33 |
+
from accelerate.logging import get_logger
|
| 34 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
from huggingface_hub import create_repo, upload_folder
|
| 37 |
+
from packaging import version
|
| 38 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 39 |
+
from peft.utils import get_peft_model_state_dict
|
| 40 |
+
from PIL import Image
|
| 41 |
+
from torchvision import transforms
|
| 42 |
+
from torchvision.transforms.functional import crop
|
| 43 |
+
from tqdm.auto import tqdm
|
| 44 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 45 |
+
|
| 46 |
+
import diffusers
|
| 47 |
+
from diffusers import (
|
| 48 |
+
AutoencoderKL,
|
| 49 |
+
DDPMScheduler,
|
| 50 |
+
DiffusionPipeline,
|
| 51 |
+
UNet2DConditionModel,
|
| 52 |
+
)
|
| 53 |
+
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
|
| 54 |
+
from diffusers.optimization import get_scheduler
|
| 55 |
+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
|
| 56 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 60 |
+
check_min_version("0.25.0.dev0")
|
| 61 |
+
|
| 62 |
+
logger = get_logger(__name__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
VALIDATION_PROMPTS = [
|
| 66 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
|
| 67 |
+
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 68 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 69 |
+
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def import_model_class_from_model_name_or_path(
|
| 74 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
| 75 |
+
):
|
| 76 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 77 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 78 |
+
)
|
| 79 |
+
model_class = text_encoder_config.architectures[0]
|
| 80 |
+
|
| 81 |
+
if model_class == "CLIPTextModel":
|
| 82 |
+
from transformers import CLIPTextModel
|
| 83 |
+
|
| 84 |
+
return CLIPTextModel
|
| 85 |
+
elif model_class == "CLIPTextModelWithProjection":
|
| 86 |
+
from transformers import CLIPTextModelWithProjection
|
| 87 |
+
|
| 88 |
+
return CLIPTextModelWithProjection
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
| 94 |
+
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
| 95 |
+
|
| 96 |
+
if is_final_validation:
|
| 97 |
+
if args.mixed_precision == "fp16":
|
| 98 |
+
vae.to(weight_dtype)
|
| 99 |
+
|
| 100 |
+
# create pipeline
|
| 101 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 102 |
+
args.pretrained_model_name_or_path,
|
| 103 |
+
vae=vae,
|
| 104 |
+
revision=args.revision,
|
| 105 |
+
variant=args.variant,
|
| 106 |
+
torch_dtype=weight_dtype,
|
| 107 |
+
)
|
| 108 |
+
if not is_final_validation:
|
| 109 |
+
pipeline.unet = accelerator.unwrap_model(unet)
|
| 110 |
+
else:
|
| 111 |
+
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
| 112 |
+
|
| 113 |
+
pipeline = pipeline.to(accelerator.device)
|
| 114 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 115 |
+
|
| 116 |
+
# run inference
|
| 117 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 118 |
+
images = []
|
| 119 |
+
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
| 120 |
+
|
| 121 |
+
guidance_scale = 5.0
|
| 122 |
+
num_inference_steps = 25
|
| 123 |
+
for prompt in VALIDATION_PROMPTS:
|
| 124 |
+
with context:
|
| 125 |
+
image = pipeline(
|
| 126 |
+
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
| 127 |
+
).images[0]
|
| 128 |
+
images.append(image)
|
| 129 |
+
|
| 130 |
+
tracker_key = "test" if is_final_validation else "validation"
|
| 131 |
+
for tracker in accelerator.trackers:
|
| 132 |
+
if tracker.name == "tensorboard":
|
| 133 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 134 |
+
tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
|
| 135 |
+
if tracker.name == "wandb":
|
| 136 |
+
tracker.log(
|
| 137 |
+
{
|
| 138 |
+
tracker_key: [
|
| 139 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
|
| 140 |
+
]
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Also log images without the LoRA params for comparison.
|
| 145 |
+
if is_final_validation:
|
| 146 |
+
pipeline.disable_lora()
|
| 147 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 148 |
+
no_lora_images = [
|
| 149 |
+
pipeline(
|
| 150 |
+
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
| 151 |
+
).images[0]
|
| 152 |
+
for prompt in VALIDATION_PROMPTS
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
for tracker in accelerator.trackers:
|
| 156 |
+
if tracker.name == "tensorboard":
|
| 157 |
+
np_images = np.stack([np.asarray(img) for img in no_lora_images])
|
| 158 |
+
tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
|
| 159 |
+
if tracker.name == "wandb":
|
| 160 |
+
tracker.log(
|
| 161 |
+
{
|
| 162 |
+
"test_without_lora": [
|
| 163 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
|
| 164 |
+
for i, image in enumerate(no_lora_images)
|
| 165 |
+
]
|
| 166 |
+
}
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def parse_args(input_args=None):
|
| 171 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--pretrained_model_name_or_path",
|
| 174 |
+
type=str,
|
| 175 |
+
default=None,
|
| 176 |
+
required=True,
|
| 177 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--pretrained_vae_model_name_or_path",
|
| 181 |
+
type=str,
|
| 182 |
+
default=None,
|
| 183 |
+
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--revision",
|
| 187 |
+
type=str,
|
| 188 |
+
default=None,
|
| 189 |
+
required=False,
|
| 190 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--dataset_name",
|
| 194 |
+
type=str,
|
| 195 |
+
default=None,
|
| 196 |
+
help=(
|
| 197 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| 198 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| 199 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
| 200 |
+
),
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--dataset_split_name",
|
| 204 |
+
type=str,
|
| 205 |
+
default="validation",
|
| 206 |
+
help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--variant",
|
| 210 |
+
type=str,
|
| 211 |
+
default=None,
|
| 212 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--run_validation",
|
| 216 |
+
default=False,
|
| 217 |
+
action="store_true",
|
| 218 |
+
help="Whether to run validation inference in between training and also after training. Helps to track progress.",
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--validation_steps",
|
| 222 |
+
type=int,
|
| 223 |
+
default=200,
|
| 224 |
+
help="Run validation every X steps.",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--max_train_samples",
|
| 228 |
+
type=int,
|
| 229 |
+
default=None,
|
| 230 |
+
help=(
|
| 231 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 232 |
+
"value if set."
|
| 233 |
+
),
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--output_dir",
|
| 237 |
+
type=str,
|
| 238 |
+
default="diffusion-orpo-lora-sdxl",
|
| 239 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--cache_dir",
|
| 243 |
+
type=str,
|
| 244 |
+
default=None,
|
| 245 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 246 |
+
)
|
| 247 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--resolution",
|
| 250 |
+
type=int,
|
| 251 |
+
default=1024,
|
| 252 |
+
help=(
|
| 253 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 254 |
+
" resolution"
|
| 255 |
+
),
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--vae_encode_batch_size",
|
| 259 |
+
type=int,
|
| 260 |
+
default=8,
|
| 261 |
+
help="Batch size to use for VAE encoding of the images for efficient processing.",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--no_hflip",
|
| 265 |
+
action="store_true",
|
| 266 |
+
help="whether to randomly flip images horizontally",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--random_crop",
|
| 270 |
+
default=False,
|
| 271 |
+
action="store_true",
|
| 272 |
+
help=(
|
| 273 |
+
"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
|
| 274 |
+
),
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument(
|
| 277 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--max_train_steps",
|
| 282 |
+
type=int,
|
| 283 |
+
default=None,
|
| 284 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--checkpointing_steps",
|
| 288 |
+
type=int,
|
| 289 |
+
default=500,
|
| 290 |
+
help=(
|
| 291 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 292 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 293 |
+
" training using `--resume_from_checkpoint`."
|
| 294 |
+
),
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--checkpoints_total_limit",
|
| 298 |
+
type=int,
|
| 299 |
+
default=None,
|
| 300 |
+
help=("Max number of checkpoints to store."),
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--resume_from_checkpoint",
|
| 304 |
+
type=str,
|
| 305 |
+
default=None,
|
| 306 |
+
help=(
|
| 307 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 308 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 309 |
+
),
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--gradient_accumulation_steps",
|
| 313 |
+
type=int,
|
| 314 |
+
default=1,
|
| 315 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 316 |
+
)
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
"--gradient_checkpointing",
|
| 319 |
+
action="store_true",
|
| 320 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 321 |
+
)
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--beta_orpo",
|
| 324 |
+
type=float,
|
| 325 |
+
default=0.1,
|
| 326 |
+
help="ORPO contribution factor.",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--learning_rate",
|
| 330 |
+
type=float,
|
| 331 |
+
default=5e-4,
|
| 332 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--scale_lr",
|
| 336 |
+
action="store_true",
|
| 337 |
+
default=False,
|
| 338 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--lr_scheduler",
|
| 342 |
+
type=str,
|
| 343 |
+
default="constant",
|
| 344 |
+
help=(
|
| 345 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 346 |
+
' "constant", "constant_with_warmup"]'
|
| 347 |
+
),
|
| 348 |
+
)
|
| 349 |
+
parser.add_argument(
|
| 350 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 351 |
+
)
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--lr_num_cycles",
|
| 354 |
+
type=int,
|
| 355 |
+
default=1,
|
| 356 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--dataloader_num_workers",
|
| 361 |
+
type=int,
|
| 362 |
+
default=0,
|
| 363 |
+
help=(
|
| 364 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 365 |
+
),
|
| 366 |
+
)
|
| 367 |
+
parser.add_argument(
|
| 368 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 371 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 372 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 373 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 374 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 375 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 376 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 377 |
+
parser.add_argument(
|
| 378 |
+
"--hub_model_id",
|
| 379 |
+
type=str,
|
| 380 |
+
default=None,
|
| 381 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--logging_dir",
|
| 385 |
+
type=str,
|
| 386 |
+
default="logs",
|
| 387 |
+
help=(
|
| 388 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 389 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 390 |
+
),
|
| 391 |
+
)
|
| 392 |
+
parser.add_argument(
|
| 393 |
+
"--allow_tf32",
|
| 394 |
+
action="store_true",
|
| 395 |
+
help=(
|
| 396 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 397 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 398 |
+
),
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--report_to",
|
| 402 |
+
type=str,
|
| 403 |
+
default="tensorboard",
|
| 404 |
+
help=(
|
| 405 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 406 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 407 |
+
),
|
| 408 |
+
)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--mixed_precision",
|
| 411 |
+
type=str,
|
| 412 |
+
default=None,
|
| 413 |
+
choices=["no", "fp16", "bf16"],
|
| 414 |
+
help=(
|
| 415 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 416 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 417 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 418 |
+
),
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--prior_generation_precision",
|
| 422 |
+
type=str,
|
| 423 |
+
default=None,
|
| 424 |
+
choices=["no", "fp32", "fp16", "bf16"],
|
| 425 |
+
help=(
|
| 426 |
+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 427 |
+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
| 428 |
+
),
|
| 429 |
+
)
|
| 430 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 433 |
+
)
|
| 434 |
+
parser.add_argument(
|
| 435 |
+
"--rank",
|
| 436 |
+
type=int,
|
| 437 |
+
default=4,
|
| 438 |
+
help=("The dimension of the LoRA update matrices."),
|
| 439 |
+
)
|
| 440 |
+
parser.add_argument(
|
| 441 |
+
"--tracker_name",
|
| 442 |
+
type=str,
|
| 443 |
+
default="diffusion-orpo-lora-sdxl",
|
| 444 |
+
help=("The name of the tracker to report results to."),
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if input_args is not None:
|
| 448 |
+
args = parser.parse_args(input_args)
|
| 449 |
+
else:
|
| 450 |
+
args = parser.parse_args()
|
| 451 |
+
|
| 452 |
+
if args.dataset_name is None:
|
| 453 |
+
raise ValueError("Must provide a `dataset_name`.")
|
| 454 |
+
|
| 455 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 456 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 457 |
+
args.local_rank = env_local_rank
|
| 458 |
+
|
| 459 |
+
return args
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def tokenize_captions(tokenizers, examples):
|
| 463 |
+
captions = []
|
| 464 |
+
for caption in examples["caption"]:
|
| 465 |
+
captions.append(caption)
|
| 466 |
+
|
| 467 |
+
tokens_one = tokenizers[0](
|
| 468 |
+
captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt"
|
| 469 |
+
).input_ids
|
| 470 |
+
tokens_two = tokenizers[1](
|
| 471 |
+
captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt"
|
| 472 |
+
).input_ids
|
| 473 |
+
|
| 474 |
+
return tokens_one, tokens_two
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
@torch.no_grad()
|
| 478 |
+
def encode_prompt(text_encoders, text_input_ids_list):
|
| 479 |
+
prompt_embeds_list = []
|
| 480 |
+
|
| 481 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 482 |
+
text_input_ids = text_input_ids_list[i]
|
| 483 |
+
|
| 484 |
+
prompt_embeds = text_encoder(
|
| 485 |
+
text_input_ids.to(text_encoder.device),
|
| 486 |
+
output_hidden_states=True,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 490 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 491 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 492 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 493 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 494 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 495 |
+
|
| 496 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 497 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 498 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def main(args):
|
| 502 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 503 |
+
raise ValueError(
|
| 504 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 505 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 509 |
+
|
| 510 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 511 |
+
|
| 512 |
+
accelerator = Accelerator(
|
| 513 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 514 |
+
mixed_precision=args.mixed_precision,
|
| 515 |
+
log_with=args.report_to,
|
| 516 |
+
project_config=accelerator_project_config,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# Disable AMP for MPS.
|
| 520 |
+
if torch.backends.mps.is_available():
|
| 521 |
+
accelerator.native_amp = False
|
| 522 |
+
|
| 523 |
+
# Make one log on every process with the configuration for debugging.
|
| 524 |
+
logging.basicConfig(
|
| 525 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 526 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 527 |
+
level=logging.INFO,
|
| 528 |
+
)
|
| 529 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 530 |
+
if accelerator.is_local_main_process:
|
| 531 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 532 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 533 |
+
else:
|
| 534 |
+
transformers.utils.logging.set_verbosity_error()
|
| 535 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 536 |
+
|
| 537 |
+
# If passed along, set the training seed now.
|
| 538 |
+
if args.seed is not None:
|
| 539 |
+
set_seed(args.seed)
|
| 540 |
+
|
| 541 |
+
# Handle the repository creation
|
| 542 |
+
if accelerator.is_main_process:
|
| 543 |
+
if args.output_dir is not None:
|
| 544 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 545 |
+
|
| 546 |
+
if args.push_to_hub:
|
| 547 |
+
repo_id = create_repo(
|
| 548 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 549 |
+
).repo_id
|
| 550 |
+
|
| 551 |
+
# Load the tokenizers
|
| 552 |
+
tokenizer_one = AutoTokenizer.from_pretrained(
|
| 553 |
+
args.pretrained_model_name_or_path,
|
| 554 |
+
subfolder="tokenizer",
|
| 555 |
+
revision=args.revision,
|
| 556 |
+
use_fast=False,
|
| 557 |
+
)
|
| 558 |
+
tokenizer_two = AutoTokenizer.from_pretrained(
|
| 559 |
+
args.pretrained_model_name_or_path,
|
| 560 |
+
subfolder="tokenizer_2",
|
| 561 |
+
revision=args.revision,
|
| 562 |
+
use_fast=False,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# import correct text encoder classes
|
| 566 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
| 567 |
+
args.pretrained_model_name_or_path, args.revision
|
| 568 |
+
)
|
| 569 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
| 570 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Load scheduler and models
|
| 574 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 575 |
+
|
| 576 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
| 577 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 578 |
+
)
|
| 579 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
| 580 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
| 581 |
+
)
|
| 582 |
+
vae_path = (
|
| 583 |
+
args.pretrained_model_name_or_path
|
| 584 |
+
if args.pretrained_vae_model_name_or_path is None
|
| 585 |
+
else args.pretrained_vae_model_name_or_path
|
| 586 |
+
)
|
| 587 |
+
vae = AutoencoderKL.from_pretrained(
|
| 588 |
+
vae_path,
|
| 589 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
| 590 |
+
revision=args.revision,
|
| 591 |
+
variant=args.variant,
|
| 592 |
+
)
|
| 593 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 594 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# We only train the additional adapter LoRA layers
|
| 598 |
+
vae.requires_grad_(False)
|
| 599 |
+
text_encoder_one.requires_grad_(False)
|
| 600 |
+
text_encoder_two.requires_grad_(False)
|
| 601 |
+
unet.requires_grad_(False)
|
| 602 |
+
|
| 603 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 604 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 605 |
+
weight_dtype = torch.float32
|
| 606 |
+
if accelerator.mixed_precision == "fp16":
|
| 607 |
+
weight_dtype = torch.float16
|
| 608 |
+
elif accelerator.mixed_precision == "bf16":
|
| 609 |
+
weight_dtype = torch.bfloat16
|
| 610 |
+
|
| 611 |
+
# Move unet and text_encoders to device and cast to weight_dtype
|
| 612 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 613 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 614 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 615 |
+
|
| 616 |
+
# The VAE is always in float32 to avoid NaN losses.
|
| 617 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
| 618 |
+
|
| 619 |
+
# Set up LoRA.
|
| 620 |
+
unet_lora_config = LoraConfig(
|
| 621 |
+
r=args.rank,
|
| 622 |
+
lora_alpha=args.rank,
|
| 623 |
+
init_lora_weights="gaussian",
|
| 624 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
| 625 |
+
)
|
| 626 |
+
# Add adapter and make sure the trainable params are in float32.
|
| 627 |
+
unet.add_adapter(unet_lora_config)
|
| 628 |
+
if args.mixed_precision == "fp16":
|
| 629 |
+
for param in unet.parameters():
|
| 630 |
+
# only upcast trainable parameters (LoRA) into fp32
|
| 631 |
+
if param.requires_grad:
|
| 632 |
+
param.data = param.to(torch.float32)
|
| 633 |
+
|
| 634 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 635 |
+
if is_xformers_available():
|
| 636 |
+
import xformers
|
| 637 |
+
|
| 638 |
+
xformers_version = version.parse(xformers.__version__)
|
| 639 |
+
if xformers_version == version.parse("0.0.16"):
|
| 640 |
+
logger.warning(
|
| 641 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 642 |
+
)
|
| 643 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 644 |
+
else:
|
| 645 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 646 |
+
|
| 647 |
+
if args.gradient_checkpointing:
|
| 648 |
+
unet.enable_gradient_checkpointing()
|
| 649 |
+
|
| 650 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 651 |
+
def save_model_hook(models, weights, output_dir):
|
| 652 |
+
if accelerator.is_main_process:
|
| 653 |
+
# there are only two options here. Either are just the unet attn processor layers
|
| 654 |
+
# or there are the unet and text encoder atten layers
|
| 655 |
+
unet_lora_layers_to_save = None
|
| 656 |
+
|
| 657 |
+
for model in models:
|
| 658 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 659 |
+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
| 660 |
+
else:
|
| 661 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 662 |
+
|
| 663 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 664 |
+
weights.pop()
|
| 665 |
+
|
| 666 |
+
StableDiffusionXLLoraLoaderMixin.save_lora_weights(
|
| 667 |
+
output_dir,
|
| 668 |
+
unet_lora_layers=unet_lora_layers_to_save,
|
| 669 |
+
text_encoder_lora_layers=None,
|
| 670 |
+
text_encoder_2_lora_layers=None,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
def load_model_hook(models, input_dir):
|
| 674 |
+
unet_ = None
|
| 675 |
+
|
| 676 |
+
while len(models) > 0:
|
| 677 |
+
model = models.pop()
|
| 678 |
+
|
| 679 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 680 |
+
unet_ = model
|
| 681 |
+
else:
|
| 682 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 683 |
+
|
| 684 |
+
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
| 685 |
+
|
| 686 |
+
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
| 687 |
+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
| 688 |
+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
| 689 |
+
if incompatible_keys is not None:
|
| 690 |
+
# check only for unexpected keys
|
| 691 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
| 692 |
+
if unexpected_keys:
|
| 693 |
+
logger.warning(
|
| 694 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
| 695 |
+
f" {unexpected_keys}. "
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 699 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 700 |
+
|
| 701 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 702 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 703 |
+
if args.allow_tf32:
|
| 704 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 705 |
+
|
| 706 |
+
if args.scale_lr:
|
| 707 |
+
args.learning_rate = (
|
| 708 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 712 |
+
if args.use_8bit_adam:
|
| 713 |
+
try:
|
| 714 |
+
import bitsandbytes as bnb
|
| 715 |
+
except ImportError:
|
| 716 |
+
raise ImportError(
|
| 717 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 721 |
+
else:
|
| 722 |
+
optimizer_class = torch.optim.AdamW
|
| 723 |
+
|
| 724 |
+
# Optimizer creation
|
| 725 |
+
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 726 |
+
optimizer = optimizer_class(
|
| 727 |
+
params_to_optimize,
|
| 728 |
+
lr=args.learning_rate,
|
| 729 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 730 |
+
weight_decay=args.adam_weight_decay,
|
| 731 |
+
eps=args.adam_epsilon,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Dataset and DataLoaders creation:
|
| 735 |
+
train_dataset = load_dataset(
|
| 736 |
+
args.dataset_name,
|
| 737 |
+
cache_dir=args.cache_dir,
|
| 738 |
+
split=args.dataset_split_name,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Preprocessing the datasets.
|
| 742 |
+
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
| 743 |
+
train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
|
| 744 |
+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 745 |
+
to_tensor = transforms.ToTensor()
|
| 746 |
+
normalize = transforms.Normalize([0.5], [0.5])
|
| 747 |
+
|
| 748 |
+
def preprocess_train(examples):
|
| 749 |
+
all_pixel_values = []
|
| 750 |
+
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]]
|
| 751 |
+
original_sizes = [(image.height, image.width) for image in images]
|
| 752 |
+
crop_top_lefts = []
|
| 753 |
+
|
| 754 |
+
for col_name in ["jpg_0", "jpg_1"]:
|
| 755 |
+
images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
|
| 756 |
+
if col_name == "jpg_1":
|
| 757 |
+
# Need to bring down the image to the same resolution.
|
| 758 |
+
# This seems like the simplest reasonable approach.
|
| 759 |
+
# "::-1" because PIL resize takes (width, height).
|
| 760 |
+
images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]
|
| 761 |
+
pixel_values = [to_tensor(image) for image in images]
|
| 762 |
+
all_pixel_values.append(pixel_values)
|
| 763 |
+
|
| 764 |
+
# Double on channel dim, jpg_y then jpg_w
|
| 765 |
+
im_tup_iterator = zip(*all_pixel_values)
|
| 766 |
+
combined_pixel_values = []
|
| 767 |
+
for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
|
| 768 |
+
# We randomize selection and rejection.
|
| 769 |
+
if label_0 == 0.5:
|
| 770 |
+
if random.random() < 0.5:
|
| 771 |
+
label_0 = 0
|
| 772 |
+
else:
|
| 773 |
+
label_0 = 1
|
| 774 |
+
|
| 775 |
+
if label_0 == 0:
|
| 776 |
+
im_tup = im_tup[::-1]
|
| 777 |
+
|
| 778 |
+
combined_im = torch.cat(im_tup, dim=0) # no batch dim
|
| 779 |
+
|
| 780 |
+
# Resize.
|
| 781 |
+
combined_im = train_resize(combined_im)
|
| 782 |
+
|
| 783 |
+
# Flipping.
|
| 784 |
+
if not args.no_hflip and random.random() < 0.5:
|
| 785 |
+
combined_im = train_flip(combined_im)
|
| 786 |
+
|
| 787 |
+
# Cropping.
|
| 788 |
+
if not args.random_crop:
|
| 789 |
+
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
|
| 790 |
+
x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
|
| 791 |
+
combined_im = train_crop(combined_im)
|
| 792 |
+
else:
|
| 793 |
+
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
|
| 794 |
+
combined_im = crop(combined_im, y1, x1, h, w)
|
| 795 |
+
|
| 796 |
+
crop_top_left = (y1, x1)
|
| 797 |
+
crop_top_lefts.append(crop_top_left)
|
| 798 |
+
combined_im = normalize(combined_im)
|
| 799 |
+
combined_pixel_values.append(combined_im)
|
| 800 |
+
|
| 801 |
+
examples["pixel_values"] = combined_pixel_values
|
| 802 |
+
examples["original_sizes"] = original_sizes
|
| 803 |
+
examples["crop_top_lefts"] = crop_top_lefts
|
| 804 |
+
tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)
|
| 805 |
+
examples["input_ids_one"] = tokens_one
|
| 806 |
+
examples["input_ids_two"] = tokens_two
|
| 807 |
+
return examples
|
| 808 |
+
|
| 809 |
+
with accelerator.main_process_first():
|
| 810 |
+
if args.max_train_samples is not None:
|
| 811 |
+
train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 812 |
+
# Set the training transforms
|
| 813 |
+
train_dataset = train_dataset.with_transform(preprocess_train)
|
| 814 |
+
|
| 815 |
+
def collate_fn(examples):
|
| 816 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 817 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 818 |
+
original_sizes = [example["original_sizes"] for example in examples]
|
| 819 |
+
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
| 820 |
+
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
| 821 |
+
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
| 822 |
+
|
| 823 |
+
return {
|
| 824 |
+
"pixel_values": pixel_values,
|
| 825 |
+
"input_ids_one": input_ids_one,
|
| 826 |
+
"input_ids_two": input_ids_two,
|
| 827 |
+
"original_sizes": original_sizes,
|
| 828 |
+
"crop_top_lefts": crop_top_lefts,
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 832 |
+
train_dataset,
|
| 833 |
+
batch_size=args.train_batch_size,
|
| 834 |
+
shuffle=True,
|
| 835 |
+
collate_fn=collate_fn,
|
| 836 |
+
num_workers=args.dataloader_num_workers,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
# Scheduler and math around the number of training steps.
|
| 840 |
+
overrode_max_train_steps = False
|
| 841 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 842 |
+
if args.max_train_steps is None:
|
| 843 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 844 |
+
overrode_max_train_steps = True
|
| 845 |
+
|
| 846 |
+
lr_scheduler = get_scheduler(
|
| 847 |
+
args.lr_scheduler,
|
| 848 |
+
optimizer=optimizer,
|
| 849 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 850 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 851 |
+
num_cycles=args.lr_num_cycles,
|
| 852 |
+
power=args.lr_power,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 856 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 860 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 861 |
+
if overrode_max_train_steps:
|
| 862 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 863 |
+
# Afterwards we recalculate our number of training epochs
|
| 864 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 865 |
+
|
| 866 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 867 |
+
# The trackers initializes automatically on the main process.
|
| 868 |
+
if accelerator.is_main_process:
|
| 869 |
+
accelerator.init_trackers(args.tracker_name, config=vars(args))
|
| 870 |
+
|
| 871 |
+
# Train!
|
| 872 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 873 |
+
|
| 874 |
+
logger.info("***** Running training *****")
|
| 875 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 876 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 877 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 878 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 879 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 880 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 881 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 882 |
+
global_step = 0
|
| 883 |
+
first_epoch = 0
|
| 884 |
+
|
| 885 |
+
# Potentially load in the weights and states from a previous save
|
| 886 |
+
if args.resume_from_checkpoint:
|
| 887 |
+
if args.resume_from_checkpoint != "latest":
|
| 888 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 889 |
+
else:
|
| 890 |
+
# Get the mos recent checkpoint
|
| 891 |
+
dirs = os.listdir(args.output_dir)
|
| 892 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 893 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 894 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 895 |
+
|
| 896 |
+
if path is None:
|
| 897 |
+
accelerator.print(
|
| 898 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 899 |
+
)
|
| 900 |
+
args.resume_from_checkpoint = None
|
| 901 |
+
initial_global_step = 0
|
| 902 |
+
else:
|
| 903 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 904 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 905 |
+
global_step = int(path.split("-")[1])
|
| 906 |
+
|
| 907 |
+
initial_global_step = global_step
|
| 908 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 909 |
+
else:
|
| 910 |
+
initial_global_step = 0
|
| 911 |
+
|
| 912 |
+
progress_bar = tqdm(
|
| 913 |
+
range(0, args.max_train_steps),
|
| 914 |
+
initial=initial_global_step,
|
| 915 |
+
desc="Steps",
|
| 916 |
+
# Only show the progress bar once on each machine.
|
| 917 |
+
disable=not accelerator.is_local_main_process,
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
unet.train()
|
| 921 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 922 |
+
for step, batch in enumerate(train_dataloader):
|
| 923 |
+
with accelerator.accumulate(unet):
|
| 924 |
+
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
|
| 925 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 926 |
+
feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
|
| 927 |
+
|
| 928 |
+
latents = []
|
| 929 |
+
for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
|
| 930 |
+
latents.append(
|
| 931 |
+
vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
|
| 932 |
+
)
|
| 933 |
+
latents = torch.cat(latents, dim=0)
|
| 934 |
+
latents = latents * vae.config.scaling_factor
|
| 935 |
+
if args.pretrained_vae_model_name_or_path is None:
|
| 936 |
+
latents = latents.to(weight_dtype)
|
| 937 |
+
|
| 938 |
+
# Sample noise that we'll add to the latents
|
| 939 |
+
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
|
| 940 |
+
|
| 941 |
+
# Sample a random timestep for each image
|
| 942 |
+
bsz = latents.shape[0] // 2
|
| 943 |
+
timesteps = torch.randint(
|
| 944 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
| 945 |
+
).repeat(2)
|
| 946 |
+
|
| 947 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
| 948 |
+
# (this is the forward diffusion process)
|
| 949 |
+
noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 950 |
+
|
| 951 |
+
# time ids
|
| 952 |
+
def compute_time_ids(original_size, crops_coords_top_left):
|
| 953 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
| 954 |
+
target_size = (args.resolution, args.resolution)
|
| 955 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 956 |
+
add_time_ids = torch.tensor([add_time_ids])
|
| 957 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
| 958 |
+
return add_time_ids
|
| 959 |
+
|
| 960 |
+
add_time_ids = torch.cat(
|
| 961 |
+
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
| 962 |
+
).repeat(2, 1)
|
| 963 |
+
|
| 964 |
+
# Get the text embedding for conditioning
|
| 965 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
| 966 |
+
[text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
|
| 967 |
+
)
|
| 968 |
+
prompt_embeds = prompt_embeds.repeat(2, 1, 1)
|
| 969 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
|
| 970 |
+
|
| 971 |
+
# Predict the noise residual
|
| 972 |
+
model_pred = unet(
|
| 973 |
+
noisy_model_input,
|
| 974 |
+
timesteps,
|
| 975 |
+
prompt_embeds,
|
| 976 |
+
added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
|
| 977 |
+
).sample
|
| 978 |
+
|
| 979 |
+
# Get the target for loss depending on the prediction type
|
| 980 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 981 |
+
target = noise
|
| 982 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 983 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 984 |
+
else:
|
| 985 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 986 |
+
|
| 987 |
+
# ODDS ratio loss.
|
| 988 |
+
# In the diffusion formulation, we're assuming that the MSE loss
|
| 989 |
+
# approximates the logp.
|
| 990 |
+
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 991 |
+
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
| 992 |
+
model_losses_w, model_losses_l = model_losses.chunk(2)
|
| 993 |
+
log_odds = model_losses_w - model_losses_l
|
| 994 |
+
|
| 995 |
+
# Ratio loss.
|
| 996 |
+
ratio = F.logsigmoid(log_odds)
|
| 997 |
+
ratio_losses = args.beta_orpo * ratio
|
| 998 |
+
|
| 999 |
+
# Full ORPO loss
|
| 1000 |
+
loss = model_losses_w.mean() - ratio_losses.mean()
|
| 1001 |
+
|
| 1002 |
+
# Backprop.
|
| 1003 |
+
accelerator.backward(loss)
|
| 1004 |
+
if accelerator.sync_gradients:
|
| 1005 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
| 1006 |
+
optimizer.step()
|
| 1007 |
+
lr_scheduler.step()
|
| 1008 |
+
optimizer.zero_grad()
|
| 1009 |
+
|
| 1010 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1011 |
+
if accelerator.sync_gradients:
|
| 1012 |
+
progress_bar.update(1)
|
| 1013 |
+
global_step += 1
|
| 1014 |
+
|
| 1015 |
+
if accelerator.is_main_process:
|
| 1016 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1017 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1018 |
+
if args.checkpoints_total_limit is not None:
|
| 1019 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1020 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1021 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1022 |
+
|
| 1023 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 1024 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1025 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1026 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1027 |
+
|
| 1028 |
+
logger.info(
|
| 1029 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 1030 |
+
)
|
| 1031 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 1032 |
+
|
| 1033 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1034 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1035 |
+
shutil.rmtree(removing_checkpoint)
|
| 1036 |
+
|
| 1037 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1038 |
+
accelerator.save_state(save_path)
|
| 1039 |
+
logger.info(f"Saved state to {save_path}")
|
| 1040 |
+
|
| 1041 |
+
if args.run_validation and global_step % args.validation_steps == 0:
|
| 1042 |
+
log_validation(
|
| 1043 |
+
args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1047 |
+
progress_bar.set_postfix(**logs)
|
| 1048 |
+
accelerator.log(logs, step=global_step)
|
| 1049 |
+
|
| 1050 |
+
if global_step >= args.max_train_steps:
|
| 1051 |
+
break
|
| 1052 |
+
|
| 1053 |
+
# Save the lora layers
|
| 1054 |
+
accelerator.wait_for_everyone()
|
| 1055 |
+
if accelerator.is_main_process:
|
| 1056 |
+
unet = accelerator.unwrap_model(unet)
|
| 1057 |
+
unet = unet.to(torch.float32)
|
| 1058 |
+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
| 1059 |
+
|
| 1060 |
+
StableDiffusionXLLoraLoaderMixin.save_lora_weights(
|
| 1061 |
+
save_directory=args.output_dir,
|
| 1062 |
+
unet_lora_layers=unet_lora_state_dict,
|
| 1063 |
+
text_encoder_lora_layers=None,
|
| 1064 |
+
text_encoder_2_lora_layers=None,
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
# Final validation?
|
| 1068 |
+
if args.run_validation:
|
| 1069 |
+
log_validation(
|
| 1070 |
+
args,
|
| 1071 |
+
unet=None,
|
| 1072 |
+
vae=vae,
|
| 1073 |
+
accelerator=accelerator,
|
| 1074 |
+
weight_dtype=weight_dtype,
|
| 1075 |
+
epoch=epoch,
|
| 1076 |
+
is_final_validation=True,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
if args.push_to_hub:
|
| 1080 |
+
upload_folder(
|
| 1081 |
+
repo_id=repo_id,
|
| 1082 |
+
folder_path=args.output_dir,
|
| 1083 |
+
commit_message="End of training",
|
| 1084 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
accelerator.end_training()
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
if __name__ == "__main__":
|
| 1091 |
+
args = parse_args()
|
| 1092 |
+
main(args)
|
diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import contextlib
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
import transformers
|
| 30 |
+
import wandb
|
| 31 |
+
import webdataset as wds
|
| 32 |
+
from accelerate import Accelerator
|
| 33 |
+
from accelerate.logging import get_logger
|
| 34 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 35 |
+
from huggingface_hub import create_repo, upload_folder
|
| 36 |
+
from packaging import version
|
| 37 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 38 |
+
from peft.utils import get_peft_model_state_dict
|
| 39 |
+
from torchvision import transforms
|
| 40 |
+
from torchvision.transforms.functional import crop
|
| 41 |
+
from tqdm.auto import tqdm
|
| 42 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 43 |
+
|
| 44 |
+
import diffusers
|
| 45 |
+
from diffusers import (
|
| 46 |
+
AutoencoderKL,
|
| 47 |
+
DDPMScheduler,
|
| 48 |
+
DiffusionPipeline,
|
| 49 |
+
UNet2DConditionModel,
|
| 50 |
+
)
|
| 51 |
+
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
|
| 52 |
+
from diffusers.optimization import get_scheduler
|
| 53 |
+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
|
| 54 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 58 |
+
check_min_version("0.25.0.dev0")
|
| 59 |
+
|
| 60 |
+
logger = get_logger(__name__)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
VALIDATION_PROMPTS = [
|
| 64 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
|
| 65 |
+
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 66 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 67 |
+
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def import_model_class_from_model_name_or_path(
|
| 72 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
| 73 |
+
):
|
| 74 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 75 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 76 |
+
)
|
| 77 |
+
model_class = text_encoder_config.architectures[0]
|
| 78 |
+
|
| 79 |
+
if model_class == "CLIPTextModel":
|
| 80 |
+
from transformers import CLIPTextModel
|
| 81 |
+
|
| 82 |
+
return CLIPTextModel
|
| 83 |
+
elif model_class == "CLIPTextModelWithProjection":
|
| 84 |
+
from transformers import CLIPTextModelWithProjection
|
| 85 |
+
|
| 86 |
+
return CLIPTextModelWithProjection
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
|
| 92 |
+
logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
|
| 93 |
+
|
| 94 |
+
if is_final_validation:
|
| 95 |
+
if args.mixed_precision == "fp16":
|
| 96 |
+
vae.to(weight_dtype)
|
| 97 |
+
|
| 98 |
+
# create pipeline
|
| 99 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 100 |
+
args.pretrained_model_name_or_path,
|
| 101 |
+
vae=vae,
|
| 102 |
+
revision=args.revision,
|
| 103 |
+
variant=args.variant,
|
| 104 |
+
torch_dtype=weight_dtype,
|
| 105 |
+
)
|
| 106 |
+
if not is_final_validation:
|
| 107 |
+
pipeline.unet = accelerator.unwrap_model(unet)
|
| 108 |
+
else:
|
| 109 |
+
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
| 110 |
+
|
| 111 |
+
pipeline = pipeline.to(accelerator.device)
|
| 112 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 113 |
+
|
| 114 |
+
# run inference
|
| 115 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 116 |
+
images = []
|
| 117 |
+
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
| 118 |
+
|
| 119 |
+
guidance_scale = 5.0
|
| 120 |
+
num_inference_steps = 25
|
| 121 |
+
for prompt in VALIDATION_PROMPTS:
|
| 122 |
+
with context:
|
| 123 |
+
image = pipeline(
|
| 124 |
+
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
| 125 |
+
).images[0]
|
| 126 |
+
images.append(image)
|
| 127 |
+
|
| 128 |
+
tracker_key = "test" if is_final_validation else "validation"
|
| 129 |
+
for tracker in accelerator.trackers:
|
| 130 |
+
if tracker.name == "tensorboard":
|
| 131 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 132 |
+
tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
|
| 133 |
+
if tracker.name == "wandb":
|
| 134 |
+
tracker.log(
|
| 135 |
+
{
|
| 136 |
+
tracker_key: [
|
| 137 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
|
| 138 |
+
]
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Also log images without the LoRA params for comparison.
|
| 143 |
+
if is_final_validation:
|
| 144 |
+
pipeline.disable_lora()
|
| 145 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 146 |
+
no_lora_images = [
|
| 147 |
+
pipeline(
|
| 148 |
+
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
| 149 |
+
).images[0]
|
| 150 |
+
for prompt in VALIDATION_PROMPTS
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
for tracker in accelerator.trackers:
|
| 154 |
+
if tracker.name == "tensorboard":
|
| 155 |
+
np_images = np.stack([np.asarray(img) for img in no_lora_images])
|
| 156 |
+
tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
|
| 157 |
+
if tracker.name == "wandb":
|
| 158 |
+
tracker.log(
|
| 159 |
+
{
|
| 160 |
+
"test_without_lora": [
|
| 161 |
+
wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
|
| 162 |
+
for i, image in enumerate(no_lora_images)
|
| 163 |
+
]
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def parse_args(input_args=None):
|
| 169 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--pretrained_model_name_or_path",
|
| 172 |
+
type=str,
|
| 173 |
+
default=None,
|
| 174 |
+
required=True,
|
| 175 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--pretrained_vae_model_name_or_path",
|
| 179 |
+
type=str,
|
| 180 |
+
default=None,
|
| 181 |
+
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--revision",
|
| 185 |
+
type=str,
|
| 186 |
+
default=None,
|
| 187 |
+
required=False,
|
| 188 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--dataset_path",
|
| 192 |
+
type=str,
|
| 193 |
+
default="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--num_train_examples",
|
| 197 |
+
type=int,
|
| 198 |
+
default=1001352,
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--variant",
|
| 202 |
+
type=str,
|
| 203 |
+
default=None,
|
| 204 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--run_validation",
|
| 208 |
+
default=False,
|
| 209 |
+
action="store_true",
|
| 210 |
+
help="Whether to run validation inference in between training and also after training. Helps to track progress.",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--validation_steps",
|
| 214 |
+
type=int,
|
| 215 |
+
default=200,
|
| 216 |
+
help="Run validation every X steps.",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--output_dir",
|
| 220 |
+
type=str,
|
| 221 |
+
default="diffusion-orpo-lora-sdxl",
|
| 222 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--cache_dir",
|
| 226 |
+
type=str,
|
| 227 |
+
default=None,
|
| 228 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--resolution",
|
| 233 |
+
type=int,
|
| 234 |
+
default=1024,
|
| 235 |
+
help=(
|
| 236 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 237 |
+
" resolution"
|
| 238 |
+
),
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--vae_encode_batch_size",
|
| 242 |
+
type=int,
|
| 243 |
+
default=8,
|
| 244 |
+
help="Batch size to use for VAE encoding of the images for efficient processing.",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--no_hflip",
|
| 248 |
+
action="store_true",
|
| 249 |
+
help="whether to randomly flip images horizontally",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--random_crop",
|
| 253 |
+
default=False,
|
| 254 |
+
action="store_true",
|
| 255 |
+
help=(
|
| 256 |
+
"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument("--global_batch_size", type=int, default=64, help="Total batch size.")
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--per_gpu_batch_size", type=int, default=8, help="Number of samples in a batch for a single GPU."
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--max_train_steps",
|
| 266 |
+
type=int,
|
| 267 |
+
default=None,
|
| 268 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--checkpointing_steps",
|
| 272 |
+
type=int,
|
| 273 |
+
default=500,
|
| 274 |
+
help=(
|
| 275 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 276 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 277 |
+
" training using `--resume_from_checkpoint`."
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--checkpoints_total_limit",
|
| 282 |
+
type=int,
|
| 283 |
+
default=None,
|
| 284 |
+
help=("Max number of checkpoints to store."),
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--resume_from_checkpoint",
|
| 288 |
+
type=str,
|
| 289 |
+
default=None,
|
| 290 |
+
help=(
|
| 291 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 292 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 293 |
+
),
|
| 294 |
+
)
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--gradient_accumulation_steps",
|
| 297 |
+
type=int,
|
| 298 |
+
default=1,
|
| 299 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--gradient_checkpointing",
|
| 303 |
+
action="store_true",
|
| 304 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--beta_orpo",
|
| 308 |
+
type=float,
|
| 309 |
+
default=0.1,
|
| 310 |
+
help="ORPO contribution factor.",
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--learning_rate",
|
| 314 |
+
type=float,
|
| 315 |
+
default=5e-4,
|
| 316 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--scale_lr",
|
| 320 |
+
action="store_true",
|
| 321 |
+
default=False,
|
| 322 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--lr_scheduler",
|
| 326 |
+
type=str,
|
| 327 |
+
default="constant",
|
| 328 |
+
help=(
|
| 329 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 330 |
+
' "constant", "constant_with_warmup"]'
|
| 331 |
+
),
|
| 332 |
+
)
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--lr_num_cycles",
|
| 338 |
+
type=int,
|
| 339 |
+
default=1,
|
| 340 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--dataloader_num_workers",
|
| 345 |
+
type=int,
|
| 346 |
+
default=0,
|
| 347 |
+
help=(
|
| 348 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 349 |
+
),
|
| 350 |
+
)
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 355 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 356 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 357 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 358 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 359 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 360 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 361 |
+
parser.add_argument(
|
| 362 |
+
"--hub_model_id",
|
| 363 |
+
type=str,
|
| 364 |
+
default=None,
|
| 365 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 366 |
+
)
|
| 367 |
+
parser.add_argument(
|
| 368 |
+
"--logging_dir",
|
| 369 |
+
type=str,
|
| 370 |
+
default="logs",
|
| 371 |
+
help=(
|
| 372 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 373 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 374 |
+
),
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--allow_tf32",
|
| 378 |
+
action="store_true",
|
| 379 |
+
help=(
|
| 380 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 381 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 382 |
+
),
|
| 383 |
+
)
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
"--report_to",
|
| 386 |
+
type=str,
|
| 387 |
+
default="tensorboard",
|
| 388 |
+
help=(
|
| 389 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 390 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 391 |
+
),
|
| 392 |
+
)
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--mixed_precision",
|
| 395 |
+
type=str,
|
| 396 |
+
default=None,
|
| 397 |
+
choices=["no", "fp16", "bf16"],
|
| 398 |
+
help=(
|
| 399 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 400 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 401 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 402 |
+
),
|
| 403 |
+
)
|
| 404 |
+
parser.add_argument(
|
| 405 |
+
"--prior_generation_precision",
|
| 406 |
+
type=str,
|
| 407 |
+
default=None,
|
| 408 |
+
choices=["no", "fp32", "fp16", "bf16"],
|
| 409 |
+
help=(
|
| 410 |
+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 411 |
+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
| 412 |
+
),
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 415 |
+
parser.add_argument(
|
| 416 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 417 |
+
)
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--rank",
|
| 420 |
+
type=int,
|
| 421 |
+
default=4,
|
| 422 |
+
help=("The dimension of the LoRA update matrices."),
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--tracker_name",
|
| 426 |
+
type=str,
|
| 427 |
+
default="diffusion-orpo-lora-sdxl",
|
| 428 |
+
help=("The name of the tracker to report results to."),
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if input_args is not None:
|
| 432 |
+
args = parser.parse_args(input_args)
|
| 433 |
+
else:
|
| 434 |
+
args = parser.parse_args()
|
| 435 |
+
|
| 436 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 437 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 438 |
+
args.local_rank = env_local_rank
|
| 439 |
+
|
| 440 |
+
return args
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def tokenize_captions(tokenizers, sample):
|
| 444 |
+
tokens_one = tokenizers[0](
|
| 445 |
+
sample["original_prompt"],
|
| 446 |
+
truncation=True,
|
| 447 |
+
padding="max_length",
|
| 448 |
+
max_length=tokenizers[0].model_max_length,
|
| 449 |
+
return_tensors="pt",
|
| 450 |
+
).input_ids
|
| 451 |
+
tokens_two = tokenizers[1](
|
| 452 |
+
sample["original_prompt"],
|
| 453 |
+
truncation=True,
|
| 454 |
+
padding="max_length",
|
| 455 |
+
max_length=tokenizers[1].model_max_length,
|
| 456 |
+
return_tensors="pt",
|
| 457 |
+
).input_ids
|
| 458 |
+
|
| 459 |
+
return tokens_one, tokens_two
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@torch.no_grad()
|
| 463 |
+
def encode_prompt(text_encoders, text_input_ids_list):
|
| 464 |
+
prompt_embeds_list = []
|
| 465 |
+
|
| 466 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 467 |
+
text_input_ids = text_input_ids_list[i]
|
| 468 |
+
|
| 469 |
+
prompt_embeds = text_encoder(
|
| 470 |
+
text_input_ids.to(text_encoder.device),
|
| 471 |
+
output_hidden_states=True,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 475 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 476 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 477 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 478 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
| 479 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 480 |
+
|
| 481 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 482 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 483 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def get_dataset(args):
|
| 487 |
+
dataset = (
|
| 488 |
+
wds.WebDataset(args.dataset_path, resampled=True, handler=wds.warn_and_continue)
|
| 489 |
+
.shuffle(690, handler=wds.warn_and_continue)
|
| 490 |
+
.decode("pil", handler=wds.warn_and_continue)
|
| 491 |
+
.rename(
|
| 492 |
+
original_prompt="original_prompt.txt",
|
| 493 |
+
jpg_0="jpg_0.jpg",
|
| 494 |
+
jpg_1="jpg_1.jpg",
|
| 495 |
+
label_0="label_0.txt",
|
| 496 |
+
label_1="label_1.txt",
|
| 497 |
+
handler=wds.warn_and_continue,
|
| 498 |
+
)
|
| 499 |
+
)
|
| 500 |
+
return dataset
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def get_loader(args, tokenizer_one, tokenizer_two):
|
| 504 |
+
# 1,001,352
|
| 505 |
+
num_batches = math.ceil(args.num_train_examples / args.global_batch_size)
|
| 506 |
+
num_worker_batches = math.ceil(
|
| 507 |
+
args.num_train_examples / (args.global_batch_size * args.dataloader_num_workers)
|
| 508 |
+
) # per dataloader worker
|
| 509 |
+
num_batches = num_worker_batches * args.dataloader_num_workers
|
| 510 |
+
num_samples = num_batches * args.global_batch_size
|
| 511 |
+
|
| 512 |
+
dataset = get_dataset(args)
|
| 513 |
+
|
| 514 |
+
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
| 515 |
+
train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
|
| 516 |
+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 517 |
+
to_tensor = transforms.ToTensor()
|
| 518 |
+
normalize = transforms.Normalize([0.5], [0.5])
|
| 519 |
+
|
| 520 |
+
def preprocess_images(sample):
|
| 521 |
+
jpg_0_image = sample["jpg_0"]
|
| 522 |
+
original_size = (jpg_0_image.height, jpg_0_image.width)
|
| 523 |
+
crop_top_left = []
|
| 524 |
+
|
| 525 |
+
jpg_1_image = sample["jpg_1"]
|
| 526 |
+
# Need to bring down the image to the same resolution.
|
| 527 |
+
# This seems like the simplest reasonable approach.
|
| 528 |
+
# "::-1" because PIL resize takes (width, height).
|
| 529 |
+
jpg_1_image = jpg_1_image.resize(original_size[::-1])
|
| 530 |
+
|
| 531 |
+
# We randomize selection and rejection.
|
| 532 |
+
label_0 = sample["label_0"]
|
| 533 |
+
if sample["label_0"] == 0.5:
|
| 534 |
+
if random.random() < 0.5:
|
| 535 |
+
label_0 = 0
|
| 536 |
+
else:
|
| 537 |
+
label_0 = 1
|
| 538 |
+
|
| 539 |
+
# Double on channel dim, jpg_y then jpg_w
|
| 540 |
+
if label_0 == 0:
|
| 541 |
+
pixel_values = torch.cat([to_tensor(image) for image in [jpg_1_image, jpg_0_image]])
|
| 542 |
+
else:
|
| 543 |
+
pixel_values = torch.cat([to_tensor(image) for image in [jpg_0_image, jpg_1_image]])
|
| 544 |
+
|
| 545 |
+
# Resize.
|
| 546 |
+
combined_im = train_resize(pixel_values)
|
| 547 |
+
|
| 548 |
+
# Flipping.
|
| 549 |
+
if not args.no_hflip and random.random() < 0.5:
|
| 550 |
+
combined_im = train_flip(combined_im)
|
| 551 |
+
|
| 552 |
+
# Cropping.
|
| 553 |
+
if not args.random_crop:
|
| 554 |
+
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
|
| 555 |
+
x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
|
| 556 |
+
combined_im = train_crop(combined_im)
|
| 557 |
+
else:
|
| 558 |
+
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
|
| 559 |
+
combined_im = crop(combined_im, y1, x1, h, w)
|
| 560 |
+
|
| 561 |
+
crop_top_left = (y1, x1)
|
| 562 |
+
combined_im = normalize(combined_im)
|
| 563 |
+
tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], sample)
|
| 564 |
+
|
| 565 |
+
return {
|
| 566 |
+
"pixel_values": combined_im,
|
| 567 |
+
"original_size": original_size,
|
| 568 |
+
"crop_top_left": crop_top_left,
|
| 569 |
+
"tokens_one": tokens_one,
|
| 570 |
+
"tokens_two": tokens_two,
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
def collate_fn(samples):
|
| 574 |
+
pixel_values = torch.stack([sample["pixel_values"] for sample in samples])
|
| 575 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 576 |
+
|
| 577 |
+
original_sizes = [example["original_size"] for example in samples]
|
| 578 |
+
crop_top_lefts = [example["crop_top_left"] for example in samples]
|
| 579 |
+
input_ids_one = torch.stack([example["tokens_one"] for example in samples])
|
| 580 |
+
input_ids_two = torch.stack([example["tokens_two"] for example in samples])
|
| 581 |
+
|
| 582 |
+
return {
|
| 583 |
+
"pixel_values": pixel_values,
|
| 584 |
+
"input_ids_one": input_ids_one,
|
| 585 |
+
"input_ids_two": input_ids_two,
|
| 586 |
+
"original_sizes": original_sizes,
|
| 587 |
+
"crop_top_lefts": crop_top_lefts,
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
dataset = dataset.map(preprocess_images, handler=wds.warn_and_continue)
|
| 591 |
+
dataset = dataset.batched(args.per_gpu_batch_size, partial=False, collation_fn=collate_fn)
|
| 592 |
+
dataset = dataset.with_epoch(num_worker_batches)
|
| 593 |
+
|
| 594 |
+
dataloader = wds.WebLoader(
|
| 595 |
+
dataset,
|
| 596 |
+
batch_size=None,
|
| 597 |
+
shuffle=False,
|
| 598 |
+
num_workers=args.dataloader_num_workers,
|
| 599 |
+
pin_memory=True,
|
| 600 |
+
persistent_workers=True,
|
| 601 |
+
)
|
| 602 |
+
# add meta-data to dataloader instance for convenience
|
| 603 |
+
dataloader.num_batches = num_batches
|
| 604 |
+
dataloader.num_samples = num_samples
|
| 605 |
+
return dataloader
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def main(args):
|
| 609 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 610 |
+
raise ValueError(
|
| 611 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 612 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 616 |
+
|
| 617 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 618 |
+
|
| 619 |
+
accelerator = Accelerator(
|
| 620 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 621 |
+
mixed_precision=args.mixed_precision,
|
| 622 |
+
log_with=args.report_to,
|
| 623 |
+
project_config=accelerator_project_config,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
# Disable AMP for MPS.
|
| 627 |
+
if torch.backends.mps.is_available():
|
| 628 |
+
accelerator.native_amp = False
|
| 629 |
+
|
| 630 |
+
# Make one log on every process with the configuration for debugging.
|
| 631 |
+
logging.basicConfig(
|
| 632 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 633 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 634 |
+
level=logging.INFO,
|
| 635 |
+
)
|
| 636 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 637 |
+
if accelerator.is_local_main_process:
|
| 638 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 639 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 640 |
+
else:
|
| 641 |
+
transformers.utils.logging.set_verbosity_error()
|
| 642 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 643 |
+
|
| 644 |
+
# If passed along, set the training seed now.
|
| 645 |
+
if args.seed is not None:
|
| 646 |
+
set_seed(args.seed)
|
| 647 |
+
|
| 648 |
+
# Handle the repository creation
|
| 649 |
+
if accelerator.is_main_process:
|
| 650 |
+
if args.output_dir is not None:
|
| 651 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 652 |
+
|
| 653 |
+
if args.push_to_hub:
|
| 654 |
+
repo_id = create_repo(
|
| 655 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 656 |
+
).repo_id
|
| 657 |
+
|
| 658 |
+
# Load the tokenizers
|
| 659 |
+
tokenizer_one = AutoTokenizer.from_pretrained(
|
| 660 |
+
args.pretrained_model_name_or_path,
|
| 661 |
+
subfolder="tokenizer",
|
| 662 |
+
revision=args.revision,
|
| 663 |
+
use_fast=False,
|
| 664 |
+
)
|
| 665 |
+
tokenizer_two = AutoTokenizer.from_pretrained(
|
| 666 |
+
args.pretrained_model_name_or_path,
|
| 667 |
+
subfolder="tokenizer_2",
|
| 668 |
+
revision=args.revision,
|
| 669 |
+
use_fast=False,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# import correct text encoder classes
|
| 673 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
| 674 |
+
args.pretrained_model_name_or_path, args.revision
|
| 675 |
+
)
|
| 676 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
| 677 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Load scheduler and models
|
| 681 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 682 |
+
|
| 683 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
| 684 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 685 |
+
)
|
| 686 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
| 687 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
| 688 |
+
)
|
| 689 |
+
vae_path = (
|
| 690 |
+
args.pretrained_model_name_or_path
|
| 691 |
+
if args.pretrained_vae_model_name_or_path is None
|
| 692 |
+
else args.pretrained_vae_model_name_or_path
|
| 693 |
+
)
|
| 694 |
+
vae = AutoencoderKL.from_pretrained(
|
| 695 |
+
vae_path,
|
| 696 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
| 697 |
+
revision=args.revision,
|
| 698 |
+
variant=args.variant,
|
| 699 |
+
)
|
| 700 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 701 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
# We only train the additional adapter LoRA layers
|
| 705 |
+
vae.requires_grad_(False)
|
| 706 |
+
text_encoder_one.requires_grad_(False)
|
| 707 |
+
text_encoder_two.requires_grad_(False)
|
| 708 |
+
unet.requires_grad_(False)
|
| 709 |
+
|
| 710 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 711 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 712 |
+
weight_dtype = torch.float32
|
| 713 |
+
if accelerator.mixed_precision == "fp16":
|
| 714 |
+
weight_dtype = torch.float16
|
| 715 |
+
elif accelerator.mixed_precision == "bf16":
|
| 716 |
+
weight_dtype = torch.bfloat16
|
| 717 |
+
|
| 718 |
+
# Move unet and text_encoders to device and cast to weight_dtype
|
| 719 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 720 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 721 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 722 |
+
|
| 723 |
+
# The VAE is always in float32 to avoid NaN losses.
|
| 724 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
| 725 |
+
|
| 726 |
+
# Set up LoRA.
|
| 727 |
+
unet_lora_config = LoraConfig(
|
| 728 |
+
r=args.rank,
|
| 729 |
+
lora_alpha=args.rank,
|
| 730 |
+
init_lora_weights="gaussian",
|
| 731 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
| 732 |
+
)
|
| 733 |
+
# Add adapter and make sure the trainable params are in float32.
|
| 734 |
+
unet.add_adapter(unet_lora_config)
|
| 735 |
+
if args.mixed_precision == "fp16":
|
| 736 |
+
for param in unet.parameters():
|
| 737 |
+
# only upcast trainable parameters (LoRA) into fp32
|
| 738 |
+
if param.requires_grad:
|
| 739 |
+
param.data = param.to(torch.float32)
|
| 740 |
+
|
| 741 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 742 |
+
if is_xformers_available():
|
| 743 |
+
import xformers
|
| 744 |
+
|
| 745 |
+
xformers_version = version.parse(xformers.__version__)
|
| 746 |
+
if xformers_version == version.parse("0.0.16"):
|
| 747 |
+
logger.warning(
|
| 748 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 749 |
+
)
|
| 750 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 751 |
+
else:
|
| 752 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 753 |
+
|
| 754 |
+
if args.gradient_checkpointing:
|
| 755 |
+
unet.enable_gradient_checkpointing()
|
| 756 |
+
|
| 757 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 758 |
+
def save_model_hook(models, weights, output_dir):
|
| 759 |
+
if accelerator.is_main_process:
|
| 760 |
+
# there are only two options here. Either are just the unet attn processor layers
|
| 761 |
+
# or there are the unet and text encoder atten layers
|
| 762 |
+
unet_lora_layers_to_save = None
|
| 763 |
+
|
| 764 |
+
for model in models:
|
| 765 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 766 |
+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
| 767 |
+
else:
|
| 768 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 769 |
+
|
| 770 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 771 |
+
weights.pop()
|
| 772 |
+
|
| 773 |
+
StableDiffusionXLLoraLoaderMixin.save_lora_weights(
|
| 774 |
+
output_dir,
|
| 775 |
+
unet_lora_layers=unet_lora_layers_to_save,
|
| 776 |
+
text_encoder_lora_layers=None,
|
| 777 |
+
text_encoder_2_lora_layers=None,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
def load_model_hook(models, input_dir):
|
| 781 |
+
unet_ = None
|
| 782 |
+
|
| 783 |
+
while len(models) > 0:
|
| 784 |
+
model = models.pop()
|
| 785 |
+
|
| 786 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
| 787 |
+
unet_ = model
|
| 788 |
+
else:
|
| 789 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 790 |
+
|
| 791 |
+
lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
|
| 792 |
+
|
| 793 |
+
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
| 794 |
+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
| 795 |
+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
| 796 |
+
if incompatible_keys is not None:
|
| 797 |
+
# check only for unexpected keys
|
| 798 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
| 799 |
+
if unexpected_keys:
|
| 800 |
+
logger.warning(
|
| 801 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
| 802 |
+
f" {unexpected_keys}. "
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 806 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 807 |
+
|
| 808 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 809 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 810 |
+
if args.allow_tf32:
|
| 811 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 812 |
+
|
| 813 |
+
if args.scale_lr:
|
| 814 |
+
args.learning_rate = (
|
| 815 |
+
args.learning_rate * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 819 |
+
if args.use_8bit_adam:
|
| 820 |
+
try:
|
| 821 |
+
import bitsandbytes as bnb
|
| 822 |
+
except ImportError:
|
| 823 |
+
raise ImportError(
|
| 824 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 828 |
+
else:
|
| 829 |
+
optimizer_class = torch.optim.AdamW
|
| 830 |
+
|
| 831 |
+
# Optimizer creation
|
| 832 |
+
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 833 |
+
optimizer = optimizer_class(
|
| 834 |
+
params_to_optimize,
|
| 835 |
+
lr=args.learning_rate,
|
| 836 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 837 |
+
weight_decay=args.adam_weight_decay,
|
| 838 |
+
eps=args.adam_epsilon,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
# Dataset and DataLoaders creation:
|
| 842 |
+
args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
|
| 843 |
+
train_dataloader = get_loader(args, tokenizer_one=tokenizer_one, tokenizer_two=tokenizer_two)
|
| 844 |
+
|
| 845 |
+
# Scheduler and math around the number of training steps.
|
| 846 |
+
overrode_max_train_steps = False
|
| 847 |
+
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
| 848 |
+
if args.max_train_steps is None:
|
| 849 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 850 |
+
overrode_max_train_steps = True
|
| 851 |
+
|
| 852 |
+
lr_scheduler = get_scheduler(
|
| 853 |
+
args.lr_scheduler,
|
| 854 |
+
optimizer=optimizer,
|
| 855 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 856 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 857 |
+
num_cycles=args.lr_num_cycles,
|
| 858 |
+
power=args.lr_power,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
| 862 |
+
|
| 863 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 864 |
+
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
| 865 |
+
if overrode_max_train_steps:
|
| 866 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 867 |
+
# Afterwards we recalculate our number of training epochs
|
| 868 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 869 |
+
|
| 870 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 871 |
+
# The trackers initializes automatically on the main process.
|
| 872 |
+
if accelerator.is_main_process:
|
| 873 |
+
accelerator.init_trackers(args.tracker_name, config=vars(args))
|
| 874 |
+
|
| 875 |
+
# Train!
|
| 876 |
+
total_batch_size = args.per_gpu_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 877 |
+
|
| 878 |
+
logger.info("***** Running training *****")
|
| 879 |
+
logger.info(f" Num examples = {train_dataloader.num_samples}")
|
| 880 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 881 |
+
logger.info(f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
|
| 882 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 883 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 884 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 885 |
+
global_step = 0
|
| 886 |
+
first_epoch = 0
|
| 887 |
+
|
| 888 |
+
# Potentially load in the weights and states from a previous save
|
| 889 |
+
if args.resume_from_checkpoint:
|
| 890 |
+
if args.resume_from_checkpoint != "latest":
|
| 891 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 892 |
+
else:
|
| 893 |
+
# Get the mos recent checkpoint
|
| 894 |
+
dirs = os.listdir(args.output_dir)
|
| 895 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 896 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 897 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 898 |
+
|
| 899 |
+
if path is None:
|
| 900 |
+
accelerator.print(
|
| 901 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 902 |
+
)
|
| 903 |
+
args.resume_from_checkpoint = None
|
| 904 |
+
initial_global_step = 0
|
| 905 |
+
else:
|
| 906 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 907 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 908 |
+
global_step = int(path.split("-")[1])
|
| 909 |
+
|
| 910 |
+
initial_global_step = global_step
|
| 911 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 912 |
+
else:
|
| 913 |
+
initial_global_step = 0
|
| 914 |
+
|
| 915 |
+
progress_bar = tqdm(
|
| 916 |
+
range(0, args.max_train_steps),
|
| 917 |
+
initial=initial_global_step,
|
| 918 |
+
desc="Steps",
|
| 919 |
+
# Only show the progress bar once on each machine.
|
| 920 |
+
disable=not accelerator.is_local_main_process,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
unet.train()
|
| 924 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 925 |
+
for step, batch in enumerate(train_dataloader):
|
| 926 |
+
with accelerator.accumulate(unet):
|
| 927 |
+
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
|
| 928 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype, device=accelerator.device, non_blocking=True)
|
| 929 |
+
feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
|
| 930 |
+
|
| 931 |
+
latents = []
|
| 932 |
+
for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
|
| 933 |
+
latents.append(
|
| 934 |
+
vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
|
| 935 |
+
)
|
| 936 |
+
latents = torch.cat(latents, dim=0)
|
| 937 |
+
latents = latents * vae.config.scaling_factor
|
| 938 |
+
if args.pretrained_vae_model_name_or_path is None:
|
| 939 |
+
latents = latents.to(weight_dtype)
|
| 940 |
+
|
| 941 |
+
# Sample noise that we'll add to the latents
|
| 942 |
+
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
|
| 943 |
+
|
| 944 |
+
# Sample a random timestep for each image
|
| 945 |
+
bsz = latents.shape[0] // 2
|
| 946 |
+
timesteps = torch.randint(
|
| 947 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
| 948 |
+
).repeat(2)
|
| 949 |
+
|
| 950 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
| 951 |
+
# (this is the forward diffusion process)
|
| 952 |
+
noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 953 |
+
|
| 954 |
+
# time ids
|
| 955 |
+
def compute_time_ids(original_size, crops_coords_top_left):
|
| 956 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
| 957 |
+
target_size = (args.resolution, args.resolution)
|
| 958 |
+
add_time_ids = list(tuple(original_size) + tuple(crops_coords_top_left) + target_size)
|
| 959 |
+
add_time_ids = torch.tensor([add_time_ids])
|
| 960 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
| 961 |
+
return add_time_ids
|
| 962 |
+
|
| 963 |
+
add_time_ids = torch.cat(
|
| 964 |
+
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
| 965 |
+
).repeat(2, 1)
|
| 966 |
+
|
| 967 |
+
# Get the text embedding for conditioning
|
| 968 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
| 969 |
+
[text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
|
| 970 |
+
)
|
| 971 |
+
prompt_embeds = prompt_embeds.repeat(2, 1, 1)
|
| 972 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
|
| 973 |
+
|
| 974 |
+
# Predict the noise residual
|
| 975 |
+
model_pred = unet(
|
| 976 |
+
noisy_model_input,
|
| 977 |
+
timesteps,
|
| 978 |
+
prompt_embeds,
|
| 979 |
+
added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
|
| 980 |
+
).sample
|
| 981 |
+
|
| 982 |
+
# Get the target for loss depending on the prediction type
|
| 983 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 984 |
+
target = noise
|
| 985 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 986 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 987 |
+
else:
|
| 988 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 989 |
+
|
| 990 |
+
# ODDS ratio loss.
|
| 991 |
+
# In the diffusion formulation, we're assuming that the MSE loss
|
| 992 |
+
# approximates the logp.
|
| 993 |
+
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 994 |
+
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
| 995 |
+
model_losses_w, model_losses_l = model_losses.chunk(2)
|
| 996 |
+
log_odds = model_losses_w - model_losses_l
|
| 997 |
+
|
| 998 |
+
# Ratio loss.
|
| 999 |
+
ratio = F.logsigmoid(log_odds)
|
| 1000 |
+
ratio_losses = args.beta_orpo * ratio
|
| 1001 |
+
|
| 1002 |
+
# Full ORPO loss
|
| 1003 |
+
loss = model_losses_w.mean() - ratio_losses.mean()
|
| 1004 |
+
|
| 1005 |
+
# Backprop.
|
| 1006 |
+
accelerator.backward(loss)
|
| 1007 |
+
if accelerator.sync_gradients:
|
| 1008 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
| 1009 |
+
optimizer.step()
|
| 1010 |
+
lr_scheduler.step()
|
| 1011 |
+
optimizer.zero_grad()
|
| 1012 |
+
|
| 1013 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1014 |
+
if accelerator.sync_gradients:
|
| 1015 |
+
progress_bar.update(1)
|
| 1016 |
+
global_step += 1
|
| 1017 |
+
|
| 1018 |
+
if accelerator.is_main_process:
|
| 1019 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1020 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1021 |
+
if args.checkpoints_total_limit is not None:
|
| 1022 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1023 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1024 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1025 |
+
|
| 1026 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 1027 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1028 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1029 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1030 |
+
|
| 1031 |
+
logger.info(
|
| 1032 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 1033 |
+
)
|
| 1034 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 1035 |
+
|
| 1036 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1037 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1038 |
+
shutil.rmtree(removing_checkpoint)
|
| 1039 |
+
|
| 1040 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1041 |
+
accelerator.save_state(save_path)
|
| 1042 |
+
logger.info(f"Saved state to {save_path}")
|
| 1043 |
+
|
| 1044 |
+
if args.run_validation and global_step % args.validation_steps == 0:
|
| 1045 |
+
log_validation(
|
| 1046 |
+
args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1050 |
+
progress_bar.set_postfix(**logs)
|
| 1051 |
+
accelerator.log(logs, step=global_step)
|
| 1052 |
+
|
| 1053 |
+
if global_step >= args.max_train_steps:
|
| 1054 |
+
break
|
| 1055 |
+
|
| 1056 |
+
# Save the lora layers
|
| 1057 |
+
accelerator.wait_for_everyone()
|
| 1058 |
+
if accelerator.is_main_process:
|
| 1059 |
+
unet = accelerator.unwrap_model(unet)
|
| 1060 |
+
unet = unet.to(torch.float32)
|
| 1061 |
+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
| 1062 |
+
|
| 1063 |
+
StableDiffusionXLLoraLoaderMixin.save_lora_weights(
|
| 1064 |
+
save_directory=args.output_dir,
|
| 1065 |
+
unet_lora_layers=unet_lora_state_dict,
|
| 1066 |
+
text_encoder_lora_layers=None,
|
| 1067 |
+
text_encoder_2_lora_layers=None,
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# Final validation?
|
| 1071 |
+
if args.run_validation:
|
| 1072 |
+
log_validation(
|
| 1073 |
+
args,
|
| 1074 |
+
unet=None,
|
| 1075 |
+
vae=vae,
|
| 1076 |
+
accelerator=accelerator,
|
| 1077 |
+
weight_dtype=weight_dtype,
|
| 1078 |
+
epoch=epoch,
|
| 1079 |
+
is_final_validation=True,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
if args.push_to_hub:
|
| 1083 |
+
upload_folder(
|
| 1084 |
+
repo_id=repo_id,
|
| 1085 |
+
folder_path=args.output_dir,
|
| 1086 |
+
commit_message="End of training",
|
| 1087 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
accelerator.end_training()
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
if __name__ == "__main__":
|
| 1094 |
+
args = parse_args()
|
| 1095 |
+
main(args)
|
diffusers/examples/research_projects/dreambooth_inpaint/README.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dreambooth for the inpainting model
|
| 2 |
+
|
| 3 |
+
This script was added by @thedarkzeno .
|
| 4 |
+
|
| 5 |
+
Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
| 9 |
+
export INSTANCE_DIR="path-to-instance-images"
|
| 10 |
+
export OUTPUT_DIR="path-to-save-model"
|
| 11 |
+
|
| 12 |
+
accelerate launch train_dreambooth_inpaint.py \
|
| 13 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 14 |
+
--instance_data_dir=$INSTANCE_DIR \
|
| 15 |
+
--output_dir=$OUTPUT_DIR \
|
| 16 |
+
--instance_prompt="a photo of sks dog" \
|
| 17 |
+
--resolution=512 \
|
| 18 |
+
--train_batch_size=1 \
|
| 19 |
+
--gradient_accumulation_steps=1 \
|
| 20 |
+
--learning_rate=5e-6 \
|
| 21 |
+
--lr_scheduler="constant" \
|
| 22 |
+
--lr_warmup_steps=0 \
|
| 23 |
+
--max_train_steps=400
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### Training with prior-preservation loss
|
| 27 |
+
|
| 28 |
+
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
|
| 29 |
+
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
| 33 |
+
export INSTANCE_DIR="path-to-instance-images"
|
| 34 |
+
export CLASS_DIR="path-to-class-images"
|
| 35 |
+
export OUTPUT_DIR="path-to-save-model"
|
| 36 |
+
|
| 37 |
+
accelerate launch train_dreambooth_inpaint.py \
|
| 38 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 39 |
+
--instance_data_dir=$INSTANCE_DIR \
|
| 40 |
+
--class_data_dir=$CLASS_DIR \
|
| 41 |
+
--output_dir=$OUTPUT_DIR \
|
| 42 |
+
--with_prior_preservation --prior_loss_weight=1.0 \
|
| 43 |
+
--instance_prompt="a photo of sks dog" \
|
| 44 |
+
--class_prompt="a photo of dog" \
|
| 45 |
+
--resolution=512 \
|
| 46 |
+
--train_batch_size=1 \
|
| 47 |
+
--gradient_accumulation_steps=1 \
|
| 48 |
+
--learning_rate=5e-6 \
|
| 49 |
+
--lr_scheduler="constant" \
|
| 50 |
+
--lr_warmup_steps=0 \
|
| 51 |
+
--num_class_images=200 \
|
| 52 |
+
--max_train_steps=800
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
### Training with gradient checkpointing and 8-bit optimizer:
|
| 57 |
+
|
| 58 |
+
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
|
| 59 |
+
|
| 60 |
+
To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
| 64 |
+
export INSTANCE_DIR="path-to-instance-images"
|
| 65 |
+
export CLASS_DIR="path-to-class-images"
|
| 66 |
+
export OUTPUT_DIR="path-to-save-model"
|
| 67 |
+
|
| 68 |
+
accelerate launch train_dreambooth_inpaint.py \
|
| 69 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 70 |
+
--instance_data_dir=$INSTANCE_DIR \
|
| 71 |
+
--class_data_dir=$CLASS_DIR \
|
| 72 |
+
--output_dir=$OUTPUT_DIR \
|
| 73 |
+
--with_prior_preservation --prior_loss_weight=1.0 \
|
| 74 |
+
--instance_prompt="a photo of sks dog" \
|
| 75 |
+
--class_prompt="a photo of dog" \
|
| 76 |
+
--resolution=512 \
|
| 77 |
+
--train_batch_size=1 \
|
| 78 |
+
--gradient_accumulation_steps=2 --gradient_checkpointing \
|
| 79 |
+
--use_8bit_adam \
|
| 80 |
+
--learning_rate=5e-6 \
|
| 81 |
+
--lr_scheduler="constant" \
|
| 82 |
+
--lr_warmup_steps=0 \
|
| 83 |
+
--num_class_images=200 \
|
| 84 |
+
--max_train_steps=800
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Fine-tune text encoder with the UNet.
|
| 88 |
+
|
| 89 |
+
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
|
| 90 |
+
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
|
| 91 |
+
|
| 92 |
+
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
| 96 |
+
export INSTANCE_DIR="path-to-instance-images"
|
| 97 |
+
export CLASS_DIR="path-to-class-images"
|
| 98 |
+
export OUTPUT_DIR="path-to-save-model"
|
| 99 |
+
|
| 100 |
+
accelerate launch train_dreambooth_inpaint.py \
|
| 101 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 102 |
+
--train_text_encoder \
|
| 103 |
+
--instance_data_dir=$INSTANCE_DIR \
|
| 104 |
+
--class_data_dir=$CLASS_DIR \
|
| 105 |
+
--output_dir=$OUTPUT_DIR \
|
| 106 |
+
--with_prior_preservation --prior_loss_weight=1.0 \
|
| 107 |
+
--instance_prompt="a photo of sks dog" \
|
| 108 |
+
--class_prompt="a photo of dog" \
|
| 109 |
+
--resolution=512 \
|
| 110 |
+
--train_batch_size=1 \
|
| 111 |
+
--use_8bit_adam \
|
| 112 |
+
--gradient_checkpointing \
|
| 113 |
+
--learning_rate=2e-6 \
|
| 114 |
+
--lr_scheduler="constant" \
|
| 115 |
+
--lr_warmup_steps=0 \
|
| 116 |
+
--num_class_images=200 \
|
| 117 |
+
--max_train_steps=800
|
| 118 |
+
```
|
diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers==0.9.0
|
| 2 |
+
accelerate>=0.16.0
|
| 3 |
+
torchvision
|
| 4 |
+
transformers>=4.21.0
|
| 5 |
+
ftfy
|
| 6 |
+
tensorboard
|
| 7 |
+
Jinja2
|
diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
ADDED
|
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.utils.checkpoint
|
| 12 |
+
from accelerate import Accelerator
|
| 13 |
+
from accelerate.logging import get_logger
|
| 14 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 15 |
+
from huggingface_hub import create_repo, upload_folder
|
| 16 |
+
from huggingface_hub.utils import insecure_hashlib
|
| 17 |
+
from PIL import Image, ImageDraw
|
| 18 |
+
from torch.utils.data import Dataset
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
from tqdm.auto import tqdm
|
| 21 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 22 |
+
|
| 23 |
+
from diffusers import (
|
| 24 |
+
AutoencoderKL,
|
| 25 |
+
DDPMScheduler,
|
| 26 |
+
StableDiffusionInpaintPipeline,
|
| 27 |
+
StableDiffusionPipeline,
|
| 28 |
+
UNet2DConditionModel,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.optimization import get_scheduler
|
| 31 |
+
from diffusers.utils import check_min_version
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 35 |
+
check_min_version("0.13.0.dev0")
|
| 36 |
+
|
| 37 |
+
logger = get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def prepare_mask_and_masked_image(image, mask):
|
| 41 |
+
image = np.array(image.convert("RGB"))
|
| 42 |
+
image = image[None].transpose(0, 3, 1, 2)
|
| 43 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
| 44 |
+
|
| 45 |
+
mask = np.array(mask.convert("L"))
|
| 46 |
+
mask = mask.astype(np.float32) / 255.0
|
| 47 |
+
mask = mask[None, None]
|
| 48 |
+
mask[mask < 0.5] = 0
|
| 49 |
+
mask[mask >= 0.5] = 1
|
| 50 |
+
mask = torch.from_numpy(mask)
|
| 51 |
+
|
| 52 |
+
masked_image = image * (mask < 0.5)
|
| 53 |
+
|
| 54 |
+
return mask, masked_image
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# generate random masks
|
| 58 |
+
def random_mask(im_shape, ratio=1, mask_full_image=False):
|
| 59 |
+
mask = Image.new("L", im_shape, 0)
|
| 60 |
+
draw = ImageDraw.Draw(mask)
|
| 61 |
+
size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
|
| 62 |
+
# use this to always mask the whole image
|
| 63 |
+
if mask_full_image:
|
| 64 |
+
size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
|
| 65 |
+
limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
|
| 66 |
+
center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
|
| 67 |
+
draw_type = random.randint(0, 1)
|
| 68 |
+
if draw_type == 0 or mask_full_image:
|
| 69 |
+
draw.rectangle(
|
| 70 |
+
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
|
| 71 |
+
fill=255,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
draw.ellipse(
|
| 75 |
+
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
|
| 76 |
+
fill=255,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return mask
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def parse_args():
|
| 83 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--pretrained_model_name_or_path",
|
| 86 |
+
type=str,
|
| 87 |
+
default=None,
|
| 88 |
+
required=True,
|
| 89 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--tokenizer_name",
|
| 93 |
+
type=str,
|
| 94 |
+
default=None,
|
| 95 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--instance_data_dir",
|
| 99 |
+
type=str,
|
| 100 |
+
default=None,
|
| 101 |
+
required=True,
|
| 102 |
+
help="A folder containing the training data of instance images.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--class_data_dir",
|
| 106 |
+
type=str,
|
| 107 |
+
default=None,
|
| 108 |
+
required=False,
|
| 109 |
+
help="A folder containing the training data of class images.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--instance_prompt",
|
| 113 |
+
type=str,
|
| 114 |
+
default=None,
|
| 115 |
+
help="The prompt with identifier specifying the instance",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--class_prompt",
|
| 119 |
+
type=str,
|
| 120 |
+
default=None,
|
| 121 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--with_prior_preservation",
|
| 125 |
+
default=False,
|
| 126 |
+
action="store_true",
|
| 127 |
+
help="Flag to add prior preservation loss.",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--num_class_images",
|
| 132 |
+
type=int,
|
| 133 |
+
default=100,
|
| 134 |
+
help=(
|
| 135 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
| 136 |
+
" sampled with class_prompt."
|
| 137 |
+
),
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--output_dir",
|
| 141 |
+
type=str,
|
| 142 |
+
default="text-inversion-model",
|
| 143 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--resolution",
|
| 148 |
+
type=int,
|
| 149 |
+
default=512,
|
| 150 |
+
help=(
|
| 151 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 152 |
+
" resolution"
|
| 153 |
+
),
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--center_crop",
|
| 157 |
+
default=False,
|
| 158 |
+
action="store_true",
|
| 159 |
+
help=(
|
| 160 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 161 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--max_train_steps",
|
| 174 |
+
type=int,
|
| 175 |
+
default=None,
|
| 176 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--gradient_accumulation_steps",
|
| 180 |
+
type=int,
|
| 181 |
+
default=1,
|
| 182 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--gradient_checkpointing",
|
| 186 |
+
action="store_true",
|
| 187 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--learning_rate",
|
| 191 |
+
type=float,
|
| 192 |
+
default=5e-6,
|
| 193 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--scale_lr",
|
| 197 |
+
action="store_true",
|
| 198 |
+
default=False,
|
| 199 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--lr_scheduler",
|
| 203 |
+
type=str,
|
| 204 |
+
default="constant",
|
| 205 |
+
help=(
|
| 206 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 207 |
+
' "constant", "constant_with_warmup"]'
|
| 208 |
+
),
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 217 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 218 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 219 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 220 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 221 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 222 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--hub_model_id",
|
| 225 |
+
type=str,
|
| 226 |
+
default=None,
|
| 227 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--logging_dir",
|
| 231 |
+
type=str,
|
| 232 |
+
default="logs",
|
| 233 |
+
help=(
|
| 234 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 235 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 236 |
+
),
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--mixed_precision",
|
| 240 |
+
type=str,
|
| 241 |
+
default="no",
|
| 242 |
+
choices=["no", "fp16", "bf16"],
|
| 243 |
+
help=(
|
| 244 |
+
"Whether to use mixed precision. Choose"
|
| 245 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
| 246 |
+
"and an Nvidia Ampere GPU."
|
| 247 |
+
),
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--checkpointing_steps",
|
| 252 |
+
type=int,
|
| 253 |
+
default=500,
|
| 254 |
+
help=(
|
| 255 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 256 |
+
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
| 257 |
+
" using `--resume_from_checkpoint`."
|
| 258 |
+
),
|
| 259 |
+
)
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--checkpoints_total_limit",
|
| 262 |
+
type=int,
|
| 263 |
+
default=None,
|
| 264 |
+
help=(
|
| 265 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
| 266 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
| 267 |
+
" for more docs"
|
| 268 |
+
),
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--resume_from_checkpoint",
|
| 272 |
+
type=str,
|
| 273 |
+
default=None,
|
| 274 |
+
help=(
|
| 275 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 276 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
args = parser.parse_args()
|
| 281 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 282 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 283 |
+
args.local_rank = env_local_rank
|
| 284 |
+
|
| 285 |
+
if args.instance_data_dir is None:
|
| 286 |
+
raise ValueError("You must specify a train data directory.")
|
| 287 |
+
|
| 288 |
+
if args.with_prior_preservation:
|
| 289 |
+
if args.class_data_dir is None:
|
| 290 |
+
raise ValueError("You must specify a data directory for class images.")
|
| 291 |
+
if args.class_prompt is None:
|
| 292 |
+
raise ValueError("You must specify prompt for class images.")
|
| 293 |
+
|
| 294 |
+
return args
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class DreamBoothDataset(Dataset):
|
| 298 |
+
"""
|
| 299 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
| 300 |
+
It pre-processes the images and the tokenizes prompts.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
instance_data_root,
|
| 306 |
+
instance_prompt,
|
| 307 |
+
tokenizer,
|
| 308 |
+
class_data_root=None,
|
| 309 |
+
class_prompt=None,
|
| 310 |
+
size=512,
|
| 311 |
+
center_crop=False,
|
| 312 |
+
):
|
| 313 |
+
self.size = size
|
| 314 |
+
self.center_crop = center_crop
|
| 315 |
+
self.tokenizer = tokenizer
|
| 316 |
+
|
| 317 |
+
self.instance_data_root = Path(instance_data_root)
|
| 318 |
+
if not self.instance_data_root.exists():
|
| 319 |
+
raise ValueError("Instance images root doesn't exists.")
|
| 320 |
+
|
| 321 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
| 322 |
+
self.num_instance_images = len(self.instance_images_path)
|
| 323 |
+
self.instance_prompt = instance_prompt
|
| 324 |
+
self._length = self.num_instance_images
|
| 325 |
+
|
| 326 |
+
if class_data_root is not None:
|
| 327 |
+
self.class_data_root = Path(class_data_root)
|
| 328 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
| 329 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
| 330 |
+
self.num_class_images = len(self.class_images_path)
|
| 331 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
| 332 |
+
self.class_prompt = class_prompt
|
| 333 |
+
else:
|
| 334 |
+
self.class_data_root = None
|
| 335 |
+
|
| 336 |
+
self.image_transforms_resize_and_crop = transforms.Compose(
|
| 337 |
+
[
|
| 338 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 339 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
| 340 |
+
]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
self.image_transforms = transforms.Compose(
|
| 344 |
+
[
|
| 345 |
+
transforms.ToTensor(),
|
| 346 |
+
transforms.Normalize([0.5], [0.5]),
|
| 347 |
+
]
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def __len__(self):
|
| 351 |
+
return self._length
|
| 352 |
+
|
| 353 |
+
def __getitem__(self, index):
|
| 354 |
+
example = {}
|
| 355 |
+
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
| 356 |
+
if not instance_image.mode == "RGB":
|
| 357 |
+
instance_image = instance_image.convert("RGB")
|
| 358 |
+
instance_image = self.image_transforms_resize_and_crop(instance_image)
|
| 359 |
+
|
| 360 |
+
example["PIL_images"] = instance_image
|
| 361 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
| 362 |
+
|
| 363 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
| 364 |
+
self.instance_prompt,
|
| 365 |
+
padding="do_not_pad",
|
| 366 |
+
truncation=True,
|
| 367 |
+
max_length=self.tokenizer.model_max_length,
|
| 368 |
+
).input_ids
|
| 369 |
+
|
| 370 |
+
if self.class_data_root:
|
| 371 |
+
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
| 372 |
+
if not class_image.mode == "RGB":
|
| 373 |
+
class_image = class_image.convert("RGB")
|
| 374 |
+
class_image = self.image_transforms_resize_and_crop(class_image)
|
| 375 |
+
example["class_images"] = self.image_transforms(class_image)
|
| 376 |
+
example["class_PIL_images"] = class_image
|
| 377 |
+
example["class_prompt_ids"] = self.tokenizer(
|
| 378 |
+
self.class_prompt,
|
| 379 |
+
padding="do_not_pad",
|
| 380 |
+
truncation=True,
|
| 381 |
+
max_length=self.tokenizer.model_max_length,
|
| 382 |
+
).input_ids
|
| 383 |
+
|
| 384 |
+
return example
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class PromptDataset(Dataset):
|
| 388 |
+
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
|
| 389 |
+
|
| 390 |
+
def __init__(self, prompt, num_samples):
|
| 391 |
+
self.prompt = prompt
|
| 392 |
+
self.num_samples = num_samples
|
| 393 |
+
|
| 394 |
+
def __len__(self):
|
| 395 |
+
return self.num_samples
|
| 396 |
+
|
| 397 |
+
def __getitem__(self, index):
|
| 398 |
+
example = {}
|
| 399 |
+
example["prompt"] = self.prompt
|
| 400 |
+
example["index"] = index
|
| 401 |
+
return example
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def main():
|
| 405 |
+
args = parse_args()
|
| 406 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 407 |
+
|
| 408 |
+
project_config = ProjectConfiguration(
|
| 409 |
+
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
accelerator = Accelerator(
|
| 413 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 414 |
+
mixed_precision=args.mixed_precision,
|
| 415 |
+
log_with="tensorboard",
|
| 416 |
+
project_config=project_config,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
| 420 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
| 421 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
| 422 |
+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
| 423 |
+
raise ValueError(
|
| 424 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
| 425 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if args.seed is not None:
|
| 429 |
+
set_seed(args.seed)
|
| 430 |
+
|
| 431 |
+
if args.with_prior_preservation:
|
| 432 |
+
class_images_dir = Path(args.class_data_dir)
|
| 433 |
+
if not class_images_dir.exists():
|
| 434 |
+
class_images_dir.mkdir(parents=True)
|
| 435 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
| 436 |
+
|
| 437 |
+
if cur_class_images < args.num_class_images:
|
| 438 |
+
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
| 439 |
+
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 440 |
+
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
|
| 441 |
+
)
|
| 442 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 443 |
+
|
| 444 |
+
num_new_images = args.num_class_images - cur_class_images
|
| 445 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
| 446 |
+
|
| 447 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
| 448 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
| 449 |
+
sample_dataset, batch_size=args.sample_batch_size, num_workers=1
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
| 453 |
+
pipeline.to(accelerator.device)
|
| 454 |
+
transform_to_pil = transforms.ToPILImage()
|
| 455 |
+
for example in tqdm(
|
| 456 |
+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
| 457 |
+
):
|
| 458 |
+
bsz = len(example["prompt"])
|
| 459 |
+
fake_images = torch.rand((3, args.resolution, args.resolution))
|
| 460 |
+
transform_to_pil = transforms.ToPILImage()
|
| 461 |
+
fake_pil_images = transform_to_pil(fake_images)
|
| 462 |
+
|
| 463 |
+
fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
|
| 464 |
+
|
| 465 |
+
images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
|
| 466 |
+
|
| 467 |
+
for i, image in enumerate(images):
|
| 468 |
+
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
|
| 469 |
+
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
| 470 |
+
image.save(image_filename)
|
| 471 |
+
|
| 472 |
+
del pipeline
|
| 473 |
+
if torch.cuda.is_available():
|
| 474 |
+
torch.cuda.empty_cache()
|
| 475 |
+
|
| 476 |
+
# Handle the repository creation
|
| 477 |
+
if accelerator.is_main_process:
|
| 478 |
+
if args.output_dir is not None:
|
| 479 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 480 |
+
|
| 481 |
+
if args.push_to_hub:
|
| 482 |
+
repo_id = create_repo(
|
| 483 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 484 |
+
).repo_id
|
| 485 |
+
|
| 486 |
+
# Load the tokenizer
|
| 487 |
+
if args.tokenizer_name:
|
| 488 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
| 489 |
+
elif args.pretrained_model_name_or_path:
|
| 490 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 491 |
+
|
| 492 |
+
# Load models and create wrapper for stable diffusion
|
| 493 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 494 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 495 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 496 |
+
|
| 497 |
+
vae.requires_grad_(False)
|
| 498 |
+
if not args.train_text_encoder:
|
| 499 |
+
text_encoder.requires_grad_(False)
|
| 500 |
+
|
| 501 |
+
if args.gradient_checkpointing:
|
| 502 |
+
unet.enable_gradient_checkpointing()
|
| 503 |
+
if args.train_text_encoder:
|
| 504 |
+
text_encoder.gradient_checkpointing_enable()
|
| 505 |
+
|
| 506 |
+
if args.scale_lr:
|
| 507 |
+
args.learning_rate = (
|
| 508 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 512 |
+
if args.use_8bit_adam:
|
| 513 |
+
try:
|
| 514 |
+
import bitsandbytes as bnb
|
| 515 |
+
except ImportError:
|
| 516 |
+
raise ImportError(
|
| 517 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 521 |
+
else:
|
| 522 |
+
optimizer_class = torch.optim.AdamW
|
| 523 |
+
|
| 524 |
+
params_to_optimize = (
|
| 525 |
+
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
| 526 |
+
)
|
| 527 |
+
optimizer = optimizer_class(
|
| 528 |
+
params_to_optimize,
|
| 529 |
+
lr=args.learning_rate,
|
| 530 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 531 |
+
weight_decay=args.adam_weight_decay,
|
| 532 |
+
eps=args.adam_epsilon,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 536 |
+
|
| 537 |
+
train_dataset = DreamBoothDataset(
|
| 538 |
+
instance_data_root=args.instance_data_dir,
|
| 539 |
+
instance_prompt=args.instance_prompt,
|
| 540 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
| 541 |
+
class_prompt=args.class_prompt,
|
| 542 |
+
tokenizer=tokenizer,
|
| 543 |
+
size=args.resolution,
|
| 544 |
+
center_crop=args.center_crop,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
def collate_fn(examples):
|
| 548 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 549 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 550 |
+
|
| 551 |
+
# Concat class and instance examples for prior preservation.
|
| 552 |
+
# We do this to avoid doing two forward passes.
|
| 553 |
+
if args.with_prior_preservation:
|
| 554 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
| 555 |
+
pixel_values += [example["class_images"] for example in examples]
|
| 556 |
+
pior_pil = [example["class_PIL_images"] for example in examples]
|
| 557 |
+
|
| 558 |
+
masks = []
|
| 559 |
+
masked_images = []
|
| 560 |
+
for example in examples:
|
| 561 |
+
pil_image = example["PIL_images"]
|
| 562 |
+
# generate a random mask
|
| 563 |
+
mask = random_mask(pil_image.size, 1, False)
|
| 564 |
+
# prepare mask and masked image
|
| 565 |
+
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
|
| 566 |
+
|
| 567 |
+
masks.append(mask)
|
| 568 |
+
masked_images.append(masked_image)
|
| 569 |
+
|
| 570 |
+
if args.with_prior_preservation:
|
| 571 |
+
for pil_image in pior_pil:
|
| 572 |
+
# generate a random mask
|
| 573 |
+
mask = random_mask(pil_image.size, 1, False)
|
| 574 |
+
# prepare mask and masked image
|
| 575 |
+
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
|
| 576 |
+
|
| 577 |
+
masks.append(mask)
|
| 578 |
+
masked_images.append(masked_image)
|
| 579 |
+
|
| 580 |
+
pixel_values = torch.stack(pixel_values)
|
| 581 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 582 |
+
|
| 583 |
+
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
| 584 |
+
masks = torch.stack(masks)
|
| 585 |
+
masked_images = torch.stack(masked_images)
|
| 586 |
+
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
|
| 587 |
+
return batch
|
| 588 |
+
|
| 589 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 590 |
+
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# Scheduler and math around the number of training steps.
|
| 594 |
+
overrode_max_train_steps = False
|
| 595 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 596 |
+
if args.max_train_steps is None:
|
| 597 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 598 |
+
overrode_max_train_steps = True
|
| 599 |
+
|
| 600 |
+
lr_scheduler = get_scheduler(
|
| 601 |
+
args.lr_scheduler,
|
| 602 |
+
optimizer=optimizer,
|
| 603 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 604 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if args.train_text_encoder:
|
| 608 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 609 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 610 |
+
)
|
| 611 |
+
else:
|
| 612 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 613 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
| 614 |
+
)
|
| 615 |
+
accelerator.register_for_checkpointing(lr_scheduler)
|
| 616 |
+
|
| 617 |
+
weight_dtype = torch.float32
|
| 618 |
+
if args.mixed_precision == "fp16":
|
| 619 |
+
weight_dtype = torch.float16
|
| 620 |
+
elif args.mixed_precision == "bf16":
|
| 621 |
+
weight_dtype = torch.bfloat16
|
| 622 |
+
|
| 623 |
+
# Move text_encode and vae to gpu.
|
| 624 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 625 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 626 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 627 |
+
if not args.train_text_encoder:
|
| 628 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 629 |
+
|
| 630 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 631 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 632 |
+
if overrode_max_train_steps:
|
| 633 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 634 |
+
# Afterwards we recalculate our number of training epochs
|
| 635 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 636 |
+
|
| 637 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 638 |
+
# The trackers initializes automatically on the main process.
|
| 639 |
+
if accelerator.is_main_process:
|
| 640 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
| 641 |
+
|
| 642 |
+
# Train!
|
| 643 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 644 |
+
|
| 645 |
+
logger.info("***** Running training *****")
|
| 646 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 647 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 648 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 649 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 650 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 651 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 652 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 653 |
+
global_step = 0
|
| 654 |
+
first_epoch = 0
|
| 655 |
+
|
| 656 |
+
if args.resume_from_checkpoint:
|
| 657 |
+
if args.resume_from_checkpoint != "latest":
|
| 658 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 659 |
+
else:
|
| 660 |
+
# Get the most recent checkpoint
|
| 661 |
+
dirs = os.listdir(args.output_dir)
|
| 662 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 663 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 664 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 665 |
+
|
| 666 |
+
if path is None:
|
| 667 |
+
accelerator.print(
|
| 668 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 669 |
+
)
|
| 670 |
+
args.resume_from_checkpoint = None
|
| 671 |
+
else:
|
| 672 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 673 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 674 |
+
global_step = int(path.split("-")[1])
|
| 675 |
+
|
| 676 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
| 677 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 678 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
| 679 |
+
|
| 680 |
+
# Only show the progress bar once on each machine.
|
| 681 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
| 682 |
+
progress_bar.set_description("Steps")
|
| 683 |
+
|
| 684 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 685 |
+
unet.train()
|
| 686 |
+
for step, batch in enumerate(train_dataloader):
|
| 687 |
+
# Skip steps until we reach the resumed step
|
| 688 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 689 |
+
if step % args.gradient_accumulation_steps == 0:
|
| 690 |
+
progress_bar.update(1)
|
| 691 |
+
continue
|
| 692 |
+
|
| 693 |
+
with accelerator.accumulate(unet):
|
| 694 |
+
# Convert images to latent space
|
| 695 |
+
|
| 696 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 697 |
+
latents = latents * vae.config.scaling_factor
|
| 698 |
+
|
| 699 |
+
# Convert masked images to latent space
|
| 700 |
+
masked_latents = vae.encode(
|
| 701 |
+
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
| 702 |
+
).latent_dist.sample()
|
| 703 |
+
masked_latents = masked_latents * vae.config.scaling_factor
|
| 704 |
+
|
| 705 |
+
masks = batch["masks"]
|
| 706 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 707 |
+
mask = torch.stack(
|
| 708 |
+
[
|
| 709 |
+
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
| 710 |
+
for mask in masks
|
| 711 |
+
]
|
| 712 |
+
)
|
| 713 |
+
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
| 714 |
+
|
| 715 |
+
# Sample noise that we'll add to the latents
|
| 716 |
+
noise = torch.randn_like(latents)
|
| 717 |
+
bsz = latents.shape[0]
|
| 718 |
+
# Sample a random timestep for each image
|
| 719 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 720 |
+
timesteps = timesteps.long()
|
| 721 |
+
|
| 722 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 723 |
+
# (this is the forward diffusion process)
|
| 724 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 725 |
+
|
| 726 |
+
# concatenate the noised latents with the mask and the masked latents
|
| 727 |
+
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
|
| 728 |
+
|
| 729 |
+
# Get the text embedding for conditioning
|
| 730 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 731 |
+
|
| 732 |
+
# Predict the noise residual
|
| 733 |
+
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
| 734 |
+
|
| 735 |
+
# Get the target for loss depending on the prediction type
|
| 736 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 737 |
+
target = noise
|
| 738 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 739 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 740 |
+
else:
|
| 741 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 742 |
+
|
| 743 |
+
if args.with_prior_preservation:
|
| 744 |
+
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
| 745 |
+
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
|
| 746 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
| 747 |
+
|
| 748 |
+
# Compute instance loss
|
| 749 |
+
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
| 750 |
+
|
| 751 |
+
# Compute prior loss
|
| 752 |
+
prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
|
| 753 |
+
|
| 754 |
+
# Add the prior loss to the instance loss.
|
| 755 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
| 756 |
+
else:
|
| 757 |
+
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
| 758 |
+
|
| 759 |
+
accelerator.backward(loss)
|
| 760 |
+
if accelerator.sync_gradients:
|
| 761 |
+
params_to_clip = (
|
| 762 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
| 763 |
+
if args.train_text_encoder
|
| 764 |
+
else unet.parameters()
|
| 765 |
+
)
|
| 766 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 767 |
+
optimizer.step()
|
| 768 |
+
lr_scheduler.step()
|
| 769 |
+
optimizer.zero_grad()
|
| 770 |
+
|
| 771 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 772 |
+
if accelerator.sync_gradients:
|
| 773 |
+
progress_bar.update(1)
|
| 774 |
+
global_step += 1
|
| 775 |
+
|
| 776 |
+
if global_step % args.checkpointing_steps == 0:
|
| 777 |
+
if accelerator.is_main_process:
|
| 778 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 779 |
+
accelerator.save_state(save_path)
|
| 780 |
+
logger.info(f"Saved state to {save_path}")
|
| 781 |
+
|
| 782 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 783 |
+
progress_bar.set_postfix(**logs)
|
| 784 |
+
accelerator.log(logs, step=global_step)
|
| 785 |
+
|
| 786 |
+
if global_step >= args.max_train_steps:
|
| 787 |
+
break
|
| 788 |
+
|
| 789 |
+
accelerator.wait_for_everyone()
|
| 790 |
+
|
| 791 |
+
# Create the pipeline using using the trained modules and save it.
|
| 792 |
+
if accelerator.is_main_process:
|
| 793 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 794 |
+
args.pretrained_model_name_or_path,
|
| 795 |
+
unet=accelerator.unwrap_model(unet),
|
| 796 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 797 |
+
)
|
| 798 |
+
pipeline.save_pretrained(args.output_dir)
|
| 799 |
+
|
| 800 |
+
if args.push_to_hub:
|
| 801 |
+
upload_folder(
|
| 802 |
+
repo_id=repo_id,
|
| 803 |
+
folder_path=args.output_dir,
|
| 804 |
+
commit_message="End of training",
|
| 805 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
accelerator.end_training()
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
if __name__ == "__main__":
|
| 812 |
+
main()
|
diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
+
from accelerate.logging import get_logger
|
| 13 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 14 |
+
from huggingface_hub import create_repo, upload_folder
|
| 15 |
+
from huggingface_hub.utils import insecure_hashlib
|
| 16 |
+
from PIL import Image, ImageDraw
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
from tqdm.auto import tqdm
|
| 20 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 21 |
+
|
| 22 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
|
| 23 |
+
from diffusers.loaders import AttnProcsLayers
|
| 24 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor
|
| 25 |
+
from diffusers.optimization import get_scheduler
|
| 26 |
+
from diffusers.utils import check_min_version
|
| 27 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 31 |
+
check_min_version("0.13.0.dev0")
|
| 32 |
+
|
| 33 |
+
logger = get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def prepare_mask_and_masked_image(image, mask):
|
| 37 |
+
image = np.array(image.convert("RGB"))
|
| 38 |
+
image = image[None].transpose(0, 3, 1, 2)
|
| 39 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
| 40 |
+
|
| 41 |
+
mask = np.array(mask.convert("L"))
|
| 42 |
+
mask = mask.astype(np.float32) / 255.0
|
| 43 |
+
mask = mask[None, None]
|
| 44 |
+
mask[mask < 0.5] = 0
|
| 45 |
+
mask[mask >= 0.5] = 1
|
| 46 |
+
mask = torch.from_numpy(mask)
|
| 47 |
+
|
| 48 |
+
masked_image = image * (mask < 0.5)
|
| 49 |
+
|
| 50 |
+
return mask, masked_image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# generate random masks
|
| 54 |
+
def random_mask(im_shape, ratio=1, mask_full_image=False):
|
| 55 |
+
mask = Image.new("L", im_shape, 0)
|
| 56 |
+
draw = ImageDraw.Draw(mask)
|
| 57 |
+
size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
|
| 58 |
+
# use this to always mask the whole image
|
| 59 |
+
if mask_full_image:
|
| 60 |
+
size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
|
| 61 |
+
limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
|
| 62 |
+
center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
|
| 63 |
+
draw_type = random.randint(0, 1)
|
| 64 |
+
if draw_type == 0 or mask_full_image:
|
| 65 |
+
draw.rectangle(
|
| 66 |
+
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
|
| 67 |
+
fill=255,
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
draw.ellipse(
|
| 71 |
+
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
|
| 72 |
+
fill=255,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return mask
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def parse_args():
|
| 79 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--pretrained_model_name_or_path",
|
| 82 |
+
type=str,
|
| 83 |
+
default=None,
|
| 84 |
+
required=True,
|
| 85 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--tokenizer_name",
|
| 89 |
+
type=str,
|
| 90 |
+
default=None,
|
| 91 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--instance_data_dir",
|
| 95 |
+
type=str,
|
| 96 |
+
default=None,
|
| 97 |
+
required=True,
|
| 98 |
+
help="A folder containing the training data of instance images.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--class_data_dir",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
required=False,
|
| 105 |
+
help="A folder containing the training data of class images.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--instance_prompt",
|
| 109 |
+
type=str,
|
| 110 |
+
default=None,
|
| 111 |
+
help="The prompt with identifier specifying the instance",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--class_prompt",
|
| 115 |
+
type=str,
|
| 116 |
+
default=None,
|
| 117 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--with_prior_preservation",
|
| 121 |
+
default=False,
|
| 122 |
+
action="store_true",
|
| 123 |
+
help="Flag to add prior preservation loss.",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--num_class_images",
|
| 128 |
+
type=int,
|
| 129 |
+
default=100,
|
| 130 |
+
help=(
|
| 131 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
| 132 |
+
" sampled with class_prompt."
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--output_dir",
|
| 137 |
+
type=str,
|
| 138 |
+
default="dreambooth-inpaint-model",
|
| 139 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--resolution",
|
| 144 |
+
type=int,
|
| 145 |
+
default=512,
|
| 146 |
+
help=(
|
| 147 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 148 |
+
" resolution"
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--center_crop",
|
| 153 |
+
default=False,
|
| 154 |
+
action="store_true",
|
| 155 |
+
help=(
|
| 156 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 157 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 158 |
+
),
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--max_train_steps",
|
| 170 |
+
type=int,
|
| 171 |
+
default=None,
|
| 172 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--gradient_accumulation_steps",
|
| 176 |
+
type=int,
|
| 177 |
+
default=1,
|
| 178 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--gradient_checkpointing",
|
| 182 |
+
action="store_true",
|
| 183 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--learning_rate",
|
| 187 |
+
type=float,
|
| 188 |
+
default=5e-6,
|
| 189 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--scale_lr",
|
| 193 |
+
action="store_true",
|
| 194 |
+
default=False,
|
| 195 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--lr_scheduler",
|
| 199 |
+
type=str,
|
| 200 |
+
default="constant",
|
| 201 |
+
help=(
|
| 202 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 203 |
+
' "constant", "constant_with_warmup"]'
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 213 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 214 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 215 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 216 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 217 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 218 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--hub_model_id",
|
| 221 |
+
type=str,
|
| 222 |
+
default=None,
|
| 223 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--logging_dir",
|
| 227 |
+
type=str,
|
| 228 |
+
default="logs",
|
| 229 |
+
help=(
|
| 230 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 231 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--mixed_precision",
|
| 236 |
+
type=str,
|
| 237 |
+
default="no",
|
| 238 |
+
choices=["no", "fp16", "bf16"],
|
| 239 |
+
help=(
|
| 240 |
+
"Whether to use mixed precision. Choose"
|
| 241 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
| 242 |
+
"and an Nvidia Ampere GPU."
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--checkpointing_steps",
|
| 248 |
+
type=int,
|
| 249 |
+
default=500,
|
| 250 |
+
help=(
|
| 251 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 252 |
+
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
| 253 |
+
" using `--resume_from_checkpoint`."
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--checkpoints_total_limit",
|
| 258 |
+
type=int,
|
| 259 |
+
default=None,
|
| 260 |
+
help=(
|
| 261 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
| 262 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
| 263 |
+
" for more docs"
|
| 264 |
+
),
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--resume_from_checkpoint",
|
| 268 |
+
type=str,
|
| 269 |
+
default=None,
|
| 270 |
+
help=(
|
| 271 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 272 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 273 |
+
),
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 281 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 282 |
+
args.local_rank = env_local_rank
|
| 283 |
+
|
| 284 |
+
if args.instance_data_dir is None:
|
| 285 |
+
raise ValueError("You must specify a train data directory.")
|
| 286 |
+
|
| 287 |
+
if args.with_prior_preservation:
|
| 288 |
+
if args.class_data_dir is None:
|
| 289 |
+
raise ValueError("You must specify a data directory for class images.")
|
| 290 |
+
if args.class_prompt is None:
|
| 291 |
+
raise ValueError("You must specify prompt for class images.")
|
| 292 |
+
|
| 293 |
+
return args
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class DreamBoothDataset(Dataset):
|
| 297 |
+
"""
|
| 298 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
| 299 |
+
It pre-processes the images and the tokenizes prompts.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
instance_data_root,
|
| 305 |
+
instance_prompt,
|
| 306 |
+
tokenizer,
|
| 307 |
+
class_data_root=None,
|
| 308 |
+
class_prompt=None,
|
| 309 |
+
size=512,
|
| 310 |
+
center_crop=False,
|
| 311 |
+
):
|
| 312 |
+
self.size = size
|
| 313 |
+
self.center_crop = center_crop
|
| 314 |
+
self.tokenizer = tokenizer
|
| 315 |
+
|
| 316 |
+
self.instance_data_root = Path(instance_data_root)
|
| 317 |
+
if not self.instance_data_root.exists():
|
| 318 |
+
raise ValueError("Instance images root doesn't exists.")
|
| 319 |
+
|
| 320 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
| 321 |
+
self.num_instance_images = len(self.instance_images_path)
|
| 322 |
+
self.instance_prompt = instance_prompt
|
| 323 |
+
self._length = self.num_instance_images
|
| 324 |
+
|
| 325 |
+
if class_data_root is not None:
|
| 326 |
+
self.class_data_root = Path(class_data_root)
|
| 327 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
| 328 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
| 329 |
+
self.num_class_images = len(self.class_images_path)
|
| 330 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
| 331 |
+
self.class_prompt = class_prompt
|
| 332 |
+
else:
|
| 333 |
+
self.class_data_root = None
|
| 334 |
+
|
| 335 |
+
self.image_transforms_resize_and_crop = transforms.Compose(
|
| 336 |
+
[
|
| 337 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 338 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.image_transforms = transforms.Compose(
|
| 343 |
+
[
|
| 344 |
+
transforms.ToTensor(),
|
| 345 |
+
transforms.Normalize([0.5], [0.5]),
|
| 346 |
+
]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def __len__(self):
|
| 350 |
+
return self._length
|
| 351 |
+
|
| 352 |
+
def __getitem__(self, index):
|
| 353 |
+
example = {}
|
| 354 |
+
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
| 355 |
+
if not instance_image.mode == "RGB":
|
| 356 |
+
instance_image = instance_image.convert("RGB")
|
| 357 |
+
instance_image = self.image_transforms_resize_and_crop(instance_image)
|
| 358 |
+
|
| 359 |
+
example["PIL_images"] = instance_image
|
| 360 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
| 361 |
+
|
| 362 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
| 363 |
+
self.instance_prompt,
|
| 364 |
+
padding="do_not_pad",
|
| 365 |
+
truncation=True,
|
| 366 |
+
max_length=self.tokenizer.model_max_length,
|
| 367 |
+
).input_ids
|
| 368 |
+
|
| 369 |
+
if self.class_data_root:
|
| 370 |
+
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
| 371 |
+
if not class_image.mode == "RGB":
|
| 372 |
+
class_image = class_image.convert("RGB")
|
| 373 |
+
class_image = self.image_transforms_resize_and_crop(class_image)
|
| 374 |
+
example["class_images"] = self.image_transforms(class_image)
|
| 375 |
+
example["class_PIL_images"] = class_image
|
| 376 |
+
example["class_prompt_ids"] = self.tokenizer(
|
| 377 |
+
self.class_prompt,
|
| 378 |
+
padding="do_not_pad",
|
| 379 |
+
truncation=True,
|
| 380 |
+
max_length=self.tokenizer.model_max_length,
|
| 381 |
+
).input_ids
|
| 382 |
+
|
| 383 |
+
return example
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class PromptDataset(Dataset):
|
| 387 |
+
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
|
| 388 |
+
|
| 389 |
+
def __init__(self, prompt, num_samples):
|
| 390 |
+
self.prompt = prompt
|
| 391 |
+
self.num_samples = num_samples
|
| 392 |
+
|
| 393 |
+
def __len__(self):
|
| 394 |
+
return self.num_samples
|
| 395 |
+
|
| 396 |
+
def __getitem__(self, index):
|
| 397 |
+
example = {}
|
| 398 |
+
example["prompt"] = self.prompt
|
| 399 |
+
example["index"] = index
|
| 400 |
+
return example
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def main():
|
| 404 |
+
args = parse_args()
|
| 405 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 406 |
+
|
| 407 |
+
accelerator_project_config = ProjectConfiguration(
|
| 408 |
+
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
accelerator = Accelerator(
|
| 412 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 413 |
+
mixed_precision=args.mixed_precision,
|
| 414 |
+
log_with="tensorboard",
|
| 415 |
+
project_config=accelerator_project_config,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
| 419 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
| 420 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
| 421 |
+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
| 424 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if args.seed is not None:
|
| 428 |
+
set_seed(args.seed)
|
| 429 |
+
|
| 430 |
+
if args.with_prior_preservation:
|
| 431 |
+
class_images_dir = Path(args.class_data_dir)
|
| 432 |
+
if not class_images_dir.exists():
|
| 433 |
+
class_images_dir.mkdir(parents=True)
|
| 434 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
| 435 |
+
|
| 436 |
+
if cur_class_images < args.num_class_images:
|
| 437 |
+
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
| 438 |
+
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 439 |
+
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
|
| 440 |
+
)
|
| 441 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 442 |
+
|
| 443 |
+
num_new_images = args.num_class_images - cur_class_images
|
| 444 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
| 445 |
+
|
| 446 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
| 447 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
| 448 |
+
sample_dataset, batch_size=args.sample_batch_size, num_workers=1
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
| 452 |
+
pipeline.to(accelerator.device)
|
| 453 |
+
transform_to_pil = transforms.ToPILImage()
|
| 454 |
+
for example in tqdm(
|
| 455 |
+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
| 456 |
+
):
|
| 457 |
+
bsz = len(example["prompt"])
|
| 458 |
+
fake_images = torch.rand((3, args.resolution, args.resolution))
|
| 459 |
+
transform_to_pil = transforms.ToPILImage()
|
| 460 |
+
fake_pil_images = transform_to_pil(fake_images)
|
| 461 |
+
|
| 462 |
+
fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
|
| 463 |
+
|
| 464 |
+
images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
|
| 465 |
+
|
| 466 |
+
for i, image in enumerate(images):
|
| 467 |
+
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
|
| 468 |
+
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
| 469 |
+
image.save(image_filename)
|
| 470 |
+
|
| 471 |
+
del pipeline
|
| 472 |
+
if torch.cuda.is_available():
|
| 473 |
+
torch.cuda.empty_cache()
|
| 474 |
+
|
| 475 |
+
# Handle the repository creation
|
| 476 |
+
if accelerator.is_main_process:
|
| 477 |
+
if args.output_dir is not None:
|
| 478 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 479 |
+
|
| 480 |
+
if args.push_to_hub:
|
| 481 |
+
repo_id = create_repo(
|
| 482 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 483 |
+
).repo_id
|
| 484 |
+
|
| 485 |
+
# Load the tokenizer
|
| 486 |
+
if args.tokenizer_name:
|
| 487 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
| 488 |
+
elif args.pretrained_model_name_or_path:
|
| 489 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 490 |
+
|
| 491 |
+
# Load models and create wrapper for stable diffusion
|
| 492 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 493 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 494 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 495 |
+
|
| 496 |
+
# We only train the additional adapter LoRA layers
|
| 497 |
+
vae.requires_grad_(False)
|
| 498 |
+
text_encoder.requires_grad_(False)
|
| 499 |
+
unet.requires_grad_(False)
|
| 500 |
+
|
| 501 |
+
weight_dtype = torch.float32
|
| 502 |
+
if args.mixed_precision == "fp16":
|
| 503 |
+
weight_dtype = torch.float16
|
| 504 |
+
elif args.mixed_precision == "bf16":
|
| 505 |
+
weight_dtype = torch.bfloat16
|
| 506 |
+
|
| 507 |
+
# Move text_encode and vae to gpu.
|
| 508 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 509 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 510 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 511 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 512 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 513 |
+
|
| 514 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 515 |
+
if is_xformers_available():
|
| 516 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 517 |
+
else:
|
| 518 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 519 |
+
|
| 520 |
+
# now we will add new LoRA weights to the attention layers
|
| 521 |
+
# It's important to realize here how many attention weights will be added and of which sizes
|
| 522 |
+
# The sizes of the attention layers consist only of two different variables:
|
| 523 |
+
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
|
| 524 |
+
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
|
| 525 |
+
|
| 526 |
+
# Let's first see how many attention processors we will have to set.
|
| 527 |
+
# For Stable Diffusion, it should be equal to:
|
| 528 |
+
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
| 529 |
+
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
| 530 |
+
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
|
| 531 |
+
# => 32 layers
|
| 532 |
+
|
| 533 |
+
# Set correct lora layers
|
| 534 |
+
lora_attn_procs = {}
|
| 535 |
+
for name in unet.attn_processors.keys():
|
| 536 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 537 |
+
if name.startswith("mid_block"):
|
| 538 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 539 |
+
elif name.startswith("up_blocks"):
|
| 540 |
+
block_id = int(name[len("up_blocks.")])
|
| 541 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 542 |
+
elif name.startswith("down_blocks"):
|
| 543 |
+
block_id = int(name[len("down_blocks.")])
|
| 544 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 545 |
+
|
| 546 |
+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
| 547 |
+
|
| 548 |
+
unet.set_attn_processor(lora_attn_procs)
|
| 549 |
+
lora_layers = AttnProcsLayers(unet.attn_processors)
|
| 550 |
+
|
| 551 |
+
accelerator.register_for_checkpointing(lora_layers)
|
| 552 |
+
|
| 553 |
+
if args.scale_lr:
|
| 554 |
+
args.learning_rate = (
|
| 555 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 559 |
+
if args.use_8bit_adam:
|
| 560 |
+
try:
|
| 561 |
+
import bitsandbytes as bnb
|
| 562 |
+
except ImportError:
|
| 563 |
+
raise ImportError(
|
| 564 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 568 |
+
else:
|
| 569 |
+
optimizer_class = torch.optim.AdamW
|
| 570 |
+
|
| 571 |
+
optimizer = optimizer_class(
|
| 572 |
+
lora_layers.parameters(),
|
| 573 |
+
lr=args.learning_rate,
|
| 574 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 575 |
+
weight_decay=args.adam_weight_decay,
|
| 576 |
+
eps=args.adam_epsilon,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 580 |
+
|
| 581 |
+
train_dataset = DreamBoothDataset(
|
| 582 |
+
instance_data_root=args.instance_data_dir,
|
| 583 |
+
instance_prompt=args.instance_prompt,
|
| 584 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
| 585 |
+
class_prompt=args.class_prompt,
|
| 586 |
+
tokenizer=tokenizer,
|
| 587 |
+
size=args.resolution,
|
| 588 |
+
center_crop=args.center_crop,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
def collate_fn(examples):
|
| 592 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 593 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 594 |
+
|
| 595 |
+
# Concat class and instance examples for prior preservation.
|
| 596 |
+
# We do this to avoid doing two forward passes.
|
| 597 |
+
if args.with_prior_preservation:
|
| 598 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
| 599 |
+
pixel_values += [example["class_images"] for example in examples]
|
| 600 |
+
pior_pil = [example["class_PIL_images"] for example in examples]
|
| 601 |
+
|
| 602 |
+
masks = []
|
| 603 |
+
masked_images = []
|
| 604 |
+
for example in examples:
|
| 605 |
+
pil_image = example["PIL_images"]
|
| 606 |
+
# generate a random mask
|
| 607 |
+
mask = random_mask(pil_image.size, 1, False)
|
| 608 |
+
# prepare mask and masked image
|
| 609 |
+
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
|
| 610 |
+
|
| 611 |
+
masks.append(mask)
|
| 612 |
+
masked_images.append(masked_image)
|
| 613 |
+
|
| 614 |
+
if args.with_prior_preservation:
|
| 615 |
+
for pil_image in pior_pil:
|
| 616 |
+
# generate a random mask
|
| 617 |
+
mask = random_mask(pil_image.size, 1, False)
|
| 618 |
+
# prepare mask and masked image
|
| 619 |
+
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
|
| 620 |
+
|
| 621 |
+
masks.append(mask)
|
| 622 |
+
masked_images.append(masked_image)
|
| 623 |
+
|
| 624 |
+
pixel_values = torch.stack(pixel_values)
|
| 625 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 626 |
+
|
| 627 |
+
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
| 628 |
+
masks = torch.stack(masks)
|
| 629 |
+
masked_images = torch.stack(masked_images)
|
| 630 |
+
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
|
| 631 |
+
return batch
|
| 632 |
+
|
| 633 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 634 |
+
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Scheduler and math around the number of training steps.
|
| 638 |
+
overrode_max_train_steps = False
|
| 639 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 640 |
+
if args.max_train_steps is None:
|
| 641 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 642 |
+
overrode_max_train_steps = True
|
| 643 |
+
|
| 644 |
+
lr_scheduler = get_scheduler(
|
| 645 |
+
args.lr_scheduler,
|
| 646 |
+
optimizer=optimizer,
|
| 647 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 648 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Prepare everything with our `accelerator`.
|
| 652 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 653 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler
|
| 654 |
+
)
|
| 655 |
+
# accelerator.register_for_checkpointing(lr_scheduler)
|
| 656 |
+
|
| 657 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 658 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 659 |
+
if overrode_max_train_steps:
|
| 660 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 661 |
+
# Afterwards we recalculate our number of training epochs
|
| 662 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 663 |
+
|
| 664 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 665 |
+
# The trackers initializes automatically on the main process.
|
| 666 |
+
if accelerator.is_main_process:
|
| 667 |
+
accelerator.init_trackers("dreambooth-inpaint-lora", config=vars(args))
|
| 668 |
+
|
| 669 |
+
# Train!
|
| 670 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 671 |
+
|
| 672 |
+
logger.info("***** Running training *****")
|
| 673 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 674 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 675 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 676 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 677 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 678 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 679 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 680 |
+
global_step = 0
|
| 681 |
+
first_epoch = 0
|
| 682 |
+
|
| 683 |
+
if args.resume_from_checkpoint:
|
| 684 |
+
if args.resume_from_checkpoint != "latest":
|
| 685 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 686 |
+
else:
|
| 687 |
+
# Get the most recent checkpoint
|
| 688 |
+
dirs = os.listdir(args.output_dir)
|
| 689 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 690 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 691 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 692 |
+
|
| 693 |
+
if path is None:
|
| 694 |
+
accelerator.print(
|
| 695 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 696 |
+
)
|
| 697 |
+
args.resume_from_checkpoint = None
|
| 698 |
+
else:
|
| 699 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 700 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 701 |
+
global_step = int(path.split("-")[1])
|
| 702 |
+
|
| 703 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
| 704 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 705 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
| 706 |
+
|
| 707 |
+
# Only show the progress bar once on each machine.
|
| 708 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
| 709 |
+
progress_bar.set_description("Steps")
|
| 710 |
+
|
| 711 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 712 |
+
unet.train()
|
| 713 |
+
for step, batch in enumerate(train_dataloader):
|
| 714 |
+
# Skip steps until we reach the resumed step
|
| 715 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 716 |
+
if step % args.gradient_accumulation_steps == 0:
|
| 717 |
+
progress_bar.update(1)
|
| 718 |
+
continue
|
| 719 |
+
|
| 720 |
+
with accelerator.accumulate(unet):
|
| 721 |
+
# Convert images to latent space
|
| 722 |
+
|
| 723 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 724 |
+
latents = latents * vae.config.scaling_factor
|
| 725 |
+
|
| 726 |
+
# Convert masked images to latent space
|
| 727 |
+
masked_latents = vae.encode(
|
| 728 |
+
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
| 729 |
+
).latent_dist.sample()
|
| 730 |
+
masked_latents = masked_latents * vae.config.scaling_factor
|
| 731 |
+
|
| 732 |
+
masks = batch["masks"]
|
| 733 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 734 |
+
mask = torch.stack(
|
| 735 |
+
[
|
| 736 |
+
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
| 737 |
+
for mask in masks
|
| 738 |
+
]
|
| 739 |
+
).to(dtype=weight_dtype)
|
| 740 |
+
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
| 741 |
+
|
| 742 |
+
# Sample noise that we'll add to the latents
|
| 743 |
+
noise = torch.randn_like(latents)
|
| 744 |
+
bsz = latents.shape[0]
|
| 745 |
+
# Sample a random timestep for each image
|
| 746 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 747 |
+
timesteps = timesteps.long()
|
| 748 |
+
|
| 749 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 750 |
+
# (this is the forward diffusion process)
|
| 751 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 752 |
+
|
| 753 |
+
# concatenate the noised latents with the mask and the masked latents
|
| 754 |
+
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
|
| 755 |
+
|
| 756 |
+
# Get the text embedding for conditioning
|
| 757 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 758 |
+
|
| 759 |
+
# Predict the noise residual
|
| 760 |
+
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
| 761 |
+
|
| 762 |
+
# Get the target for loss depending on the prediction type
|
| 763 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 764 |
+
target = noise
|
| 765 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 766 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 767 |
+
else:
|
| 768 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 769 |
+
|
| 770 |
+
if args.with_prior_preservation:
|
| 771 |
+
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
|
| 772 |
+
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
|
| 773 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
| 774 |
+
|
| 775 |
+
# Compute instance loss
|
| 776 |
+
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
| 777 |
+
|
| 778 |
+
# Compute prior loss
|
| 779 |
+
prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
|
| 780 |
+
|
| 781 |
+
# Add the prior loss to the instance loss.
|
| 782 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
| 783 |
+
else:
|
| 784 |
+
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
| 785 |
+
|
| 786 |
+
accelerator.backward(loss)
|
| 787 |
+
if accelerator.sync_gradients:
|
| 788 |
+
params_to_clip = lora_layers.parameters()
|
| 789 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 790 |
+
optimizer.step()
|
| 791 |
+
lr_scheduler.step()
|
| 792 |
+
optimizer.zero_grad()
|
| 793 |
+
|
| 794 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 795 |
+
if accelerator.sync_gradients:
|
| 796 |
+
progress_bar.update(1)
|
| 797 |
+
global_step += 1
|
| 798 |
+
|
| 799 |
+
if global_step % args.checkpointing_steps == 0:
|
| 800 |
+
if accelerator.is_main_process:
|
| 801 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 802 |
+
accelerator.save_state(save_path)
|
| 803 |
+
logger.info(f"Saved state to {save_path}")
|
| 804 |
+
|
| 805 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 806 |
+
progress_bar.set_postfix(**logs)
|
| 807 |
+
accelerator.log(logs, step=global_step)
|
| 808 |
+
|
| 809 |
+
if global_step >= args.max_train_steps:
|
| 810 |
+
break
|
| 811 |
+
|
| 812 |
+
accelerator.wait_for_everyone()
|
| 813 |
+
|
| 814 |
+
# Save the lora layers
|
| 815 |
+
if accelerator.is_main_process:
|
| 816 |
+
unet = unet.to(torch.float32)
|
| 817 |
+
unet.save_attn_procs(args.output_dir)
|
| 818 |
+
|
| 819 |
+
if args.push_to_hub:
|
| 820 |
+
upload_folder(
|
| 821 |
+
repo_id=repo_id,
|
| 822 |
+
folder_path=args.output_dir,
|
| 823 |
+
commit_message="End of training",
|
| 824 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
accelerator.end_training()
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
if __name__ == "__main__":
|
| 831 |
+
main()
|
diffusers/examples/research_projects/flux_lora_quantization/README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## LoRA fine-tuning Flux.1 Dev with quantization
|
| 2 |
+
|
| 3 |
+
> [!NOTE]
|
| 4 |
+
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
|
| 5 |
+
|
| 6 |
+
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
|
| 7 |
+
|
| 8 |
+
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
|
| 9 |
+
* Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print.
|
| 10 |
+
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
|
| 11 |
+
* Since we already precomputed the text embeddings, we don't load the text encoders.
|
| 12 |
+
* We load the VAE and use it to precompute the image latents and we then delete it.
|
| 13 |
+
* Load the Flux transformer, quantize it with the [NF4 datatype](https://huggingface.co/papers/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
|
| 14 |
+
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
|
| 15 |
+
* Train!
|
| 16 |
+
|
| 17 |
+
To run training in a memory-optimized manner, we additionally use:
|
| 18 |
+
|
| 19 |
+
* 8Bit Adam
|
| 20 |
+
* Gradient checkpointing
|
| 21 |
+
|
| 22 |
+
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
|
| 23 |
+
|
| 24 |
+
## Training
|
| 25 |
+
|
| 26 |
+
Ensure you have installed the required libraries:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip install -U transformers accelerate bitsandbytes peft datasets
|
| 30 |
+
pip install git+https://github.com/huggingface/diffusers -U
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Now, compute the text embeddings:
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
python compute_embeddings.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
huggingface-cli
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Then launch:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
accelerate launch --config_file=accelerate.yaml \
|
| 49 |
+
train_dreambooth_lora_flux_miniature.py \
|
| 50 |
+
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
| 51 |
+
--data_df_path="embeddings.parquet" \
|
| 52 |
+
--output_dir="yarn_art_lora_flux_nf4" \
|
| 53 |
+
--mixed_precision="fp16" \
|
| 54 |
+
--use_8bit_adam \
|
| 55 |
+
--weighting_scheme="none" \
|
| 56 |
+
--resolution=1024 \
|
| 57 |
+
--train_batch_size=1 \
|
| 58 |
+
--repeats=1 \
|
| 59 |
+
--learning_rate=1e-4 \
|
| 60 |
+
--guidance_scale=1 \
|
| 61 |
+
--report_to="wandb" \
|
| 62 |
+
--gradient_accumulation_steps=4 \
|
| 63 |
+
--gradient_checkpointing \
|
| 64 |
+
--lr_scheduler="constant" \
|
| 65 |
+
--lr_warmup_steps=0 \
|
| 66 |
+
--cache_latents \
|
| 67 |
+
--rank=4 \
|
| 68 |
+
--max_train_steps=700 \
|
| 69 |
+
--seed="0"
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
We can directly pass a quantized checkpoint path, too:
|
| 73 |
+
|
| 74 |
+
```diff
|
| 75 |
+
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
|
| 79 |
+
|
| 80 |
+
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
pip install -Uq deepspeed
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
And then launch:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
accelerate launch --config_file=ds2.yaml \
|
| 90 |
+
train_dreambooth_lora_flux_miniature.py \
|
| 91 |
+
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
| 92 |
+
--data_df_path="embeddings.parquet" \
|
| 93 |
+
--output_dir="yarn_art_lora_flux_nf4" \
|
| 94 |
+
--mixed_precision="no" \
|
| 95 |
+
--use_8bit_adam \
|
| 96 |
+
--weighting_scheme="none" \
|
| 97 |
+
--resolution=1024 \
|
| 98 |
+
--train_batch_size=1 \
|
| 99 |
+
--repeats=1 \
|
| 100 |
+
--learning_rate=1e-4 \
|
| 101 |
+
--guidance_scale=1 \
|
| 102 |
+
--report_to="wandb" \
|
| 103 |
+
--gradient_accumulation_steps=4 \
|
| 104 |
+
--gradient_checkpointing \
|
| 105 |
+
--lr_scheduler="constant" \
|
| 106 |
+
--lr_warmup_steps=0 \
|
| 107 |
+
--cache_latents \
|
| 108 |
+
--rank=4 \
|
| 109 |
+
--max_train_steps=700 \
|
| 110 |
+
--seed="0"
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## Inference
|
| 114 |
+
|
| 115 |
+
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
|
| 116 |
+
|
| 117 |
+
1. First, load the original model and merge the LoRA params into it:
|
| 118 |
+
|
| 119 |
+
```py
|
| 120 |
+
from diffusers import FluxPipeline
|
| 121 |
+
import torch
|
| 122 |
+
|
| 123 |
+
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
| 124 |
+
pipeline = FluxPipeline.from_pretrained(
|
| 125 |
+
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
|
| 126 |
+
)
|
| 127 |
+
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
|
| 128 |
+
pipeline.fuse_lora()
|
| 129 |
+
pipeline.unload_lora_weights()
|
| 130 |
+
|
| 131 |
+
pipeline.transformer.save_pretrained("fused_transformer")
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
2. Quantize the model and run inference
|
| 135 |
+
|
| 136 |
+
```py
|
| 137 |
+
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
|
| 138 |
+
import torch
|
| 139 |
+
|
| 140 |
+
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
| 141 |
+
bnb_4bit_compute_dtype = torch.float16
|
| 142 |
+
nf4_config = BitsAndBytesConfig(
|
| 143 |
+
load_in_4bit=True,
|
| 144 |
+
bnb_4bit_quant_type="nf4",
|
| 145 |
+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
| 146 |
+
)
|
| 147 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 148 |
+
"fused_transformer",
|
| 149 |
+
quantization_config=nf4_config,
|
| 150 |
+
torch_dtype=bnb_4bit_compute_dtype,
|
| 151 |
+
)
|
| 152 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
| 153 |
+
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
|
| 154 |
+
)
|
| 155 |
+
pipeline.enable_model_cpu_offload()
|
| 156 |
+
|
| 157 |
+
image = pipeline(
|
| 158 |
+
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
|
| 159 |
+
).images[0]
|
| 160 |
+
image.save("yarn_merged.png")
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
| Dequantize, merge, quantize | Merging directly into quantized model |
|
| 164 |
+
|-------|-------|
|
| 165 |
+
|  |  |
|
| 166 |
+
|
| 167 |
+
As we can notice the first column result follows the style more closely.
|
diffusers/examples/research_projects/flux_lora_quantization/accelerate.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: NO
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: true
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 1
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
diffusers/examples/research_projects/flux_lora_quantization/compute_embeddings.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import torch
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
from huggingface_hub.utils import insecure_hashlib
|
| 23 |
+
from tqdm.auto import tqdm
|
| 24 |
+
from transformers import T5EncoderModel
|
| 25 |
+
|
| 26 |
+
from diffusers import FluxPipeline
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MAX_SEQ_LENGTH = 77
|
| 30 |
+
OUTPUT_PATH = "embeddings.parquet"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def generate_image_hash(image):
|
| 34 |
+
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_flux_dev_pipeline():
|
| 38 |
+
id = "black-forest-labs/FLUX.1-dev"
|
| 39 |
+
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
|
| 40 |
+
pipeline = FluxPipeline.from_pretrained(
|
| 41 |
+
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
|
| 42 |
+
)
|
| 43 |
+
return pipeline
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def compute_embeddings(pipeline, prompts, max_sequence_length):
|
| 48 |
+
all_prompt_embeds = []
|
| 49 |
+
all_pooled_prompt_embeds = []
|
| 50 |
+
all_text_ids = []
|
| 51 |
+
for prompt in tqdm(prompts, desc="Encoding prompts."):
|
| 52 |
+
(
|
| 53 |
+
prompt_embeds,
|
| 54 |
+
pooled_prompt_embeds,
|
| 55 |
+
text_ids,
|
| 56 |
+
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
|
| 57 |
+
all_prompt_embeds.append(prompt_embeds)
|
| 58 |
+
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
|
| 59 |
+
all_text_ids.append(text_ids)
|
| 60 |
+
|
| 61 |
+
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
|
| 62 |
+
print(f"Max memory allocated: {max_memory:.3f} GB")
|
| 63 |
+
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def run(args):
|
| 67 |
+
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
|
| 68 |
+
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
|
| 69 |
+
all_prompts = list(image_prompts.values())
|
| 70 |
+
print(f"{len(all_prompts)=}")
|
| 71 |
+
|
| 72 |
+
pipeline = load_flux_dev_pipeline()
|
| 73 |
+
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
|
| 74 |
+
pipeline, all_prompts, args.max_sequence_length
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
data = []
|
| 78 |
+
for i, (image_hash, _) in enumerate(image_prompts.items()):
|
| 79 |
+
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
|
| 80 |
+
print(f"{len(data)=}")
|
| 81 |
+
|
| 82 |
+
# Create a DataFrame
|
| 83 |
+
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
|
| 84 |
+
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
|
| 85 |
+
print(f"{len(df)=}")
|
| 86 |
+
|
| 87 |
+
# Convert embedding lists to arrays (for proper storage in parquet)
|
| 88 |
+
for col in embedding_cols:
|
| 89 |
+
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
|
| 90 |
+
|
| 91 |
+
# Save the dataframe to a parquet file
|
| 92 |
+
df.to_parquet(args.output_path)
|
| 93 |
+
print(f"Data successfully serialized to {args.output_path}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
parser = argparse.ArgumentParser()
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--max_sequence_length",
|
| 100 |
+
type=int,
|
| 101 |
+
default=MAX_SEQ_LENGTH,
|
| 102 |
+
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
|
| 107 |
+
run(args)
|
diffusers/examples/research_projects/flux_lora_quantization/ds2.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
gradient_accumulation_steps: 1
|
| 5 |
+
gradient_clipping: 1.0
|
| 6 |
+
offload_optimizer_device: cpu
|
| 7 |
+
offload_param_device: cpu
|
| 8 |
+
zero3_init_flag: false
|
| 9 |
+
zero_stage: 2
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
enable_cpu_affinity: false
|
| 13 |
+
machine_rank: 0
|
| 14 |
+
main_training_function: main
|
| 15 |
+
mixed_precision: 'no'
|
| 16 |
+
num_machines: 1
|
| 17 |
+
num_processes: 1
|
| 18 |
+
rdzv_backend: static
|
| 19 |
+
same_network: true
|
| 20 |
+
tpu_env: []
|
| 21 |
+
tpu_use_cluster: false
|
| 22 |
+
tpu_use_sudo: false
|
| 23 |
+
use_cpu: false
|
diffusers/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
ADDED
|
@@ -0,0 +1,1200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import copy
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pandas as pd
|
| 27 |
+
import torch
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
import transformers
|
| 30 |
+
from accelerate import Accelerator, DistributedType
|
| 31 |
+
from accelerate.logging import get_logger
|
| 32 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from huggingface_hub import create_repo, upload_folder
|
| 35 |
+
from huggingface_hub.utils import insecure_hashlib
|
| 36 |
+
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
|
| 37 |
+
from peft.utils import get_peft_model_state_dict
|
| 38 |
+
from PIL.ImageOps import exif_transpose
|
| 39 |
+
from torch.utils.data import Dataset
|
| 40 |
+
from torchvision import transforms
|
| 41 |
+
from torchvision.transforms.functional import crop
|
| 42 |
+
from tqdm.auto import tqdm
|
| 43 |
+
|
| 44 |
+
import diffusers
|
| 45 |
+
from diffusers import (
|
| 46 |
+
AutoencoderKL,
|
| 47 |
+
BitsAndBytesConfig,
|
| 48 |
+
FlowMatchEulerDiscreteScheduler,
|
| 49 |
+
FluxPipeline,
|
| 50 |
+
FluxTransformer2DModel,
|
| 51 |
+
)
|
| 52 |
+
from diffusers.optimization import get_scheduler
|
| 53 |
+
from diffusers.training_utils import (
|
| 54 |
+
cast_training_params,
|
| 55 |
+
compute_density_for_timestep_sampling,
|
| 56 |
+
compute_loss_weighting_for_sd3,
|
| 57 |
+
free_memory,
|
| 58 |
+
)
|
| 59 |
+
from diffusers.utils import (
|
| 60 |
+
check_min_version,
|
| 61 |
+
convert_unet_state_dict_to_peft,
|
| 62 |
+
is_wandb_available,
|
| 63 |
+
)
|
| 64 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 65 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if is_wandb_available():
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 72 |
+
check_min_version("0.31.0.dev0")
|
| 73 |
+
|
| 74 |
+
logger = get_logger(__name__)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def save_model_card(
|
| 78 |
+
repo_id: str,
|
| 79 |
+
base_model: str = None,
|
| 80 |
+
instance_prompt=None,
|
| 81 |
+
repo_folder=None,
|
| 82 |
+
quantization_config=None,
|
| 83 |
+
):
|
| 84 |
+
widget_dict = []
|
| 85 |
+
|
| 86 |
+
model_description = f"""
|
| 87 |
+
# Flux DreamBooth LoRA - {repo_id}
|
| 88 |
+
|
| 89 |
+
<Gallery />
|
| 90 |
+
|
| 91 |
+
## Model description
|
| 92 |
+
|
| 93 |
+
These are {repo_id} DreamBooth LoRA weights for {base_model}.
|
| 94 |
+
|
| 95 |
+
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
|
| 96 |
+
|
| 97 |
+
Was LoRA for the text encoder enabled? False.
|
| 98 |
+
|
| 99 |
+
Quantization config:
|
| 100 |
+
|
| 101 |
+
```yaml
|
| 102 |
+
{quantization_config}
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Trigger words
|
| 106 |
+
|
| 107 |
+
You should use `{instance_prompt}` to trigger the image generation.
|
| 108 |
+
|
| 109 |
+
## Download model
|
| 110 |
+
|
| 111 |
+
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
|
| 112 |
+
|
| 113 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
| 114 |
+
|
| 115 |
+
## Usage
|
| 116 |
+
|
| 117 |
+
TODO
|
| 118 |
+
|
| 119 |
+
## License
|
| 120 |
+
|
| 121 |
+
Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
|
| 122 |
+
"""
|
| 123 |
+
model_card = load_or_create_model_card(
|
| 124 |
+
repo_id_or_path=repo_id,
|
| 125 |
+
from_training=True,
|
| 126 |
+
license="other",
|
| 127 |
+
base_model=base_model,
|
| 128 |
+
prompt=instance_prompt,
|
| 129 |
+
model_description=model_description,
|
| 130 |
+
widget=widget_dict,
|
| 131 |
+
)
|
| 132 |
+
tags = [
|
| 133 |
+
"text-to-image",
|
| 134 |
+
"diffusers-training",
|
| 135 |
+
"diffusers",
|
| 136 |
+
"lora",
|
| 137 |
+
"flux",
|
| 138 |
+
"flux-diffusers",
|
| 139 |
+
"template:sd-lora",
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
model_card = populate_model_card(model_card, tags=tags)
|
| 143 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def parse_args(input_args=None):
|
| 147 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--pretrained_model_name_or_path",
|
| 150 |
+
type=str,
|
| 151 |
+
default=None,
|
| 152 |
+
required=True,
|
| 153 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--quantized_model_path",
|
| 157 |
+
type=str,
|
| 158 |
+
default=None,
|
| 159 |
+
help="Path to the quantized model.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--revision",
|
| 163 |
+
type=str,
|
| 164 |
+
default=None,
|
| 165 |
+
required=False,
|
| 166 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--variant",
|
| 170 |
+
type=str,
|
| 171 |
+
default=None,
|
| 172 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--data_df_path",
|
| 176 |
+
type=str,
|
| 177 |
+
default=None,
|
| 178 |
+
help=("Path to the parquet file serialized with compute_embeddings.py."),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--cache_dir",
|
| 183 |
+
type=str,
|
| 184 |
+
default=None,
|
| 185 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
| 188 |
+
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--max_sequence_length",
|
| 191 |
+
type=int,
|
| 192 |
+
default=77,
|
| 193 |
+
help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--rank",
|
| 198 |
+
type=int,
|
| 199 |
+
default=4,
|
| 200 |
+
help=("The dimension of the LoRA update matrices."),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--output_dir",
|
| 205 |
+
type=str,
|
| 206 |
+
default="flux-dreambooth-lora-nf4",
|
| 207 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--resolution",
|
| 212 |
+
type=int,
|
| 213 |
+
default=512,
|
| 214 |
+
help=(
|
| 215 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 216 |
+
" resolution"
|
| 217 |
+
),
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--center_crop",
|
| 221 |
+
default=False,
|
| 222 |
+
action="store_true",
|
| 223 |
+
help=(
|
| 224 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 225 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--random_flip",
|
| 230 |
+
action="store_true",
|
| 231 |
+
help="whether to randomly flip images horizontally",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--max_train_steps",
|
| 242 |
+
type=int,
|
| 243 |
+
default=None,
|
| 244 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--checkpointing_steps",
|
| 248 |
+
type=int,
|
| 249 |
+
default=500,
|
| 250 |
+
help=(
|
| 251 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 252 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 253 |
+
" training using `--resume_from_checkpoint`."
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--checkpoints_total_limit",
|
| 258 |
+
type=int,
|
| 259 |
+
default=None,
|
| 260 |
+
help=("Max number of checkpoints to store."),
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--resume_from_checkpoint",
|
| 264 |
+
type=str,
|
| 265 |
+
default=None,
|
| 266 |
+
help=(
|
| 267 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 268 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 269 |
+
),
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--gradient_accumulation_steps",
|
| 273 |
+
type=int,
|
| 274 |
+
default=1,
|
| 275 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--gradient_checkpointing",
|
| 279 |
+
action="store_true",
|
| 280 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--learning_rate",
|
| 284 |
+
type=float,
|
| 285 |
+
default=1e-4,
|
| 286 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--guidance_scale",
|
| 291 |
+
type=float,
|
| 292 |
+
default=3.5,
|
| 293 |
+
help="the FLUX.1 dev variant is a guidance distilled model",
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--scale_lr",
|
| 298 |
+
action="store_true",
|
| 299 |
+
default=False,
|
| 300 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--lr_scheduler",
|
| 304 |
+
type=str,
|
| 305 |
+
default="constant",
|
| 306 |
+
help=(
|
| 307 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 308 |
+
' "constant", "constant_with_warmup"]'
|
| 309 |
+
),
|
| 310 |
+
)
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--lr_num_cycles",
|
| 316 |
+
type=int,
|
| 317 |
+
default=1,
|
| 318 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 319 |
+
)
|
| 320 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
| 321 |
+
parser.add_argument(
|
| 322 |
+
"--dataloader_num_workers",
|
| 323 |
+
type=int,
|
| 324 |
+
default=0,
|
| 325 |
+
help=(
|
| 326 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 327 |
+
),
|
| 328 |
+
)
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--weighting_scheme",
|
| 331 |
+
type=str,
|
| 332 |
+
default="none",
|
| 333 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
| 334 |
+
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--mode_scale",
|
| 344 |
+
type=float,
|
| 345 |
+
default=1.29,
|
| 346 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--optimizer",
|
| 350 |
+
type=str,
|
| 351 |
+
default="AdamW",
|
| 352 |
+
choices=["AdamW", "Prodigy", "AdEMAMix"],
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
parser.add_argument(
|
| 356 |
+
"--use_8bit_adam",
|
| 357 |
+
action="store_true",
|
| 358 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--use_8bit_ademamix",
|
| 362 |
+
action="store_true",
|
| 363 |
+
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--prodigy_beta3",
|
| 374 |
+
type=float,
|
| 375 |
+
default=None,
|
| 376 |
+
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
|
| 377 |
+
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
| 378 |
+
)
|
| 379 |
+
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
| 380 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
| 381 |
+
|
| 382 |
+
parser.add_argument(
|
| 383 |
+
"--adam_epsilon",
|
| 384 |
+
type=float,
|
| 385 |
+
default=1e-08,
|
| 386 |
+
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--prodigy_use_bias_correction",
|
| 391 |
+
type=bool,
|
| 392 |
+
default=True,
|
| 393 |
+
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
|
| 394 |
+
)
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--prodigy_safeguard_warmup",
|
| 397 |
+
type=bool,
|
| 398 |
+
default=True,
|
| 399 |
+
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
|
| 400 |
+
"Ignored if optimizer is adamW",
|
| 401 |
+
)
|
| 402 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 403 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 404 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 405 |
+
parser.add_argument(
|
| 406 |
+
"--hub_model_id",
|
| 407 |
+
type=str,
|
| 408 |
+
default=None,
|
| 409 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 410 |
+
)
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--logging_dir",
|
| 413 |
+
type=str,
|
| 414 |
+
default="logs",
|
| 415 |
+
help=(
|
| 416 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 417 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 418 |
+
),
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
parser.add_argument(
|
| 422 |
+
"--cache_latents",
|
| 423 |
+
action="store_true",
|
| 424 |
+
default=False,
|
| 425 |
+
help="Cache the VAE latents",
|
| 426 |
+
)
|
| 427 |
+
parser.add_argument(
|
| 428 |
+
"--report_to",
|
| 429 |
+
type=str,
|
| 430 |
+
default="tensorboard",
|
| 431 |
+
help=(
|
| 432 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 433 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 434 |
+
),
|
| 435 |
+
)
|
| 436 |
+
parser.add_argument(
|
| 437 |
+
"--mixed_precision",
|
| 438 |
+
type=str,
|
| 439 |
+
default=None,
|
| 440 |
+
choices=["no", "fp16", "bf16"],
|
| 441 |
+
help=(
|
| 442 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 443 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 444 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 445 |
+
),
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 448 |
+
|
| 449 |
+
if input_args is not None:
|
| 450 |
+
args = parser.parse_args(input_args)
|
| 451 |
+
else:
|
| 452 |
+
args = parser.parse_args()
|
| 453 |
+
|
| 454 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 455 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 456 |
+
args.local_rank = env_local_rank
|
| 457 |
+
|
| 458 |
+
return args
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class DreamBoothDataset(Dataset):
|
| 462 |
+
def __init__(
|
| 463 |
+
self,
|
| 464 |
+
data_df_path,
|
| 465 |
+
dataset_name,
|
| 466 |
+
size=1024,
|
| 467 |
+
max_sequence_length=77,
|
| 468 |
+
center_crop=False,
|
| 469 |
+
):
|
| 470 |
+
# Logistics
|
| 471 |
+
self.size = size
|
| 472 |
+
self.center_crop = center_crop
|
| 473 |
+
self.max_sequence_length = max_sequence_length
|
| 474 |
+
|
| 475 |
+
self.data_df_path = Path(data_df_path)
|
| 476 |
+
if not self.data_df_path.exists():
|
| 477 |
+
raise ValueError("`data_df_path` doesn't exists.")
|
| 478 |
+
|
| 479 |
+
# Load images.
|
| 480 |
+
dataset = load_dataset(dataset_name, split="train")
|
| 481 |
+
instance_images = [sample["image"] for sample in dataset]
|
| 482 |
+
image_hashes = [self.generate_image_hash(image) for image in instance_images]
|
| 483 |
+
self.instance_images = instance_images
|
| 484 |
+
self.image_hashes = image_hashes
|
| 485 |
+
|
| 486 |
+
# Image transformations
|
| 487 |
+
self.pixel_values = self.apply_image_transformations(
|
| 488 |
+
instance_images=instance_images, size=size, center_crop=center_crop
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Map hashes to embeddings.
|
| 492 |
+
self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)
|
| 493 |
+
|
| 494 |
+
self.num_instance_images = len(instance_images)
|
| 495 |
+
self._length = self.num_instance_images
|
| 496 |
+
|
| 497 |
+
def __len__(self):
|
| 498 |
+
return self._length
|
| 499 |
+
|
| 500 |
+
def __getitem__(self, index):
|
| 501 |
+
example = {}
|
| 502 |
+
instance_image = self.pixel_values[index % self.num_instance_images]
|
| 503 |
+
image_hash = self.image_hashes[index % self.num_instance_images]
|
| 504 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash]
|
| 505 |
+
example["instance_images"] = instance_image
|
| 506 |
+
example["prompt_embeds"] = prompt_embeds
|
| 507 |
+
example["pooled_prompt_embeds"] = pooled_prompt_embeds
|
| 508 |
+
example["text_ids"] = text_ids
|
| 509 |
+
return example
|
| 510 |
+
|
| 511 |
+
def apply_image_transformations(self, instance_images, size, center_crop):
|
| 512 |
+
pixel_values = []
|
| 513 |
+
|
| 514 |
+
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
| 515 |
+
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
| 516 |
+
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 517 |
+
train_transforms = transforms.Compose(
|
| 518 |
+
[
|
| 519 |
+
transforms.ToTensor(),
|
| 520 |
+
transforms.Normalize([0.5], [0.5]),
|
| 521 |
+
]
|
| 522 |
+
)
|
| 523 |
+
for image in instance_images:
|
| 524 |
+
image = exif_transpose(image)
|
| 525 |
+
if not image.mode == "RGB":
|
| 526 |
+
image = image.convert("RGB")
|
| 527 |
+
image = train_resize(image)
|
| 528 |
+
if args.random_flip and random.random() < 0.5:
|
| 529 |
+
# flip
|
| 530 |
+
image = train_flip(image)
|
| 531 |
+
if args.center_crop:
|
| 532 |
+
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
| 533 |
+
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
| 534 |
+
image = train_crop(image)
|
| 535 |
+
else:
|
| 536 |
+
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
| 537 |
+
image = crop(image, y1, x1, h, w)
|
| 538 |
+
image = train_transforms(image)
|
| 539 |
+
pixel_values.append(image)
|
| 540 |
+
|
| 541 |
+
return pixel_values
|
| 542 |
+
|
| 543 |
+
def convert_to_torch_tensor(self, embeddings: list):
|
| 544 |
+
prompt_embeds = embeddings[0]
|
| 545 |
+
pooled_prompt_embeds = embeddings[1]
|
| 546 |
+
text_ids = embeddings[2]
|
| 547 |
+
prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, 4096)
|
| 548 |
+
pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(768)
|
| 549 |
+
text_ids = np.array(text_ids).reshape(77, 3)
|
| 550 |
+
return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids)
|
| 551 |
+
|
| 552 |
+
def map_image_hash_embedding(self, data_df_path):
|
| 553 |
+
hashes_df = pd.read_parquet(data_df_path)
|
| 554 |
+
data_dict = {}
|
| 555 |
+
for i, row in hashes_df.iterrows():
|
| 556 |
+
embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"], row["text_ids"]]
|
| 557 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(embeddings=embeddings)
|
| 558 |
+
data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds, text_ids)})
|
| 559 |
+
return data_dict
|
| 560 |
+
|
| 561 |
+
def generate_image_hash(self, image):
|
| 562 |
+
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def collate_fn(examples):
|
| 566 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 567 |
+
prompt_embeds = [example["prompt_embeds"] for example in examples]
|
| 568 |
+
pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
|
| 569 |
+
text_ids = [example["text_ids"] for example in examples]
|
| 570 |
+
|
| 571 |
+
pixel_values = torch.stack(pixel_values)
|
| 572 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 573 |
+
prompt_embeds = torch.stack(prompt_embeds)
|
| 574 |
+
pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
|
| 575 |
+
text_ids = torch.stack(text_ids)[0] # just 2D tensor
|
| 576 |
+
|
| 577 |
+
batch = {
|
| 578 |
+
"pixel_values": pixel_values,
|
| 579 |
+
"prompt_embeds": prompt_embeds,
|
| 580 |
+
"pooled_prompt_embeds": pooled_prompt_embeds,
|
| 581 |
+
"text_ids": text_ids,
|
| 582 |
+
}
|
| 583 |
+
return batch
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def main(args):
|
| 587 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 588 |
+
raise ValueError(
|
| 589 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 590 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 594 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
| 595 |
+
raise ValueError(
|
| 596 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 600 |
+
|
| 601 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 602 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 603 |
+
accelerator = Accelerator(
|
| 604 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 605 |
+
mixed_precision=args.mixed_precision,
|
| 606 |
+
log_with=args.report_to,
|
| 607 |
+
project_config=accelerator_project_config,
|
| 608 |
+
kwargs_handlers=[kwargs],
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Disable AMP for MPS.
|
| 612 |
+
if torch.backends.mps.is_available():
|
| 613 |
+
accelerator.native_amp = False
|
| 614 |
+
|
| 615 |
+
if args.report_to == "wandb":
|
| 616 |
+
if not is_wandb_available():
|
| 617 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 618 |
+
|
| 619 |
+
# Make one log on every process with the configuration for debugging.
|
| 620 |
+
logging.basicConfig(
|
| 621 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 622 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 623 |
+
level=logging.INFO,
|
| 624 |
+
)
|
| 625 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 626 |
+
if accelerator.is_local_main_process:
|
| 627 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 628 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 629 |
+
else:
|
| 630 |
+
transformers.utils.logging.set_verbosity_error()
|
| 631 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 632 |
+
|
| 633 |
+
# If passed along, set the training seed now.
|
| 634 |
+
if args.seed is not None:
|
| 635 |
+
set_seed(args.seed)
|
| 636 |
+
|
| 637 |
+
# Handle the repository creation
|
| 638 |
+
if accelerator.is_main_process:
|
| 639 |
+
if args.output_dir is not None:
|
| 640 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 641 |
+
|
| 642 |
+
if args.push_to_hub:
|
| 643 |
+
repo_id = create_repo(
|
| 644 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
| 645 |
+
exist_ok=True,
|
| 646 |
+
).repo_id
|
| 647 |
+
|
| 648 |
+
# Load scheduler and models
|
| 649 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 650 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
| 651 |
+
)
|
| 652 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 653 |
+
vae = AutoencoderKL.from_pretrained(
|
| 654 |
+
args.pretrained_model_name_or_path,
|
| 655 |
+
subfolder="vae",
|
| 656 |
+
revision=args.revision,
|
| 657 |
+
variant=args.variant,
|
| 658 |
+
)
|
| 659 |
+
bnb_4bit_compute_dtype = torch.float32
|
| 660 |
+
if args.mixed_precision == "fp16":
|
| 661 |
+
bnb_4bit_compute_dtype = torch.float16
|
| 662 |
+
elif args.mixed_precision == "bf16":
|
| 663 |
+
bnb_4bit_compute_dtype = torch.bfloat16
|
| 664 |
+
if args.quantized_model_path is not None:
|
| 665 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 666 |
+
args.quantized_model_path,
|
| 667 |
+
subfolder="transformer",
|
| 668 |
+
revision=args.revision,
|
| 669 |
+
variant=args.variant,
|
| 670 |
+
torch_dtype=bnb_4bit_compute_dtype,
|
| 671 |
+
)
|
| 672 |
+
else:
|
| 673 |
+
nf4_config = BitsAndBytesConfig(
|
| 674 |
+
load_in_4bit=True,
|
| 675 |
+
bnb_4bit_quant_type="nf4",
|
| 676 |
+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
| 677 |
+
)
|
| 678 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 679 |
+
args.pretrained_model_name_or_path,
|
| 680 |
+
subfolder="transformer",
|
| 681 |
+
revision=args.revision,
|
| 682 |
+
variant=args.variant,
|
| 683 |
+
quantization_config=nf4_config,
|
| 684 |
+
torch_dtype=bnb_4bit_compute_dtype,
|
| 685 |
+
)
|
| 686 |
+
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
|
| 687 |
+
|
| 688 |
+
# We only train the additional adapter LoRA layers
|
| 689 |
+
transformer.requires_grad_(False)
|
| 690 |
+
vae.requires_grad_(False)
|
| 691 |
+
|
| 692 |
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
| 693 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 694 |
+
weight_dtype = torch.float32
|
| 695 |
+
if accelerator.mixed_precision == "fp16":
|
| 696 |
+
weight_dtype = torch.float16
|
| 697 |
+
elif accelerator.mixed_precision == "bf16":
|
| 698 |
+
weight_dtype = torch.bfloat16
|
| 699 |
+
|
| 700 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 701 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
| 702 |
+
raise ValueError(
|
| 703 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 707 |
+
if args.gradient_checkpointing:
|
| 708 |
+
transformer.enable_gradient_checkpointing()
|
| 709 |
+
|
| 710 |
+
# now we will add new LoRA weights to the attention layers
|
| 711 |
+
transformer_lora_config = LoraConfig(
|
| 712 |
+
r=args.rank,
|
| 713 |
+
lora_alpha=args.rank,
|
| 714 |
+
init_lora_weights="gaussian",
|
| 715 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
| 716 |
+
)
|
| 717 |
+
transformer.add_adapter(transformer_lora_config)
|
| 718 |
+
|
| 719 |
+
def unwrap_model(model):
|
| 720 |
+
model = accelerator.unwrap_model(model)
|
| 721 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 722 |
+
return model
|
| 723 |
+
|
| 724 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 725 |
+
def save_model_hook(models, weights, output_dir):
|
| 726 |
+
if accelerator.is_main_process:
|
| 727 |
+
transformer_lora_layers_to_save = None
|
| 728 |
+
|
| 729 |
+
for model in models:
|
| 730 |
+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
| 731 |
+
model = unwrap_model(model)
|
| 732 |
+
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
| 733 |
+
else:
|
| 734 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 735 |
+
|
| 736 |
+
# make sure to pop weight so that corresponding model is not saved again
|
| 737 |
+
if weights:
|
| 738 |
+
weights.pop()
|
| 739 |
+
|
| 740 |
+
FluxPipeline.save_lora_weights(
|
| 741 |
+
output_dir,
|
| 742 |
+
transformer_lora_layers=transformer_lora_layers_to_save,
|
| 743 |
+
text_encoder_lora_layers=None,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def load_model_hook(models, input_dir):
|
| 747 |
+
transformer_ = None
|
| 748 |
+
|
| 749 |
+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
| 750 |
+
while len(models) > 0:
|
| 751 |
+
model = models.pop()
|
| 752 |
+
|
| 753 |
+
if isinstance(model, type(unwrap_model(transformer))):
|
| 754 |
+
transformer_ = model
|
| 755 |
+
else:
|
| 756 |
+
raise ValueError(f"unexpected save model: {model.__class__}")
|
| 757 |
+
else:
|
| 758 |
+
if args.quantized_model_path is not None:
|
| 759 |
+
transformer_ = FluxTransformer2DModel.from_pretrained(
|
| 760 |
+
args.quantized_model_path,
|
| 761 |
+
subfolder="transformer",
|
| 762 |
+
revision=args.revision,
|
| 763 |
+
variant=args.variant,
|
| 764 |
+
torch_dtype=bnb_4bit_compute_dtype,
|
| 765 |
+
)
|
| 766 |
+
else:
|
| 767 |
+
nf4_config = BitsAndBytesConfig(
|
| 768 |
+
load_in_4bit=True,
|
| 769 |
+
bnb_4bit_quant_type="nf4",
|
| 770 |
+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
| 771 |
+
)
|
| 772 |
+
transformer_ = FluxTransformer2DModel.from_pretrained(
|
| 773 |
+
args.pretrained_model_name_or_path,
|
| 774 |
+
subfolder="transformer",
|
| 775 |
+
revision=args.revision,
|
| 776 |
+
variant=args.variant,
|
| 777 |
+
quantization_config=nf4_config,
|
| 778 |
+
torch_dtype=bnb_4bit_compute_dtype,
|
| 779 |
+
)
|
| 780 |
+
transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False)
|
| 781 |
+
transformer_.add_adapter(transformer_lora_config)
|
| 782 |
+
|
| 783 |
+
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
|
| 784 |
+
|
| 785 |
+
transformer_state_dict = {
|
| 786 |
+
f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
| 787 |
+
}
|
| 788 |
+
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
| 789 |
+
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
| 790 |
+
if incompatible_keys is not None:
|
| 791 |
+
# check only for unexpected keys
|
| 792 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
| 793 |
+
if unexpected_keys:
|
| 794 |
+
logger.warning(
|
| 795 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
| 796 |
+
f" {unexpected_keys}. "
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# Make sure the trainable params are in float32. This is again needed since the base models
|
| 800 |
+
# are in `weight_dtype`. More details:
|
| 801 |
+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
| 802 |
+
if args.mixed_precision == "fp16":
|
| 803 |
+
models = [transformer_]
|
| 804 |
+
# only upcast trainable parameters (LoRA) into fp32
|
| 805 |
+
cast_training_params(models)
|
| 806 |
+
|
| 807 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 808 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 809 |
+
|
| 810 |
+
if args.scale_lr:
|
| 811 |
+
args.learning_rate = (
|
| 812 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# Make sure the trainable params are in float32.
|
| 816 |
+
if args.mixed_precision == "fp16":
|
| 817 |
+
models = [transformer]
|
| 818 |
+
# only upcast trainable parameters (LoRA) into fp32
|
| 819 |
+
cast_training_params(models, dtype=torch.float32)
|
| 820 |
+
|
| 821 |
+
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
| 822 |
+
|
| 823 |
+
# Optimization parameters
|
| 824 |
+
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
| 825 |
+
params_to_optimize = [transformer_parameters_with_lr]
|
| 826 |
+
|
| 827 |
+
# Optimizer creation
|
| 828 |
+
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
| 829 |
+
logger.warning(
|
| 830 |
+
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
| 831 |
+
f"set to {args.optimizer.lower()}"
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
|
| 835 |
+
logger.warning(
|
| 836 |
+
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
|
| 837 |
+
f"set to {args.optimizer.lower()}"
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
if args.optimizer.lower() == "adamw":
|
| 841 |
+
if args.use_8bit_adam:
|
| 842 |
+
try:
|
| 843 |
+
import bitsandbytes as bnb
|
| 844 |
+
except ImportError:
|
| 845 |
+
raise ImportError(
|
| 846 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 850 |
+
else:
|
| 851 |
+
optimizer_class = torch.optim.AdamW
|
| 852 |
+
|
| 853 |
+
optimizer = optimizer_class(
|
| 854 |
+
params_to_optimize,
|
| 855 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 856 |
+
weight_decay=args.adam_weight_decay,
|
| 857 |
+
eps=args.adam_epsilon,
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
elif args.optimizer.lower() == "ademamix":
|
| 861 |
+
try:
|
| 862 |
+
import bitsandbytes as bnb
|
| 863 |
+
except ImportError:
|
| 864 |
+
raise ImportError(
|
| 865 |
+
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
|
| 866 |
+
)
|
| 867 |
+
if args.use_8bit_ademamix:
|
| 868 |
+
optimizer_class = bnb.optim.AdEMAMix8bit
|
| 869 |
+
else:
|
| 870 |
+
optimizer_class = bnb.optim.AdEMAMix
|
| 871 |
+
|
| 872 |
+
optimizer = optimizer_class(params_to_optimize)
|
| 873 |
+
|
| 874 |
+
if args.optimizer.lower() == "prodigy":
|
| 875 |
+
try:
|
| 876 |
+
import prodigyopt
|
| 877 |
+
except ImportError:
|
| 878 |
+
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
| 879 |
+
|
| 880 |
+
optimizer_class = prodigyopt.Prodigy
|
| 881 |
+
|
| 882 |
+
if args.learning_rate <= 0.1:
|
| 883 |
+
logger.warning(
|
| 884 |
+
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
optimizer = optimizer_class(
|
| 888 |
+
params_to_optimize,
|
| 889 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 890 |
+
beta3=args.prodigy_beta3,
|
| 891 |
+
weight_decay=args.adam_weight_decay,
|
| 892 |
+
eps=args.adam_epsilon,
|
| 893 |
+
decouple=args.prodigy_decouple,
|
| 894 |
+
use_bias_correction=args.prodigy_use_bias_correction,
|
| 895 |
+
safeguard_warmup=args.prodigy_safeguard_warmup,
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
# Dataset and DataLoaders creation:
|
| 899 |
+
train_dataset = DreamBoothDataset(
|
| 900 |
+
data_df_path=args.data_df_path,
|
| 901 |
+
dataset_name="Norod78/Yarn-art-style",
|
| 902 |
+
size=args.resolution,
|
| 903 |
+
max_sequence_length=args.max_sequence_length,
|
| 904 |
+
center_crop=args.center_crop,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 908 |
+
train_dataset,
|
| 909 |
+
batch_size=args.train_batch_size,
|
| 910 |
+
shuffle=True,
|
| 911 |
+
collate_fn=collate_fn,
|
| 912 |
+
num_workers=args.dataloader_num_workers,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 916 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 917 |
+
vae_config_block_out_channels = vae.config.block_out_channels
|
| 918 |
+
if args.cache_latents:
|
| 919 |
+
latents_cache = []
|
| 920 |
+
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
| 921 |
+
with torch.no_grad():
|
| 922 |
+
batch["pixel_values"] = batch["pixel_values"].to(
|
| 923 |
+
accelerator.device, non_blocking=True, dtype=weight_dtype
|
| 924 |
+
)
|
| 925 |
+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
| 926 |
+
|
| 927 |
+
del vae
|
| 928 |
+
free_memory()
|
| 929 |
+
|
| 930 |
+
# Scheduler and math around the number of training steps.
|
| 931 |
+
overrode_max_train_steps = False
|
| 932 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 933 |
+
if args.max_train_steps is None:
|
| 934 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 935 |
+
overrode_max_train_steps = True
|
| 936 |
+
|
| 937 |
+
lr_scheduler = get_scheduler(
|
| 938 |
+
args.lr_scheduler,
|
| 939 |
+
optimizer=optimizer,
|
| 940 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 941 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 942 |
+
num_cycles=args.lr_num_cycles,
|
| 943 |
+
power=args.lr_power,
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
# Prepare everything with our `accelerator`.
|
| 947 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 948 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 952 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 953 |
+
if overrode_max_train_steps:
|
| 954 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 955 |
+
# Afterwards we recalculate our number of training epochs
|
| 956 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 957 |
+
|
| 958 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 959 |
+
# The trackers initializes automatically on the main process.
|
| 960 |
+
if accelerator.is_main_process:
|
| 961 |
+
tracker_name = "dreambooth-flux-dev-lora-nf4"
|
| 962 |
+
accelerator.init_trackers(tracker_name, config=vars(args))
|
| 963 |
+
|
| 964 |
+
# Train!
|
| 965 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 966 |
+
|
| 967 |
+
logger.info("***** Running training *****")
|
| 968 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 969 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 970 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 971 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 972 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 973 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 974 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 975 |
+
global_step = 0
|
| 976 |
+
first_epoch = 0
|
| 977 |
+
|
| 978 |
+
# Potentially load in the weights and states from a previous save
|
| 979 |
+
if args.resume_from_checkpoint:
|
| 980 |
+
if args.resume_from_checkpoint != "latest":
|
| 981 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 982 |
+
else:
|
| 983 |
+
# Get the mos recent checkpoint
|
| 984 |
+
dirs = os.listdir(args.output_dir)
|
| 985 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 986 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 987 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 988 |
+
|
| 989 |
+
if path is None:
|
| 990 |
+
accelerator.print(
|
| 991 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 992 |
+
)
|
| 993 |
+
args.resume_from_checkpoint = None
|
| 994 |
+
initial_global_step = 0
|
| 995 |
+
else:
|
| 996 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 997 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 998 |
+
global_step = int(path.split("-")[1])
|
| 999 |
+
|
| 1000 |
+
initial_global_step = global_step
|
| 1001 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 1002 |
+
|
| 1003 |
+
else:
|
| 1004 |
+
initial_global_step = 0
|
| 1005 |
+
|
| 1006 |
+
progress_bar = tqdm(
|
| 1007 |
+
range(0, args.max_train_steps),
|
| 1008 |
+
initial=initial_global_step,
|
| 1009 |
+
desc="Steps",
|
| 1010 |
+
# Only show the progress bar once on each machine.
|
| 1011 |
+
disable=not accelerator.is_local_main_process,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 1015 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 1016 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 1017 |
+
timesteps = timesteps.to(accelerator.device)
|
| 1018 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 1019 |
+
|
| 1020 |
+
sigma = sigmas[step_indices].flatten()
|
| 1021 |
+
while len(sigma.shape) < n_dim:
|
| 1022 |
+
sigma = sigma.unsqueeze(-1)
|
| 1023 |
+
return sigma
|
| 1024 |
+
|
| 1025 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 1026 |
+
transformer.train()
|
| 1027 |
+
|
| 1028 |
+
for step, batch in enumerate(train_dataloader):
|
| 1029 |
+
models_to_accumulate = [transformer]
|
| 1030 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 1031 |
+
# Convert images to latent space
|
| 1032 |
+
if args.cache_latents:
|
| 1033 |
+
model_input = latents_cache[step].sample()
|
| 1034 |
+
else:
|
| 1035 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 1036 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 1037 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 1038 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 1039 |
+
|
| 1040 |
+
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
| 1041 |
+
|
| 1042 |
+
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
| 1043 |
+
model_input.shape[0],
|
| 1044 |
+
model_input.shape[2] // 2,
|
| 1045 |
+
model_input.shape[3] // 2,
|
| 1046 |
+
accelerator.device,
|
| 1047 |
+
weight_dtype,
|
| 1048 |
+
)
|
| 1049 |
+
# Sample noise that we'll add to the latents
|
| 1050 |
+
noise = torch.randn_like(model_input)
|
| 1051 |
+
bsz = model_input.shape[0]
|
| 1052 |
+
|
| 1053 |
+
# Sample a random timestep for each image
|
| 1054 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
| 1055 |
+
u = compute_density_for_timestep_sampling(
|
| 1056 |
+
weighting_scheme=args.weighting_scheme,
|
| 1057 |
+
batch_size=bsz,
|
| 1058 |
+
logit_mean=args.logit_mean,
|
| 1059 |
+
logit_std=args.logit_std,
|
| 1060 |
+
mode_scale=args.mode_scale,
|
| 1061 |
+
)
|
| 1062 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 1063 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 1064 |
+
|
| 1065 |
+
# Add noise according to flow matching.
|
| 1066 |
+
# zt = (1 - texp) * x + texp * z1
|
| 1067 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 1068 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 1069 |
+
|
| 1070 |
+
packed_noisy_model_input = FluxPipeline._pack_latents(
|
| 1071 |
+
noisy_model_input,
|
| 1072 |
+
batch_size=model_input.shape[0],
|
| 1073 |
+
num_channels_latents=model_input.shape[1],
|
| 1074 |
+
height=model_input.shape[2],
|
| 1075 |
+
width=model_input.shape[3],
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
# handle guidance
|
| 1079 |
+
if unwrap_model(transformer).config.guidance_embeds:
|
| 1080 |
+
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
| 1081 |
+
guidance = guidance.expand(model_input.shape[0])
|
| 1082 |
+
else:
|
| 1083 |
+
guidance = None
|
| 1084 |
+
|
| 1085 |
+
# Predict the noise
|
| 1086 |
+
prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
|
| 1087 |
+
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
|
| 1088 |
+
text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype)
|
| 1089 |
+
model_pred = transformer(
|
| 1090 |
+
hidden_states=packed_noisy_model_input,
|
| 1091 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
| 1092 |
+
timestep=timesteps / 1000,
|
| 1093 |
+
guidance=guidance,
|
| 1094 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1095 |
+
encoder_hidden_states=prompt_embeds,
|
| 1096 |
+
txt_ids=text_ids,
|
| 1097 |
+
img_ids=latent_image_ids,
|
| 1098 |
+
return_dict=False,
|
| 1099 |
+
)[0]
|
| 1100 |
+
model_pred = FluxPipeline._unpack_latents(
|
| 1101 |
+
model_pred,
|
| 1102 |
+
height=model_input.shape[2] * vae_scale_factor,
|
| 1103 |
+
width=model_input.shape[3] * vae_scale_factor,
|
| 1104 |
+
vae_scale_factor=vae_scale_factor,
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
# these weighting schemes use a uniform timestep sampling
|
| 1108 |
+
# and instead post-weight the loss
|
| 1109 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 1110 |
+
|
| 1111 |
+
# flow matching loss
|
| 1112 |
+
target = noise - model_input
|
| 1113 |
+
|
| 1114 |
+
# Compute regular loss.
|
| 1115 |
+
loss = torch.mean(
|
| 1116 |
+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 1117 |
+
1,
|
| 1118 |
+
)
|
| 1119 |
+
loss = loss.mean()
|
| 1120 |
+
accelerator.backward(loss)
|
| 1121 |
+
|
| 1122 |
+
if accelerator.sync_gradients:
|
| 1123 |
+
params_to_clip = transformer.parameters()
|
| 1124 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 1125 |
+
|
| 1126 |
+
optimizer.step()
|
| 1127 |
+
lr_scheduler.step()
|
| 1128 |
+
optimizer.zero_grad()
|
| 1129 |
+
|
| 1130 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1131 |
+
if accelerator.sync_gradients:
|
| 1132 |
+
progress_bar.update(1)
|
| 1133 |
+
global_step += 1
|
| 1134 |
+
|
| 1135 |
+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
| 1136 |
+
if global_step % args.checkpointing_steps == 0:
|
| 1137 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1138 |
+
if args.checkpoints_total_limit is not None:
|
| 1139 |
+
checkpoints = os.listdir(args.output_dir)
|
| 1140 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 1141 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 1142 |
+
|
| 1143 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 1144 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 1145 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 1146 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 1147 |
+
|
| 1148 |
+
logger.info(
|
| 1149 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 1150 |
+
)
|
| 1151 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 1152 |
+
|
| 1153 |
+
for removing_checkpoint in removing_checkpoints:
|
| 1154 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 1155 |
+
shutil.rmtree(removing_checkpoint)
|
| 1156 |
+
|
| 1157 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 1158 |
+
accelerator.save_state(save_path)
|
| 1159 |
+
logger.info(f"Saved state to {save_path}")
|
| 1160 |
+
|
| 1161 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 1162 |
+
progress_bar.set_postfix(**logs)
|
| 1163 |
+
accelerator.log(logs, step=global_step)
|
| 1164 |
+
|
| 1165 |
+
if global_step >= args.max_train_steps:
|
| 1166 |
+
break
|
| 1167 |
+
|
| 1168 |
+
# Save the lora layers
|
| 1169 |
+
accelerator.wait_for_everyone()
|
| 1170 |
+
if accelerator.is_main_process:
|
| 1171 |
+
transformer = unwrap_model(transformer)
|
| 1172 |
+
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
| 1173 |
+
|
| 1174 |
+
FluxPipeline.save_lora_weights(
|
| 1175 |
+
save_directory=args.output_dir,
|
| 1176 |
+
transformer_lora_layers=transformer_lora_layers,
|
| 1177 |
+
text_encoder_lora_layers=None,
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
if args.push_to_hub:
|
| 1181 |
+
save_model_card(
|
| 1182 |
+
repo_id,
|
| 1183 |
+
base_model=args.pretrained_model_name_or_path,
|
| 1184 |
+
instance_prompt=None,
|
| 1185 |
+
repo_folder=args.output_dir,
|
| 1186 |
+
quantization_config=transformer.config["quantization_config"],
|
| 1187 |
+
)
|
| 1188 |
+
upload_folder(
|
| 1189 |
+
repo_id=repo_id,
|
| 1190 |
+
folder_path=args.output_dir,
|
| 1191 |
+
commit_message="End of training",
|
| 1192 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
accelerator.end_training()
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
if __name__ == "__main__":
|
| 1199 |
+
args = parse_args()
|
| 1200 |
+
main(args)
|
diffusers/examples/research_projects/intel_opts/README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Diffusers examples with Intel optimizations
|
| 2 |
+
|
| 3 |
+
**This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**
|
| 4 |
+
|
| 5 |
+
This aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.
|
| 6 |
+
|
| 7 |
+
## Accelerating the fine-tuning for textual inversion
|
| 8 |
+
|
| 9 |
+
We accelerate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
|
| 10 |
+
|
| 11 |
+
## Accelerating the inference for Stable Diffusion using Bfloat16
|
| 12 |
+
|
| 13 |
+
We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
|
| 14 |
+
```bash
|
| 15 |
+
pip install diffusers transformers accelerate scipy safetensors
|
| 16 |
+
|
| 17 |
+
export KMP_BLOCKTIME=1
|
| 18 |
+
export KMP_SETTINGS=1
|
| 19 |
+
export KMP_AFFINITY=granularity=fine,compact,1,0
|
| 20 |
+
|
| 21 |
+
# Intel OpenMP
|
| 22 |
+
export OMP_NUM_THREADS=< Cores to use >
|
| 23 |
+
export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so
|
| 24 |
+
# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
|
| 25 |
+
export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so
|
| 26 |
+
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000"
|
| 27 |
+
|
| 28 |
+
# Launch with default DDIM
|
| 29 |
+
numactl --membind <node N> -C <cpu list> python python inference_bf16.py
|
| 30 |
+
# Launch with DPMSolverMultistepScheduler
|
| 31 |
+
numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Accelerating the inference for Stable Diffusion using INT8
|
| 35 |
+
|
| 36 |
+
Coming soon ...
|
diffusers/examples/research_projects/intel_opts/inference_bf16.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import intel_extension_for_pytorch as ipex
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
|
| 10 |
+
parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
|
| 11 |
+
parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
|
| 12 |
+
args = parser.parse_args()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
device = "cpu"
|
| 16 |
+
prompt = "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brightly buildings"
|
| 17 |
+
|
| 18 |
+
model_id = "path-to-your-trained-model"
|
| 19 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
| 20 |
+
if args.dpm:
|
| 21 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
| 22 |
+
pipe = pipe.to(device)
|
| 23 |
+
|
| 24 |
+
# to channels last
|
| 25 |
+
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
|
| 26 |
+
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
|
| 27 |
+
pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
|
| 28 |
+
if pipe.requires_safety_checker:
|
| 29 |
+
pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)
|
| 30 |
+
|
| 31 |
+
# optimize with ipex
|
| 32 |
+
sample = torch.randn(2, 4, 64, 64)
|
| 33 |
+
timestep = torch.rand(1) * 999
|
| 34 |
+
encoder_hidden_status = torch.randn(2, 77, 768)
|
| 35 |
+
input_example = (sample, timestep, encoder_hidden_status)
|
| 36 |
+
try:
|
| 37 |
+
pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
|
| 38 |
+
except Exception:
|
| 39 |
+
pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
|
| 40 |
+
pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
|
| 41 |
+
pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
|
| 42 |
+
if pipe.requires_safety_checker:
|
| 43 |
+
pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
|
| 44 |
+
|
| 45 |
+
# compute
|
| 46 |
+
seed = 666
|
| 47 |
+
generator = torch.Generator(device).manual_seed(seed)
|
| 48 |
+
generate_kwargs = {"generator": generator}
|
| 49 |
+
if args.steps is not None:
|
| 50 |
+
generate_kwargs["num_inference_steps"] = args.steps
|
| 51 |
+
|
| 52 |
+
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
| 53 |
+
image = pipe(prompt, **generate_kwargs).images[0]
|
| 54 |
+
|
| 55 |
+
# save image
|
| 56 |
+
image.save("generated.png")
|
diffusers/examples/research_projects/intel_opts/textual_inversion/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Textual Inversion fine-tuning example
|
| 2 |
+
|
| 3 |
+
[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
|
| 4 |
+
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
| 5 |
+
|
| 6 |
+
## Training with Intel Extension for PyTorch
|
| 7 |
+
|
| 8 |
+
Intel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example "textual_inversion.py". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.
|
| 9 |
+
|
| 10 |
+
The example supports both single node and multi-node distributed training:
|
| 11 |
+
|
| 12 |
+
### Single node training
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
| 16 |
+
export DATA_DIR="path-to-dir-containing-dicoo-images"
|
| 17 |
+
|
| 18 |
+
python textual_inversion.py \
|
| 19 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 20 |
+
--train_data_dir=$DATA_DIR \
|
| 21 |
+
--learnable_property="object" \
|
| 22 |
+
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
| 23 |
+
--seed=7 \
|
| 24 |
+
--resolution=512 \
|
| 25 |
+
--train_batch_size=1 \
|
| 26 |
+
--gradient_accumulation_steps=1 \
|
| 27 |
+
--max_train_steps=3000 \
|
| 28 |
+
--learning_rate=2.5e-03 --scale_lr \
|
| 29 |
+
--output_dir="textual_inversion_dicoo"
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Note: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.
|
| 33 |
+
|
| 34 |
+
### Multi-node distributed training
|
| 35 |
+
|
| 36 |
+
Before running the scripts, make sure to install the library's training dependencies successfully:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
| 44 |
+
export DATA_DIR="path-to-dir-containing-dicoo-images"
|
| 45 |
+
|
| 46 |
+
oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
|
| 47 |
+
source $oneccl_bindings_for_pytorch_path/env/setvars.sh
|
| 48 |
+
|
| 49 |
+
python -m intel_extension_for_pytorch.cpu.launch --distributed \
|
| 50 |
+
--hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \
|
| 51 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 52 |
+
--train_data_dir=$DATA_DIR \
|
| 53 |
+
--learnable_property="object" \
|
| 54 |
+
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
| 55 |
+
--seed=7 \
|
| 56 |
+
--resolution=512 \
|
| 57 |
+
--train_batch_size=1 \
|
| 58 |
+
--gradient_accumulation_steps=1 \
|
| 59 |
+
--max_train_steps=750 \
|
| 60 |
+
--learning_rate=2.5e-03 --scale_lr \
|
| 61 |
+
--output_dir="textual_inversion_dicoo"
|
| 62 |
+
```
|
| 63 |
+
The above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the "hostfile" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
### Reference
|
| 67 |
+
|
| 68 |
+
We publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.
|
diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=0.16.0
|
| 2 |
+
torchvision
|
| 3 |
+
transformers>=4.21.0
|
| 4 |
+
ftfy
|
| 5 |
+
tensorboard
|
| 6 |
+
Jinja2
|
| 7 |
+
intel_extension_for_pytorch>=1.13
|
diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import intel_extension_for_pytorch as ipex
|
| 9 |
+
import numpy as np
|
| 10 |
+
import PIL
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
from accelerate import Accelerator
|
| 15 |
+
from accelerate.logging import get_logger
|
| 16 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 17 |
+
from huggingface_hub import create_repo, upload_folder
|
| 18 |
+
|
| 19 |
+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
| 20 |
+
from packaging import version
|
| 21 |
+
from PIL import Image
|
| 22 |
+
from torch.utils.data import Dataset
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
from tqdm.auto import tqdm
|
| 25 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 26 |
+
|
| 27 |
+
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
| 28 |
+
from diffusers.optimization import get_scheduler
|
| 29 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 30 |
+
from diffusers.utils import check_min_version
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 34 |
+
PIL_INTERPOLATION = {
|
| 35 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
| 36 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
| 37 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
| 38 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
| 39 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
| 40 |
+
}
|
| 41 |
+
else:
|
| 42 |
+
PIL_INTERPOLATION = {
|
| 43 |
+
"linear": PIL.Image.LINEAR,
|
| 44 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 45 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 46 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 47 |
+
"nearest": PIL.Image.NEAREST,
|
| 48 |
+
}
|
| 49 |
+
# ------------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 53 |
+
check_min_version("0.13.0.dev0")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = get_logger(__name__)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
| 60 |
+
logger.info("Saving embeddings")
|
| 61 |
+
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
| 62 |
+
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
| 63 |
+
torch.save(learned_embeds_dict, save_path)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def parse_args():
|
| 67 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--save_steps",
|
| 70 |
+
type=int,
|
| 71 |
+
default=500,
|
| 72 |
+
help="Save learned_embeds.bin every X updates steps.",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--only_save_embeds",
|
| 76 |
+
action="store_true",
|
| 77 |
+
default=False,
|
| 78 |
+
help="Save only the embeddings for the new concept.",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--pretrained_model_name_or_path",
|
| 82 |
+
type=str,
|
| 83 |
+
default=None,
|
| 84 |
+
required=True,
|
| 85 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--revision",
|
| 89 |
+
type=str,
|
| 90 |
+
default=None,
|
| 91 |
+
required=False,
|
| 92 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--tokenizer_name",
|
| 96 |
+
type=str,
|
| 97 |
+
default=None,
|
| 98 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--placeholder_token",
|
| 105 |
+
type=str,
|
| 106 |
+
default=None,
|
| 107 |
+
required=True,
|
| 108 |
+
help="A token to use as a placeholder for the concept.",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
| 114 |
+
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--output_dir",
|
| 117 |
+
type=str,
|
| 118 |
+
default="text-inversion-model",
|
| 119 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--resolution",
|
| 124 |
+
type=int,
|
| 125 |
+
default=512,
|
| 126 |
+
help=(
|
| 127 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 128 |
+
" resolution"
|
| 129 |
+
),
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--max_train_steps",
|
| 140 |
+
type=int,
|
| 141 |
+
default=5000,
|
| 142 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--gradient_accumulation_steps",
|
| 146 |
+
type=int,
|
| 147 |
+
default=1,
|
| 148 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--learning_rate",
|
| 152 |
+
type=float,
|
| 153 |
+
default=1e-4,
|
| 154 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--scale_lr",
|
| 158 |
+
action="store_true",
|
| 159 |
+
default=True,
|
| 160 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--lr_scheduler",
|
| 164 |
+
type=str,
|
| 165 |
+
default="constant",
|
| 166 |
+
help=(
|
| 167 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 168 |
+
' "constant", "constant_with_warmup"]'
|
| 169 |
+
),
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 175 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 176 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 177 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 178 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 179 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--hub_model_id",
|
| 182 |
+
type=str,
|
| 183 |
+
default=None,
|
| 184 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--logging_dir",
|
| 188 |
+
type=str,
|
| 189 |
+
default="logs",
|
| 190 |
+
help=(
|
| 191 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 192 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 193 |
+
),
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--mixed_precision",
|
| 197 |
+
type=str,
|
| 198 |
+
default="no",
|
| 199 |
+
choices=["no", "fp16", "bf16"],
|
| 200 |
+
help=(
|
| 201 |
+
"Whether to use mixed precision. Choose"
|
| 202 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
| 203 |
+
"and an Nvidia Ampere GPU."
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 207 |
+
|
| 208 |
+
args = parser.parse_args()
|
| 209 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 210 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 211 |
+
args.local_rank = env_local_rank
|
| 212 |
+
|
| 213 |
+
if args.train_data_dir is None:
|
| 214 |
+
raise ValueError("You must specify a train data directory.")
|
| 215 |
+
|
| 216 |
+
return args
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
imagenet_templates_small = [
|
| 220 |
+
"a photo of a {}",
|
| 221 |
+
"a rendering of a {}",
|
| 222 |
+
"a cropped photo of the {}",
|
| 223 |
+
"the photo of a {}",
|
| 224 |
+
"a photo of a clean {}",
|
| 225 |
+
"a photo of a dirty {}",
|
| 226 |
+
"a dark photo of the {}",
|
| 227 |
+
"a photo of my {}",
|
| 228 |
+
"a photo of the cool {}",
|
| 229 |
+
"a close-up photo of a {}",
|
| 230 |
+
"a bright photo of the {}",
|
| 231 |
+
"a cropped photo of a {}",
|
| 232 |
+
"a photo of the {}",
|
| 233 |
+
"a good photo of the {}",
|
| 234 |
+
"a photo of one {}",
|
| 235 |
+
"a close-up photo of the {}",
|
| 236 |
+
"a rendition of the {}",
|
| 237 |
+
"a photo of the clean {}",
|
| 238 |
+
"a rendition of a {}",
|
| 239 |
+
"a photo of a nice {}",
|
| 240 |
+
"a good photo of a {}",
|
| 241 |
+
"a photo of the nice {}",
|
| 242 |
+
"a photo of the small {}",
|
| 243 |
+
"a photo of the weird {}",
|
| 244 |
+
"a photo of the large {}",
|
| 245 |
+
"a photo of a cool {}",
|
| 246 |
+
"a photo of a small {}",
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
imagenet_style_templates_small = [
|
| 250 |
+
"a painting in the style of {}",
|
| 251 |
+
"a rendering in the style of {}",
|
| 252 |
+
"a cropped painting in the style of {}",
|
| 253 |
+
"the painting in the style of {}",
|
| 254 |
+
"a clean painting in the style of {}",
|
| 255 |
+
"a dirty painting in the style of {}",
|
| 256 |
+
"a dark painting in the style of {}",
|
| 257 |
+
"a picture in the style of {}",
|
| 258 |
+
"a cool painting in the style of {}",
|
| 259 |
+
"a close-up painting in the style of {}",
|
| 260 |
+
"a bright painting in the style of {}",
|
| 261 |
+
"a cropped painting in the style of {}",
|
| 262 |
+
"a good painting in the style of {}",
|
| 263 |
+
"a close-up painting in the style of {}",
|
| 264 |
+
"a rendition in the style of {}",
|
| 265 |
+
"a nice painting in the style of {}",
|
| 266 |
+
"a small painting in the style of {}",
|
| 267 |
+
"a weird painting in the style of {}",
|
| 268 |
+
"a large painting in the style of {}",
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class TextualInversionDataset(Dataset):
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
data_root,
|
| 276 |
+
tokenizer,
|
| 277 |
+
learnable_property="object", # [object, style]
|
| 278 |
+
size=512,
|
| 279 |
+
repeats=100,
|
| 280 |
+
interpolation="bicubic",
|
| 281 |
+
flip_p=0.5,
|
| 282 |
+
set="train",
|
| 283 |
+
placeholder_token="*",
|
| 284 |
+
center_crop=False,
|
| 285 |
+
):
|
| 286 |
+
self.data_root = data_root
|
| 287 |
+
self.tokenizer = tokenizer
|
| 288 |
+
self.learnable_property = learnable_property
|
| 289 |
+
self.size = size
|
| 290 |
+
self.placeholder_token = placeholder_token
|
| 291 |
+
self.center_crop = center_crop
|
| 292 |
+
self.flip_p = flip_p
|
| 293 |
+
|
| 294 |
+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
| 295 |
+
|
| 296 |
+
self.num_images = len(self.image_paths)
|
| 297 |
+
self._length = self.num_images
|
| 298 |
+
|
| 299 |
+
if set == "train":
|
| 300 |
+
self._length = self.num_images * repeats
|
| 301 |
+
|
| 302 |
+
self.interpolation = {
|
| 303 |
+
"linear": PIL_INTERPOLATION["linear"],
|
| 304 |
+
"bilinear": PIL_INTERPOLATION["bilinear"],
|
| 305 |
+
"bicubic": PIL_INTERPOLATION["bicubic"],
|
| 306 |
+
"lanczos": PIL_INTERPOLATION["lanczos"],
|
| 307 |
+
}[interpolation]
|
| 308 |
+
|
| 309 |
+
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
| 310 |
+
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
| 311 |
+
|
| 312 |
+
def __len__(self):
|
| 313 |
+
return self._length
|
| 314 |
+
|
| 315 |
+
def __getitem__(self, i):
|
| 316 |
+
example = {}
|
| 317 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
| 318 |
+
|
| 319 |
+
if not image.mode == "RGB":
|
| 320 |
+
image = image.convert("RGB")
|
| 321 |
+
|
| 322 |
+
placeholder_string = self.placeholder_token
|
| 323 |
+
text = random.choice(self.templates).format(placeholder_string)
|
| 324 |
+
|
| 325 |
+
example["input_ids"] = self.tokenizer(
|
| 326 |
+
text,
|
| 327 |
+
padding="max_length",
|
| 328 |
+
truncation=True,
|
| 329 |
+
max_length=self.tokenizer.model_max_length,
|
| 330 |
+
return_tensors="pt",
|
| 331 |
+
).input_ids[0]
|
| 332 |
+
|
| 333 |
+
# default to score-sde preprocessing
|
| 334 |
+
img = np.array(image).astype(np.uint8)
|
| 335 |
+
|
| 336 |
+
if self.center_crop:
|
| 337 |
+
crop = min(img.shape[0], img.shape[1])
|
| 338 |
+
(
|
| 339 |
+
h,
|
| 340 |
+
w,
|
| 341 |
+
) = (
|
| 342 |
+
img.shape[0],
|
| 343 |
+
img.shape[1],
|
| 344 |
+
)
|
| 345 |
+
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
| 346 |
+
|
| 347 |
+
image = Image.fromarray(img)
|
| 348 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
| 349 |
+
|
| 350 |
+
image = self.flip_transform(image)
|
| 351 |
+
image = np.array(image).astype(np.uint8)
|
| 352 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
| 353 |
+
|
| 354 |
+
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
| 355 |
+
return example
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def freeze_params(params):
|
| 359 |
+
for param in params:
|
| 360 |
+
param.requires_grad = False
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def main():
|
| 364 |
+
args = parse_args()
|
| 365 |
+
|
| 366 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
| 367 |
+
raise ValueError(
|
| 368 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 369 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
| 373 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 374 |
+
accelerator = Accelerator(
|
| 375 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 376 |
+
mixed_precision=args.mixed_precision,
|
| 377 |
+
log_with=args.report_to,
|
| 378 |
+
project_config=accelerator_project_config,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Disable AMP for MPS.
|
| 382 |
+
if torch.backends.mps.is_available():
|
| 383 |
+
accelerator.native_amp = False
|
| 384 |
+
|
| 385 |
+
# If passed along, set the training seed now.
|
| 386 |
+
if args.seed is not None:
|
| 387 |
+
set_seed(args.seed)
|
| 388 |
+
|
| 389 |
+
# Handle the repository creation
|
| 390 |
+
if accelerator.is_main_process:
|
| 391 |
+
if args.output_dir is not None:
|
| 392 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 393 |
+
|
| 394 |
+
if args.push_to_hub:
|
| 395 |
+
repo_id = create_repo(
|
| 396 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 397 |
+
).repo_id
|
| 398 |
+
|
| 399 |
+
# Load the tokenizer and add the placeholder token as a additional special token
|
| 400 |
+
if args.tokenizer_name:
|
| 401 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
| 402 |
+
elif args.pretrained_model_name_or_path:
|
| 403 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 404 |
+
|
| 405 |
+
# Add the placeholder token in tokenizer
|
| 406 |
+
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
| 407 |
+
if num_added_tokens == 0:
|
| 408 |
+
raise ValueError(
|
| 409 |
+
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
| 410 |
+
" `placeholder_token` that is not already in the tokenizer."
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Convert the initializer_token, placeholder_token to ids
|
| 414 |
+
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
| 415 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
| 416 |
+
if len(token_ids) > 1:
|
| 417 |
+
raise ValueError("The initializer token must be a single token.")
|
| 418 |
+
|
| 419 |
+
initializer_token_id = token_ids[0]
|
| 420 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
| 421 |
+
|
| 422 |
+
# Load models and create wrapper for stable diffusion
|
| 423 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 424 |
+
args.pretrained_model_name_or_path,
|
| 425 |
+
subfolder="text_encoder",
|
| 426 |
+
revision=args.revision,
|
| 427 |
+
)
|
| 428 |
+
vae = AutoencoderKL.from_pretrained(
|
| 429 |
+
args.pretrained_model_name_or_path,
|
| 430 |
+
subfolder="vae",
|
| 431 |
+
revision=args.revision,
|
| 432 |
+
)
|
| 433 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 434 |
+
args.pretrained_model_name_or_path,
|
| 435 |
+
subfolder="unet",
|
| 436 |
+
revision=args.revision,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
| 440 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 441 |
+
|
| 442 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
| 443 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
| 444 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
| 445 |
+
|
| 446 |
+
# Freeze vae and unet
|
| 447 |
+
freeze_params(vae.parameters())
|
| 448 |
+
freeze_params(unet.parameters())
|
| 449 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
| 450 |
+
params_to_freeze = itertools.chain(
|
| 451 |
+
text_encoder.text_model.encoder.parameters(),
|
| 452 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
| 453 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
| 454 |
+
)
|
| 455 |
+
freeze_params(params_to_freeze)
|
| 456 |
+
|
| 457 |
+
if args.scale_lr:
|
| 458 |
+
args.learning_rate = (
|
| 459 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Initialize the optimizer
|
| 463 |
+
optimizer = torch.optim.AdamW(
|
| 464 |
+
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
| 465 |
+
lr=args.learning_rate,
|
| 466 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 467 |
+
weight_decay=args.adam_weight_decay,
|
| 468 |
+
eps=args.adam_epsilon,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 472 |
+
|
| 473 |
+
train_dataset = TextualInversionDataset(
|
| 474 |
+
data_root=args.train_data_dir,
|
| 475 |
+
tokenizer=tokenizer,
|
| 476 |
+
size=args.resolution,
|
| 477 |
+
placeholder_token=args.placeholder_token,
|
| 478 |
+
repeats=args.repeats,
|
| 479 |
+
learnable_property=args.learnable_property,
|
| 480 |
+
center_crop=args.center_crop,
|
| 481 |
+
set="train",
|
| 482 |
+
)
|
| 483 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
| 484 |
+
|
| 485 |
+
# Scheduler and math around the number of training steps.
|
| 486 |
+
overrode_max_train_steps = False
|
| 487 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 488 |
+
if args.max_train_steps is None:
|
| 489 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 490 |
+
overrode_max_train_steps = True
|
| 491 |
+
|
| 492 |
+
lr_scheduler = get_scheduler(
|
| 493 |
+
args.lr_scheduler,
|
| 494 |
+
optimizer=optimizer,
|
| 495 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 496 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 500 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
# Move vae and unet to device
|
| 504 |
+
vae.to(accelerator.device)
|
| 505 |
+
unet.to(accelerator.device)
|
| 506 |
+
|
| 507 |
+
# Keep vae and unet in eval model as we don't train these
|
| 508 |
+
vae.eval()
|
| 509 |
+
unet.eval()
|
| 510 |
+
|
| 511 |
+
unet = ipex.optimize(unet, dtype=torch.bfloat16, inplace=True)
|
| 512 |
+
vae = ipex.optimize(vae, dtype=torch.bfloat16, inplace=True)
|
| 513 |
+
|
| 514 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 515 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 516 |
+
if overrode_max_train_steps:
|
| 517 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 518 |
+
# Afterwards we recalculate our number of training epochs
|
| 519 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 520 |
+
|
| 521 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 522 |
+
# The trackers initializes automatically on the main process.
|
| 523 |
+
if accelerator.is_main_process:
|
| 524 |
+
accelerator.init_trackers("textual_inversion", config=vars(args))
|
| 525 |
+
|
| 526 |
+
# Train!
|
| 527 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 528 |
+
|
| 529 |
+
logger.info("***** Running training *****")
|
| 530 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 531 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 532 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 533 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 534 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 535 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 536 |
+
# Only show the progress bar once on each machine.
|
| 537 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
| 538 |
+
progress_bar.set_description("Steps")
|
| 539 |
+
global_step = 0
|
| 540 |
+
|
| 541 |
+
text_encoder.train()
|
| 542 |
+
text_encoder, optimizer = ipex.optimize(text_encoder, optimizer=optimizer, dtype=torch.bfloat16)
|
| 543 |
+
|
| 544 |
+
for epoch in range(args.num_train_epochs):
|
| 545 |
+
for step, batch in enumerate(train_dataloader):
|
| 546 |
+
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
| 547 |
+
with accelerator.accumulate(text_encoder):
|
| 548 |
+
# Convert images to latent space
|
| 549 |
+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
| 550 |
+
latents = latents * vae.config.scaling_factor
|
| 551 |
+
|
| 552 |
+
# Sample noise that we'll add to the latents
|
| 553 |
+
noise = torch.randn(latents.shape).to(latents.device)
|
| 554 |
+
bsz = latents.shape[0]
|
| 555 |
+
# Sample a random timestep for each image
|
| 556 |
+
timesteps = torch.randint(
|
| 557 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
| 558 |
+
).long()
|
| 559 |
+
|
| 560 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 561 |
+
# (this is the forward diffusion process)
|
| 562 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 563 |
+
|
| 564 |
+
# Get the text embedding for conditioning
|
| 565 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 566 |
+
|
| 567 |
+
# Predict the noise residual
|
| 568 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 569 |
+
|
| 570 |
+
# Get the target for loss depending on the prediction type
|
| 571 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 572 |
+
target = noise
|
| 573 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 574 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 575 |
+
else:
|
| 576 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 577 |
+
|
| 578 |
+
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
| 579 |
+
accelerator.backward(loss)
|
| 580 |
+
|
| 581 |
+
# Zero out the gradients for all token embeddings except the newly added
|
| 582 |
+
# embeddings for the concept, as we only want to optimize the concept embeddings
|
| 583 |
+
if accelerator.num_processes > 1:
|
| 584 |
+
grads = text_encoder.module.get_input_embeddings().weight.grad
|
| 585 |
+
else:
|
| 586 |
+
grads = text_encoder.get_input_embeddings().weight.grad
|
| 587 |
+
# Get the index for tokens that we want to zero the grads for
|
| 588 |
+
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
| 589 |
+
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
| 590 |
+
|
| 591 |
+
optimizer.step()
|
| 592 |
+
lr_scheduler.step()
|
| 593 |
+
optimizer.zero_grad()
|
| 594 |
+
|
| 595 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 596 |
+
if accelerator.sync_gradients:
|
| 597 |
+
progress_bar.update(1)
|
| 598 |
+
global_step += 1
|
| 599 |
+
if global_step % args.save_steps == 0:
|
| 600 |
+
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
| 601 |
+
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
| 602 |
+
|
| 603 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 604 |
+
progress_bar.set_postfix(**logs)
|
| 605 |
+
accelerator.log(logs, step=global_step)
|
| 606 |
+
|
| 607 |
+
if global_step >= args.max_train_steps:
|
| 608 |
+
break
|
| 609 |
+
|
| 610 |
+
accelerator.wait_for_everyone()
|
| 611 |
+
|
| 612 |
+
# Create the pipeline using using the trained modules and save it.
|
| 613 |
+
if accelerator.is_main_process:
|
| 614 |
+
if args.push_to_hub and args.only_save_embeds:
|
| 615 |
+
logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
|
| 616 |
+
save_full_model = True
|
| 617 |
+
else:
|
| 618 |
+
save_full_model = not args.only_save_embeds
|
| 619 |
+
if save_full_model:
|
| 620 |
+
pipeline = StableDiffusionPipeline(
|
| 621 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 622 |
+
vae=vae,
|
| 623 |
+
unet=unet,
|
| 624 |
+
tokenizer=tokenizer,
|
| 625 |
+
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
|
| 626 |
+
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
| 627 |
+
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
|
| 628 |
+
)
|
| 629 |
+
pipeline.save_pretrained(args.output_dir)
|
| 630 |
+
# Save the newly trained embeddings
|
| 631 |
+
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
| 632 |
+
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
| 633 |
+
|
| 634 |
+
if args.push_to_hub:
|
| 635 |
+
upload_folder(
|
| 636 |
+
repo_id=repo_id,
|
| 637 |
+
folder_path=args.output_dir,
|
| 638 |
+
commit_message="End of training",
|
| 639 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
accelerator.end_training()
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
if __name__ == "__main__":
|
| 646 |
+
main()
|
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Distillation for quantization on Textual Inversion models to personalize text2image
|
| 2 |
+
|
| 3 |
+
[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
|
| 4 |
+
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
| 5 |
+
We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
|
| 6 |
+
|
| 7 |
+
## Installing the dependencies
|
| 8 |
+
|
| 9 |
+
Before running the scripts, make sure to install the library's training dependencies:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Prepare Datasets
|
| 16 |
+
|
| 17 |
+
One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
|
| 18 |
+
|
| 19 |
+
<a href="https://huggingface.co/sd-concepts-library/dicoo2/blob/main/concept_images/1.jpeg">
|
| 20 |
+
<img src="https://huggingface.co/sd-concepts-library/dicoo2/resolve/main/concept_images/1.jpeg" width = "300" height="300">
|
| 21 |
+
</a>
|
| 22 |
+
|
| 23 |
+
## Get a FP32 Textual Inversion model
|
| 24 |
+
|
| 25 |
+
Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
| 29 |
+
export DATA_DIR="./dicoo"
|
| 30 |
+
|
| 31 |
+
accelerate launch textual_inversion.py \
|
| 32 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 33 |
+
--train_data_dir=$DATA_DIR \
|
| 34 |
+
--learnable_property="object" \
|
| 35 |
+
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
| 36 |
+
--resolution=512 \
|
| 37 |
+
--train_batch_size=1 \
|
| 38 |
+
--gradient_accumulation_steps=4 \
|
| 39 |
+
--max_train_steps=3000 \
|
| 40 |
+
--learning_rate=5.0e-04 --scale_lr \
|
| 41 |
+
--lr_scheduler="constant" \
|
| 42 |
+
--lr_warmup_steps=0 \
|
| 43 |
+
--output_dir="dicoo_model"
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Do distillation for quantization
|
| 47 |
+
|
| 48 |
+
Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
|
| 49 |
+
|
| 50 |
+
Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
export FP32_MODEL_NAME="./dicoo_model"
|
| 54 |
+
export DATA_DIR="./dicoo"
|
| 55 |
+
|
| 56 |
+
accelerate launch textual_inversion.py \
|
| 57 |
+
--pretrained_model_name_or_path=$FP32_MODEL_NAME \
|
| 58 |
+
--train_data_dir=$DATA_DIR \
|
| 59 |
+
--use_ema --learnable_property="object" \
|
| 60 |
+
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
| 61 |
+
--resolution=512 \
|
| 62 |
+
--train_batch_size=1 \
|
| 63 |
+
--gradient_accumulation_steps=4 \
|
| 64 |
+
--max_train_steps=300 \
|
| 65 |
+
--learning_rate=5.0e-04 --max_grad_norm=3 \
|
| 66 |
+
--lr_scheduler="constant" \
|
| 67 |
+
--lr_warmup_steps=0 \
|
| 68 |
+
--output_dir="int8_model" \
|
| 69 |
+
--do_quantization --do_distillation --verify_loading
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
|
| 73 |
+
|
| 74 |
+
## Inference
|
| 75 |
+
|
| 76 |
+
Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
export INT8_MODEL_NAME="./int8_model"
|
| 80 |
+
|
| 81 |
+
python text2images.py \
|
| 82 |
+
--pretrained_model_name_or_path=$INT8_MODEL_NAME \
|
| 83 |
+
--caption "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brightly buildings." \
|
| 84 |
+
--images_num 4
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
|
| 88 |
+
|
| 89 |
+
<p float="left">
|
| 90 |
+
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/FP32.png" width = "300" height = "300" alt="FP32" align=center />
|
| 91 |
+
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/INT8.png" width = "300" height = "300" alt="INT8" align=center />
|
| 92 |
+
</p>
|
| 93 |
+
|
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
torchvision
|
| 3 |
+
transformers>=4.25.0
|
| 4 |
+
ftfy
|
| 5 |
+
tensorboard
|
| 6 |
+
modelcards
|
| 7 |
+
neural-compressor
|
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from neural_compressor.utils.pytorch import load
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 9 |
+
|
| 10 |
+
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"-m",
|
| 17 |
+
"--pretrained_model_name_or_path",
|
| 18 |
+
type=str,
|
| 19 |
+
default=None,
|
| 20 |
+
required=True,
|
| 21 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"-c",
|
| 25 |
+
"--caption",
|
| 26 |
+
type=str,
|
| 27 |
+
default="robotic cat with wings",
|
| 28 |
+
help="Text used to generate images.",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"-n",
|
| 32 |
+
"--images_num",
|
| 33 |
+
type=int,
|
| 34 |
+
default=4,
|
| 35 |
+
help="How much images to generate.",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"-s",
|
| 39 |
+
"--seed",
|
| 40 |
+
type=int,
|
| 41 |
+
default=42,
|
| 42 |
+
help="Seed for random process.",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"-ci",
|
| 46 |
+
"--cuda_id",
|
| 47 |
+
type=int,
|
| 48 |
+
default=0,
|
| 49 |
+
help="cuda_id.",
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
return args
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def image_grid(imgs, rows, cols):
|
| 56 |
+
if not len(imgs) == rows * cols:
|
| 57 |
+
raise ValueError("The specified number of rows and columns are not correct.")
|
| 58 |
+
|
| 59 |
+
w, h = imgs[0].size
|
| 60 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 61 |
+
grid_w, grid_h = grid.size
|
| 62 |
+
|
| 63 |
+
for i, img in enumerate(imgs):
|
| 64 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 65 |
+
return grid
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def generate_images(
|
| 69 |
+
pipeline,
|
| 70 |
+
prompt="robotic cat with wings",
|
| 71 |
+
guidance_scale=7.5,
|
| 72 |
+
num_inference_steps=50,
|
| 73 |
+
num_images_per_prompt=1,
|
| 74 |
+
seed=42,
|
| 75 |
+
):
|
| 76 |
+
generator = torch.Generator(pipeline.device).manual_seed(seed)
|
| 77 |
+
images = pipeline(
|
| 78 |
+
prompt,
|
| 79 |
+
guidance_scale=guidance_scale,
|
| 80 |
+
num_inference_steps=num_inference_steps,
|
| 81 |
+
generator=generator,
|
| 82 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 83 |
+
).images
|
| 84 |
+
_rows = int(math.sqrt(num_images_per_prompt))
|
| 85 |
+
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
|
| 86 |
+
return grid, images
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
args = parse_args()
|
| 90 |
+
# Load models and create wrapper for stable diffusion
|
| 91 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 92 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 93 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 94 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 95 |
+
|
| 96 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 97 |
+
args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
|
| 98 |
+
)
|
| 99 |
+
pipeline.safety_checker = lambda images, clip_input: (images, False)
|
| 100 |
+
if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
|
| 101 |
+
unet = load(args.pretrained_model_name_or_path, model=unet)
|
| 102 |
+
unet.eval()
|
| 103 |
+
setattr(pipeline, "unet", unet)
|
| 104 |
+
else:
|
| 105 |
+
unet = unet.to(torch.device("cuda", args.cuda_id))
|
| 106 |
+
pipeline = pipeline.to(unet.device)
|
| 107 |
+
grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
|
| 108 |
+
grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
|
| 109 |
+
dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
|
| 110 |
+
os.makedirs(dirname, exist_ok=True)
|
| 111 |
+
for idx, image in enumerate(images):
|
| 112 |
+
image.save(os.path.join(dirname, "{}.png".format(idx + 1)))
|
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
ADDED
|
@@ -0,0 +1,996 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import PIL
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
from accelerate import Accelerator
|
| 15 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 16 |
+
from huggingface_hub import create_repo, upload_folder
|
| 17 |
+
from neural_compressor.utils import logger
|
| 18 |
+
from packaging import version
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torch.utils.data import Dataset
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
from tqdm.auto import tqdm
|
| 23 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 24 |
+
|
| 25 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
| 26 |
+
from diffusers.optimization import get_scheduler
|
| 27 |
+
from diffusers.utils import make_image_grid
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 31 |
+
PIL_INTERPOLATION = {
|
| 32 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
| 33 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
| 34 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
| 35 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
| 36 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
| 37 |
+
}
|
| 38 |
+
else:
|
| 39 |
+
PIL_INTERPOLATION = {
|
| 40 |
+
"linear": PIL.Image.LINEAR,
|
| 41 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 42 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 43 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 44 |
+
"nearest": PIL.Image.NEAREST,
|
| 45 |
+
}
|
| 46 |
+
# ------------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
| 50 |
+
logger.info("Saving embeddings")
|
| 51 |
+
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
| 52 |
+
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
| 53 |
+
torch.save(learned_embeds_dict, save_path)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def parse_args():
|
| 57 |
+
parser = argparse.ArgumentParser(description="Example of distillation for quantization on Textual Inversion.")
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--save_steps",
|
| 60 |
+
type=int,
|
| 61 |
+
default=500,
|
| 62 |
+
help="Save learned_embeds.bin every X updates steps.",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--pretrained_model_name_or_path",
|
| 66 |
+
type=str,
|
| 67 |
+
default=None,
|
| 68 |
+
required=True,
|
| 69 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--revision",
|
| 73 |
+
type=str,
|
| 74 |
+
default=None,
|
| 75 |
+
required=False,
|
| 76 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--tokenizer_name",
|
| 80 |
+
type=str,
|
| 81 |
+
default=None,
|
| 82 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--placeholder_token",
|
| 89 |
+
type=str,
|
| 90 |
+
default=None,
|
| 91 |
+
required=True,
|
| 92 |
+
help="A token to use as a placeholder for the concept.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
| 98 |
+
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--output_dir",
|
| 101 |
+
type=str,
|
| 102 |
+
default="text-inversion-model",
|
| 103 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--cache_dir",
|
| 107 |
+
type=str,
|
| 108 |
+
default=None,
|
| 109 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--resolution",
|
| 114 |
+
type=int,
|
| 115 |
+
default=512,
|
| 116 |
+
help=(
|
| 117 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 118 |
+
" resolution"
|
| 119 |
+
),
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--max_train_steps",
|
| 130 |
+
type=int,
|
| 131 |
+
default=5000,
|
| 132 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--gradient_accumulation_steps",
|
| 136 |
+
type=int,
|
| 137 |
+
default=1,
|
| 138 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--learning_rate",
|
| 142 |
+
type=float,
|
| 143 |
+
default=1e-4,
|
| 144 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--scale_lr",
|
| 148 |
+
action="store_true",
|
| 149 |
+
default=False,
|
| 150 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--lr_scheduler",
|
| 154 |
+
type=str,
|
| 155 |
+
default="constant",
|
| 156 |
+
help=(
|
| 157 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 158 |
+
' "constant", "constant_with_warmup"]'
|
| 159 |
+
),
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 165 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 166 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 167 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 168 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 169 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--hub_model_id",
|
| 172 |
+
type=str,
|
| 173 |
+
default=None,
|
| 174 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--logging_dir",
|
| 178 |
+
type=str,
|
| 179 |
+
default="logs",
|
| 180 |
+
help=(
|
| 181 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 182 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 183 |
+
),
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--mixed_precision",
|
| 187 |
+
type=str,
|
| 188 |
+
default="no",
|
| 189 |
+
choices=["no", "fp16", "bf16"],
|
| 190 |
+
help=(
|
| 191 |
+
"Whether to use mixed precision. Choose"
|
| 192 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
| 193 |
+
"and an Nvidia Ampere GPU."
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 197 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 198 |
+
parser.add_argument("--do_quantization", action="store_true", help="Whether or not to do quantization.")
|
| 199 |
+
parser.add_argument("--do_distillation", action="store_true", help="Whether or not to do distillation.")
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--verify_loading", action="store_true", help="Whether or not to verify the loading of the quantized model."
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 204 |
+
|
| 205 |
+
args = parser.parse_args()
|
| 206 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 207 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 208 |
+
args.local_rank = env_local_rank
|
| 209 |
+
|
| 210 |
+
if args.train_data_dir is None:
|
| 211 |
+
raise ValueError("You must specify a train data directory.")
|
| 212 |
+
|
| 213 |
+
return args
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
imagenet_templates_small = [
|
| 217 |
+
"a photo of a {}",
|
| 218 |
+
"a rendering of a {}",
|
| 219 |
+
"a cropped photo of the {}",
|
| 220 |
+
"the photo of a {}",
|
| 221 |
+
"a photo of a clean {}",
|
| 222 |
+
"a photo of a dirty {}",
|
| 223 |
+
"a dark photo of the {}",
|
| 224 |
+
"a photo of my {}",
|
| 225 |
+
"a photo of the cool {}",
|
| 226 |
+
"a close-up photo of a {}",
|
| 227 |
+
"a bright photo of the {}",
|
| 228 |
+
"a cropped photo of a {}",
|
| 229 |
+
"a photo of the {}",
|
| 230 |
+
"a good photo of the {}",
|
| 231 |
+
"a photo of one {}",
|
| 232 |
+
"a close-up photo of the {}",
|
| 233 |
+
"a rendition of the {}",
|
| 234 |
+
"a photo of the clean {}",
|
| 235 |
+
"a rendition of a {}",
|
| 236 |
+
"a photo of a nice {}",
|
| 237 |
+
"a good photo of a {}",
|
| 238 |
+
"a photo of the nice {}",
|
| 239 |
+
"a photo of the small {}",
|
| 240 |
+
"a photo of the weird {}",
|
| 241 |
+
"a photo of the large {}",
|
| 242 |
+
"a photo of a cool {}",
|
| 243 |
+
"a photo of a small {}",
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
imagenet_style_templates_small = [
|
| 247 |
+
"a painting in the style of {}",
|
| 248 |
+
"a rendering in the style of {}",
|
| 249 |
+
"a cropped painting in the style of {}",
|
| 250 |
+
"the painting in the style of {}",
|
| 251 |
+
"a clean painting in the style of {}",
|
| 252 |
+
"a dirty painting in the style of {}",
|
| 253 |
+
"a dark painting in the style of {}",
|
| 254 |
+
"a picture in the style of {}",
|
| 255 |
+
"a cool painting in the style of {}",
|
| 256 |
+
"a close-up painting in the style of {}",
|
| 257 |
+
"a bright painting in the style of {}",
|
| 258 |
+
"a cropped painting in the style of {}",
|
| 259 |
+
"a good painting in the style of {}",
|
| 260 |
+
"a close-up painting in the style of {}",
|
| 261 |
+
"a rendition in the style of {}",
|
| 262 |
+
"a nice painting in the style of {}",
|
| 263 |
+
"a small painting in the style of {}",
|
| 264 |
+
"a weird painting in the style of {}",
|
| 265 |
+
"a large painting in the style of {}",
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
| 270 |
+
class EMAModel:
|
| 271 |
+
"""
|
| 272 |
+
Exponential Moving Average of models weights
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
|
| 276 |
+
parameters = list(parameters)
|
| 277 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 278 |
+
|
| 279 |
+
self.decay = decay
|
| 280 |
+
self.optimization_step = 0
|
| 281 |
+
|
| 282 |
+
def get_decay(self, optimization_step):
|
| 283 |
+
"""
|
| 284 |
+
Compute the decay factor for the exponential moving average.
|
| 285 |
+
"""
|
| 286 |
+
value = (1 + optimization_step) / (10 + optimization_step)
|
| 287 |
+
return 1 - min(self.decay, value)
|
| 288 |
+
|
| 289 |
+
@torch.no_grad()
|
| 290 |
+
def step(self, parameters):
|
| 291 |
+
parameters = list(parameters)
|
| 292 |
+
|
| 293 |
+
self.optimization_step += 1
|
| 294 |
+
self.decay = self.get_decay(self.optimization_step)
|
| 295 |
+
|
| 296 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 297 |
+
if param.requires_grad:
|
| 298 |
+
tmp = self.decay * (s_param - param)
|
| 299 |
+
s_param.sub_(tmp)
|
| 300 |
+
else:
|
| 301 |
+
s_param.copy_(param)
|
| 302 |
+
|
| 303 |
+
torch.cuda.empty_cache()
|
| 304 |
+
|
| 305 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 306 |
+
"""
|
| 307 |
+
Copy current averaged parameters into given collection of parameters.
|
| 308 |
+
Args:
|
| 309 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 310 |
+
updated with the stored moving averages. If `None`, the
|
| 311 |
+
parameters with which this `ExponentialMovingAverage` was
|
| 312 |
+
initialized will be used.
|
| 313 |
+
"""
|
| 314 |
+
parameters = list(parameters)
|
| 315 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 316 |
+
param.data.copy_(s_param.data)
|
| 317 |
+
|
| 318 |
+
def to(self, device=None, dtype=None) -> None:
|
| 319 |
+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
| 320 |
+
Args:
|
| 321 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 322 |
+
"""
|
| 323 |
+
# .to() on the tensors handles None correctly
|
| 324 |
+
self.shadow_params = [
|
| 325 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
| 326 |
+
for p in self.shadow_params
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class TextualInversionDataset(Dataset):
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
data_root,
|
| 334 |
+
tokenizer,
|
| 335 |
+
learnable_property="object", # [object, style]
|
| 336 |
+
size=512,
|
| 337 |
+
repeats=100,
|
| 338 |
+
interpolation="bicubic",
|
| 339 |
+
flip_p=0.5,
|
| 340 |
+
set="train",
|
| 341 |
+
placeholder_token="*",
|
| 342 |
+
center_crop=False,
|
| 343 |
+
):
|
| 344 |
+
self.data_root = data_root
|
| 345 |
+
self.tokenizer = tokenizer
|
| 346 |
+
self.learnable_property = learnable_property
|
| 347 |
+
self.size = size
|
| 348 |
+
self.placeholder_token = placeholder_token
|
| 349 |
+
self.center_crop = center_crop
|
| 350 |
+
self.flip_p = flip_p
|
| 351 |
+
|
| 352 |
+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
| 353 |
+
|
| 354 |
+
self.num_images = len(self.image_paths)
|
| 355 |
+
self._length = self.num_images
|
| 356 |
+
|
| 357 |
+
if set == "train":
|
| 358 |
+
self._length = self.num_images * repeats
|
| 359 |
+
|
| 360 |
+
self.interpolation = {
|
| 361 |
+
"linear": PIL_INTERPOLATION["linear"],
|
| 362 |
+
"bilinear": PIL_INTERPOLATION["bilinear"],
|
| 363 |
+
"bicubic": PIL_INTERPOLATION["bicubic"],
|
| 364 |
+
"lanczos": PIL_INTERPOLATION["lanczos"],
|
| 365 |
+
}[interpolation]
|
| 366 |
+
|
| 367 |
+
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
| 368 |
+
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
| 369 |
+
|
| 370 |
+
def __len__(self):
|
| 371 |
+
return self._length
|
| 372 |
+
|
| 373 |
+
def __getitem__(self, i):
|
| 374 |
+
example = {}
|
| 375 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
| 376 |
+
|
| 377 |
+
if not image.mode == "RGB":
|
| 378 |
+
image = image.convert("RGB")
|
| 379 |
+
|
| 380 |
+
placeholder_string = self.placeholder_token
|
| 381 |
+
text = random.choice(self.templates).format(placeholder_string)
|
| 382 |
+
|
| 383 |
+
example["input_ids"] = self.tokenizer(
|
| 384 |
+
text,
|
| 385 |
+
padding="max_length",
|
| 386 |
+
truncation=True,
|
| 387 |
+
max_length=self.tokenizer.model_max_length,
|
| 388 |
+
return_tensors="pt",
|
| 389 |
+
).input_ids[0]
|
| 390 |
+
|
| 391 |
+
# default to score-sde preprocessing
|
| 392 |
+
img = np.array(image).astype(np.uint8)
|
| 393 |
+
|
| 394 |
+
if self.center_crop:
|
| 395 |
+
crop = min(img.shape[0], img.shape[1])
|
| 396 |
+
(
|
| 397 |
+
h,
|
| 398 |
+
w,
|
| 399 |
+
) = (
|
| 400 |
+
img.shape[0],
|
| 401 |
+
img.shape[1],
|
| 402 |
+
)
|
| 403 |
+
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
| 404 |
+
|
| 405 |
+
image = Image.fromarray(img)
|
| 406 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
| 407 |
+
|
| 408 |
+
image = self.flip_transform(image)
|
| 409 |
+
image = np.array(image).astype(np.uint8)
|
| 410 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
| 411 |
+
|
| 412 |
+
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
| 413 |
+
return example
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def freeze_params(params):
|
| 417 |
+
for param in params:
|
| 418 |
+
param.requires_grad = False
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):
|
| 422 |
+
generator = torch.Generator(pipeline.device).manual_seed(seed)
|
| 423 |
+
images = pipeline(
|
| 424 |
+
prompt,
|
| 425 |
+
guidance_scale=guidance_scale,
|
| 426 |
+
num_inference_steps=num_inference_steps,
|
| 427 |
+
generator=generator,
|
| 428 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 429 |
+
).images
|
| 430 |
+
_rows = int(math.sqrt(num_images_per_prompt))
|
| 431 |
+
grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
|
| 432 |
+
return grid
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def main():
|
| 436 |
+
args = parse_args()
|
| 437 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
| 438 |
+
|
| 439 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 440 |
+
|
| 441 |
+
accelerator = Accelerator(
|
| 442 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 443 |
+
mixed_precision=args.mixed_precision,
|
| 444 |
+
log_with="tensorboard",
|
| 445 |
+
project_config=accelerator_project_config,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# If passed along, set the training seed now.
|
| 449 |
+
if args.seed is not None:
|
| 450 |
+
set_seed(args.seed)
|
| 451 |
+
|
| 452 |
+
# Handle the repository creation
|
| 453 |
+
if accelerator.is_main_process:
|
| 454 |
+
if args.output_dir is not None:
|
| 455 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 456 |
+
|
| 457 |
+
if args.push_to_hub:
|
| 458 |
+
repo_id = create_repo(
|
| 459 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 460 |
+
).repo_id
|
| 461 |
+
|
| 462 |
+
# Load the tokenizer and add the placeholder token as a additional special token
|
| 463 |
+
if args.tokenizer_name:
|
| 464 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
| 465 |
+
elif args.pretrained_model_name_or_path:
|
| 466 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 467 |
+
|
| 468 |
+
# Load models and create wrapper for stable diffusion
|
| 469 |
+
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 470 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 471 |
+
args.pretrained_model_name_or_path,
|
| 472 |
+
subfolder="text_encoder",
|
| 473 |
+
revision=args.revision,
|
| 474 |
+
)
|
| 475 |
+
vae = AutoencoderKL.from_pretrained(
|
| 476 |
+
args.pretrained_model_name_or_path,
|
| 477 |
+
subfolder="vae",
|
| 478 |
+
revision=args.revision,
|
| 479 |
+
)
|
| 480 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 481 |
+
args.pretrained_model_name_or_path,
|
| 482 |
+
subfolder="unet",
|
| 483 |
+
revision=args.revision,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
train_unet = False
|
| 487 |
+
# Freeze vae and unet
|
| 488 |
+
freeze_params(vae.parameters())
|
| 489 |
+
if not args.do_quantization and not args.do_distillation:
|
| 490 |
+
# Add the placeholder token in tokenizer
|
| 491 |
+
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
| 492 |
+
if num_added_tokens == 0:
|
| 493 |
+
raise ValueError(
|
| 494 |
+
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
| 495 |
+
" `placeholder_token` that is not already in the tokenizer."
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Convert the initializer_token, placeholder_token to ids
|
| 499 |
+
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
| 500 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
| 501 |
+
if len(token_ids) > 1:
|
| 502 |
+
raise ValueError("The initializer token must be a single token.")
|
| 503 |
+
|
| 504 |
+
initializer_token_id = token_ids[0]
|
| 505 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
| 506 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
| 507 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 508 |
+
|
| 509 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
| 510 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
| 511 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
| 512 |
+
|
| 513 |
+
freeze_params(unet.parameters())
|
| 514 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
| 515 |
+
params_to_freeze = itertools.chain(
|
| 516 |
+
text_encoder.text_model.encoder.parameters(),
|
| 517 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
| 518 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
| 519 |
+
)
|
| 520 |
+
freeze_params(params_to_freeze)
|
| 521 |
+
else:
|
| 522 |
+
train_unet = True
|
| 523 |
+
freeze_params(text_encoder.parameters())
|
| 524 |
+
|
| 525 |
+
if args.scale_lr:
|
| 526 |
+
args.learning_rate = (
|
| 527 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# Initialize the optimizer
|
| 531 |
+
optimizer = torch.optim.AdamW(
|
| 532 |
+
# only optimize the unet or embeddings of text_encoder
|
| 533 |
+
unet.parameters() if train_unet else text_encoder.get_input_embeddings().parameters(),
|
| 534 |
+
lr=args.learning_rate,
|
| 535 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 536 |
+
weight_decay=args.adam_weight_decay,
|
| 537 |
+
eps=args.adam_epsilon,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
train_dataset = TextualInversionDataset(
|
| 541 |
+
data_root=args.train_data_dir,
|
| 542 |
+
tokenizer=tokenizer,
|
| 543 |
+
size=args.resolution,
|
| 544 |
+
placeholder_token=args.placeholder_token,
|
| 545 |
+
repeats=args.repeats,
|
| 546 |
+
learnable_property=args.learnable_property,
|
| 547 |
+
center_crop=args.center_crop,
|
| 548 |
+
set="train",
|
| 549 |
+
)
|
| 550 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
| 551 |
+
|
| 552 |
+
# Scheduler and math around the number of training steps.
|
| 553 |
+
overrode_max_train_steps = False
|
| 554 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 555 |
+
if args.max_train_steps is None:
|
| 556 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 557 |
+
overrode_max_train_steps = True
|
| 558 |
+
|
| 559 |
+
lr_scheduler = get_scheduler(
|
| 560 |
+
args.lr_scheduler,
|
| 561 |
+
optimizer=optimizer,
|
| 562 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 563 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
if not train_unet:
|
| 567 |
+
text_encoder = accelerator.prepare(text_encoder)
|
| 568 |
+
unet.to(accelerator.device)
|
| 569 |
+
unet.eval()
|
| 570 |
+
else:
|
| 571 |
+
unet = accelerator.prepare(unet)
|
| 572 |
+
text_encoder.to(accelerator.device)
|
| 573 |
+
text_encoder.eval()
|
| 574 |
+
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
| 575 |
+
|
| 576 |
+
# Move vae to device
|
| 577 |
+
vae.to(accelerator.device)
|
| 578 |
+
|
| 579 |
+
# Keep vae in eval model as we don't train these
|
| 580 |
+
vae.eval()
|
| 581 |
+
|
| 582 |
+
compression_manager = None
|
| 583 |
+
|
| 584 |
+
def train_func(model):
|
| 585 |
+
if train_unet:
|
| 586 |
+
unet_ = model
|
| 587 |
+
text_encoder_ = text_encoder
|
| 588 |
+
else:
|
| 589 |
+
unet_ = unet
|
| 590 |
+
text_encoder_ = model
|
| 591 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 592 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 593 |
+
if overrode_max_train_steps:
|
| 594 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 595 |
+
# Afterwards we recalculate our number of training epochs
|
| 596 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 597 |
+
|
| 598 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 599 |
+
# The trackers initializes automatically on the main process.
|
| 600 |
+
if accelerator.is_main_process:
|
| 601 |
+
accelerator.init_trackers("textual_inversion", config=vars(args))
|
| 602 |
+
|
| 603 |
+
# Train!
|
| 604 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 605 |
+
|
| 606 |
+
logger.info("***** Running training *****")
|
| 607 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 608 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 609 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 610 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 611 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 612 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 613 |
+
# Only show the progress bar once on each machine.
|
| 614 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
| 615 |
+
progress_bar.set_description("Steps")
|
| 616 |
+
global_step = 0
|
| 617 |
+
|
| 618 |
+
if train_unet and args.use_ema:
|
| 619 |
+
ema_unet = EMAModel(unet_.parameters())
|
| 620 |
+
|
| 621 |
+
for epoch in range(args.num_train_epochs):
|
| 622 |
+
model.train()
|
| 623 |
+
train_loss = 0.0
|
| 624 |
+
for step, batch in enumerate(train_dataloader):
|
| 625 |
+
with accelerator.accumulate(model):
|
| 626 |
+
# Convert images to latent space
|
| 627 |
+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
| 628 |
+
latents = latents * 0.18215
|
| 629 |
+
|
| 630 |
+
# Sample noise that we'll add to the latents
|
| 631 |
+
noise = torch.randn(latents.shape).to(latents.device)
|
| 632 |
+
bsz = latents.shape[0]
|
| 633 |
+
# Sample a random timestep for each image
|
| 634 |
+
timesteps = torch.randint(
|
| 635 |
+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
| 636 |
+
).long()
|
| 637 |
+
|
| 638 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 639 |
+
# (this is the forward diffusion process)
|
| 640 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 641 |
+
|
| 642 |
+
# Get the text embedding for conditioning
|
| 643 |
+
encoder_hidden_states = text_encoder_(batch["input_ids"])[0]
|
| 644 |
+
|
| 645 |
+
# Predict the noise residual
|
| 646 |
+
model_pred = unet_(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 647 |
+
|
| 648 |
+
loss = F.mse_loss(model_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
| 649 |
+
if train_unet and compression_manager:
|
| 650 |
+
unet_inputs = {
|
| 651 |
+
"sample": noisy_latents,
|
| 652 |
+
"timestep": timesteps,
|
| 653 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 654 |
+
}
|
| 655 |
+
loss = compression_manager.callbacks.on_after_compute_loss(unet_inputs, model_pred, loss)
|
| 656 |
+
|
| 657 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 658 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 659 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 660 |
+
|
| 661 |
+
# Backpropagate
|
| 662 |
+
accelerator.backward(loss)
|
| 663 |
+
|
| 664 |
+
if train_unet:
|
| 665 |
+
if accelerator.sync_gradients:
|
| 666 |
+
accelerator.clip_grad_norm_(unet_.parameters(), args.max_grad_norm)
|
| 667 |
+
else:
|
| 668 |
+
# Zero out the gradients for all token embeddings except the newly added
|
| 669 |
+
# embeddings for the concept, as we only want to optimize the concept embeddings
|
| 670 |
+
if accelerator.num_processes > 1:
|
| 671 |
+
grads = text_encoder_.module.get_input_embeddings().weight.grad
|
| 672 |
+
else:
|
| 673 |
+
grads = text_encoder_.get_input_embeddings().weight.grad
|
| 674 |
+
# Get the index for tokens that we want to zero the grads for
|
| 675 |
+
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
| 676 |
+
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
| 677 |
+
|
| 678 |
+
optimizer.step()
|
| 679 |
+
lr_scheduler.step()
|
| 680 |
+
optimizer.zero_grad()
|
| 681 |
+
|
| 682 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 683 |
+
if accelerator.sync_gradients:
|
| 684 |
+
if train_unet and args.use_ema:
|
| 685 |
+
ema_unet.step(unet_.parameters())
|
| 686 |
+
progress_bar.update(1)
|
| 687 |
+
global_step += 1
|
| 688 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 689 |
+
train_loss = 0.0
|
| 690 |
+
if not train_unet and global_step % args.save_steps == 0:
|
| 691 |
+
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
| 692 |
+
save_progress(text_encoder_, placeholder_token_id, accelerator, args, save_path)
|
| 693 |
+
|
| 694 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 695 |
+
progress_bar.set_postfix(**logs)
|
| 696 |
+
accelerator.log(logs, step=global_step)
|
| 697 |
+
|
| 698 |
+
if global_step >= args.max_train_steps:
|
| 699 |
+
break
|
| 700 |
+
accelerator.wait_for_everyone()
|
| 701 |
+
|
| 702 |
+
if train_unet and args.use_ema:
|
| 703 |
+
ema_unet.copy_to(unet_.parameters())
|
| 704 |
+
|
| 705 |
+
if not train_unet:
|
| 706 |
+
return text_encoder_
|
| 707 |
+
|
| 708 |
+
if not train_unet:
|
| 709 |
+
text_encoder = train_func(text_encoder)
|
| 710 |
+
else:
|
| 711 |
+
import copy
|
| 712 |
+
|
| 713 |
+
model = copy.deepcopy(unet)
|
| 714 |
+
confs = []
|
| 715 |
+
if args.do_quantization:
|
| 716 |
+
from neural_compressor import QuantizationAwareTrainingConfig
|
| 717 |
+
|
| 718 |
+
q_conf = QuantizationAwareTrainingConfig()
|
| 719 |
+
confs.append(q_conf)
|
| 720 |
+
|
| 721 |
+
if args.do_distillation:
|
| 722 |
+
teacher_model = copy.deepcopy(model)
|
| 723 |
+
|
| 724 |
+
def attention_fetcher(x):
|
| 725 |
+
return x.sample
|
| 726 |
+
|
| 727 |
+
layer_mappings = [
|
| 728 |
+
[
|
| 729 |
+
[
|
| 730 |
+
"conv_in",
|
| 731 |
+
]
|
| 732 |
+
],
|
| 733 |
+
[
|
| 734 |
+
[
|
| 735 |
+
"time_embedding",
|
| 736 |
+
]
|
| 737 |
+
],
|
| 738 |
+
[["down_blocks.0.attentions.0", attention_fetcher]],
|
| 739 |
+
[["down_blocks.0.attentions.1", attention_fetcher]],
|
| 740 |
+
[
|
| 741 |
+
[
|
| 742 |
+
"down_blocks.0.resnets.0",
|
| 743 |
+
]
|
| 744 |
+
],
|
| 745 |
+
[
|
| 746 |
+
[
|
| 747 |
+
"down_blocks.0.resnets.1",
|
| 748 |
+
]
|
| 749 |
+
],
|
| 750 |
+
[
|
| 751 |
+
[
|
| 752 |
+
"down_blocks.0.downsamplers.0",
|
| 753 |
+
]
|
| 754 |
+
],
|
| 755 |
+
[["down_blocks.1.attentions.0", attention_fetcher]],
|
| 756 |
+
[["down_blocks.1.attentions.1", attention_fetcher]],
|
| 757 |
+
[
|
| 758 |
+
[
|
| 759 |
+
"down_blocks.1.resnets.0",
|
| 760 |
+
]
|
| 761 |
+
],
|
| 762 |
+
[
|
| 763 |
+
[
|
| 764 |
+
"down_blocks.1.resnets.1",
|
| 765 |
+
]
|
| 766 |
+
],
|
| 767 |
+
[
|
| 768 |
+
[
|
| 769 |
+
"down_blocks.1.downsamplers.0",
|
| 770 |
+
]
|
| 771 |
+
],
|
| 772 |
+
[["down_blocks.2.attentions.0", attention_fetcher]],
|
| 773 |
+
[["down_blocks.2.attentions.1", attention_fetcher]],
|
| 774 |
+
[
|
| 775 |
+
[
|
| 776 |
+
"down_blocks.2.resnets.0",
|
| 777 |
+
]
|
| 778 |
+
],
|
| 779 |
+
[
|
| 780 |
+
[
|
| 781 |
+
"down_blocks.2.resnets.1",
|
| 782 |
+
]
|
| 783 |
+
],
|
| 784 |
+
[
|
| 785 |
+
[
|
| 786 |
+
"down_blocks.2.downsamplers.0",
|
| 787 |
+
]
|
| 788 |
+
],
|
| 789 |
+
[
|
| 790 |
+
[
|
| 791 |
+
"down_blocks.3.resnets.0",
|
| 792 |
+
]
|
| 793 |
+
],
|
| 794 |
+
[
|
| 795 |
+
[
|
| 796 |
+
"down_blocks.3.resnets.1",
|
| 797 |
+
]
|
| 798 |
+
],
|
| 799 |
+
[
|
| 800 |
+
[
|
| 801 |
+
"up_blocks.0.resnets.0",
|
| 802 |
+
]
|
| 803 |
+
],
|
| 804 |
+
[
|
| 805 |
+
[
|
| 806 |
+
"up_blocks.0.resnets.1",
|
| 807 |
+
]
|
| 808 |
+
],
|
| 809 |
+
[
|
| 810 |
+
[
|
| 811 |
+
"up_blocks.0.resnets.2",
|
| 812 |
+
]
|
| 813 |
+
],
|
| 814 |
+
[
|
| 815 |
+
[
|
| 816 |
+
"up_blocks.0.upsamplers.0",
|
| 817 |
+
]
|
| 818 |
+
],
|
| 819 |
+
[["up_blocks.1.attentions.0", attention_fetcher]],
|
| 820 |
+
[["up_blocks.1.attentions.1", attention_fetcher]],
|
| 821 |
+
[["up_blocks.1.attentions.2", attention_fetcher]],
|
| 822 |
+
[
|
| 823 |
+
[
|
| 824 |
+
"up_blocks.1.resnets.0",
|
| 825 |
+
]
|
| 826 |
+
],
|
| 827 |
+
[
|
| 828 |
+
[
|
| 829 |
+
"up_blocks.1.resnets.1",
|
| 830 |
+
]
|
| 831 |
+
],
|
| 832 |
+
[
|
| 833 |
+
[
|
| 834 |
+
"up_blocks.1.resnets.2",
|
| 835 |
+
]
|
| 836 |
+
],
|
| 837 |
+
[
|
| 838 |
+
[
|
| 839 |
+
"up_blocks.1.upsamplers.0",
|
| 840 |
+
]
|
| 841 |
+
],
|
| 842 |
+
[["up_blocks.2.attentions.0", attention_fetcher]],
|
| 843 |
+
[["up_blocks.2.attentions.1", attention_fetcher]],
|
| 844 |
+
[["up_blocks.2.attentions.2", attention_fetcher]],
|
| 845 |
+
[
|
| 846 |
+
[
|
| 847 |
+
"up_blocks.2.resnets.0",
|
| 848 |
+
]
|
| 849 |
+
],
|
| 850 |
+
[
|
| 851 |
+
[
|
| 852 |
+
"up_blocks.2.resnets.1",
|
| 853 |
+
]
|
| 854 |
+
],
|
| 855 |
+
[
|
| 856 |
+
[
|
| 857 |
+
"up_blocks.2.resnets.2",
|
| 858 |
+
]
|
| 859 |
+
],
|
| 860 |
+
[
|
| 861 |
+
[
|
| 862 |
+
"up_blocks.2.upsamplers.0",
|
| 863 |
+
]
|
| 864 |
+
],
|
| 865 |
+
[["up_blocks.3.attentions.0", attention_fetcher]],
|
| 866 |
+
[["up_blocks.3.attentions.1", attention_fetcher]],
|
| 867 |
+
[["up_blocks.3.attentions.2", attention_fetcher]],
|
| 868 |
+
[
|
| 869 |
+
[
|
| 870 |
+
"up_blocks.3.resnets.0",
|
| 871 |
+
]
|
| 872 |
+
],
|
| 873 |
+
[
|
| 874 |
+
[
|
| 875 |
+
"up_blocks.3.resnets.1",
|
| 876 |
+
]
|
| 877 |
+
],
|
| 878 |
+
[
|
| 879 |
+
[
|
| 880 |
+
"up_blocks.3.resnets.2",
|
| 881 |
+
]
|
| 882 |
+
],
|
| 883 |
+
[["mid_block.attentions.0", attention_fetcher]],
|
| 884 |
+
[
|
| 885 |
+
[
|
| 886 |
+
"mid_block.resnets.0",
|
| 887 |
+
]
|
| 888 |
+
],
|
| 889 |
+
[
|
| 890 |
+
[
|
| 891 |
+
"mid_block.resnets.1",
|
| 892 |
+
]
|
| 893 |
+
],
|
| 894 |
+
[
|
| 895 |
+
[
|
| 896 |
+
"conv_out",
|
| 897 |
+
]
|
| 898 |
+
],
|
| 899 |
+
]
|
| 900 |
+
layer_names = [layer_mapping[0][0] for layer_mapping in layer_mappings]
|
| 901 |
+
if not set(layer_names).issubset([n[0] for n in model.named_modules()]):
|
| 902 |
+
raise ValueError(
|
| 903 |
+
"Provided model is not compatible with the default layer_mappings, "
|
| 904 |
+
'please use the model fine-tuned from "CompVis/stable-diffusion-v1-4", '
|
| 905 |
+
"or modify the layer_mappings variable to fit your model."
|
| 906 |
+
f"\nDefault layer_mappings are as such:\n{layer_mappings}"
|
| 907 |
+
)
|
| 908 |
+
from neural_compressor.config import DistillationConfig, IntermediateLayersKnowledgeDistillationLossConfig
|
| 909 |
+
|
| 910 |
+
distillation_criterion = IntermediateLayersKnowledgeDistillationLossConfig(
|
| 911 |
+
layer_mappings=layer_mappings,
|
| 912 |
+
loss_types=["MSE"] * len(layer_mappings),
|
| 913 |
+
loss_weights=[1.0 / len(layer_mappings)] * len(layer_mappings),
|
| 914 |
+
add_origin_loss=True,
|
| 915 |
+
)
|
| 916 |
+
d_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)
|
| 917 |
+
confs.append(d_conf)
|
| 918 |
+
|
| 919 |
+
from neural_compressor.training import prepare_compression
|
| 920 |
+
|
| 921 |
+
compression_manager = prepare_compression(model, confs)
|
| 922 |
+
compression_manager.callbacks.on_train_begin()
|
| 923 |
+
model = compression_manager.model
|
| 924 |
+
train_func(model)
|
| 925 |
+
compression_manager.callbacks.on_train_end()
|
| 926 |
+
|
| 927 |
+
# Save the resulting model and its corresponding configuration in the given directory
|
| 928 |
+
model.save(args.output_dir)
|
| 929 |
+
|
| 930 |
+
logger.info(f"Optimized model saved to: {args.output_dir}.")
|
| 931 |
+
|
| 932 |
+
# change to framework model for further use
|
| 933 |
+
model = model.model
|
| 934 |
+
|
| 935 |
+
# Create the pipeline using using the trained modules and save it.
|
| 936 |
+
templates = imagenet_style_templates_small if args.learnable_property == "style" else imagenet_templates_small
|
| 937 |
+
prompt = templates[0].format(args.placeholder_token)
|
| 938 |
+
if accelerator.is_main_process:
|
| 939 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 940 |
+
args.pretrained_model_name_or_path,
|
| 941 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 942 |
+
vae=vae,
|
| 943 |
+
unet=accelerator.unwrap_model(unet),
|
| 944 |
+
tokenizer=tokenizer,
|
| 945 |
+
)
|
| 946 |
+
pipeline.save_pretrained(args.output_dir)
|
| 947 |
+
pipeline = pipeline.to(unet.device)
|
| 948 |
+
baseline_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
|
| 949 |
+
baseline_model_images.save(
|
| 950 |
+
os.path.join(args.output_dir, "{}_baseline_model.png".format("_".join(prompt.split())))
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
if not train_unet:
|
| 954 |
+
# Also save the newly trained embeddings
|
| 955 |
+
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
| 956 |
+
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
| 957 |
+
else:
|
| 958 |
+
setattr(pipeline, "unet", accelerator.unwrap_model(model))
|
| 959 |
+
if args.do_quantization:
|
| 960 |
+
pipeline = pipeline.to(torch.device("cpu"))
|
| 961 |
+
|
| 962 |
+
optimized_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
|
| 963 |
+
optimized_model_images.save(
|
| 964 |
+
os.path.join(args.output_dir, "{}_optimized_model.png".format("_".join(prompt.split())))
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
if args.push_to_hub:
|
| 968 |
+
upload_folder(
|
| 969 |
+
repo_id=repo_id,
|
| 970 |
+
folder_path=args.output_dir,
|
| 971 |
+
commit_message="End of training",
|
| 972 |
+
ignore_patterns=["step_*", "epoch_*"],
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
accelerator.end_training()
|
| 976 |
+
|
| 977 |
+
if args.do_quantization and args.verify_loading:
|
| 978 |
+
# Load the model obtained after Intel Neural Compressor quantization
|
| 979 |
+
from neural_compressor.utils.pytorch import load
|
| 980 |
+
|
| 981 |
+
loaded_model = load(args.output_dir, model=unet)
|
| 982 |
+
loaded_model.eval()
|
| 983 |
+
|
| 984 |
+
setattr(pipeline, "unet", loaded_model)
|
| 985 |
+
if args.do_quantization:
|
| 986 |
+
pipeline = pipeline.to(torch.device("cpu"))
|
| 987 |
+
|
| 988 |
+
loaded_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
|
| 989 |
+
if loaded_model_images != optimized_model_images:
|
| 990 |
+
logger.info("The quantized model was not successfully loaded.")
|
| 991 |
+
else:
|
| 992 |
+
logger.info("The quantized model was successfully loaded.")
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
if __name__ == "__main__":
|
| 996 |
+
main()
|
diffusers/examples/research_projects/ip_adapter/README.md
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IP Adapter Training Example
|
| 2 |
+
|
| 3 |
+
[IP Adapter](https://huggingface.co/papers/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
|
| 4 |
+
|
| 5 |
+
## Training locally with PyTorch
|
| 6 |
+
|
| 7 |
+
### Installing the dependencies
|
| 8 |
+
|
| 9 |
+
Before running the scripts, make sure to install the library's training dependencies:
|
| 10 |
+
|
| 11 |
+
**Important**
|
| 12 |
+
|
| 13 |
+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
git clone https://github.com/huggingface/diffusers
|
| 17 |
+
cd diffusers
|
| 18 |
+
pip install -e .
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Then cd in the example folder and run
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
accelerate config
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Or for a default accelerate configuration without answering questions about your environment
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
accelerate config default
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Or if your environment doesn't support an interactive shell e.g. a notebook
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from accelerate.utils import write_basic_config
|
| 43 |
+
write_basic_config()
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Certainly! Below is the documentation in pure Markdown format:
|
| 47 |
+
|
| 48 |
+
### Accelerate Launch Command Documentation
|
| 49 |
+
|
| 50 |
+
#### Description:
|
| 51 |
+
The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
|
| 52 |
+
|
| 53 |
+
#### Usage Example:
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
accelerate launch --mixed_precision "fp16" \
|
| 57 |
+
tutorial_train_ip-adapter.py \
|
| 58 |
+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
|
| 59 |
+
--image_encoder_path="{image_encoder_path}" \
|
| 60 |
+
--data_json_file="{data.json}" \
|
| 61 |
+
--data_root_path="{image_path}" \
|
| 62 |
+
--mixed_precision="fp16" \
|
| 63 |
+
--resolution=512 \
|
| 64 |
+
--train_batch_size=8 \
|
| 65 |
+
--dataloader_num_workers=4 \
|
| 66 |
+
--learning_rate=1e-04 \
|
| 67 |
+
--weight_decay=0.01 \
|
| 68 |
+
--output_dir="{output_dir}" \
|
| 69 |
+
--save_steps=10000
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### Multi-GPU Script:
|
| 73 |
+
```
|
| 74 |
+
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
|
| 75 |
+
tutorial_train_ip-adapter.py \
|
| 76 |
+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
|
| 77 |
+
--image_encoder_path="{image_encoder_path}" \
|
| 78 |
+
--data_json_file="{data.json}" \
|
| 79 |
+
--data_root_path="{image_path}" \
|
| 80 |
+
--mixed_precision="fp16" \
|
| 81 |
+
--resolution=512 \
|
| 82 |
+
--train_batch_size=8 \
|
| 83 |
+
--dataloader_num_workers=4 \
|
| 84 |
+
--learning_rate=1e-04 \
|
| 85 |
+
--weight_decay=0.01 \
|
| 86 |
+
--output_dir="{output_dir}" \
|
| 87 |
+
--save_steps=10000
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
#### Parameters:
|
| 91 |
+
- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).
|
| 92 |
+
- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.
|
| 93 |
+
- `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision.
|
| 94 |
+
- `tutorial_train_ip-adapter.py`: Name of the training script to be executed.
|
| 95 |
+
- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.
|
| 96 |
+
- `--image_encoder_path`: Path to the CLIP image encoder.
|
| 97 |
+
- `--data_json_file`: Path to the training data in JSON format.
|
| 98 |
+
- `--data_root_path`: Root path where training images are located.
|
| 99 |
+
- `--resolution`: Resolution of input images (512x512 in this example).
|
| 100 |
+
- `--train_batch_size`: Batch size for training data (8 in this example).
|
| 101 |
+
- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).
|
| 102 |
+
- `--learning_rate`: Learning rate for training (1e-04 in this example).
|
| 103 |
+
- `--weight_decay`: Weight decay for regularization (0.01 in this example).
|
| 104 |
+
- `--output_dir`: Directory to save model checkpoints and predictions.
|
| 105 |
+
- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).
|
| 106 |
+
|
| 107 |
+
### Inference
|
| 108 |
+
|
| 109 |
+
#### Description:
|
| 110 |
+
The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.
|
| 111 |
+
|
| 112 |
+
#### Usage Example:
|
| 113 |
+
```python
|
| 114 |
+
from safetensors.torch import load_file, save_file
|
| 115 |
+
|
| 116 |
+
# Load the trained model checkpoint in safetensors format
|
| 117 |
+
ckpt = "checkpoint-50000/pytorch_model.safetensors"
|
| 118 |
+
sd = load_file(ckpt) # Using safetensors load function
|
| 119 |
+
|
| 120 |
+
# Extract image projection and IP adapter components
|
| 121 |
+
image_proj_sd = {}
|
| 122 |
+
ip_sd = {}
|
| 123 |
+
|
| 124 |
+
for k in sd:
|
| 125 |
+
if k.startswith("unet"):
|
| 126 |
+
pass # Skip unet-related keys
|
| 127 |
+
elif k.startswith("image_proj_model"):
|
| 128 |
+
image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
|
| 129 |
+
elif k.startswith("adapter_modules"):
|
| 130 |
+
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
|
| 131 |
+
|
| 132 |
+
# Save the components into separate safetensors files
|
| 133 |
+
save_file(image_proj_sd, "image_proj.safetensors")
|
| 134 |
+
save_file(ip_sd, "ip_adapter.safetensors")
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### Sample Inference Script using the CLIP Model
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
|
| 141 |
+
import torch
|
| 142 |
+
from safetensors.torch import load_file
|
| 143 |
+
from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model
|
| 144 |
+
|
| 145 |
+
# Load model components from safetensors
|
| 146 |
+
image_proj_ckpt = "image_proj.safetensors"
|
| 147 |
+
ip_adapter_ckpt = "ip_adapter.safetensors"
|
| 148 |
+
|
| 149 |
+
# Load the saved weights
|
| 150 |
+
image_proj_sd = load_file(image_proj_ckpt)
|
| 151 |
+
ip_adapter_sd = load_file(ip_adapter_ckpt)
|
| 152 |
+
|
| 153 |
+
# Define the model Parameters
|
| 154 |
+
class ImageProjectionModel(torch.nn.Module):
|
| 155 |
+
def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.model = torch.nn.Linear(input_dim, output_dim)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
return self.model(x)
|
| 161 |
+
|
| 162 |
+
class IPAdapterModel(torch.nn.Module):
|
| 163 |
+
def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.model = torch.nn.Linear(input_dim, output_dim)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
return self.model(x)
|
| 169 |
+
|
| 170 |
+
# Initialize models
|
| 171 |
+
image_proj_model = ImageProjectionModel()
|
| 172 |
+
ip_adapter_model = IPAdapterModel()
|
| 173 |
+
|
| 174 |
+
# Load weights into models
|
| 175 |
+
image_proj_model.load_state_dict(image_proj_sd)
|
| 176 |
+
ip_adapter_model.load_state_dict(ip_adapter_sd)
|
| 177 |
+
|
| 178 |
+
# Set models to evaluation mode
|
| 179 |
+
image_proj_model.eval()
|
| 180 |
+
ip_adapter_model.eval()
|
| 181 |
+
|
| 182 |
+
#Inference pipeline
|
| 183 |
+
def inference(image_tensor):
|
| 184 |
+
"""
|
| 185 |
+
Run inference using the loaded models.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
image_tensor: Preprocessed image tensor from CLIPProcessor
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Final inference results
|
| 192 |
+
"""
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
# Step 1: Project the image features
|
| 195 |
+
image_proj = image_proj_model(image_tensor)
|
| 196 |
+
|
| 197 |
+
# Step 2: Pass the projected features through the IP Adapter
|
| 198 |
+
result = ip_adapter_model(image_proj)
|
| 199 |
+
|
| 200 |
+
return result
|
| 201 |
+
|
| 202 |
+
# Using CLIP for image preprocessing
|
| 203 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 204 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 205 |
+
|
| 206 |
+
#Image file path
|
| 207 |
+
image_path = "path/to/image.jpg"
|
| 208 |
+
|
| 209 |
+
# Preprocess the image
|
| 210 |
+
inputs = processor(images=image_path, return_tensors="pt")
|
| 211 |
+
image_features = clip_model.get_image_features(inputs["pixel_values"])
|
| 212 |
+
|
| 213 |
+
# Normalize the image features as per CLIP's recommendations
|
| 214 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 215 |
+
|
| 216 |
+
# Run inference
|
| 217 |
+
output = inference(image_features)
|
| 218 |
+
print("Inference output:", output)
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
#### Parameters:
|
| 222 |
+
- `ckpt`: Path to the trained model checkpoint file.
|
| 223 |
+
- `map_location="cpu"`: Specifies that the model should be loaded onto the CPU.
|
| 224 |
+
- `image_proj_sd`: Dictionary to store the components related to image projection.
|
| 225 |
+
- `ip_sd`: Dictionary to store the components related to the IP adapter.
|
| 226 |
+
- `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model.
|
diffusers/examples/research_projects/ip_adapter/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
torchvision
|
| 3 |
+
transformers>=4.25.1
|
| 4 |
+
ip_adapter
|
diffusers/examples/research_projects/ip_adapter/tutorial_train_faceid.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
+
from accelerate.utils import ProjectConfiguration
|
| 13 |
+
from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
| 14 |
+
from ip_adapter.ip_adapter_faceid import MLPProjModel
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 18 |
+
|
| 19 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Dataset
|
| 23 |
+
class MyDataset(torch.utils.data.Dataset):
|
| 24 |
+
def __init__(
|
| 25 |
+
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.tokenizer = tokenizer
|
| 30 |
+
self.size = size
|
| 31 |
+
self.i_drop_rate = i_drop_rate
|
| 32 |
+
self.t_drop_rate = t_drop_rate
|
| 33 |
+
self.ti_drop_rate = ti_drop_rate
|
| 34 |
+
self.image_root_path = image_root_path
|
| 35 |
+
|
| 36 |
+
self.data = json.load(
|
| 37 |
+
open(json_file)
|
| 38 |
+
) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
|
| 39 |
+
|
| 40 |
+
self.transform = transforms.Compose(
|
| 41 |
+
[
|
| 42 |
+
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 43 |
+
transforms.CenterCrop(self.size),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize([0.5], [0.5]),
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
item = self.data[idx]
|
| 51 |
+
text = item["text"]
|
| 52 |
+
image_file = item["image_file"]
|
| 53 |
+
|
| 54 |
+
# read image
|
| 55 |
+
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
| 56 |
+
image = self.transform(raw_image.convert("RGB"))
|
| 57 |
+
|
| 58 |
+
face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
|
| 59 |
+
face_id_embed = torch.from_numpy(face_id_embed)
|
| 60 |
+
|
| 61 |
+
# drop
|
| 62 |
+
drop_image_embed = 0
|
| 63 |
+
rand_num = random.random()
|
| 64 |
+
if rand_num < self.i_drop_rate:
|
| 65 |
+
drop_image_embed = 1
|
| 66 |
+
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
| 67 |
+
text = ""
|
| 68 |
+
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
| 69 |
+
text = ""
|
| 70 |
+
drop_image_embed = 1
|
| 71 |
+
if drop_image_embed:
|
| 72 |
+
face_id_embed = torch.zeros_like(face_id_embed)
|
| 73 |
+
# get text and tokenize
|
| 74 |
+
text_input_ids = self.tokenizer(
|
| 75 |
+
text,
|
| 76 |
+
max_length=self.tokenizer.model_max_length,
|
| 77 |
+
padding="max_length",
|
| 78 |
+
truncation=True,
|
| 79 |
+
return_tensors="pt",
|
| 80 |
+
).input_ids
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"image": image,
|
| 84 |
+
"text_input_ids": text_input_ids,
|
| 85 |
+
"face_id_embed": face_id_embed,
|
| 86 |
+
"drop_image_embed": drop_image_embed,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def __len__(self):
|
| 90 |
+
return len(self.data)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def collate_fn(data):
|
| 94 |
+
images = torch.stack([example["image"] for example in data])
|
| 95 |
+
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
| 96 |
+
face_id_embed = torch.stack([example["face_id_embed"] for example in data])
|
| 97 |
+
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
"images": images,
|
| 101 |
+
"text_input_ids": text_input_ids,
|
| 102 |
+
"face_id_embed": face_id_embed,
|
| 103 |
+
"drop_image_embeds": drop_image_embeds,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class IPAdapter(torch.nn.Module):
|
| 108 |
+
"""IP-Adapter"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.unet = unet
|
| 113 |
+
self.image_proj_model = image_proj_model
|
| 114 |
+
self.adapter_modules = adapter_modules
|
| 115 |
+
|
| 116 |
+
if ckpt_path is not None:
|
| 117 |
+
self.load_from_checkpoint(ckpt_path)
|
| 118 |
+
|
| 119 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
| 120 |
+
ip_tokens = self.image_proj_model(image_embeds)
|
| 121 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
| 122 |
+
# Predict the noise residual
|
| 123 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 124 |
+
return noise_pred
|
| 125 |
+
|
| 126 |
+
def load_from_checkpoint(self, ckpt_path: str):
|
| 127 |
+
# Calculate original checksums
|
| 128 |
+
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
| 129 |
+
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
| 130 |
+
|
| 131 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 132 |
+
|
| 133 |
+
# Load state dict for image_proj_model and adapter_modules
|
| 134 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
| 135 |
+
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
| 136 |
+
|
| 137 |
+
# Calculate new checksums
|
| 138 |
+
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
| 139 |
+
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
| 140 |
+
|
| 141 |
+
# Verify if the weights have changed
|
| 142 |
+
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
| 143 |
+
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
| 144 |
+
|
| 145 |
+
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parse_args():
|
| 149 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--pretrained_model_name_or_path",
|
| 152 |
+
type=str,
|
| 153 |
+
default=None,
|
| 154 |
+
required=True,
|
| 155 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--pretrained_ip_adapter_path",
|
| 159 |
+
type=str,
|
| 160 |
+
default=None,
|
| 161 |
+
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--data_json_file",
|
| 165 |
+
type=str,
|
| 166 |
+
default=None,
|
| 167 |
+
required=True,
|
| 168 |
+
help="Training data",
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--data_root_path",
|
| 172 |
+
type=str,
|
| 173 |
+
default="",
|
| 174 |
+
required=True,
|
| 175 |
+
help="Training data root path",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--image_encoder_path",
|
| 179 |
+
type=str,
|
| 180 |
+
default=None,
|
| 181 |
+
required=True,
|
| 182 |
+
help="Path to CLIP image encoder",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--output_dir",
|
| 186 |
+
type=str,
|
| 187 |
+
default="sd-ip_adapter",
|
| 188 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--logging_dir",
|
| 192 |
+
type=str,
|
| 193 |
+
default="logs",
|
| 194 |
+
help=(
|
| 195 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 196 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--resolution",
|
| 201 |
+
type=int,
|
| 202 |
+
default=512,
|
| 203 |
+
help=("The resolution for input images"),
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--learning_rate",
|
| 207 |
+
type=float,
|
| 208 |
+
default=1e-4,
|
| 209 |
+
help="Learning rate to use.",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 212 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--dataloader_num_workers",
|
| 218 |
+
type=int,
|
| 219 |
+
default=0,
|
| 220 |
+
help=(
|
| 221 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 222 |
+
),
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--save_steps",
|
| 226 |
+
type=int,
|
| 227 |
+
default=2000,
|
| 228 |
+
help=("Save a checkpoint of the training state every X updates"),
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--mixed_precision",
|
| 232 |
+
type=str,
|
| 233 |
+
default=None,
|
| 234 |
+
choices=["no", "fp16", "bf16"],
|
| 235 |
+
help=(
|
| 236 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 237 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 238 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--report_to",
|
| 243 |
+
type=str,
|
| 244 |
+
default="tensorboard",
|
| 245 |
+
help=(
|
| 246 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 247 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 251 |
+
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 254 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 255 |
+
args.local_rank = env_local_rank
|
| 256 |
+
|
| 257 |
+
return args
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
args = parse_args()
|
| 262 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 263 |
+
|
| 264 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 265 |
+
|
| 266 |
+
accelerator = Accelerator(
|
| 267 |
+
mixed_precision=args.mixed_precision,
|
| 268 |
+
log_with=args.report_to,
|
| 269 |
+
project_config=accelerator_project_config,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if accelerator.is_main_process:
|
| 273 |
+
if args.output_dir is not None:
|
| 274 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 275 |
+
|
| 276 |
+
# Load scheduler, tokenizer and models.
|
| 277 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 278 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 279 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 280 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 281 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 282 |
+
# image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
| 283 |
+
# freeze parameters of models to save more memory
|
| 284 |
+
unet.requires_grad_(False)
|
| 285 |
+
vae.requires_grad_(False)
|
| 286 |
+
text_encoder.requires_grad_(False)
|
| 287 |
+
# image_encoder.requires_grad_(False)
|
| 288 |
+
|
| 289 |
+
# ip-adapter
|
| 290 |
+
image_proj_model = MLPProjModel(
|
| 291 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
| 292 |
+
id_embeddings_dim=512,
|
| 293 |
+
num_tokens=4,
|
| 294 |
+
)
|
| 295 |
+
# init adapter modules
|
| 296 |
+
lora_rank = 128
|
| 297 |
+
attn_procs = {}
|
| 298 |
+
unet_sd = unet.state_dict()
|
| 299 |
+
for name in unet.attn_processors.keys():
|
| 300 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 301 |
+
if name.startswith("mid_block"):
|
| 302 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 303 |
+
elif name.startswith("up_blocks"):
|
| 304 |
+
block_id = int(name[len("up_blocks.")])
|
| 305 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 306 |
+
elif name.startswith("down_blocks"):
|
| 307 |
+
block_id = int(name[len("down_blocks.")])
|
| 308 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 309 |
+
if cross_attention_dim is None:
|
| 310 |
+
attn_procs[name] = LoRAAttnProcessor(
|
| 311 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
layer_name = name.split(".processor")[0]
|
| 315 |
+
weights = {
|
| 316 |
+
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
| 317 |
+
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
| 318 |
+
}
|
| 319 |
+
attn_procs[name] = LoRAIPAttnProcessor(
|
| 320 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
| 321 |
+
)
|
| 322 |
+
attn_procs[name].load_state_dict(weights, strict=False)
|
| 323 |
+
unet.set_attn_processor(attn_procs)
|
| 324 |
+
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
| 325 |
+
|
| 326 |
+
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
| 327 |
+
|
| 328 |
+
weight_dtype = torch.float32
|
| 329 |
+
if accelerator.mixed_precision == "fp16":
|
| 330 |
+
weight_dtype = torch.float16
|
| 331 |
+
elif accelerator.mixed_precision == "bf16":
|
| 332 |
+
weight_dtype = torch.bfloat16
|
| 333 |
+
# unet.to(accelerator.device, dtype=weight_dtype)
|
| 334 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 335 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 336 |
+
# image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 337 |
+
|
| 338 |
+
# optimizer
|
| 339 |
+
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
| 340 |
+
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
| 341 |
+
|
| 342 |
+
# dataloader
|
| 343 |
+
train_dataset = MyDataset(
|
| 344 |
+
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
| 345 |
+
)
|
| 346 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 347 |
+
train_dataset,
|
| 348 |
+
shuffle=True,
|
| 349 |
+
collate_fn=collate_fn,
|
| 350 |
+
batch_size=args.train_batch_size,
|
| 351 |
+
num_workers=args.dataloader_num_workers,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Prepare everything with our `accelerator`.
|
| 355 |
+
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
| 356 |
+
|
| 357 |
+
global_step = 0
|
| 358 |
+
for epoch in range(0, args.num_train_epochs):
|
| 359 |
+
begin = time.perf_counter()
|
| 360 |
+
for step, batch in enumerate(train_dataloader):
|
| 361 |
+
load_data_time = time.perf_counter() - begin
|
| 362 |
+
with accelerator.accumulate(ip_adapter):
|
| 363 |
+
# Convert images to latent space
|
| 364 |
+
with torch.no_grad():
|
| 365 |
+
latents = vae.encode(
|
| 366 |
+
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
| 367 |
+
).latent_dist.sample()
|
| 368 |
+
latents = latents * vae.config.scaling_factor
|
| 369 |
+
|
| 370 |
+
# Sample noise that we'll add to the latents
|
| 371 |
+
noise = torch.randn_like(latents)
|
| 372 |
+
bsz = latents.shape[0]
|
| 373 |
+
# Sample a random timestep for each image
|
| 374 |
+
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
| 375 |
+
timesteps = timesteps.long()
|
| 376 |
+
|
| 377 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 378 |
+
# (this is the forward diffusion process)
|
| 379 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 380 |
+
|
| 381 |
+
image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
|
| 382 |
+
|
| 383 |
+
with torch.no_grad():
|
| 384 |
+
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
| 385 |
+
|
| 386 |
+
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
| 387 |
+
|
| 388 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 389 |
+
|
| 390 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 391 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
| 392 |
+
|
| 393 |
+
# Backpropagate
|
| 394 |
+
accelerator.backward(loss)
|
| 395 |
+
optimizer.step()
|
| 396 |
+
optimizer.zero_grad()
|
| 397 |
+
|
| 398 |
+
if accelerator.is_main_process:
|
| 399 |
+
print(
|
| 400 |
+
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
| 401 |
+
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
global_step += 1
|
| 406 |
+
|
| 407 |
+
if global_step % args.save_steps == 0:
|
| 408 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 409 |
+
accelerator.save_state(save_path)
|
| 410 |
+
|
| 411 |
+
begin = time.perf_counter()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
main()
|
diffusers/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
+
from accelerate.utils import ProjectConfiguration
|
| 13 |
+
from ip_adapter.ip_adapter import ImageProjModel
|
| 14 |
+
from ip_adapter.utils import is_torch2_available
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 18 |
+
|
| 19 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if is_torch2_available():
|
| 23 |
+
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
|
| 24 |
+
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
| 25 |
+
else:
|
| 26 |
+
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Dataset
|
| 30 |
+
class MyDataset(torch.utils.data.Dataset):
|
| 31 |
+
def __init__(
|
| 32 |
+
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.tokenizer = tokenizer
|
| 37 |
+
self.size = size
|
| 38 |
+
self.i_drop_rate = i_drop_rate
|
| 39 |
+
self.t_drop_rate = t_drop_rate
|
| 40 |
+
self.ti_drop_rate = ti_drop_rate
|
| 41 |
+
self.image_root_path = image_root_path
|
| 42 |
+
|
| 43 |
+
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
|
| 44 |
+
|
| 45 |
+
self.transform = transforms.Compose(
|
| 46 |
+
[
|
| 47 |
+
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 48 |
+
transforms.CenterCrop(self.size),
|
| 49 |
+
transforms.ToTensor(),
|
| 50 |
+
transforms.Normalize([0.5], [0.5]),
|
| 51 |
+
]
|
| 52 |
+
)
|
| 53 |
+
self.clip_image_processor = CLIPImageProcessor()
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx):
|
| 56 |
+
item = self.data[idx]
|
| 57 |
+
text = item["text"]
|
| 58 |
+
image_file = item["image_file"]
|
| 59 |
+
|
| 60 |
+
# read image
|
| 61 |
+
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
| 62 |
+
image = self.transform(raw_image.convert("RGB"))
|
| 63 |
+
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
| 64 |
+
|
| 65 |
+
# drop
|
| 66 |
+
drop_image_embed = 0
|
| 67 |
+
rand_num = random.random()
|
| 68 |
+
if rand_num < self.i_drop_rate:
|
| 69 |
+
drop_image_embed = 1
|
| 70 |
+
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
| 71 |
+
text = ""
|
| 72 |
+
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
| 73 |
+
text = ""
|
| 74 |
+
drop_image_embed = 1
|
| 75 |
+
# get text and tokenize
|
| 76 |
+
text_input_ids = self.tokenizer(
|
| 77 |
+
text,
|
| 78 |
+
max_length=self.tokenizer.model_max_length,
|
| 79 |
+
padding="max_length",
|
| 80 |
+
truncation=True,
|
| 81 |
+
return_tensors="pt",
|
| 82 |
+
).input_ids
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"image": image,
|
| 86 |
+
"text_input_ids": text_input_ids,
|
| 87 |
+
"clip_image": clip_image,
|
| 88 |
+
"drop_image_embed": drop_image_embed,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return len(self.data)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def collate_fn(data):
|
| 96 |
+
images = torch.stack([example["image"] for example in data])
|
| 97 |
+
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
| 98 |
+
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
|
| 99 |
+
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
"images": images,
|
| 103 |
+
"text_input_ids": text_input_ids,
|
| 104 |
+
"clip_images": clip_images,
|
| 105 |
+
"drop_image_embeds": drop_image_embeds,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class IPAdapter(torch.nn.Module):
|
| 110 |
+
"""IP-Adapter"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.unet = unet
|
| 115 |
+
self.image_proj_model = image_proj_model
|
| 116 |
+
self.adapter_modules = adapter_modules
|
| 117 |
+
|
| 118 |
+
if ckpt_path is not None:
|
| 119 |
+
self.load_from_checkpoint(ckpt_path)
|
| 120 |
+
|
| 121 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
| 122 |
+
ip_tokens = self.image_proj_model(image_embeds)
|
| 123 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
| 124 |
+
# Predict the noise residual
|
| 125 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 126 |
+
return noise_pred
|
| 127 |
+
|
| 128 |
+
def load_from_checkpoint(self, ckpt_path: str):
|
| 129 |
+
# Calculate original checksums
|
| 130 |
+
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
| 131 |
+
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
| 132 |
+
|
| 133 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 134 |
+
|
| 135 |
+
# Load state dict for image_proj_model and adapter_modules
|
| 136 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
| 137 |
+
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
| 138 |
+
|
| 139 |
+
# Calculate new checksums
|
| 140 |
+
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
| 141 |
+
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
| 142 |
+
|
| 143 |
+
# Verify if the weights have changed
|
| 144 |
+
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
| 145 |
+
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
| 146 |
+
|
| 147 |
+
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def parse_args():
|
| 151 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--pretrained_model_name_or_path",
|
| 154 |
+
type=str,
|
| 155 |
+
default=None,
|
| 156 |
+
required=True,
|
| 157 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 158 |
+
)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--pretrained_ip_adapter_path",
|
| 161 |
+
type=str,
|
| 162 |
+
default=None,
|
| 163 |
+
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--data_json_file",
|
| 167 |
+
type=str,
|
| 168 |
+
default=None,
|
| 169 |
+
required=True,
|
| 170 |
+
help="Training data",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--data_root_path",
|
| 174 |
+
type=str,
|
| 175 |
+
default="",
|
| 176 |
+
required=True,
|
| 177 |
+
help="Training data root path",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--image_encoder_path",
|
| 181 |
+
type=str,
|
| 182 |
+
default=None,
|
| 183 |
+
required=True,
|
| 184 |
+
help="Path to CLIP image encoder",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--output_dir",
|
| 188 |
+
type=str,
|
| 189 |
+
default="sd-ip_adapter",
|
| 190 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--logging_dir",
|
| 194 |
+
type=str,
|
| 195 |
+
default="logs",
|
| 196 |
+
help=(
|
| 197 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 198 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 199 |
+
),
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--resolution",
|
| 203 |
+
type=int,
|
| 204 |
+
default=512,
|
| 205 |
+
help=("The resolution for input images"),
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--learning_rate",
|
| 209 |
+
type=float,
|
| 210 |
+
default=1e-4,
|
| 211 |
+
help="Learning rate to use.",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 214 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--dataloader_num_workers",
|
| 220 |
+
type=int,
|
| 221 |
+
default=0,
|
| 222 |
+
help=(
|
| 223 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 224 |
+
),
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--save_steps",
|
| 228 |
+
type=int,
|
| 229 |
+
default=2000,
|
| 230 |
+
help=("Save a checkpoint of the training state every X updates"),
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--mixed_precision",
|
| 234 |
+
type=str,
|
| 235 |
+
default=None,
|
| 236 |
+
choices=["no", "fp16", "bf16"],
|
| 237 |
+
help=(
|
| 238 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 239 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 240 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 241 |
+
),
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--report_to",
|
| 245 |
+
type=str,
|
| 246 |
+
default="tensorboard",
|
| 247 |
+
help=(
|
| 248 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 249 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 250 |
+
),
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 253 |
+
|
| 254 |
+
args = parser.parse_args()
|
| 255 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 256 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 257 |
+
args.local_rank = env_local_rank
|
| 258 |
+
|
| 259 |
+
return args
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def main():
|
| 263 |
+
args = parse_args()
|
| 264 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 265 |
+
|
| 266 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 267 |
+
|
| 268 |
+
accelerator = Accelerator(
|
| 269 |
+
mixed_precision=args.mixed_precision,
|
| 270 |
+
log_with=args.report_to,
|
| 271 |
+
project_config=accelerator_project_config,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if accelerator.is_main_process:
|
| 275 |
+
if args.output_dir is not None:
|
| 276 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 277 |
+
|
| 278 |
+
# Load scheduler, tokenizer and models.
|
| 279 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 280 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 281 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 282 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 283 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 284 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
| 285 |
+
# freeze parameters of models to save more memory
|
| 286 |
+
unet.requires_grad_(False)
|
| 287 |
+
vae.requires_grad_(False)
|
| 288 |
+
text_encoder.requires_grad_(False)
|
| 289 |
+
image_encoder.requires_grad_(False)
|
| 290 |
+
|
| 291 |
+
# ip-adapter
|
| 292 |
+
image_proj_model = ImageProjModel(
|
| 293 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
| 294 |
+
clip_embeddings_dim=image_encoder.config.projection_dim,
|
| 295 |
+
clip_extra_context_tokens=4,
|
| 296 |
+
)
|
| 297 |
+
# init adapter modules
|
| 298 |
+
attn_procs = {}
|
| 299 |
+
unet_sd = unet.state_dict()
|
| 300 |
+
for name in unet.attn_processors.keys():
|
| 301 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 302 |
+
if name.startswith("mid_block"):
|
| 303 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 304 |
+
elif name.startswith("up_blocks"):
|
| 305 |
+
block_id = int(name[len("up_blocks.")])
|
| 306 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 307 |
+
elif name.startswith("down_blocks"):
|
| 308 |
+
block_id = int(name[len("down_blocks.")])
|
| 309 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 310 |
+
if cross_attention_dim is None:
|
| 311 |
+
attn_procs[name] = AttnProcessor()
|
| 312 |
+
else:
|
| 313 |
+
layer_name = name.split(".processor")[0]
|
| 314 |
+
weights = {
|
| 315 |
+
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
| 316 |
+
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
| 317 |
+
}
|
| 318 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
| 319 |
+
attn_procs[name].load_state_dict(weights)
|
| 320 |
+
unet.set_attn_processor(attn_procs)
|
| 321 |
+
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
| 322 |
+
|
| 323 |
+
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
| 324 |
+
|
| 325 |
+
weight_dtype = torch.float32
|
| 326 |
+
if accelerator.mixed_precision == "fp16":
|
| 327 |
+
weight_dtype = torch.float16
|
| 328 |
+
elif accelerator.mixed_precision == "bf16":
|
| 329 |
+
weight_dtype = torch.bfloat16
|
| 330 |
+
# unet.to(accelerator.device, dtype=weight_dtype)
|
| 331 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 332 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 333 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 334 |
+
|
| 335 |
+
# optimizer
|
| 336 |
+
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
| 337 |
+
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
| 338 |
+
|
| 339 |
+
# dataloader
|
| 340 |
+
train_dataset = MyDataset(
|
| 341 |
+
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
| 342 |
+
)
|
| 343 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 344 |
+
train_dataset,
|
| 345 |
+
shuffle=True,
|
| 346 |
+
collate_fn=collate_fn,
|
| 347 |
+
batch_size=args.train_batch_size,
|
| 348 |
+
num_workers=args.dataloader_num_workers,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Prepare everything with our `accelerator`.
|
| 352 |
+
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
| 353 |
+
|
| 354 |
+
global_step = 0
|
| 355 |
+
for epoch in range(0, args.num_train_epochs):
|
| 356 |
+
begin = time.perf_counter()
|
| 357 |
+
for step, batch in enumerate(train_dataloader):
|
| 358 |
+
load_data_time = time.perf_counter() - begin
|
| 359 |
+
with accelerator.accumulate(ip_adapter):
|
| 360 |
+
# Convert images to latent space
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
latents = vae.encode(
|
| 363 |
+
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
| 364 |
+
).latent_dist.sample()
|
| 365 |
+
latents = latents * vae.config.scaling_factor
|
| 366 |
+
|
| 367 |
+
# Sample noise that we'll add to the latents
|
| 368 |
+
noise = torch.randn_like(latents)
|
| 369 |
+
bsz = latents.shape[0]
|
| 370 |
+
# Sample a random timestep for each image
|
| 371 |
+
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
| 372 |
+
timesteps = timesteps.long()
|
| 373 |
+
|
| 374 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 375 |
+
# (this is the forward diffusion process)
|
| 376 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 377 |
+
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
image_embeds = image_encoder(
|
| 380 |
+
batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
|
| 381 |
+
).image_embeds
|
| 382 |
+
image_embeds_ = []
|
| 383 |
+
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
|
| 384 |
+
if drop_image_embed == 1:
|
| 385 |
+
image_embeds_.append(torch.zeros_like(image_embed))
|
| 386 |
+
else:
|
| 387 |
+
image_embeds_.append(image_embed)
|
| 388 |
+
image_embeds = torch.stack(image_embeds_)
|
| 389 |
+
|
| 390 |
+
with torch.no_grad():
|
| 391 |
+
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
| 392 |
+
|
| 393 |
+
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
| 394 |
+
|
| 395 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 396 |
+
|
| 397 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 398 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
| 399 |
+
|
| 400 |
+
# Backpropagate
|
| 401 |
+
accelerator.backward(loss)
|
| 402 |
+
optimizer.step()
|
| 403 |
+
optimizer.zero_grad()
|
| 404 |
+
|
| 405 |
+
if accelerator.is_main_process:
|
| 406 |
+
print(
|
| 407 |
+
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
| 408 |
+
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
global_step += 1
|
| 413 |
+
|
| 414 |
+
if global_step % args.save_steps == 0:
|
| 415 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 416 |
+
accelerator.save_state(save_path)
|
| 417 |
+
|
| 418 |
+
begin = time.perf_counter()
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
if __name__ == "__main__":
|
| 422 |
+
main()
|
diffusers/examples/research_projects/ip_adapter/tutorial_train_plus.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
+
from accelerate.utils import ProjectConfiguration
|
| 13 |
+
from ip_adapter.resampler import Resampler
|
| 14 |
+
from ip_adapter.utils import is_torch2_available
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 18 |
+
|
| 19 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if is_torch2_available():
|
| 23 |
+
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
|
| 24 |
+
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
| 25 |
+
else:
|
| 26 |
+
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Dataset
|
| 30 |
+
class MyDataset(torch.utils.data.Dataset):
|
| 31 |
+
def __init__(
|
| 32 |
+
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.tokenizer = tokenizer
|
| 37 |
+
self.size = size
|
| 38 |
+
self.i_drop_rate = i_drop_rate
|
| 39 |
+
self.t_drop_rate = t_drop_rate
|
| 40 |
+
self.ti_drop_rate = ti_drop_rate
|
| 41 |
+
self.image_root_path = image_root_path
|
| 42 |
+
|
| 43 |
+
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
|
| 44 |
+
|
| 45 |
+
self.transform = transforms.Compose(
|
| 46 |
+
[
|
| 47 |
+
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 48 |
+
transforms.CenterCrop(self.size),
|
| 49 |
+
transforms.ToTensor(),
|
| 50 |
+
transforms.Normalize([0.5], [0.5]),
|
| 51 |
+
]
|
| 52 |
+
)
|
| 53 |
+
self.clip_image_processor = CLIPImageProcessor()
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx):
|
| 56 |
+
item = self.data[idx]
|
| 57 |
+
text = item["text"]
|
| 58 |
+
image_file = item["image_file"]
|
| 59 |
+
|
| 60 |
+
# read image
|
| 61 |
+
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
| 62 |
+
image = self.transform(raw_image.convert("RGB"))
|
| 63 |
+
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
| 64 |
+
|
| 65 |
+
# drop
|
| 66 |
+
drop_image_embed = 0
|
| 67 |
+
rand_num = random.random()
|
| 68 |
+
if rand_num < self.i_drop_rate:
|
| 69 |
+
drop_image_embed = 1
|
| 70 |
+
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
| 71 |
+
text = ""
|
| 72 |
+
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
| 73 |
+
text = ""
|
| 74 |
+
drop_image_embed = 1
|
| 75 |
+
# get text and tokenize
|
| 76 |
+
text_input_ids = self.tokenizer(
|
| 77 |
+
text,
|
| 78 |
+
max_length=self.tokenizer.model_max_length,
|
| 79 |
+
padding="max_length",
|
| 80 |
+
truncation=True,
|
| 81 |
+
return_tensors="pt",
|
| 82 |
+
).input_ids
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"image": image,
|
| 86 |
+
"text_input_ids": text_input_ids,
|
| 87 |
+
"clip_image": clip_image,
|
| 88 |
+
"drop_image_embed": drop_image_embed,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return len(self.data)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def collate_fn(data):
|
| 96 |
+
images = torch.stack([example["image"] for example in data])
|
| 97 |
+
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
| 98 |
+
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
|
| 99 |
+
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
"images": images,
|
| 103 |
+
"text_input_ids": text_input_ids,
|
| 104 |
+
"clip_images": clip_images,
|
| 105 |
+
"drop_image_embeds": drop_image_embeds,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class IPAdapter(torch.nn.Module):
|
| 110 |
+
"""IP-Adapter"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.unet = unet
|
| 115 |
+
self.image_proj_model = image_proj_model
|
| 116 |
+
self.adapter_modules = adapter_modules
|
| 117 |
+
|
| 118 |
+
if ckpt_path is not None:
|
| 119 |
+
self.load_from_checkpoint(ckpt_path)
|
| 120 |
+
|
| 121 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
| 122 |
+
ip_tokens = self.image_proj_model(image_embeds)
|
| 123 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
| 124 |
+
# Predict the noise residual
|
| 125 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 126 |
+
return noise_pred
|
| 127 |
+
|
| 128 |
+
def load_from_checkpoint(self, ckpt_path: str):
|
| 129 |
+
# Calculate original checksums
|
| 130 |
+
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
| 131 |
+
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
| 132 |
+
|
| 133 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 134 |
+
|
| 135 |
+
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
|
| 136 |
+
strict_load_image_proj_model = True
|
| 137 |
+
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
|
| 138 |
+
# Check if the shapes are mismatched
|
| 139 |
+
if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
|
| 140 |
+
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
|
| 141 |
+
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
|
| 142 |
+
del state_dict["image_proj"]["latents"]
|
| 143 |
+
strict_load_image_proj_model = False
|
| 144 |
+
|
| 145 |
+
# Load state dict for image_proj_model and adapter_modules
|
| 146 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
|
| 147 |
+
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
| 148 |
+
|
| 149 |
+
# Calculate new checksums
|
| 150 |
+
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
| 151 |
+
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
| 152 |
+
|
| 153 |
+
# Verify if the weights have changed
|
| 154 |
+
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
| 155 |
+
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
| 156 |
+
|
| 157 |
+
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def parse_args():
|
| 161 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--pretrained_model_name_or_path",
|
| 164 |
+
type=str,
|
| 165 |
+
default=None,
|
| 166 |
+
required=True,
|
| 167 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--pretrained_ip_adapter_path",
|
| 171 |
+
type=str,
|
| 172 |
+
default=None,
|
| 173 |
+
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--num_tokens",
|
| 177 |
+
type=int,
|
| 178 |
+
default=16,
|
| 179 |
+
help="Number of tokens to query from the CLIP image encoding.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--data_json_file",
|
| 183 |
+
type=str,
|
| 184 |
+
default=None,
|
| 185 |
+
required=True,
|
| 186 |
+
help="Training data",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--data_root_path",
|
| 190 |
+
type=str,
|
| 191 |
+
default="",
|
| 192 |
+
required=True,
|
| 193 |
+
help="Training data root path",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--image_encoder_path",
|
| 197 |
+
type=str,
|
| 198 |
+
default=None,
|
| 199 |
+
required=True,
|
| 200 |
+
help="Path to CLIP image encoder",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--output_dir",
|
| 204 |
+
type=str,
|
| 205 |
+
default="sd-ip_adapter",
|
| 206 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--logging_dir",
|
| 210 |
+
type=str,
|
| 211 |
+
default="logs",
|
| 212 |
+
help=(
|
| 213 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 214 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 215 |
+
),
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--resolution",
|
| 219 |
+
type=int,
|
| 220 |
+
default=512,
|
| 221 |
+
help=("The resolution for input images"),
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--learning_rate",
|
| 225 |
+
type=float,
|
| 226 |
+
default=1e-4,
|
| 227 |
+
help="Learning rate to use.",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 230 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--dataloader_num_workers",
|
| 236 |
+
type=int,
|
| 237 |
+
default=0,
|
| 238 |
+
help=(
|
| 239 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 240 |
+
),
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--save_steps",
|
| 244 |
+
type=int,
|
| 245 |
+
default=2000,
|
| 246 |
+
help=("Save a checkpoint of the training state every X updates"),
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--mixed_precision",
|
| 250 |
+
type=str,
|
| 251 |
+
default=None,
|
| 252 |
+
choices=["no", "fp16", "bf16"],
|
| 253 |
+
help=(
|
| 254 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 255 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 256 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--report_to",
|
| 261 |
+
type=str,
|
| 262 |
+
default="tensorboard",
|
| 263 |
+
help=(
|
| 264 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 265 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 266 |
+
),
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 269 |
+
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 272 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 273 |
+
args.local_rank = env_local_rank
|
| 274 |
+
|
| 275 |
+
return args
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def main():
|
| 279 |
+
args = parse_args()
|
| 280 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 281 |
+
|
| 282 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 283 |
+
|
| 284 |
+
accelerator = Accelerator(
|
| 285 |
+
mixed_precision=args.mixed_precision,
|
| 286 |
+
log_with=args.report_to,
|
| 287 |
+
project_config=accelerator_project_config,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if accelerator.is_main_process:
|
| 291 |
+
if args.output_dir is not None:
|
| 292 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 293 |
+
|
| 294 |
+
# Load scheduler, tokenizer and models.
|
| 295 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 296 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
| 297 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
| 298 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
| 299 |
+
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
| 300 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
| 301 |
+
# freeze parameters of models to save more memory
|
| 302 |
+
unet.requires_grad_(False)
|
| 303 |
+
vae.requires_grad_(False)
|
| 304 |
+
text_encoder.requires_grad_(False)
|
| 305 |
+
image_encoder.requires_grad_(False)
|
| 306 |
+
|
| 307 |
+
# ip-adapter-plus
|
| 308 |
+
image_proj_model = Resampler(
|
| 309 |
+
dim=unet.config.cross_attention_dim,
|
| 310 |
+
depth=4,
|
| 311 |
+
dim_head=64,
|
| 312 |
+
heads=12,
|
| 313 |
+
num_queries=args.num_tokens,
|
| 314 |
+
embedding_dim=image_encoder.config.hidden_size,
|
| 315 |
+
output_dim=unet.config.cross_attention_dim,
|
| 316 |
+
ff_mult=4,
|
| 317 |
+
)
|
| 318 |
+
# init adapter modules
|
| 319 |
+
attn_procs = {}
|
| 320 |
+
unet_sd = unet.state_dict()
|
| 321 |
+
for name in unet.attn_processors.keys():
|
| 322 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 323 |
+
if name.startswith("mid_block"):
|
| 324 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 325 |
+
elif name.startswith("up_blocks"):
|
| 326 |
+
block_id = int(name[len("up_blocks.")])
|
| 327 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 328 |
+
elif name.startswith("down_blocks"):
|
| 329 |
+
block_id = int(name[len("down_blocks.")])
|
| 330 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 331 |
+
if cross_attention_dim is None:
|
| 332 |
+
attn_procs[name] = AttnProcessor()
|
| 333 |
+
else:
|
| 334 |
+
layer_name = name.split(".processor")[0]
|
| 335 |
+
weights = {
|
| 336 |
+
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
| 337 |
+
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
| 338 |
+
}
|
| 339 |
+
attn_procs[name] = IPAttnProcessor(
|
| 340 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens
|
| 341 |
+
)
|
| 342 |
+
attn_procs[name].load_state_dict(weights)
|
| 343 |
+
unet.set_attn_processor(attn_procs)
|
| 344 |
+
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
| 345 |
+
|
| 346 |
+
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
| 347 |
+
|
| 348 |
+
weight_dtype = torch.float32
|
| 349 |
+
if accelerator.mixed_precision == "fp16":
|
| 350 |
+
weight_dtype = torch.float16
|
| 351 |
+
elif accelerator.mixed_precision == "bf16":
|
| 352 |
+
weight_dtype = torch.bfloat16
|
| 353 |
+
# unet.to(accelerator.device, dtype=weight_dtype)
|
| 354 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 355 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 356 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 357 |
+
|
| 358 |
+
# optimizer
|
| 359 |
+
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
| 360 |
+
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
| 361 |
+
|
| 362 |
+
# dataloader
|
| 363 |
+
train_dataset = MyDataset(
|
| 364 |
+
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
| 365 |
+
)
|
| 366 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 367 |
+
train_dataset,
|
| 368 |
+
shuffle=True,
|
| 369 |
+
collate_fn=collate_fn,
|
| 370 |
+
batch_size=args.train_batch_size,
|
| 371 |
+
num_workers=args.dataloader_num_workers,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Prepare everything with our `accelerator`.
|
| 375 |
+
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
| 376 |
+
|
| 377 |
+
global_step = 0
|
| 378 |
+
for epoch in range(0, args.num_train_epochs):
|
| 379 |
+
begin = time.perf_counter()
|
| 380 |
+
for step, batch in enumerate(train_dataloader):
|
| 381 |
+
load_data_time = time.perf_counter() - begin
|
| 382 |
+
with accelerator.accumulate(ip_adapter):
|
| 383 |
+
# Convert images to latent space
|
| 384 |
+
with torch.no_grad():
|
| 385 |
+
latents = vae.encode(
|
| 386 |
+
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
| 387 |
+
).latent_dist.sample()
|
| 388 |
+
latents = latents * vae.config.scaling_factor
|
| 389 |
+
|
| 390 |
+
# Sample noise that we'll add to the latents
|
| 391 |
+
noise = torch.randn_like(latents)
|
| 392 |
+
bsz = latents.shape[0]
|
| 393 |
+
# Sample a random timestep for each image
|
| 394 |
+
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
| 395 |
+
timesteps = timesteps.long()
|
| 396 |
+
|
| 397 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 398 |
+
# (this is the forward diffusion process)
|
| 399 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 400 |
+
|
| 401 |
+
clip_images = []
|
| 402 |
+
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
|
| 403 |
+
if drop_image_embed == 1:
|
| 404 |
+
clip_images.append(torch.zeros_like(clip_image))
|
| 405 |
+
else:
|
| 406 |
+
clip_images.append(clip_image)
|
| 407 |
+
clip_images = torch.stack(clip_images, dim=0)
|
| 408 |
+
with torch.no_grad():
|
| 409 |
+
image_embeds = image_encoder(
|
| 410 |
+
clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True
|
| 411 |
+
).hidden_states[-2]
|
| 412 |
+
|
| 413 |
+
with torch.no_grad():
|
| 414 |
+
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
| 415 |
+
|
| 416 |
+
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
| 417 |
+
|
| 418 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
| 419 |
+
|
| 420 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 421 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
| 422 |
+
|
| 423 |
+
# Backpropagate
|
| 424 |
+
accelerator.backward(loss)
|
| 425 |
+
optimizer.step()
|
| 426 |
+
optimizer.zero_grad()
|
| 427 |
+
|
| 428 |
+
if accelerator.is_main_process:
|
| 429 |
+
print(
|
| 430 |
+
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
| 431 |
+
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
global_step += 1
|
| 436 |
+
|
| 437 |
+
if global_step % args.save_steps == 0:
|
| 438 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 439 |
+
accelerator.save_state(save_path)
|
| 440 |
+
|
| 441 |
+
begin = time.perf_counter()
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
if __name__ == "__main__":
|
| 445 |
+
main()
|