lsmpp commited on
Commit
ca32b0e
·
verified ·
1 Parent(s): 6c61ed4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/examples/kandinsky2_2/text_to_image/README.md +317 -0
  2. diffusers/examples/kandinsky2_2/text_to_image/requirements.txt +7 -0
  3. diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +929 -0
  4. diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +812 -0
  5. diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +844 -0
  6. diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +958 -0
  7. diffusers/examples/research_projects/anytext/README.md +40 -0
  8. diffusers/examples/research_projects/anytext/anytext.py +0 -0
  9. diffusers/examples/research_projects/anytext/anytext_controlnet.py +463 -0
  10. diffusers/examples/research_projects/anytext/ocr_recog/RNN.py +209 -0
  11. diffusers/examples/research_projects/anytext/ocr_recog/RecCTCHead.py +45 -0
  12. diffusers/examples/research_projects/anytext/ocr_recog/RecModel.py +49 -0
  13. diffusers/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py +197 -0
  14. diffusers/examples/research_projects/anytext/ocr_recog/RecSVTR.py +570 -0
  15. diffusers/examples/research_projects/anytext/ocr_recog/common.py +74 -0
  16. diffusers/examples/research_projects/anytext/ocr_recog/en_dict.txt +95 -0
  17. diffusers/examples/research_projects/consistency_training/README.md +24 -0
  18. diffusers/examples/research_projects/consistency_training/requirements.txt +6 -0
  19. diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py +1438 -0
  20. diffusers/examples/research_projects/diffusion_dpo/README.md +94 -0
  21. diffusers/examples/research_projects/diffusion_dpo/requirements.txt +8 -0
  22. diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +982 -0
  23. diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +1140 -0
  24. diffusers/examples/research_projects/diffusion_orpo/README.md +1 -0
  25. diffusers/examples/research_projects/diffusion_orpo/requirements.txt +7 -0
  26. diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py +1092 -0
  27. diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +1095 -0
  28. diffusers/examples/research_projects/dreambooth_inpaint/README.md +118 -0
  29. diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt +7 -0
  30. diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py +812 -0
  31. diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +831 -0
  32. diffusers/examples/research_projects/flux_lora_quantization/README.md +167 -0
  33. diffusers/examples/research_projects/flux_lora_quantization/accelerate.yaml +17 -0
  34. diffusers/examples/research_projects/flux_lora_quantization/compute_embeddings.py +107 -0
  35. diffusers/examples/research_projects/flux_lora_quantization/ds2.yaml +23 -0
  36. diffusers/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py +1200 -0
  37. diffusers/examples/research_projects/intel_opts/README.md +36 -0
  38. diffusers/examples/research_projects/intel_opts/inference_bf16.py +56 -0
  39. diffusers/examples/research_projects/intel_opts/textual_inversion/README.md +68 -0
  40. diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt +7 -0
  41. diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +646 -0
  42. diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md +93 -0
  43. diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt +7 -0
  44. diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py +112 -0
  45. diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py +996 -0
  46. diffusers/examples/research_projects/ip_adapter/README.md +226 -0
  47. diffusers/examples/research_projects/ip_adapter/requirements.txt +4 -0
  48. diffusers/examples/research_projects/ip_adapter/tutorial_train_faceid.py +415 -0
  49. diffusers/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py +422 -0
  50. diffusers/examples/research_projects/ip_adapter/tutorial_train_plus.py +445 -0
diffusers/examples/kandinsky2_2/text_to_image/README.md ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kandinsky2.2 text-to-image fine-tuning
2
+
3
+ Kandinsky 2.2 includes a prior pipeline that generates image embeddings from text prompts, and a decoder pipeline that generates the output image based on the image embeddings. We provide `train_text_to_image_prior.py` and `train_text_to_image_decoder.py` scripts to show you how to fine-tune the Kandinsky prior and decoder models separately based on your own dataset. To achieve the best results, you should fine-tune **_both_** your prior and decoder models.
4
+
5
+ ___Note___:
6
+
7
+ ___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
8
+
9
+
10
+ ## Running locally with PyTorch
11
+
12
+ Before running the scripts, make sure to install the library's training dependencies:
13
+
14
+ **Important**
15
+
16
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
17
+ ```bash
18
+ git clone https://github.com/huggingface/diffusers
19
+ cd diffusers
20
+ pip install .
21
+ ```
22
+
23
+ Then cd in the example folder and run
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
29
+
30
+ ```bash
31
+ accelerate config
32
+ ```
33
+ For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag.
34
+
35
+ ___
36
+
37
+ ### Naruto example
38
+
39
+ For all our examples, we will directly store the trained weights on the Hub, so we need to be logged in and add the `--push_to_hub` flag. In order to do that, you have to be a registered user on the 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to the [User Access Tokens](https://huggingface.co/docs/hub/security-tokens) guide.
40
+
41
+ Run the following command to authenticate your token
42
+
43
+ ```bash
44
+ huggingface-cli login
45
+ ```
46
+
47
+ We also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run
48
+
49
+ ```bash
50
+ pip install wandb
51
+ ```
52
+
53
+ To disable wandb logging, remove the `--report_to=="wandb"` and `--validation_prompts="A robot naruto, 4k photo"` flags from below examples
54
+
55
+ #### Fine-tune decoder
56
+ <br>
57
+
58
+ <!-- accelerate_snippet_start -->
59
+ ```bash
60
+ export DATASET_NAME="lambdalabs/naruto-blip-captions"
61
+
62
+ accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
63
+ --dataset_name=$DATASET_NAME \
64
+ --resolution=768 \
65
+ --train_batch_size=1 \
66
+ --gradient_accumulation_steps=4 \
67
+ --gradient_checkpointing \
68
+ --max_train_steps=15000 \
69
+ --learning_rate=1e-05 \
70
+ --max_grad_norm=1 \
71
+ --checkpoints_total_limit=3 \
72
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
73
+ --validation_prompts="A robot naruto, 4k photo" \
74
+ --report_to="wandb" \
75
+ --push_to_hub \
76
+ --output_dir="kandi2-decoder-naruto-model"
77
+ ```
78
+ <!-- accelerate_snippet_end -->
79
+
80
+
81
+ To train on your own training files, prepare the dataset according to the format required by `datasets`. You can find the instructions for how to do that in the [ImageFolder with metadata](https://huggingface.co/docs/datasets/en/image_load#imagefolder-with-metadata) guide.
82
+ If you wish to use custom loading logic, you should modify the script and we have left pointers for that in the training script.
83
+
84
+ ```bash
85
+ export TRAIN_DIR="path_to_your_dataset"
86
+
87
+ accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
88
+ --train_data_dir=$TRAIN_DIR \
89
+ --resolution=768 \
90
+ --train_batch_size=1 \
91
+ --gradient_accumulation_steps=4 \
92
+ --gradient_checkpointing \
93
+ --max_train_steps=15000 \
94
+ --learning_rate=1e-05 \
95
+ --max_grad_norm=1 \
96
+ --checkpoints_total_limit=3 \
97
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
98
+ --validation_prompts="A robot naruto, 4k photo" \
99
+ --report_to="wandb" \
100
+ --push_to_hub \
101
+ --output_dir="kandi22-decoder-naruto-model"
102
+ ```
103
+
104
+
105
+ Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `kandi22-decoder-naruto-model`. To load the fine-tuned model for inference just pass that path to `AutoPipelineForText2Image`
106
+
107
+ ```python
108
+ from diffusers import AutoPipelineForText2Image
109
+ import torch
110
+
111
+ pipe = AutoPipelineForText2Image.from_pretrained(output_dir, torch_dtype=torch.float16)
112
+ pipe.enable_model_cpu_offload()
113
+
114
+ prompt='A robot naruto, 4k photo'
115
+ images = pipe(prompt=prompt).images
116
+ images[0].save("robot-naruto.png")
117
+ ```
118
+
119
+ Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
120
+ ```python
121
+ from diffusers import AutoPipelineForText2Image, UNet2DConditionModel
122
+
123
+ model_path = "path_to_saved_model"
124
+
125
+ unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")
126
+
127
+ pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
128
+ pipe.enable_model_cpu_offload()
129
+
130
+ image = pipe(prompt="A robot naruto, 4k photo").images[0]
131
+ image.save("robot-naruto.png")
132
+ ```
133
+
134
+ #### Fine-tune prior
135
+
136
+ You can fine-tune the Kandinsky prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning.
137
+
138
+ <br>
139
+
140
+ <!-- accelerate_snippet_start -->
141
+ ```bash
142
+ export DATASET_NAME="lambdalabs/naruto-blip-captions"
143
+
144
+ accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
145
+ --dataset_name=$DATASET_NAME \
146
+ --resolution=768 \
147
+ --train_batch_size=1 \
148
+ --gradient_accumulation_steps=4 \
149
+ --max_train_steps=15000 \
150
+ --learning_rate=1e-05 \
151
+ --max_grad_norm=1 \
152
+ --checkpoints_total_limit=3 \
153
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
154
+ --validation_prompts="A robot naruto, 4k photo" \
155
+ --report_to="wandb" \
156
+ --push_to_hub \
157
+ --output_dir="kandi2-prior-naruto-model"
158
+ ```
159
+ <!-- accelerate_snippet_end -->
160
+
161
+
162
+ To perform inference with the fine-tuned prior model, you will need to first create a prior pipeline by passing the `output_dir` to `DiffusionPipeline`. Then create a `KandinskyV22CombinedPipeline` from a pretrained or fine-tuned decoder checkpoint along with all the modules of the prior pipeline you just created.
163
+
164
+ ```python
165
+ from diffusers import AutoPipelineForText2Image, DiffusionPipeline
166
+ import torch
167
+
168
+ pipe_prior = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
169
+ prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()}
170
+ pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)
171
+
172
+ pipe.enable_model_cpu_offload()
173
+ prompt='A robot naruto, 4k photo'
174
+ images = pipe(prompt=prompt, negative_prompt=negative_prompt).images
175
+ images[0]
176
+ ```
177
+
178
+ If you want to use a fine-tuned decoder checkpoint along with your fine-tuned prior checkpoint, you can simply replace the "kandinsky-community/kandinsky-2-2-decoder" in above code with your custom model repo name. Note that in order to be able to create a `KandinskyV22CombinedPipeline`, your model repository need to have a prior tag. If you have created your model repo using our training script, the prior tag is automatically included.
179
+
180
+ #### Training with multiple GPUs
181
+
182
+ `accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
183
+ for running distributed training with `accelerate`. Here is an example command:
184
+
185
+ ```bash
186
+ export DATASET_NAME="lambdalabs/naruto-blip-captions"
187
+
188
+ accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image_decoder.py \
189
+ --dataset_name=$DATASET_NAME \
190
+ --resolution=768 \
191
+ --train_batch_size=1 \
192
+ --gradient_accumulation_steps=4 \
193
+ --gradient_checkpointing \
194
+ --max_train_steps=15000 \
195
+ --learning_rate=1e-05 \
196
+ --max_grad_norm=1 \
197
+ --checkpoints_total_limit=3 \
198
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
199
+ --validation_prompts="A robot naruto, 4k photo" \
200
+ --report_to="wandb" \
201
+ --push_to_hub \
202
+ --output_dir="kandi2-decoder-naruto-model"
203
+ ```
204
+
205
+
206
+ #### Training with Min-SNR weighting
207
+
208
+ We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://huggingface.co/papers/2303.09556) which helps achieve faster convergence
209
+ by rebalancing the loss. Enable the `--snr_gamma` argument and set it to the recommended
210
+ value of 5.0.
211
+
212
+
213
+ ## Training with LoRA
214
+
215
+ Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
216
+
217
+ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
218
+
219
+ - Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
220
+ - Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
221
+ - LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
222
+
223
+ [cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
224
+
225
+ With LoRA, it's possible to fine-tune Kandinsky 2.2 on a custom image-caption pair dataset
226
+ on consumer GPUs like Tesla T4, Tesla V100.
227
+
228
+ ### Training
229
+
230
+ First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
231
+
232
+
233
+ #### Train decoder
234
+
235
+ ```bash
236
+ export DATASET_NAME="lambdalabs/naruto-blip-captions"
237
+
238
+ accelerate launch --mixed_precision="fp16" train_text_to_image_decoder_lora.py \
239
+ --dataset_name=$DATASET_NAME --caption_column="text" \
240
+ --resolution=768 \
241
+ --train_batch_size=1 \
242
+ --num_train_epochs=100 --checkpointing_steps=5000 \
243
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
244
+ --seed=42 \
245
+ --rank=4 \
246
+ --gradient_checkpointing \
247
+ --output_dir="kandi22-decoder-naruto-lora" \
248
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
249
+ --push_to_hub \
250
+ ```
251
+
252
+ #### Train prior
253
+
254
+ ```bash
255
+ export DATASET_NAME="lambdalabs/naruto-blip-captions"
256
+
257
+ accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \
258
+ --dataset_name=$DATASET_NAME --caption_column="text" \
259
+ --resolution=768 \
260
+ --train_batch_size=1 \
261
+ --num_train_epochs=100 --checkpointing_steps=5000 \
262
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
263
+ --seed=42 \
264
+ --rank=4 \
265
+ --output_dir="kandi22-prior-naruto-lora" \
266
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
267
+ --push_to_hub \
268
+ ```
269
+
270
+ **___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run above scripts in consumer GPUs like T4 or V100.___**
271
+
272
+
273
+ ### Inference
274
+
275
+ #### Inference using fine-tuned LoRA checkpoint for decoder
276
+
277
+ Once you have trained a Kandinsky decoder model using the above command, inference can be done with the `AutoPipelineForText2Image` after loading the trained LoRA weights. You need to pass the `output_dir` for loading the LoRA weights, which in this case is `kandi22-decoder-naruto-lora`.
278
+
279
+
280
+ ```python
281
+ from diffusers import AutoPipelineForText2Image
282
+ import torch
283
+
284
+ pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
285
+ pipe.unet.load_attn_procs(output_dir)
286
+ pipe.enable_model_cpu_offload()
287
+
288
+ prompt='A robot naruto, 4k photo'
289
+ image = pipe(prompt=prompt).images[0]
290
+ image.save("robot_naruto.png")
291
+ ```
292
+
293
+ #### Inference using fine-tuned LoRA checkpoint for prior
294
+
295
+ ```python
296
+ from diffusers import AutoPipelineForText2Image
297
+ import torch
298
+
299
+ pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
300
+ pipe.prior_prior.load_attn_procs(output_dir)
301
+ pipe.enable_model_cpu_offload()
302
+
303
+ prompt='A robot naruto, 4k photo'
304
+ image = pipe(prompt=prompt).images[0]
305
+ image.save("robot_naruto.png")
306
+ image
307
+ ```
308
+
309
+ ### Training with xFormers:
310
+
311
+ You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.
312
+
313
+ xFormers training is not available for fine-tuning the prior model.
314
+
315
+ **Note**:
316
+
317
+ According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
diffusers/examples/kandinsky2_2/text_to_image/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ datasets
5
+ ftfy
6
+ tensorboard
7
+ Jinja2
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import shutil
21
+ from pathlib import Path
22
+
23
+ import accelerate
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.state import AcceleratorState
33
+ from accelerate.utils import ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from packaging import version
37
+ from PIL import Image
38
+ from tqdm import tqdm
39
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
40
+ from transformers.utils import ContextManagers
41
+
42
+ import diffusers
43
+ from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.training_utils import EMAModel, compute_snr
46
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
47
+ from diffusers.utils.import_utils import is_xformers_available
48
+
49
+
50
+ if is_wandb_available():
51
+ import wandb
52
+
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.35.0.dev0")
56
+
57
+ logger = get_logger(__name__, log_level="INFO")
58
+
59
+
60
+ def save_model_card(
61
+ args,
62
+ repo_id: str,
63
+ images=None,
64
+ repo_folder=None,
65
+ ):
66
+ img_str = ""
67
+ if len(images) > 0:
68
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
69
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
70
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
71
+
72
+ yaml = f"""
73
+ ---
74
+ license: creativeml-openrail-m
75
+ base_model: {args.pretrained_decoder_model_name_or_path}
76
+ datasets:
77
+ - {args.dataset_name}
78
+ prior:
79
+ - {args.pretrained_prior_model_name_or_path}
80
+ tags:
81
+ - kandinsky
82
+ - text-to-image
83
+ - diffusers
84
+ - diffusers-training
85
+ inference: true
86
+ ---
87
+ """
88
+ model_card = f"""
89
+ # Finetuning - {repo_id}
90
+
91
+ This pipeline was finetuned from **{args.pretrained_decoder_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
92
+ {img_str}
93
+
94
+ ## Pipeline usage
95
+
96
+ You can use the pipeline like so:
97
+
98
+ ```python
99
+ from diffusers import DiffusionPipeline
100
+ import torch
101
+
102
+ pipeline = AutoPipelineForText2Image.from_pretrained("{repo_id}", torch_dtype=torch.float16)
103
+ prompt = "{args.validation_prompts[0]}"
104
+ image = pipeline(prompt).images[0]
105
+ image.save("my_image.png")
106
+ ```
107
+
108
+ ## Training info
109
+
110
+ These are the key hyperparameters used during training:
111
+
112
+ * Epochs: {args.num_train_epochs}
113
+ * Learning rate: {args.learning_rate}
114
+ * Batch size: {args.train_batch_size}
115
+ * Gradient accumulation steps: {args.gradient_accumulation_steps}
116
+ * Image resolution: {args.resolution}
117
+ * Mixed-precision: {args.mixed_precision}
118
+
119
+ """
120
+ wandb_info = ""
121
+ if is_wandb_available():
122
+ wandb_run_url = None
123
+ if wandb.run is not None:
124
+ wandb_run_url = wandb.run.url
125
+
126
+ if wandb_run_url is not None:
127
+ wandb_info = f"""
128
+ More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
129
+ """
130
+
131
+ model_card += wandb_info
132
+
133
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
134
+ f.write(yaml + model_card)
135
+
136
+
137
+ def log_validation(vae, image_encoder, image_processor, unet, args, accelerator, weight_dtype, epoch):
138
+ logger.info("Running validation... ")
139
+
140
+ pipeline = AutoPipelineForText2Image.from_pretrained(
141
+ args.pretrained_decoder_model_name_or_path,
142
+ vae=accelerator.unwrap_model(vae),
143
+ prior_image_encoder=accelerator.unwrap_model(image_encoder),
144
+ prior_image_processor=image_processor,
145
+ unet=accelerator.unwrap_model(unet),
146
+ torch_dtype=weight_dtype,
147
+ )
148
+ pipeline = pipeline.to(accelerator.device)
149
+ pipeline.set_progress_bar_config(disable=True)
150
+
151
+ if args.enable_xformers_memory_efficient_attention:
152
+ pipeline.enable_xformers_memory_efficient_attention()
153
+
154
+ if args.seed is None:
155
+ generator = None
156
+ else:
157
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
158
+
159
+ images = []
160
+ for i in range(len(args.validation_prompts)):
161
+ with torch.autocast("cuda"):
162
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
163
+
164
+ images.append(image)
165
+
166
+ for tracker in accelerator.trackers:
167
+ if tracker.name == "tensorboard":
168
+ np_images = np.stack([np.asarray(img) for img in images])
169
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
170
+ elif tracker.name == "wandb":
171
+ tracker.log(
172
+ {
173
+ "validation": [
174
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
175
+ for i, image in enumerate(images)
176
+ ]
177
+ }
178
+ )
179
+ else:
180
+ logger.warning(f"image logging not implemented for {tracker.name}")
181
+
182
+ del pipeline
183
+ torch.cuda.empty_cache()
184
+
185
+ return images
186
+
187
+
188
+ def parse_args():
189
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
190
+ parser.add_argument(
191
+ "--pretrained_decoder_model_name_or_path",
192
+ type=str,
193
+ default="kandinsky-community/kandinsky-2-2-decoder",
194
+ required=False,
195
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
196
+ )
197
+ parser.add_argument(
198
+ "--pretrained_prior_model_name_or_path",
199
+ type=str,
200
+ default="kandinsky-community/kandinsky-2-2-prior",
201
+ required=False,
202
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
203
+ )
204
+ parser.add_argument(
205
+ "--dataset_name",
206
+ type=str,
207
+ default=None,
208
+ help=(
209
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
210
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
211
+ " or to a folder containing files that 🤗 Datasets can understand."
212
+ ),
213
+ )
214
+ parser.add_argument(
215
+ "--dataset_config_name",
216
+ type=str,
217
+ default=None,
218
+ help="The config of the Dataset, leave as None if there's only one config.",
219
+ )
220
+ parser.add_argument(
221
+ "--train_data_dir",
222
+ type=str,
223
+ default=None,
224
+ help=(
225
+ "A folder containing the training data. Folder contents must follow the structure described in"
226
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
227
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
228
+ ),
229
+ )
230
+ parser.add_argument(
231
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
232
+ )
233
+ parser.add_argument(
234
+ "--max_train_samples",
235
+ type=int,
236
+ default=None,
237
+ help=(
238
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
239
+ "value if set."
240
+ ),
241
+ )
242
+ parser.add_argument(
243
+ "--validation_prompts",
244
+ type=str,
245
+ default=None,
246
+ nargs="+",
247
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
248
+ )
249
+ parser.add_argument(
250
+ "--output_dir",
251
+ type=str,
252
+ default="kandi_2_2-model-finetuned",
253
+ help="The output directory where the model predictions and checkpoints will be written.",
254
+ )
255
+ parser.add_argument(
256
+ "--cache_dir",
257
+ type=str,
258
+ default=None,
259
+ help="The directory where the downloaded models and datasets will be stored.",
260
+ )
261
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
262
+ parser.add_argument(
263
+ "--resolution",
264
+ type=int,
265
+ default=512,
266
+ help=(
267
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
268
+ " resolution"
269
+ ),
270
+ )
271
+ parser.add_argument(
272
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
273
+ )
274
+ parser.add_argument("--num_train_epochs", type=int, default=100)
275
+ parser.add_argument(
276
+ "--max_train_steps",
277
+ type=int,
278
+ default=None,
279
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
280
+ )
281
+ parser.add_argument(
282
+ "--gradient_accumulation_steps",
283
+ type=int,
284
+ default=1,
285
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
286
+ )
287
+ parser.add_argument(
288
+ "--gradient_checkpointing",
289
+ action="store_true",
290
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
291
+ )
292
+ parser.add_argument(
293
+ "--learning_rate",
294
+ type=float,
295
+ default=1e-4,
296
+ help="learning rate",
297
+ )
298
+ parser.add_argument(
299
+ "--lr_scheduler",
300
+ type=str,
301
+ default="constant",
302
+ help=(
303
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
304
+ ' "constant", "constant_with_warmup"]'
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
309
+ )
310
+ parser.add_argument(
311
+ "--snr_gamma",
312
+ type=float,
313
+ default=None,
314
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
315
+ "More details here: https://huggingface.co/papers/2303.09556.",
316
+ )
317
+ parser.add_argument(
318
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
319
+ )
320
+ parser.add_argument(
321
+ "--allow_tf32",
322
+ action="store_true",
323
+ help=(
324
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
325
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
326
+ ),
327
+ )
328
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
329
+ parser.add_argument(
330
+ "--dataloader_num_workers",
331
+ type=int,
332
+ default=0,
333
+ help=(
334
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
335
+ ),
336
+ )
337
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
338
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
339
+ parser.add_argument(
340
+ "--adam_weight_decay",
341
+ type=float,
342
+ default=0.0,
343
+ required=False,
344
+ help="weight decay_to_use",
345
+ )
346
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
347
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
348
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
349
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
350
+ parser.add_argument(
351
+ "--hub_model_id",
352
+ type=str,
353
+ default=None,
354
+ help="The name of the repository to keep in sync with the local `output_dir`.",
355
+ )
356
+ parser.add_argument(
357
+ "--logging_dir",
358
+ type=str,
359
+ default="logs",
360
+ help=(
361
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
362
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
363
+ ),
364
+ )
365
+ parser.add_argument(
366
+ "--mixed_precision",
367
+ type=str,
368
+ default=None,
369
+ choices=["no", "fp16", "bf16"],
370
+ help=(
371
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
372
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
373
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
374
+ ),
375
+ )
376
+ parser.add_argument(
377
+ "--report_to",
378
+ type=str,
379
+ default="tensorboard",
380
+ help=(
381
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
382
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
383
+ ),
384
+ )
385
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
386
+ parser.add_argument(
387
+ "--checkpointing_steps",
388
+ type=int,
389
+ default=500,
390
+ help=(
391
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
392
+ " training using `--resume_from_checkpoint`."
393
+ ),
394
+ )
395
+ parser.add_argument(
396
+ "--checkpoints_total_limit",
397
+ type=int,
398
+ default=None,
399
+ help=("Max number of checkpoints to store."),
400
+ )
401
+ parser.add_argument(
402
+ "--resume_from_checkpoint",
403
+ type=str,
404
+ default=None,
405
+ help=(
406
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
407
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
408
+ ),
409
+ )
410
+ parser.add_argument(
411
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
412
+ )
413
+ parser.add_argument(
414
+ "--validation_epochs",
415
+ type=int,
416
+ default=5,
417
+ help="Run validation every X epochs.",
418
+ )
419
+ parser.add_argument(
420
+ "--tracker_project_name",
421
+ type=str,
422
+ default="text2image-fine-tune",
423
+ help=(
424
+ "The `project_name` argument passed to Accelerator.init_trackers for"
425
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
426
+ ),
427
+ )
428
+
429
+ args = parser.parse_args()
430
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
431
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
432
+ args.local_rank = env_local_rank
433
+
434
+ # Sanity checks
435
+ if args.dataset_name is None and args.train_data_dir is None:
436
+ raise ValueError("Need either a dataset name or a training folder.")
437
+
438
+ return args
439
+
440
+
441
+ def main():
442
+ args = parse_args()
443
+
444
+ if args.report_to == "wandb" and args.hub_token is not None:
445
+ raise ValueError(
446
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
447
+ " Please use `huggingface-cli login` to authenticate with the Hub."
448
+ )
449
+
450
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
451
+ accelerator_project_config = ProjectConfiguration(
452
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
453
+ )
454
+ accelerator = Accelerator(
455
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
456
+ mixed_precision=args.mixed_precision,
457
+ log_with=args.report_to,
458
+ project_config=accelerator_project_config,
459
+ )
460
+
461
+ # Disable AMP for MPS.
462
+ if torch.backends.mps.is_available():
463
+ accelerator.native_amp = False
464
+
465
+ # Make one log on every process with the configuration for debugging.
466
+ logging.basicConfig(
467
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
468
+ datefmt="%m/%d/%Y %H:%M:%S",
469
+ level=logging.INFO,
470
+ )
471
+ logger.info(accelerator.state, main_process_only=False)
472
+ if accelerator.is_local_main_process:
473
+ datasets.utils.logging.set_verbosity_warning()
474
+ transformers.utils.logging.set_verbosity_warning()
475
+ diffusers.utils.logging.set_verbosity_info()
476
+ else:
477
+ datasets.utils.logging.set_verbosity_error()
478
+ transformers.utils.logging.set_verbosity_error()
479
+ diffusers.utils.logging.set_verbosity_error()
480
+
481
+ # If passed along, set the training seed now.
482
+ if args.seed is not None:
483
+ set_seed(args.seed)
484
+
485
+ # Handle the repository creation
486
+ if accelerator.is_main_process:
487
+ if args.output_dir is not None:
488
+ os.makedirs(args.output_dir, exist_ok=True)
489
+
490
+ if args.push_to_hub:
491
+ repo_id = create_repo(
492
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
493
+ ).repo_id
494
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
495
+ image_processor = CLIPImageProcessor.from_pretrained(
496
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
497
+ )
498
+
499
+ def deepspeed_zero_init_disabled_context_manager():
500
+ """
501
+ returns either a context list that includes one that will disable zero.Init or an empty context list
502
+ """
503
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
504
+ if deepspeed_plugin is None:
505
+ return []
506
+
507
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
508
+
509
+ weight_dtype = torch.float32
510
+ if accelerator.mixed_precision == "fp16":
511
+ weight_dtype = torch.float16
512
+ elif accelerator.mixed_precision == "bf16":
513
+ weight_dtype = torch.bfloat16
514
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
515
+ vae = VQModel.from_pretrained(
516
+ args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
517
+ ).eval()
518
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
519
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
520
+ ).eval()
521
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
522
+
523
+ # Freeze vae and image_encoder
524
+ vae.requires_grad_(False)
525
+ image_encoder.requires_grad_(False)
526
+
527
+ # Set unet to trainable.
528
+ unet.train()
529
+
530
+ # Create EMA for the unet.
531
+ if args.use_ema:
532
+ ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
533
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
534
+ ema_unet.to(accelerator.device)
535
+ if args.enable_xformers_memory_efficient_attention:
536
+ if is_xformers_available():
537
+ import xformers
538
+
539
+ xformers_version = version.parse(xformers.__version__)
540
+ if xformers_version == version.parse("0.0.16"):
541
+ logger.warning(
542
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
543
+ )
544
+ unet.enable_xformers_memory_efficient_attention()
545
+ else:
546
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
547
+
548
+ # `accelerate` 0.16.0 will have better support for customized saving
549
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
550
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
551
+ def save_model_hook(models, weights, output_dir):
552
+ if args.use_ema:
553
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
554
+
555
+ for i, model in enumerate(models):
556
+ model.save_pretrained(os.path.join(output_dir, "unet"))
557
+
558
+ # make sure to pop weight so that corresponding model is not saved again
559
+ weights.pop()
560
+
561
+ def load_model_hook(models, input_dir):
562
+ if args.use_ema:
563
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
564
+ ema_unet.load_state_dict(load_model.state_dict())
565
+ ema_unet.to(accelerator.device)
566
+ del load_model
567
+
568
+ for i in range(len(models)):
569
+ # pop models so that they are not loaded again
570
+ model = models.pop()
571
+
572
+ # load diffusers style into model
573
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
574
+ model.register_to_config(**load_model.config)
575
+
576
+ model.load_state_dict(load_model.state_dict())
577
+ del load_model
578
+
579
+ accelerator.register_save_state_pre_hook(save_model_hook)
580
+ accelerator.register_load_state_pre_hook(load_model_hook)
581
+
582
+ if args.gradient_checkpointing:
583
+ unet.enable_gradient_checkpointing()
584
+
585
+ # Enable TF32 for faster training on Ampere GPUs,
586
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
587
+ if args.allow_tf32:
588
+ torch.backends.cuda.matmul.allow_tf32 = True
589
+
590
+ if args.use_8bit_adam:
591
+ try:
592
+ import bitsandbytes as bnb
593
+ except ImportError:
594
+ raise ImportError(
595
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
596
+ )
597
+
598
+ optimizer_cls = bnb.optim.AdamW8bit
599
+ else:
600
+ optimizer_cls = torch.optim.AdamW
601
+
602
+ optimizer = optimizer_cls(
603
+ unet.parameters(),
604
+ lr=args.learning_rate,
605
+ betas=(args.adam_beta1, args.adam_beta2),
606
+ weight_decay=args.adam_weight_decay,
607
+ eps=args.adam_epsilon,
608
+ )
609
+
610
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
611
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
612
+
613
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
614
+ # download the dataset.
615
+ if args.dataset_name is not None:
616
+ # Downloading and loading a dataset from the hub.
617
+ dataset = load_dataset(
618
+ args.dataset_name,
619
+ args.dataset_config_name,
620
+ cache_dir=args.cache_dir,
621
+ )
622
+ else:
623
+ data_files = {}
624
+ if args.train_data_dir is not None:
625
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
626
+ dataset = load_dataset(
627
+ "imagefolder",
628
+ data_files=data_files,
629
+ cache_dir=args.cache_dir,
630
+ )
631
+ # See more about loading custom images at
632
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
633
+
634
+ # Preprocessing the datasets.
635
+ # We need to tokenize inputs and targets.
636
+ column_names = dataset["train"].column_names
637
+
638
+ image_column = args.image_column
639
+ if image_column not in column_names:
640
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
641
+
642
+ def center_crop(image):
643
+ width, height = image.size
644
+ new_size = min(width, height)
645
+ left = (width - new_size) / 2
646
+ top = (height - new_size) / 2
647
+ right = (width + new_size) / 2
648
+ bottom = (height + new_size) / 2
649
+ return image.crop((left, top, right, bottom))
650
+
651
+ def train_transforms(img):
652
+ img = center_crop(img)
653
+ img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
654
+ img = np.array(img).astype(np.float32) / 127.5 - 1
655
+ img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
656
+ return img
657
+
658
+ def preprocess_train(examples):
659
+ images = [image.convert("RGB") for image in examples[image_column]]
660
+ examples["pixel_values"] = [train_transforms(image) for image in images]
661
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
662
+ return examples
663
+
664
+ with accelerator.main_process_first():
665
+ if args.max_train_samples is not None:
666
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
667
+ # Set the training transforms
668
+ train_dataset = dataset["train"].with_transform(preprocess_train)
669
+
670
+ def collate_fn(examples):
671
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
672
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
673
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
674
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
675
+ return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
676
+
677
+ train_dataloader = torch.utils.data.DataLoader(
678
+ train_dataset,
679
+ shuffle=True,
680
+ collate_fn=collate_fn,
681
+ batch_size=args.train_batch_size,
682
+ num_workers=args.dataloader_num_workers,
683
+ )
684
+
685
+ # Scheduler and math around the number of training steps.
686
+ overrode_max_train_steps = False
687
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
688
+ if args.max_train_steps is None:
689
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
690
+ overrode_max_train_steps = True
691
+
692
+ lr_scheduler = get_scheduler(
693
+ args.lr_scheduler,
694
+ optimizer=optimizer,
695
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
696
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
697
+ )
698
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
699
+ unet, optimizer, train_dataloader, lr_scheduler
700
+ )
701
+ # Move image_encode and vae to gpu and cast to weight_dtype
702
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
703
+ vae.to(accelerator.device, dtype=weight_dtype)
704
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
705
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
706
+ if overrode_max_train_steps:
707
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
708
+ # Afterwards we recalculate our number of training epochs
709
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
710
+
711
+ # We need to initialize the trackers we use, and also store our configuration.
712
+ # The trackers initializes automatically on the main process.
713
+ if accelerator.is_main_process:
714
+ tracker_config = dict(vars(args))
715
+ tracker_config.pop("validation_prompts")
716
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
717
+
718
+ # Train!
719
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
720
+
721
+ logger.info("***** Running training *****")
722
+ logger.info(f" Num examples = {len(train_dataset)}")
723
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
724
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
725
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
726
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
727
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
728
+ global_step = 0
729
+ first_epoch = 0
730
+ if args.resume_from_checkpoint:
731
+ if args.resume_from_checkpoint != "latest":
732
+ path = os.path.basename(args.resume_from_checkpoint)
733
+ else:
734
+ # Get the most recent checkpoint
735
+ dirs = os.listdir(args.output_dir)
736
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
737
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
738
+ path = dirs[-1] if len(dirs) > 0 else None
739
+
740
+ if path is None:
741
+ accelerator.print(
742
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
743
+ )
744
+ args.resume_from_checkpoint = None
745
+ initial_global_step = 0
746
+ else:
747
+ accelerator.print(f"Resuming from checkpoint {path}")
748
+ accelerator.load_state(os.path.join(args.output_dir, path))
749
+ global_step = int(path.split("-")[1])
750
+
751
+ initial_global_step = global_step
752
+ first_epoch = global_step // num_update_steps_per_epoch
753
+ else:
754
+ initial_global_step = 0
755
+
756
+ progress_bar = tqdm(
757
+ range(0, args.max_train_steps),
758
+ initial=initial_global_step,
759
+ desc="Steps",
760
+ # Only show the progress bar once on each machine.
761
+ disable=not accelerator.is_local_main_process,
762
+ )
763
+
764
+ for epoch in range(first_epoch, args.num_train_epochs):
765
+ train_loss = 0.0
766
+ for step, batch in enumerate(train_dataloader):
767
+ with accelerator.accumulate(unet):
768
+ # Convert images to latent space
769
+ images = batch["pixel_values"].to(weight_dtype)
770
+ clip_images = batch["clip_pixel_values"].to(weight_dtype)
771
+ latents = vae.encode(images).latents
772
+ image_embeds = image_encoder(clip_images).image_embeds
773
+ # Sample noise that we'll add to the latents
774
+ noise = torch.randn_like(latents)
775
+ bsz = latents.shape[0]
776
+ # Sample a random timestep for each image
777
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
778
+ timesteps = timesteps.long()
779
+
780
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
781
+
782
+ target = noise
783
+
784
+ # Predict the noise residual and compute loss
785
+ added_cond_kwargs = {"image_embeds": image_embeds}
786
+
787
+ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
788
+
789
+ if args.snr_gamma is None:
790
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
791
+ else:
792
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
793
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
794
+ # This is discussed in Section 4.2 of the same paper.
795
+ snr = compute_snr(noise_scheduler, timesteps)
796
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
797
+ dim=1
798
+ )[0]
799
+ if noise_scheduler.config.prediction_type == "epsilon":
800
+ mse_loss_weights = mse_loss_weights / snr
801
+ elif noise_scheduler.config.prediction_type == "v_prediction":
802
+ mse_loss_weights = mse_loss_weights / (snr + 1)
803
+
804
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
805
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
806
+ loss = loss.mean()
807
+
808
+ # Gather the losses across all processes for logging (if we use distributed training).
809
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
810
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
811
+
812
+ # Backpropagate
813
+ accelerator.backward(loss)
814
+ if accelerator.sync_gradients:
815
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
816
+ optimizer.step()
817
+ lr_scheduler.step()
818
+ optimizer.zero_grad()
819
+
820
+ # Checks if the accelerator has performed an optimization step behind the scenes
821
+ if accelerator.sync_gradients:
822
+ if args.use_ema:
823
+ ema_unet.step(unet.parameters())
824
+ progress_bar.update(1)
825
+ global_step += 1
826
+ accelerator.log({"train_loss": train_loss}, step=global_step)
827
+ train_loss = 0.0
828
+
829
+ if global_step % args.checkpointing_steps == 0:
830
+ if accelerator.is_main_process:
831
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
832
+ if args.checkpoints_total_limit is not None:
833
+ checkpoints = os.listdir(args.output_dir)
834
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
835
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
836
+
837
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
838
+ if len(checkpoints) >= args.checkpoints_total_limit:
839
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
840
+ removing_checkpoints = checkpoints[0:num_to_remove]
841
+
842
+ logger.info(
843
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
844
+ )
845
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
846
+
847
+ for removing_checkpoint in removing_checkpoints:
848
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
849
+ shutil.rmtree(removing_checkpoint)
850
+
851
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
852
+ accelerator.save_state(save_path)
853
+ logger.info(f"Saved state to {save_path}")
854
+
855
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
856
+ progress_bar.set_postfix(**logs)
857
+
858
+ if global_step >= args.max_train_steps:
859
+ break
860
+
861
+ if accelerator.is_main_process:
862
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
863
+ if args.use_ema:
864
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
865
+ ema_unet.store(unet.parameters())
866
+ ema_unet.copy_to(unet.parameters())
867
+ log_validation(
868
+ vae,
869
+ image_encoder,
870
+ image_processor,
871
+ unet,
872
+ args,
873
+ accelerator,
874
+ weight_dtype,
875
+ global_step,
876
+ )
877
+ if args.use_ema:
878
+ # Switch back to the original UNet parameters.
879
+ ema_unet.restore(unet.parameters())
880
+
881
+ # Create the pipeline using the trained modules and save it.
882
+ accelerator.wait_for_everyone()
883
+ if accelerator.is_main_process:
884
+ unet = accelerator.unwrap_model(unet)
885
+ if args.use_ema:
886
+ ema_unet.copy_to(unet.parameters())
887
+
888
+ pipeline = AutoPipelineForText2Image.from_pretrained(
889
+ args.pretrained_decoder_model_name_or_path,
890
+ vae=vae,
891
+ unet=unet,
892
+ )
893
+ pipeline.decoder_pipe.save_pretrained(args.output_dir)
894
+
895
+ # Run a final round of inference.
896
+ images = []
897
+ if args.validation_prompts is not None:
898
+ logger.info("Running inference for collecting generated images...")
899
+ pipeline.torch_dtype = weight_dtype
900
+ pipeline.set_progress_bar_config(disable=True)
901
+ pipeline.enable_model_cpu_offload()
902
+
903
+ if args.enable_xformers_memory_efficient_attention:
904
+ pipeline.enable_xformers_memory_efficient_attention()
905
+
906
+ if args.seed is None:
907
+ generator = None
908
+ else:
909
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
910
+
911
+ for i in range(len(args.validation_prompts)):
912
+ with torch.autocast("cuda"):
913
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
914
+ images.append(image)
915
+
916
+ if args.push_to_hub:
917
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
918
+ upload_folder(
919
+ repo_id=repo_id,
920
+ folder_path=args.output_dir,
921
+ commit_message="End of training",
922
+ ignore_patterns=["step_*", "epoch_*"],
923
+ )
924
+
925
+ accelerator.end_training()
926
+
927
+
928
+ if __name__ == "__main__":
929
+ main()
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fine-tuning script for Kandinsky with support for LoRA."""
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import shutil
22
+ from pathlib import Path
23
+
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import ProjectConfiguration, set_seed
33
+ from datasets import load_dataset
34
+ from huggingface_hub import create_repo, upload_folder
35
+ from PIL import Image
36
+ from tqdm import tqdm
37
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
38
+
39
+ import diffusers
40
+ from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
41
+ from diffusers.loaders import AttnProcsLayers
42
+ from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
43
+ from diffusers.optimization import get_scheduler
44
+ from diffusers.training_utils import compute_snr
45
+ from diffusers.utils import check_min_version, is_wandb_available
46
+
47
+
48
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
49
+ check_min_version("0.35.0.dev0")
50
+
51
+ logger = get_logger(__name__, log_level="INFO")
52
+
53
+
54
+ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
55
+ img_str = ""
56
+ for i, image in enumerate(images):
57
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
58
+ img_str += f"![img_{i}](./image_{i}.png)\n"
59
+
60
+ yaml = f"""
61
+ ---
62
+ license: creativeml-openrail-m
63
+ base_model: {base_model}
64
+ tags:
65
+ - kandinsky
66
+ - text-to-image
67
+ - diffusers
68
+ - diffusers-training
69
+ - lora
70
+ inference: true
71
+ ---
72
+ """
73
+ model_card = f"""
74
+ # LoRA text2image fine-tuning - {repo_id}
75
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
76
+ {img_str}
77
+ """
78
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
79
+ f.write(yaml + model_card)
80
+
81
+
82
+ def parse_args():
83
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2 with LoRA.")
84
+ parser.add_argument(
85
+ "--pretrained_decoder_model_name_or_path",
86
+ type=str,
87
+ default="kandinsky-community/kandinsky-2-2-decoder",
88
+ required=False,
89
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
90
+ )
91
+ parser.add_argument(
92
+ "--pretrained_prior_model_name_or_path",
93
+ type=str,
94
+ default="kandinsky-community/kandinsky-2-2-prior",
95
+ required=False,
96
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
97
+ )
98
+ parser.add_argument(
99
+ "--dataset_name",
100
+ type=str,
101
+ default=None,
102
+ help=(
103
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
104
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
105
+ " or to a folder containing files that 🤗 Datasets can understand."
106
+ ),
107
+ )
108
+ parser.add_argument(
109
+ "--dataset_config_name",
110
+ type=str,
111
+ default=None,
112
+ help="The config of the Dataset, leave as None if there's only one config.",
113
+ )
114
+ parser.add_argument(
115
+ "--train_data_dir",
116
+ type=str,
117
+ default=None,
118
+ help=(
119
+ "A folder containing the training data. Folder contents must follow the structure described in"
120
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
121
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
122
+ ),
123
+ )
124
+ parser.add_argument(
125
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
126
+ )
127
+ parser.add_argument(
128
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
129
+ )
130
+ parser.add_argument(
131
+ "--num_validation_images",
132
+ type=int,
133
+ default=4,
134
+ help="Number of images that should be generated during validation with `validation_prompt`.",
135
+ )
136
+ parser.add_argument(
137
+ "--validation_epochs",
138
+ type=int,
139
+ default=1,
140
+ help=(
141
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
142
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
143
+ ),
144
+ )
145
+ parser.add_argument(
146
+ "--max_train_samples",
147
+ type=int,
148
+ default=None,
149
+ help=(
150
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
151
+ "value if set."
152
+ ),
153
+ )
154
+ parser.add_argument(
155
+ "--output_dir",
156
+ type=str,
157
+ default="kandi_2_2-model-finetuned-lora",
158
+ help="The output directory where the model predictions and checkpoints will be written.",
159
+ )
160
+ parser.add_argument(
161
+ "--cache_dir",
162
+ type=str,
163
+ default=None,
164
+ help="The directory where the downloaded models and datasets will be stored.",
165
+ )
166
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
167
+ parser.add_argument(
168
+ "--resolution",
169
+ type=int,
170
+ default=512,
171
+ help=(
172
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
173
+ " resolution"
174
+ ),
175
+ )
176
+ parser.add_argument(
177
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
178
+ )
179
+ parser.add_argument("--num_train_epochs", type=int, default=100)
180
+ parser.add_argument(
181
+ "--max_train_steps",
182
+ type=int,
183
+ default=None,
184
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
185
+ )
186
+ parser.add_argument(
187
+ "--gradient_accumulation_steps",
188
+ type=int,
189
+ default=1,
190
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
191
+ )
192
+ parser.add_argument(
193
+ "--gradient_checkpointing",
194
+ action="store_true",
195
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
196
+ )
197
+ parser.add_argument(
198
+ "--learning_rate",
199
+ type=float,
200
+ default=1e-4,
201
+ help="Initial learning rate (after the potential warmup period) to use.",
202
+ )
203
+ parser.add_argument(
204
+ "--lr_scheduler",
205
+ type=str,
206
+ default="constant",
207
+ help=(
208
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
209
+ ' "constant", "constant_with_warmup"]'
210
+ ),
211
+ )
212
+ parser.add_argument(
213
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
214
+ )
215
+ parser.add_argument(
216
+ "--snr_gamma",
217
+ type=float,
218
+ default=None,
219
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
220
+ "More details here: https://huggingface.co/papers/2303.09556.",
221
+ )
222
+ parser.add_argument(
223
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
224
+ )
225
+ parser.add_argument(
226
+ "--allow_tf32",
227
+ action="store_true",
228
+ help=(
229
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
230
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
231
+ ),
232
+ )
233
+ parser.add_argument(
234
+ "--dataloader_num_workers",
235
+ type=int,
236
+ default=0,
237
+ help=(
238
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
239
+ ),
240
+ )
241
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
242
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
243
+ parser.add_argument("--adam_weight_decay", type=float, default=0.0, help="Weight decay to use.")
244
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
245
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
246
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
247
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
248
+ parser.add_argument(
249
+ "--hub_model_id",
250
+ type=str,
251
+ default=None,
252
+ help="The name of the repository to keep in sync with the local `output_dir`.",
253
+ )
254
+ parser.add_argument(
255
+ "--logging_dir",
256
+ type=str,
257
+ default="logs",
258
+ help=(
259
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
260
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
261
+ ),
262
+ )
263
+ parser.add_argument(
264
+ "--mixed_precision",
265
+ type=str,
266
+ default=None,
267
+ choices=["no", "fp16", "bf16"],
268
+ help=(
269
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
270
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
271
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
272
+ ),
273
+ )
274
+ parser.add_argument(
275
+ "--report_to",
276
+ type=str,
277
+ default="tensorboard",
278
+ help=(
279
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
280
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
281
+ ),
282
+ )
283
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
284
+ parser.add_argument(
285
+ "--checkpointing_steps",
286
+ type=int,
287
+ default=500,
288
+ help=(
289
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
290
+ " training using `--resume_from_checkpoint`."
291
+ ),
292
+ )
293
+ parser.add_argument(
294
+ "--checkpoints_total_limit",
295
+ type=int,
296
+ default=None,
297
+ help=("Max number of checkpoints to store."),
298
+ )
299
+ parser.add_argument(
300
+ "--resume_from_checkpoint",
301
+ type=str,
302
+ default=None,
303
+ help=(
304
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
305
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
306
+ ),
307
+ )
308
+ parser.add_argument(
309
+ "--rank",
310
+ type=int,
311
+ default=4,
312
+ help=("The dimension of the LoRA update matrices."),
313
+ )
314
+
315
+ args = parser.parse_args()
316
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
317
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
318
+ args.local_rank = env_local_rank
319
+
320
+ # Sanity checks
321
+ if args.dataset_name is None and args.train_data_dir is None:
322
+ raise ValueError("Need either a dataset name or a training folder.")
323
+
324
+ return args
325
+
326
+
327
+ def main():
328
+ args = parse_args()
329
+
330
+ if args.report_to == "wandb" and args.hub_token is not None:
331
+ raise ValueError(
332
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
333
+ " Please use `huggingface-cli login` to authenticate with the Hub."
334
+ )
335
+
336
+ logging_dir = Path(args.output_dir, args.logging_dir)
337
+ accelerator_project_config = ProjectConfiguration(
338
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
339
+ )
340
+ accelerator = Accelerator(
341
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
342
+ mixed_precision=args.mixed_precision,
343
+ log_with=args.report_to,
344
+ project_config=accelerator_project_config,
345
+ )
346
+
347
+ # Disable AMP for MPS.
348
+ if torch.backends.mps.is_available():
349
+ accelerator.native_amp = False
350
+
351
+ if args.report_to == "wandb":
352
+ if not is_wandb_available():
353
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
354
+ import wandb
355
+
356
+ # Make one log on every process with the configuration for debugging.
357
+ logging.basicConfig(
358
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
359
+ datefmt="%m/%d/%Y %H:%M:%S",
360
+ level=logging.INFO,
361
+ )
362
+ logger.info(accelerator.state, main_process_only=False)
363
+ if accelerator.is_local_main_process:
364
+ datasets.utils.logging.set_verbosity_warning()
365
+ transformers.utils.logging.set_verbosity_warning()
366
+ diffusers.utils.logging.set_verbosity_info()
367
+ else:
368
+ datasets.utils.logging.set_verbosity_error()
369
+ transformers.utils.logging.set_verbosity_error()
370
+ diffusers.utils.logging.set_verbosity_error()
371
+
372
+ # If passed along, set the training seed now.
373
+ if args.seed is not None:
374
+ set_seed(args.seed)
375
+
376
+ # Handle the repository creation
377
+ if accelerator.is_main_process:
378
+ if args.output_dir is not None:
379
+ os.makedirs(args.output_dir, exist_ok=True)
380
+
381
+ if args.push_to_hub:
382
+ repo_id = create_repo(
383
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
384
+ ).repo_id
385
+ # Load scheduler, tokenizer and models.
386
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
387
+ image_processor = CLIPImageProcessor.from_pretrained(
388
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
389
+ )
390
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
391
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
392
+ )
393
+
394
+ vae = VQModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="movq")
395
+
396
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
397
+ # freeze parameters of models to save more memory
398
+ unet.requires_grad_(False)
399
+ vae.requires_grad_(False)
400
+
401
+ image_encoder.requires_grad_(False)
402
+
403
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
404
+ # as these weights are only used for inference, keeping weights in full precision is not required.
405
+ weight_dtype = torch.float32
406
+ if accelerator.mixed_precision == "fp16":
407
+ weight_dtype = torch.float16
408
+ elif accelerator.mixed_precision == "bf16":
409
+ weight_dtype = torch.bfloat16
410
+
411
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
412
+ unet.to(accelerator.device, dtype=weight_dtype)
413
+ vae.to(accelerator.device, dtype=weight_dtype)
414
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
415
+
416
+ lora_attn_procs = {}
417
+ for name in unet.attn_processors.keys():
418
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
419
+ if name.startswith("mid_block"):
420
+ hidden_size = unet.config.block_out_channels[-1]
421
+ elif name.startswith("up_blocks"):
422
+ block_id = int(name[len("up_blocks.")])
423
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
424
+ elif name.startswith("down_blocks"):
425
+ block_id = int(name[len("down_blocks.")])
426
+ hidden_size = unet.config.block_out_channels[block_id]
427
+
428
+ lora_attn_procs[name] = LoRAAttnAddedKVProcessor(
429
+ hidden_size=hidden_size,
430
+ cross_attention_dim=cross_attention_dim,
431
+ rank=args.rank,
432
+ )
433
+
434
+ unet.set_attn_processor(lora_attn_procs)
435
+
436
+ lora_layers = AttnProcsLayers(unet.attn_processors)
437
+
438
+ if args.allow_tf32:
439
+ torch.backends.cuda.matmul.allow_tf32 = True
440
+
441
+ if args.use_8bit_adam:
442
+ try:
443
+ import bitsandbytes as bnb
444
+ except ImportError:
445
+ raise ImportError(
446
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
447
+ )
448
+
449
+ optimizer_cls = bnb.optim.AdamW8bit
450
+ else:
451
+ optimizer_cls = torch.optim.AdamW
452
+
453
+ optimizer = optimizer_cls(
454
+ lora_layers.parameters(),
455
+ lr=args.learning_rate,
456
+ betas=(args.adam_beta1, args.adam_beta2),
457
+ weight_decay=args.adam_weight_decay,
458
+ eps=args.adam_epsilon,
459
+ )
460
+
461
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
462
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
463
+
464
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
465
+ # download the dataset.
466
+ if args.dataset_name is not None:
467
+ # Downloading and loading a dataset from the hub.
468
+ dataset = load_dataset(
469
+ args.dataset_name,
470
+ args.dataset_config_name,
471
+ cache_dir=args.cache_dir,
472
+ )
473
+ else:
474
+ data_files = {}
475
+ if args.train_data_dir is not None:
476
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
477
+ dataset = load_dataset(
478
+ "imagefolder",
479
+ data_files=data_files,
480
+ cache_dir=args.cache_dir,
481
+ )
482
+ # See more about loading custom images at
483
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
484
+
485
+ # Preprocessing the datasets.
486
+ # We need to tokenize inputs and targets.
487
+ column_names = dataset["train"].column_names
488
+
489
+ image_column = args.image_column
490
+ if image_column not in column_names:
491
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
492
+
493
+ def center_crop(image):
494
+ width, height = image.size
495
+ new_size = min(width, height)
496
+ left = (width - new_size) / 2
497
+ top = (height - new_size) / 2
498
+ right = (width + new_size) / 2
499
+ bottom = (height + new_size) / 2
500
+ return image.crop((left, top, right, bottom))
501
+
502
+ def train_transforms(img):
503
+ img = center_crop(img)
504
+ img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
505
+ img = np.array(img).astype(np.float32) / 127.5 - 1
506
+ img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
507
+ return img
508
+
509
+ def preprocess_train(examples):
510
+ images = [image.convert("RGB") for image in examples[image_column]]
511
+ examples["pixel_values"] = [train_transforms(image) for image in images]
512
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
513
+ return examples
514
+
515
+ with accelerator.main_process_first():
516
+ if args.max_train_samples is not None:
517
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
518
+ # Set the training transforms
519
+ train_dataset = dataset["train"].with_transform(preprocess_train)
520
+
521
+ def collate_fn(examples):
522
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
523
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
524
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
525
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
526
+ return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
527
+
528
+ train_dataloader = torch.utils.data.DataLoader(
529
+ train_dataset,
530
+ shuffle=True,
531
+ collate_fn=collate_fn,
532
+ batch_size=args.train_batch_size,
533
+ num_workers=args.dataloader_num_workers,
534
+ )
535
+
536
+ # Scheduler and math around the number of training steps.
537
+ overrode_max_train_steps = False
538
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
539
+ if args.max_train_steps is None:
540
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
541
+ overrode_max_train_steps = True
542
+
543
+ lr_scheduler = get_scheduler(
544
+ args.lr_scheduler,
545
+ optimizer=optimizer,
546
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
547
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
548
+ )
549
+ # Prepare everything with our `accelerator`.
550
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
551
+ lora_layers, optimizer, train_dataloader, lr_scheduler
552
+ )
553
+
554
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
555
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
556
+ if overrode_max_train_steps:
557
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
558
+ # Afterwards we recalculate our number of training epochs
559
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
560
+
561
+ # We need to initialize the trackers we use, and also store our configuration.
562
+ # The trackers initializes automatically on the main process.
563
+ if accelerator.is_main_process:
564
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
565
+
566
+ # Train!
567
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
568
+
569
+ logger.info("***** Running training *****")
570
+ logger.info(f" Num examples = {len(train_dataset)}")
571
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
572
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
573
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
574
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
575
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
576
+ global_step = 0
577
+ first_epoch = 0
578
+
579
+ # Potentially load in the weights and states from a previous save
580
+ if args.resume_from_checkpoint:
581
+ if args.resume_from_checkpoint != "latest":
582
+ path = os.path.basename(args.resume_from_checkpoint)
583
+ else:
584
+ # Get the most recent checkpoint
585
+ dirs = os.listdir(args.output_dir)
586
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
587
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
588
+ path = dirs[-1] if len(dirs) > 0 else None
589
+
590
+ if path is None:
591
+ accelerator.print(
592
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
593
+ )
594
+ args.resume_from_checkpoint = None
595
+ initial_global_step = 0
596
+ else:
597
+ accelerator.print(f"Resuming from checkpoint {path}")
598
+ accelerator.load_state(os.path.join(args.output_dir, path))
599
+ global_step = int(path.split("-")[1])
600
+
601
+ initial_global_step = global_step
602
+ first_epoch = global_step // num_update_steps_per_epoch
603
+ else:
604
+ initial_global_step = 0
605
+
606
+ progress_bar = tqdm(
607
+ range(0, args.max_train_steps),
608
+ initial=initial_global_step,
609
+ desc="Steps",
610
+ # Only show the progress bar once on each machine.
611
+ disable=not accelerator.is_local_main_process,
612
+ )
613
+
614
+ for epoch in range(first_epoch, args.num_train_epochs):
615
+ unet.train()
616
+ train_loss = 0.0
617
+ for step, batch in enumerate(train_dataloader):
618
+ with accelerator.accumulate(unet):
619
+ # Convert images to latent space
620
+ images = batch["pixel_values"].to(weight_dtype)
621
+ clip_images = batch["clip_pixel_values"].to(weight_dtype)
622
+ latents = vae.encode(images).latents
623
+ image_embeds = image_encoder(clip_images).image_embeds
624
+ # Sample noise that we'll add to the latents
625
+ noise = torch.randn_like(latents)
626
+ bsz = latents.shape[0]
627
+ # Sample a random timestep for each image
628
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
629
+ timesteps = timesteps.long()
630
+
631
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
632
+
633
+ target = noise
634
+
635
+ # Predict the noise residual and compute loss
636
+ added_cond_kwargs = {"image_embeds": image_embeds}
637
+
638
+ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
639
+
640
+ if args.snr_gamma is None:
641
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
642
+ else:
643
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
644
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
645
+ # This is discussed in Section 4.2 of the same paper.
646
+ snr = compute_snr(noise_scheduler, timesteps)
647
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
648
+ dim=1
649
+ )[0]
650
+ if noise_scheduler.config.prediction_type == "epsilon":
651
+ mse_loss_weights = mse_loss_weights / snr
652
+ elif noise_scheduler.config.prediction_type == "v_prediction":
653
+ mse_loss_weights = mse_loss_weights / (snr + 1)
654
+
655
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
656
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
657
+ loss = loss.mean()
658
+
659
+ # Gather the losses across all processes for logging (if we use distributed training).
660
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
661
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
662
+
663
+ # Backpropagate
664
+ accelerator.backward(loss)
665
+ if accelerator.sync_gradients:
666
+ params_to_clip = lora_layers.parameters()
667
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
668
+ optimizer.step()
669
+ lr_scheduler.step()
670
+ optimizer.zero_grad()
671
+
672
+ # Checks if the accelerator has performed an optimization step behind the scenes
673
+ if accelerator.sync_gradients:
674
+ progress_bar.update(1)
675
+ global_step += 1
676
+ accelerator.log({"train_loss": train_loss}, step=global_step)
677
+ train_loss = 0.0
678
+
679
+ if global_step % args.checkpointing_steps == 0:
680
+ if accelerator.is_main_process:
681
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
682
+ if args.checkpoints_total_limit is not None:
683
+ checkpoints = os.listdir(args.output_dir)
684
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
685
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
686
+
687
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
688
+ if len(checkpoints) >= args.checkpoints_total_limit:
689
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
690
+ removing_checkpoints = checkpoints[0:num_to_remove]
691
+
692
+ logger.info(
693
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
694
+ )
695
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
696
+
697
+ for removing_checkpoint in removing_checkpoints:
698
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
699
+ shutil.rmtree(removing_checkpoint)
700
+
701
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
702
+ accelerator.save_state(save_path)
703
+ logger.info(f"Saved state to {save_path}")
704
+
705
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
706
+ progress_bar.set_postfix(**logs)
707
+
708
+ if global_step >= args.max_train_steps:
709
+ break
710
+
711
+ if accelerator.is_main_process:
712
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
713
+ logger.info(
714
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
715
+ f" {args.validation_prompt}."
716
+ )
717
+ # create pipeline
718
+ pipeline = AutoPipelineForText2Image.from_pretrained(
719
+ args.pretrained_decoder_model_name_or_path,
720
+ unet=accelerator.unwrap_model(unet),
721
+ torch_dtype=weight_dtype,
722
+ )
723
+ pipeline = pipeline.to(accelerator.device)
724
+ pipeline.set_progress_bar_config(disable=True)
725
+
726
+ # run inference
727
+ generator = torch.Generator(device=accelerator.device)
728
+ if args.seed is not None:
729
+ generator = generator.manual_seed(args.seed)
730
+ images = []
731
+ for _ in range(args.num_validation_images):
732
+ images.append(
733
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
734
+ )
735
+
736
+ for tracker in accelerator.trackers:
737
+ if tracker.name == "tensorboard":
738
+ np_images = np.stack([np.asarray(img) for img in images])
739
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
740
+ if tracker.name == "wandb":
741
+ tracker.log(
742
+ {
743
+ "validation": [
744
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
745
+ for i, image in enumerate(images)
746
+ ]
747
+ }
748
+ )
749
+
750
+ del pipeline
751
+ torch.cuda.empty_cache()
752
+
753
+ # Save the lora layers
754
+ accelerator.wait_for_everyone()
755
+ if accelerator.is_main_process:
756
+ unet = unet.to(torch.float32)
757
+ unet.save_attn_procs(args.output_dir)
758
+
759
+ if args.push_to_hub:
760
+ save_model_card(
761
+ repo_id,
762
+ images=images,
763
+ base_model=args.pretrained_decoder_model_name_or_path,
764
+ dataset_name=args.dataset_name,
765
+ repo_folder=args.output_dir,
766
+ )
767
+ upload_folder(
768
+ repo_id=repo_id,
769
+ folder_path=args.output_dir,
770
+ commit_message="End of training",
771
+ ignore_patterns=["step_*", "epoch_*"],
772
+ )
773
+
774
+ # Final inference
775
+ # Load previous pipeline
776
+ pipeline = AutoPipelineForText2Image.from_pretrained(
777
+ args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
778
+ )
779
+ pipeline = pipeline.to(accelerator.device)
780
+
781
+ # load attention processors
782
+ pipeline.unet.load_attn_procs(args.output_dir)
783
+
784
+ # run inference
785
+ generator = torch.Generator(device=accelerator.device)
786
+ if args.seed is not None:
787
+ generator = generator.manual_seed(args.seed)
788
+ images = []
789
+ for _ in range(args.num_validation_images):
790
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
791
+
792
+ if accelerator.is_main_process:
793
+ for tracker in accelerator.trackers:
794
+ if len(images) != 0:
795
+ if tracker.name == "tensorboard":
796
+ np_images = np.stack([np.asarray(img) for img in images])
797
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
798
+ if tracker.name == "wandb":
799
+ tracker.log(
800
+ {
801
+ "test": [
802
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
803
+ for i, image in enumerate(images)
804
+ ]
805
+ }
806
+ )
807
+
808
+ accelerator.end_training()
809
+
810
+
811
+ if __name__ == "__main__":
812
+ main()
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import datasets
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from tqdm import tqdm
37
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
38
+
39
+ import diffusers
40
+ from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
41
+ from diffusers.loaders import AttnProcsLayers
42
+ from diffusers.models.attention_processor import LoRAAttnProcessor
43
+ from diffusers.optimization import get_scheduler
44
+ from diffusers.training_utils import compute_snr
45
+ from diffusers.utils import check_min_version, is_wandb_available
46
+
47
+
48
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
49
+ check_min_version("0.35.0.dev0")
50
+
51
+ logger = get_logger(__name__, log_level="INFO")
52
+
53
+
54
+ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
55
+ img_str = ""
56
+ for i, image in enumerate(images):
57
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
58
+ img_str += f"![img_{i}](./image_{i}.png)\n"
59
+
60
+ yaml = f"""
61
+ ---
62
+ license: creativeml-openrail-m
63
+ base_model: {base_model}
64
+ tags:
65
+ - kandinsky
66
+ - text-to-image
67
+ - diffusers
68
+ - diffusers-training
69
+ - lora
70
+ inference: true
71
+ ---
72
+ """
73
+ model_card = f"""
74
+ # LoRA text2image fine-tuning - {repo_id}
75
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
76
+ {img_str}
77
+ """
78
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
79
+ f.write(yaml + model_card)
80
+
81
+
82
+ def parse_args():
83
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
84
+ parser.add_argument(
85
+ "--pretrained_decoder_model_name_or_path",
86
+ type=str,
87
+ default="kandinsky-community/kandinsky-2-2-decoder",
88
+ required=False,
89
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
90
+ )
91
+ parser.add_argument(
92
+ "--pretrained_prior_model_name_or_path",
93
+ type=str,
94
+ default="kandinsky-community/kandinsky-2-2-prior",
95
+ required=False,
96
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
97
+ )
98
+ parser.add_argument(
99
+ "--dataset_name",
100
+ type=str,
101
+ default=None,
102
+ help=(
103
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
104
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
105
+ " or to a folder containing files that 🤗 Datasets can understand."
106
+ ),
107
+ )
108
+ parser.add_argument(
109
+ "--dataset_config_name",
110
+ type=str,
111
+ default=None,
112
+ help="The config of the Dataset, leave as None if there's only one config.",
113
+ )
114
+ parser.add_argument(
115
+ "--train_data_dir",
116
+ type=str,
117
+ default=None,
118
+ help=(
119
+ "A folder containing the training data. Folder contents must follow the structure described in"
120
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
121
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
122
+ ),
123
+ )
124
+ parser.add_argument(
125
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
126
+ )
127
+ parser.add_argument(
128
+ "--caption_column",
129
+ type=str,
130
+ default="text",
131
+ help="The column of the dataset containing a caption or a list of captions.",
132
+ )
133
+ parser.add_argument(
134
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
135
+ )
136
+ parser.add_argument(
137
+ "--num_validation_images",
138
+ type=int,
139
+ default=4,
140
+ help="Number of images that should be generated during validation with `validation_prompt`.",
141
+ )
142
+ parser.add_argument(
143
+ "--validation_epochs",
144
+ type=int,
145
+ default=1,
146
+ help=(
147
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
148
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
149
+ ),
150
+ )
151
+ parser.add_argument(
152
+ "--max_train_samples",
153
+ type=int,
154
+ default=None,
155
+ help=(
156
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
157
+ "value if set."
158
+ ),
159
+ )
160
+ parser.add_argument(
161
+ "--output_dir",
162
+ type=str,
163
+ default="kandi_2_2-model-finetuned-lora",
164
+ help="The output directory where the model predictions and checkpoints will be written.",
165
+ )
166
+ parser.add_argument(
167
+ "--cache_dir",
168
+ type=str,
169
+ default=None,
170
+ help="The directory where the downloaded models and datasets will be stored.",
171
+ )
172
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
173
+ parser.add_argument(
174
+ "--resolution",
175
+ type=int,
176
+ default=512,
177
+ help=(
178
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
179
+ " resolution"
180
+ ),
181
+ )
182
+ parser.add_argument(
183
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
184
+ )
185
+ parser.add_argument("--num_train_epochs", type=int, default=100)
186
+ parser.add_argument(
187
+ "--max_train_steps",
188
+ type=int,
189
+ default=None,
190
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
191
+ )
192
+ parser.add_argument(
193
+ "--gradient_accumulation_steps",
194
+ type=int,
195
+ default=1,
196
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
197
+ )
198
+ parser.add_argument(
199
+ "--learning_rate",
200
+ type=float,
201
+ default=1e-4,
202
+ help="learning rate",
203
+ )
204
+ parser.add_argument(
205
+ "--lr_scheduler",
206
+ type=str,
207
+ default="constant",
208
+ help=(
209
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
210
+ ' "constant", "constant_with_warmup"]'
211
+ ),
212
+ )
213
+ parser.add_argument(
214
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
215
+ )
216
+ parser.add_argument(
217
+ "--snr_gamma",
218
+ type=float,
219
+ default=None,
220
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
221
+ "More details here: https://huggingface.co/papers/2303.09556.",
222
+ )
223
+ parser.add_argument(
224
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
225
+ )
226
+ parser.add_argument(
227
+ "--allow_tf32",
228
+ action="store_true",
229
+ help=(
230
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
231
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
232
+ ),
233
+ )
234
+ parser.add_argument(
235
+ "--dataloader_num_workers",
236
+ type=int,
237
+ default=0,
238
+ help=(
239
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
240
+ ),
241
+ )
242
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
243
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
244
+ parser.add_argument(
245
+ "--adam_weight_decay",
246
+ type=float,
247
+ default=0.0,
248
+ required=False,
249
+ help="weight decay_to_use",
250
+ )
251
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
252
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
253
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
254
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
255
+ parser.add_argument(
256
+ "--hub_model_id",
257
+ type=str,
258
+ default=None,
259
+ help="The name of the repository to keep in sync with the local `output_dir`.",
260
+ )
261
+ parser.add_argument(
262
+ "--logging_dir",
263
+ type=str,
264
+ default="logs",
265
+ help=(
266
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
267
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
268
+ ),
269
+ )
270
+ parser.add_argument(
271
+ "--mixed_precision",
272
+ type=str,
273
+ default=None,
274
+ choices=["no", "fp16", "bf16"],
275
+ help=(
276
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
277
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
278
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
279
+ ),
280
+ )
281
+ parser.add_argument(
282
+ "--report_to",
283
+ type=str,
284
+ default="tensorboard",
285
+ help=(
286
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
287
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
288
+ ),
289
+ )
290
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
291
+ parser.add_argument(
292
+ "--checkpointing_steps",
293
+ type=int,
294
+ default=500,
295
+ help=(
296
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
297
+ " training using `--resume_from_checkpoint`."
298
+ ),
299
+ )
300
+ parser.add_argument(
301
+ "--checkpoints_total_limit",
302
+ type=int,
303
+ default=None,
304
+ help=("Max number of checkpoints to store."),
305
+ )
306
+ parser.add_argument(
307
+ "--resume_from_checkpoint",
308
+ type=str,
309
+ default=None,
310
+ help=(
311
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
312
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
313
+ ),
314
+ )
315
+ parser.add_argument(
316
+ "--rank",
317
+ type=int,
318
+ default=4,
319
+ help=("The dimension of the LoRA update matrices."),
320
+ )
321
+
322
+ args = parser.parse_args()
323
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
324
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
325
+ args.local_rank = env_local_rank
326
+
327
+ # Sanity checks
328
+ if args.dataset_name is None and args.train_data_dir is None:
329
+ raise ValueError("Need either a dataset name or a training folder.")
330
+
331
+ return args
332
+
333
+
334
+ DATASET_NAME_MAPPING = {
335
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
336
+ }
337
+
338
+
339
+ def main():
340
+ args = parse_args()
341
+
342
+ if args.report_to == "wandb" and args.hub_token is not None:
343
+ raise ValueError(
344
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
345
+ " Please use `huggingface-cli login` to authenticate with the Hub."
346
+ )
347
+
348
+ logging_dir = Path(args.output_dir, args.logging_dir)
349
+
350
+ accelerator_project_config = ProjectConfiguration(
351
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
352
+ )
353
+ accelerator = Accelerator(
354
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
355
+ mixed_precision=args.mixed_precision,
356
+ log_with=args.report_to,
357
+ project_config=accelerator_project_config,
358
+ )
359
+
360
+ # Disable AMP for MPS.
361
+ if torch.backends.mps.is_available():
362
+ accelerator.native_amp = False
363
+
364
+ if args.report_to == "wandb":
365
+ if not is_wandb_available():
366
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
367
+ import wandb
368
+
369
+ # Make one log on every process with the configuration for debugging.
370
+ logging.basicConfig(
371
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
372
+ datefmt="%m/%d/%Y %H:%M:%S",
373
+ level=logging.INFO,
374
+ )
375
+ logger.info(accelerator.state, main_process_only=False)
376
+ if accelerator.is_local_main_process:
377
+ datasets.utils.logging.set_verbosity_warning()
378
+ transformers.utils.logging.set_verbosity_warning()
379
+ diffusers.utils.logging.set_verbosity_info()
380
+ else:
381
+ datasets.utils.logging.set_verbosity_error()
382
+ transformers.utils.logging.set_verbosity_error()
383
+ diffusers.utils.logging.set_verbosity_error()
384
+
385
+ # If passed along, set the training seed now.
386
+ if args.seed is not None:
387
+ set_seed(args.seed)
388
+
389
+ # Handle the repository creation
390
+ if accelerator.is_main_process:
391
+ if args.output_dir is not None:
392
+ os.makedirs(args.output_dir, exist_ok=True)
393
+
394
+ if args.push_to_hub:
395
+ repo_id = create_repo(
396
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
397
+ ).repo_id
398
+ # Load scheduler, image_processor, tokenizer and models.
399
+ noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
400
+ image_processor = CLIPImageProcessor.from_pretrained(
401
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
402
+ )
403
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
404
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
405
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
406
+ )
407
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
408
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder"
409
+ )
410
+ prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
411
+ # freeze parameters of models to save more memory
412
+ image_encoder.requires_grad_(False)
413
+ prior.requires_grad_(False)
414
+ text_encoder.requires_grad_(False)
415
+ weight_dtype = torch.float32
416
+ if accelerator.mixed_precision == "fp16":
417
+ weight_dtype = torch.float16
418
+ elif accelerator.mixed_precision == "bf16":
419
+ weight_dtype = torch.bfloat16
420
+
421
+ # Move image_encoder, text_encoder and prior to device and cast to weight_dtype
422
+ prior.to(accelerator.device, dtype=weight_dtype)
423
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
424
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
425
+ lora_attn_procs = {}
426
+ for name in prior.attn_processors.keys():
427
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)
428
+
429
+ prior.set_attn_processor(lora_attn_procs)
430
+ lora_layers = AttnProcsLayers(prior.attn_processors)
431
+
432
+ if args.allow_tf32:
433
+ torch.backends.cuda.matmul.allow_tf32 = True
434
+
435
+ if args.use_8bit_adam:
436
+ try:
437
+ import bitsandbytes as bnb
438
+ except ImportError:
439
+ raise ImportError(
440
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
441
+ )
442
+
443
+ optimizer_cls = bnb.optim.AdamW8bit
444
+ else:
445
+ optimizer_cls = torch.optim.AdamW
446
+
447
+ optimizer = optimizer_cls(
448
+ lora_layers.parameters(),
449
+ lr=args.learning_rate,
450
+ betas=(args.adam_beta1, args.adam_beta2),
451
+ weight_decay=args.adam_weight_decay,
452
+ eps=args.adam_epsilon,
453
+ )
454
+
455
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
456
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
457
+
458
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
459
+ # download the dataset.
460
+ if args.dataset_name is not None:
461
+ # Downloading and loading a dataset from the hub.
462
+ dataset = load_dataset(
463
+ args.dataset_name,
464
+ args.dataset_config_name,
465
+ cache_dir=args.cache_dir,
466
+ )
467
+ else:
468
+ data_files = {}
469
+ if args.train_data_dir is not None:
470
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
471
+ dataset = load_dataset(
472
+ "imagefolder",
473
+ data_files=data_files,
474
+ cache_dir=args.cache_dir,
475
+ )
476
+ # See more about loading custom images at
477
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
478
+
479
+ # Preprocessing the datasets.
480
+ # We need to tokenize inputs and targets.
481
+ column_names = dataset["train"].column_names
482
+
483
+ # 6. Get the column names for input/target.
484
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
485
+ if args.image_column is None:
486
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
487
+ else:
488
+ image_column = args.image_column
489
+ if image_column not in column_names:
490
+ raise ValueError(
491
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
492
+ )
493
+ if args.caption_column is None:
494
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
495
+ else:
496
+ caption_column = args.caption_column
497
+ if caption_column not in column_names:
498
+ raise ValueError(
499
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
500
+ )
501
+
502
+ # Preprocessing the datasets.
503
+ # We need to tokenize input captions and transform the images.
504
+ def tokenize_captions(examples, is_train=True):
505
+ captions = []
506
+ for caption in examples[caption_column]:
507
+ if isinstance(caption, str):
508
+ captions.append(caption)
509
+ elif isinstance(caption, (list, np.ndarray)):
510
+ # take a random caption if there are multiple
511
+ captions.append(random.choice(caption) if is_train else caption[0])
512
+ else:
513
+ raise ValueError(
514
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
515
+ )
516
+ inputs = tokenizer(
517
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
518
+ )
519
+ text_input_ids = inputs.input_ids
520
+ text_mask = inputs.attention_mask.bool()
521
+ return text_input_ids, text_mask
522
+
523
+ def preprocess_train(examples):
524
+ images = [image.convert("RGB") for image in examples[image_column]]
525
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
526
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
527
+ return examples
528
+
529
+ with accelerator.main_process_first():
530
+ if args.max_train_samples is not None:
531
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
532
+ # Set the training transforms
533
+ train_dataset = dataset["train"].with_transform(preprocess_train)
534
+
535
+ def collate_fn(examples):
536
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
537
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
538
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
539
+ text_mask = torch.stack([example["text_mask"] for example in examples])
540
+ return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
541
+
542
+ # DataLoaders creation:
543
+ train_dataloader = torch.utils.data.DataLoader(
544
+ train_dataset,
545
+ shuffle=True,
546
+ collate_fn=collate_fn,
547
+ batch_size=args.train_batch_size,
548
+ num_workers=args.dataloader_num_workers,
549
+ )
550
+
551
+ # Scheduler and math around the number of training steps.
552
+ overrode_max_train_steps = False
553
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
554
+ if args.max_train_steps is None:
555
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
556
+ overrode_max_train_steps = True
557
+
558
+ lr_scheduler = get_scheduler(
559
+ args.lr_scheduler,
560
+ optimizer=optimizer,
561
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
562
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
563
+ )
564
+ clip_mean = prior.clip_mean.clone()
565
+ clip_std = prior.clip_std.clone()
566
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
567
+ lora_layers, optimizer, train_dataloader, lr_scheduler
568
+ )
569
+
570
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
571
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
572
+ if overrode_max_train_steps:
573
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
574
+ # Afterwards we recalculate our number of training epochs
575
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
576
+
577
+ # We need to initialize the trackers we use, and also store our configuration.
578
+ # The trackers initializes automatically on the main process.
579
+ if accelerator.is_main_process:
580
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
581
+
582
+ # Train!
583
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
584
+
585
+ logger.info("***** Running training *****")
586
+ logger.info(f" Num examples = {len(train_dataset)}")
587
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
588
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
589
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
590
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
591
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
592
+ global_step = 0
593
+ first_epoch = 0
594
+
595
+ # Potentially load in the weights and states from a previous save
596
+ if args.resume_from_checkpoint:
597
+ if args.resume_from_checkpoint != "latest":
598
+ path = os.path.basename(args.resume_from_checkpoint)
599
+ else:
600
+ # Get the most recent checkpoint
601
+ dirs = os.listdir(args.output_dir)
602
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
603
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
604
+ path = dirs[-1] if len(dirs) > 0 else None
605
+
606
+ if path is None:
607
+ accelerator.print(
608
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
609
+ )
610
+ args.resume_from_checkpoint = None
611
+ initial_global_step = 0
612
+ else:
613
+ accelerator.print(f"Resuming from checkpoint {path}")
614
+ accelerator.load_state(os.path.join(args.output_dir, path))
615
+ global_step = int(path.split("-")[1])
616
+
617
+ initial_global_step = global_step
618
+ first_epoch = global_step // num_update_steps_per_epoch
619
+
620
+ else:
621
+ initial_global_step = 0
622
+
623
+ progress_bar = tqdm(
624
+ range(0, args.max_train_steps),
625
+ initial=initial_global_step,
626
+ desc="Steps",
627
+ # Only show the progress bar once on each machine.
628
+ disable=not accelerator.is_local_main_process,
629
+ )
630
+
631
+ clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
632
+ clip_std = clip_std.to(weight_dtype).to(accelerator.device)
633
+
634
+ for epoch in range(first_epoch, args.num_train_epochs):
635
+ prior.train()
636
+ train_loss = 0.0
637
+ for step, batch in enumerate(train_dataloader):
638
+ with accelerator.accumulate(prior):
639
+ # Convert images to latent space
640
+ text_input_ids, text_mask, clip_images = (
641
+ batch["text_input_ids"],
642
+ batch["text_mask"],
643
+ batch["clip_pixel_values"].to(weight_dtype),
644
+ )
645
+ with torch.no_grad():
646
+ text_encoder_output = text_encoder(text_input_ids)
647
+ prompt_embeds = text_encoder_output.text_embeds
648
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
649
+
650
+ image_embeds = image_encoder(clip_images).image_embeds
651
+ # Sample noise that we'll add to the image_embeds
652
+ noise = torch.randn_like(image_embeds)
653
+ bsz = image_embeds.shape[0]
654
+ # Sample a random timestep for each image
655
+ timesteps = torch.randint(
656
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
657
+ )
658
+ timesteps = timesteps.long()
659
+ image_embeds = (image_embeds - clip_mean) / clip_std
660
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
661
+
662
+ target = image_embeds
663
+
664
+ # Predict the noise residual and compute loss
665
+ model_pred = prior(
666
+ noisy_latents,
667
+ timestep=timesteps,
668
+ proj_embedding=prompt_embeds,
669
+ encoder_hidden_states=text_encoder_hidden_states,
670
+ attention_mask=text_mask,
671
+ ).predicted_image_embedding
672
+
673
+ if args.snr_gamma is None:
674
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
675
+ else:
676
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
677
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
678
+ # This is discussed in Section 4.2 of the same paper.
679
+ snr = compute_snr(noise_scheduler, timesteps)
680
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
681
+ dim=1
682
+ )[0]
683
+ if noise_scheduler.config.prediction_type == "epsilon":
684
+ mse_loss_weights = mse_loss_weights / snr
685
+ elif noise_scheduler.config.prediction_type == "v_prediction":
686
+ mse_loss_weights = mse_loss_weights / (snr + 1)
687
+
688
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
689
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
690
+ loss = loss.mean()
691
+
692
+ # Gather the losses across all processes for logging (if we use distributed training).
693
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
694
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
695
+
696
+ # Backpropagate
697
+ accelerator.backward(loss)
698
+ if accelerator.sync_gradients:
699
+ accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
700
+ optimizer.step()
701
+ lr_scheduler.step()
702
+ optimizer.zero_grad()
703
+
704
+ # Checks if the accelerator has performed an optimization step behind the scenes
705
+ if accelerator.sync_gradients:
706
+ progress_bar.update(1)
707
+ global_step += 1
708
+ accelerator.log({"train_loss": train_loss}, step=global_step)
709
+ train_loss = 0.0
710
+
711
+ if global_step % args.checkpointing_steps == 0:
712
+ if accelerator.is_main_process:
713
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
714
+ if args.checkpoints_total_limit is not None:
715
+ checkpoints = os.listdir(args.output_dir)
716
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
717
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
718
+
719
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
720
+ if len(checkpoints) >= args.checkpoints_total_limit:
721
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
722
+ removing_checkpoints = checkpoints[0:num_to_remove]
723
+
724
+ logger.info(
725
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
726
+ )
727
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
728
+
729
+ for removing_checkpoint in removing_checkpoints:
730
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
731
+ shutil.rmtree(removing_checkpoint)
732
+
733
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
734
+ accelerator.save_state(save_path)
735
+ logger.info(f"Saved state to {save_path}")
736
+
737
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
738
+ progress_bar.set_postfix(**logs)
739
+
740
+ if global_step >= args.max_train_steps:
741
+ break
742
+
743
+ if accelerator.is_main_process:
744
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
745
+ logger.info(
746
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
747
+ f" {args.validation_prompt}."
748
+ )
749
+ # create pipeline
750
+ pipeline = AutoPipelineForText2Image.from_pretrained(
751
+ args.pretrained_decoder_model_name_or_path,
752
+ prior_prior=accelerator.unwrap_model(prior),
753
+ torch_dtype=weight_dtype,
754
+ )
755
+ pipeline = pipeline.to(accelerator.device)
756
+ pipeline.set_progress_bar_config(disable=True)
757
+
758
+ # run inference
759
+ generator = torch.Generator(device=accelerator.device)
760
+ if args.seed is not None:
761
+ generator = generator.manual_seed(args.seed)
762
+ images = []
763
+ for _ in range(args.num_validation_images):
764
+ images.append(
765
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
766
+ )
767
+
768
+ for tracker in accelerator.trackers:
769
+ if tracker.name == "tensorboard":
770
+ np_images = np.stack([np.asarray(img) for img in images])
771
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
772
+ if tracker.name == "wandb":
773
+ tracker.log(
774
+ {
775
+ "validation": [
776
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
777
+ for i, image in enumerate(images)
778
+ ]
779
+ }
780
+ )
781
+
782
+ del pipeline
783
+ torch.cuda.empty_cache()
784
+
785
+ # Save the lora layers
786
+ accelerator.wait_for_everyone()
787
+ if accelerator.is_main_process:
788
+ prior = prior.to(torch.float32)
789
+ prior.save_attn_procs(args.output_dir)
790
+
791
+ if args.push_to_hub:
792
+ save_model_card(
793
+ repo_id,
794
+ images=images,
795
+ base_model=args.pretrained_prior_model_name_or_path,
796
+ dataset_name=args.dataset_name,
797
+ repo_folder=args.output_dir,
798
+ )
799
+ upload_folder(
800
+ repo_id=repo_id,
801
+ folder_path=args.output_dir,
802
+ commit_message="End of training",
803
+ ignore_patterns=["step_*", "epoch_*"],
804
+ )
805
+
806
+ # Final inference
807
+ # Load previous pipeline
808
+ pipeline = AutoPipelineForText2Image.from_pretrained(
809
+ args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
810
+ )
811
+ pipeline = pipeline.to(accelerator.device)
812
+
813
+ # load attention processors
814
+ pipeline.prior_prior.load_attn_procs(args.output_dir)
815
+
816
+ # run inference
817
+ generator = torch.Generator(device=accelerator.device)
818
+ if args.seed is not None:
819
+ generator = generator.manual_seed(args.seed)
820
+ images = []
821
+ for _ in range(args.num_validation_images):
822
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
823
+
824
+ if accelerator.is_main_process:
825
+ for tracker in accelerator.trackers:
826
+ if len(images) != 0:
827
+ if tracker.name == "tensorboard":
828
+ np_images = np.stack([np.asarray(img) for img in images])
829
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
830
+ if tracker.name == "wandb":
831
+ tracker.log(
832
+ {
833
+ "test": [
834
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
835
+ for i, image in enumerate(images)
836
+ ]
837
+ }
838
+ )
839
+
840
+ accelerator.end_training()
841
+
842
+
843
+ if __name__ == "__main__":
844
+ main()
diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ import shutil
22
+ from pathlib import Path
23
+
24
+ import accelerate
25
+ import datasets
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.state import AcceleratorState
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from tqdm import tqdm
39
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
40
+ from transformers.utils import ContextManagers
41
+
42
+ import diffusers
43
+ from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.training_utils import EMAModel, compute_snr
46
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
47
+
48
+
49
+ if is_wandb_available():
50
+ import wandb
51
+
52
+
53
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
54
+ check_min_version("0.35.0.dev0")
55
+
56
+ logger = get_logger(__name__, log_level="INFO")
57
+
58
+ DATASET_NAME_MAPPING = {
59
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
60
+ }
61
+
62
+
63
+ def save_model_card(
64
+ args,
65
+ repo_id: str,
66
+ images=None,
67
+ repo_folder=None,
68
+ ):
69
+ img_str = ""
70
+ if len(images) > 0:
71
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
72
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
73
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
74
+
75
+ yaml = f"""
76
+ ---
77
+ license: creativeml-openrail-m
78
+ base_model: {args.pretrained_prior_model_name_or_path}
79
+ datasets:
80
+ - {args.dataset_name}
81
+ tags:
82
+ - kandinsky
83
+ - text-to-image
84
+ - diffusers
85
+ - diffusers-training
86
+ inference: true
87
+ ---
88
+ """
89
+ model_card = f"""
90
+ # Finetuning - {repo_id}
91
+
92
+ This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
93
+ {img_str}
94
+
95
+ ## Pipeline usage
96
+
97
+ You can use the pipeline like so:
98
+
99
+ ```python
100
+ from diffusers import DiffusionPipeline
101
+ import torch
102
+
103
+ pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
104
+ pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype=torch.float16)
105
+ prompt = "{args.validation_prompts[0]}"
106
+ image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
107
+ image = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]
108
+ image.save("my_image.png")
109
+ ```
110
+
111
+ ## Training info
112
+
113
+ These are the key hyperparameters used during training:
114
+
115
+ * Epochs: {args.num_train_epochs}
116
+ * Learning rate: {args.learning_rate}
117
+ * Batch size: {args.train_batch_size}
118
+ * Gradient accumulation steps: {args.gradient_accumulation_steps}
119
+ * Image resolution: {args.resolution}
120
+ * Mixed-precision: {args.mixed_precision}
121
+
122
+ """
123
+ wandb_info = ""
124
+ if is_wandb_available():
125
+ wandb_run_url = None
126
+ if wandb.run is not None:
127
+ wandb_run_url = wandb.run.url
128
+
129
+ if wandb_run_url is not None:
130
+ wandb_info = f"""
131
+ More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
132
+ """
133
+
134
+ model_card += wandb_info
135
+
136
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
137
+ f.write(yaml + model_card)
138
+
139
+
140
+ def log_validation(
141
+ image_encoder, image_processor, text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch
142
+ ):
143
+ logger.info("Running validation... ")
144
+
145
+ pipeline = AutoPipelineForText2Image.from_pretrained(
146
+ args.pretrained_decoder_model_name_or_path,
147
+ prior_image_encoder=accelerator.unwrap_model(image_encoder),
148
+ prior_image_processor=image_processor,
149
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
150
+ prior_tokenizer=tokenizer,
151
+ prior_prior=accelerator.unwrap_model(prior),
152
+ torch_dtype=weight_dtype,
153
+ )
154
+ pipeline = pipeline.to(accelerator.device)
155
+ pipeline.set_progress_bar_config(disable=True)
156
+
157
+ if args.seed is None:
158
+ generator = None
159
+ else:
160
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
161
+
162
+ images = []
163
+ for i in range(len(args.validation_prompts)):
164
+ with torch.autocast("cuda"):
165
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
166
+
167
+ images.append(image)
168
+
169
+ for tracker in accelerator.trackers:
170
+ if tracker.name == "tensorboard":
171
+ np_images = np.stack([np.asarray(img) for img in images])
172
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
173
+ elif tracker.name == "wandb":
174
+ tracker.log(
175
+ {
176
+ "validation": [
177
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
178
+ for i, image in enumerate(images)
179
+ ]
180
+ }
181
+ )
182
+ else:
183
+ logger.warning(f"image logging not implemented for {tracker.name}")
184
+
185
+ del pipeline
186
+ torch.cuda.empty_cache()
187
+
188
+ return images
189
+
190
+
191
+ def parse_args():
192
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
193
+ parser.add_argument(
194
+ "--pretrained_decoder_model_name_or_path",
195
+ type=str,
196
+ default="kandinsky-community/kandinsky-2-2-decoder",
197
+ required=False,
198
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
199
+ )
200
+ parser.add_argument(
201
+ "--pretrained_prior_model_name_or_path",
202
+ type=str,
203
+ default="kandinsky-community/kandinsky-2-2-prior",
204
+ required=False,
205
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
206
+ )
207
+ parser.add_argument(
208
+ "--dataset_name",
209
+ type=str,
210
+ default=None,
211
+ help=(
212
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
213
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
214
+ " or to a folder containing files that 🤗 Datasets can understand."
215
+ ),
216
+ )
217
+ parser.add_argument(
218
+ "--dataset_config_name",
219
+ type=str,
220
+ default=None,
221
+ help="The config of the Dataset, leave as None if there's only one config.",
222
+ )
223
+ parser.add_argument(
224
+ "--train_data_dir",
225
+ type=str,
226
+ default=None,
227
+ help=(
228
+ "A folder containing the training data. Folder contents must follow the structure described in"
229
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
230
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
231
+ ),
232
+ )
233
+ parser.add_argument(
234
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
235
+ )
236
+ parser.add_argument(
237
+ "--caption_column",
238
+ type=str,
239
+ default="text",
240
+ help="The column of the dataset containing a caption or a list of captions.",
241
+ )
242
+ parser.add_argument(
243
+ "--max_train_samples",
244
+ type=int,
245
+ default=None,
246
+ help=(
247
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
248
+ "value if set."
249
+ ),
250
+ )
251
+ parser.add_argument(
252
+ "--validation_prompts",
253
+ type=str,
254
+ default=None,
255
+ nargs="+",
256
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
257
+ )
258
+ parser.add_argument(
259
+ "--output_dir",
260
+ type=str,
261
+ default="kandi_2_2-model-finetuned",
262
+ help="The output directory where the model predictions and checkpoints will be written.",
263
+ )
264
+ parser.add_argument(
265
+ "--cache_dir",
266
+ type=str,
267
+ default=None,
268
+ help="The directory where the downloaded models and datasets will be stored.",
269
+ )
270
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
271
+ parser.add_argument(
272
+ "--resolution",
273
+ type=int,
274
+ default=512,
275
+ help=(
276
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
277
+ " resolution"
278
+ ),
279
+ )
280
+ parser.add_argument(
281
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
282
+ )
283
+ parser.add_argument("--num_train_epochs", type=int, default=100)
284
+ parser.add_argument(
285
+ "--max_train_steps",
286
+ type=int,
287
+ default=None,
288
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
289
+ )
290
+ parser.add_argument(
291
+ "--gradient_accumulation_steps",
292
+ type=int,
293
+ default=1,
294
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
295
+ )
296
+ parser.add_argument(
297
+ "--learning_rate",
298
+ type=float,
299
+ default=1e-4,
300
+ help="learning rate",
301
+ )
302
+ parser.add_argument(
303
+ "--lr_scheduler",
304
+ type=str,
305
+ default="constant",
306
+ help=(
307
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
308
+ ' "constant", "constant_with_warmup"]'
309
+ ),
310
+ )
311
+ parser.add_argument(
312
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
313
+ )
314
+ parser.add_argument(
315
+ "--snr_gamma",
316
+ type=float,
317
+ default=None,
318
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
319
+ "More details here: https://huggingface.co/papers/2303.09556.",
320
+ )
321
+ parser.add_argument(
322
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
323
+ )
324
+ parser.add_argument(
325
+ "--allow_tf32",
326
+ action="store_true",
327
+ help=(
328
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
329
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
330
+ ),
331
+ )
332
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
333
+ parser.add_argument(
334
+ "--dataloader_num_workers",
335
+ type=int,
336
+ default=0,
337
+ help=(
338
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
339
+ ),
340
+ )
341
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
342
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
343
+ parser.add_argument(
344
+ "--adam_weight_decay",
345
+ type=float,
346
+ default=0.0,
347
+ required=False,
348
+ help="weight decay_to_use",
349
+ )
350
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
351
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
352
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
353
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
354
+ parser.add_argument(
355
+ "--hub_model_id",
356
+ type=str,
357
+ default=None,
358
+ help="The name of the repository to keep in sync with the local `output_dir`.",
359
+ )
360
+ parser.add_argument(
361
+ "--logging_dir",
362
+ type=str,
363
+ default="logs",
364
+ help=(
365
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
366
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
367
+ ),
368
+ )
369
+ parser.add_argument(
370
+ "--mixed_precision",
371
+ type=str,
372
+ default=None,
373
+ choices=["no", "fp16", "bf16"],
374
+ help=(
375
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
376
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
377
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--report_to",
382
+ type=str,
383
+ default="tensorboard",
384
+ help=(
385
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
386
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
387
+ ),
388
+ )
389
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
390
+ parser.add_argument(
391
+ "--checkpointing_steps",
392
+ type=int,
393
+ default=500,
394
+ help=(
395
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
396
+ " training using `--resume_from_checkpoint`."
397
+ ),
398
+ )
399
+ parser.add_argument(
400
+ "--checkpoints_total_limit",
401
+ type=int,
402
+ default=None,
403
+ help=("Max number of checkpoints to store."),
404
+ )
405
+ parser.add_argument(
406
+ "--resume_from_checkpoint",
407
+ type=str,
408
+ default=None,
409
+ help=(
410
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
411
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
412
+ ),
413
+ )
414
+ parser.add_argument(
415
+ "--validation_epochs",
416
+ type=int,
417
+ default=5,
418
+ help="Run validation every X epochs.",
419
+ )
420
+ parser.add_argument(
421
+ "--tracker_project_name",
422
+ type=str,
423
+ default="text2image-fine-tune",
424
+ help=(
425
+ "The `project_name` argument passed to Accelerator.init_trackers for"
426
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
427
+ ),
428
+ )
429
+
430
+ args = parser.parse_args()
431
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
432
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
433
+ args.local_rank = env_local_rank
434
+
435
+ # Sanity checks
436
+ if args.dataset_name is None and args.train_data_dir is None:
437
+ raise ValueError("Need either a dataset name or a training folder.")
438
+
439
+ return args
440
+
441
+
442
+ def main():
443
+ args = parse_args()
444
+
445
+ if args.report_to == "wandb" and args.hub_token is not None:
446
+ raise ValueError(
447
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
448
+ " Please use `huggingface-cli login` to authenticate with the Hub."
449
+ )
450
+
451
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
452
+ accelerator_project_config = ProjectConfiguration(
453
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
454
+ )
455
+ accelerator = Accelerator(
456
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
457
+ mixed_precision=args.mixed_precision,
458
+ log_with=args.report_to,
459
+ project_config=accelerator_project_config,
460
+ )
461
+
462
+ # Disable AMP for MPS.
463
+ if torch.backends.mps.is_available():
464
+ accelerator.native_amp = False
465
+
466
+ # Make one log on every process with the configuration for debugging.
467
+ logging.basicConfig(
468
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
469
+ datefmt="%m/%d/%Y %H:%M:%S",
470
+ level=logging.INFO,
471
+ )
472
+ logger.info(accelerator.state, main_process_only=False)
473
+ if accelerator.is_local_main_process:
474
+ datasets.utils.logging.set_verbosity_warning()
475
+ transformers.utils.logging.set_verbosity_warning()
476
+ diffusers.utils.logging.set_verbosity_info()
477
+ else:
478
+ datasets.utils.logging.set_verbosity_error()
479
+ transformers.utils.logging.set_verbosity_error()
480
+ diffusers.utils.logging.set_verbosity_error()
481
+
482
+ # If passed along, set the training seed now.
483
+ if args.seed is not None:
484
+ set_seed(args.seed)
485
+
486
+ # Handle the repository creation
487
+ if accelerator.is_main_process:
488
+ if args.output_dir is not None:
489
+ os.makedirs(args.output_dir, exist_ok=True)
490
+
491
+ if args.push_to_hub:
492
+ repo_id = create_repo(
493
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
494
+ ).repo_id
495
+
496
+ # Load scheduler, image_processor, tokenizer and models.
497
+ noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
498
+ image_processor = CLIPImageProcessor.from_pretrained(
499
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
500
+ )
501
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
502
+
503
+ def deepspeed_zero_init_disabled_context_manager():
504
+ """
505
+ returns either a context list that includes one that will disable zero.Init or an empty context list
506
+ """
507
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
508
+ if deepspeed_plugin is None:
509
+ return []
510
+
511
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
512
+
513
+ weight_dtype = torch.float32
514
+ if accelerator.mixed_precision == "fp16":
515
+ weight_dtype = torch.float16
516
+ elif accelerator.mixed_precision == "bf16":
517
+ weight_dtype = torch.bfloat16
518
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
519
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
520
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
521
+ ).eval()
522
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
523
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
524
+ ).eval()
525
+
526
+ prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
527
+
528
+ # Freeze text_encoder and image_encoder
529
+ text_encoder.requires_grad_(False)
530
+ image_encoder.requires_grad_(False)
531
+
532
+ # Set prior to trainable.
533
+ prior.train()
534
+
535
+ # Create EMA for the prior.
536
+ if args.use_ema:
537
+ ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
538
+ ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
539
+ ema_prior.to(accelerator.device)
540
+
541
+ # `accelerate` 0.16.0 will have better support for customized saving
542
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
543
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
544
+ def save_model_hook(models, weights, output_dir):
545
+ if args.use_ema:
546
+ ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
547
+
548
+ for i, model in enumerate(models):
549
+ model.save_pretrained(os.path.join(output_dir, "prior"))
550
+
551
+ # make sure to pop weight so that corresponding model is not saved again
552
+ weights.pop()
553
+
554
+ def load_model_hook(models, input_dir):
555
+ if args.use_ema:
556
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), PriorTransformer)
557
+ ema_prior.load_state_dict(load_model.state_dict())
558
+ ema_prior.to(accelerator.device)
559
+ del load_model
560
+
561
+ for i in range(len(models)):
562
+ # pop models so that they are not loaded again
563
+ model = models.pop()
564
+
565
+ # load diffusers style into model
566
+ load_model = PriorTransformer.from_pretrained(input_dir, subfolder="prior")
567
+ model.register_to_config(**load_model.config)
568
+
569
+ model.load_state_dict(load_model.state_dict())
570
+ del load_model
571
+
572
+ accelerator.register_save_state_pre_hook(save_model_hook)
573
+ accelerator.register_load_state_pre_hook(load_model_hook)
574
+
575
+ if args.allow_tf32:
576
+ torch.backends.cuda.matmul.allow_tf32 = True
577
+
578
+ if args.use_8bit_adam:
579
+ try:
580
+ import bitsandbytes as bnb
581
+ except ImportError:
582
+ raise ImportError(
583
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
584
+ )
585
+
586
+ optimizer_cls = bnb.optim.AdamW8bit
587
+ else:
588
+ optimizer_cls = torch.optim.AdamW
589
+ optimizer = optimizer_cls(
590
+ prior.parameters(),
591
+ lr=args.learning_rate,
592
+ betas=(args.adam_beta1, args.adam_beta2),
593
+ weight_decay=args.adam_weight_decay,
594
+ eps=args.adam_epsilon,
595
+ )
596
+
597
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
598
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
599
+
600
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
601
+ # download the dataset.
602
+ if args.dataset_name is not None:
603
+ # Downloading and loading a dataset from the hub.
604
+ dataset = load_dataset(
605
+ args.dataset_name,
606
+ args.dataset_config_name,
607
+ cache_dir=args.cache_dir,
608
+ )
609
+ else:
610
+ data_files = {}
611
+ if args.train_data_dir is not None:
612
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
613
+ dataset = load_dataset(
614
+ "imagefolder",
615
+ data_files=data_files,
616
+ cache_dir=args.cache_dir,
617
+ )
618
+ # See more about loading custom images at
619
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
620
+
621
+ # Preprocessing the datasets.
622
+ # We need to tokenize inputs and targets.
623
+ column_names = dataset["train"].column_names
624
+
625
+ # 6. Get the column names for input/target.
626
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
627
+ if args.image_column is None:
628
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
629
+ else:
630
+ image_column = args.image_column
631
+ if image_column not in column_names:
632
+ raise ValueError(
633
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
634
+ )
635
+ if args.caption_column is None:
636
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
637
+ else:
638
+ caption_column = args.caption_column
639
+ if caption_column not in column_names:
640
+ raise ValueError(
641
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
642
+ )
643
+
644
+ # Preprocessing the datasets.
645
+ # We need to tokenize input captions and transform the images.
646
+ def tokenize_captions(examples, is_train=True):
647
+ captions = []
648
+ for caption in examples[caption_column]:
649
+ if isinstance(caption, str):
650
+ captions.append(caption)
651
+ elif isinstance(caption, (list, np.ndarray)):
652
+ # take a random caption if there are multiple
653
+ captions.append(random.choice(caption) if is_train else caption[0])
654
+ else:
655
+ raise ValueError(
656
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
657
+ )
658
+ inputs = tokenizer(
659
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
660
+ )
661
+ text_input_ids = inputs.input_ids
662
+ text_mask = inputs.attention_mask.bool()
663
+ return text_input_ids, text_mask
664
+
665
+ def preprocess_train(examples):
666
+ images = [image.convert("RGB") for image in examples[image_column]]
667
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
668
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
669
+ return examples
670
+
671
+ with accelerator.main_process_first():
672
+ if args.max_train_samples is not None:
673
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
674
+ # Set the training transforms
675
+ train_dataset = dataset["train"].with_transform(preprocess_train)
676
+
677
+ def collate_fn(examples):
678
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
679
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
680
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
681
+ text_mask = torch.stack([example["text_mask"] for example in examples])
682
+ return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
683
+
684
+ # DataLoaders creation:
685
+ train_dataloader = torch.utils.data.DataLoader(
686
+ train_dataset,
687
+ shuffle=True,
688
+ collate_fn=collate_fn,
689
+ batch_size=args.train_batch_size,
690
+ num_workers=args.dataloader_num_workers,
691
+ )
692
+
693
+ # Scheduler and math around the number of training steps.
694
+ overrode_max_train_steps = False
695
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
696
+ if args.max_train_steps is None:
697
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
698
+ overrode_max_train_steps = True
699
+
700
+ lr_scheduler = get_scheduler(
701
+ args.lr_scheduler,
702
+ optimizer=optimizer,
703
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
704
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
705
+ )
706
+
707
+ clip_mean = prior.clip_mean.clone()
708
+ clip_std = prior.clip_std.clone()
709
+
710
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
711
+ prior, optimizer, train_dataloader, lr_scheduler
712
+ )
713
+
714
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
715
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
716
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
717
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
718
+ if overrode_max_train_steps:
719
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
720
+ # Afterwards we recalculate our number of training epochs
721
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
722
+
723
+ # We need to initialize the trackers we use, and also store our configuration.
724
+ # The trackers initializes automatically on the main process.
725
+ if accelerator.is_main_process:
726
+ tracker_config = dict(vars(args))
727
+ tracker_config.pop("validation_prompts")
728
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
729
+
730
+ # Train!
731
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
732
+
733
+ logger.info("***** Running training *****")
734
+ logger.info(f" Num examples = {len(train_dataset)}")
735
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
736
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
737
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
738
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
739
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
740
+ global_step = 0
741
+ first_epoch = 0
742
+
743
+ # Potentially load in the weights and states from a previous save
744
+ if args.resume_from_checkpoint:
745
+ if args.resume_from_checkpoint != "latest":
746
+ path = os.path.basename(args.resume_from_checkpoint)
747
+ else:
748
+ # Get the most recent checkpoint
749
+ dirs = os.listdir(args.output_dir)
750
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
751
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
752
+ path = dirs[-1] if len(dirs) > 0 else None
753
+
754
+ if path is None:
755
+ accelerator.print(
756
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
757
+ )
758
+ args.resume_from_checkpoint = None
759
+ initial_global_step = 0
760
+ else:
761
+ accelerator.print(f"Resuming from checkpoint {path}")
762
+ accelerator.load_state(os.path.join(args.output_dir, path))
763
+ global_step = int(path.split("-")[1])
764
+
765
+ initial_global_step = global_step
766
+ first_epoch = global_step // num_update_steps_per_epoch
767
+ else:
768
+ initial_global_step = 0
769
+
770
+ progress_bar = tqdm(
771
+ range(0, args.max_train_steps),
772
+ initial=initial_global_step,
773
+ desc="Steps",
774
+ # Only show the progress bar once on each machine.
775
+ disable=not accelerator.is_local_main_process,
776
+ )
777
+
778
+ clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
779
+ clip_std = clip_std.to(weight_dtype).to(accelerator.device)
780
+
781
+ for epoch in range(first_epoch, args.num_train_epochs):
782
+ train_loss = 0.0
783
+ for step, batch in enumerate(train_dataloader):
784
+ with accelerator.accumulate(prior):
785
+ # Convert images to latent space
786
+ text_input_ids, text_mask, clip_images = (
787
+ batch["text_input_ids"],
788
+ batch["text_mask"],
789
+ batch["clip_pixel_values"].to(weight_dtype),
790
+ )
791
+ with torch.no_grad():
792
+ text_encoder_output = text_encoder(text_input_ids)
793
+ prompt_embeds = text_encoder_output.text_embeds
794
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
795
+
796
+ image_embeds = image_encoder(clip_images).image_embeds
797
+ # Sample noise that we'll add to the image_embeds
798
+ noise = torch.randn_like(image_embeds)
799
+ bsz = image_embeds.shape[0]
800
+ # Sample a random timestep for each image
801
+ timesteps = torch.randint(
802
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
803
+ )
804
+ timesteps = timesteps.long()
805
+ image_embeds = (image_embeds - clip_mean) / clip_std
806
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
807
+
808
+ target = image_embeds
809
+
810
+ # Predict the noise residual and compute loss
811
+ model_pred = prior(
812
+ noisy_latents,
813
+ timestep=timesteps,
814
+ proj_embedding=prompt_embeds,
815
+ encoder_hidden_states=text_encoder_hidden_states,
816
+ attention_mask=text_mask,
817
+ ).predicted_image_embedding
818
+
819
+ if args.snr_gamma is None:
820
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
821
+ else:
822
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
823
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
824
+ # This is discussed in Section 4.2 of the same paper.
825
+ snr = compute_snr(noise_scheduler, timesteps)
826
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
827
+ dim=1
828
+ )[0]
829
+ if noise_scheduler.config.prediction_type == "epsilon":
830
+ mse_loss_weights = mse_loss_weights / snr
831
+ elif noise_scheduler.config.prediction_type == "v_prediction":
832
+ mse_loss_weights = mse_loss_weights / (snr + 1)
833
+
834
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
835
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
836
+ loss = loss.mean()
837
+
838
+ # Gather the losses across all processes for logging (if we use distributed training).
839
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
840
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
841
+
842
+ # Backpropagate
843
+ accelerator.backward(loss)
844
+ if accelerator.sync_gradients:
845
+ accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
846
+ optimizer.step()
847
+ lr_scheduler.step()
848
+ optimizer.zero_grad()
849
+
850
+ # Checks if the accelerator has performed an optimization step behind the scenes
851
+ if accelerator.sync_gradients:
852
+ if args.use_ema:
853
+ ema_prior.step(prior.parameters())
854
+ progress_bar.update(1)
855
+ global_step += 1
856
+ accelerator.log({"train_loss": train_loss}, step=global_step)
857
+ train_loss = 0.0
858
+
859
+ if global_step % args.checkpointing_steps == 0:
860
+ if accelerator.is_main_process:
861
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
862
+ if args.checkpoints_total_limit is not None:
863
+ checkpoints = os.listdir(args.output_dir)
864
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
865
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
866
+
867
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
868
+ if len(checkpoints) >= args.checkpoints_total_limit:
869
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
870
+ removing_checkpoints = checkpoints[0:num_to_remove]
871
+
872
+ logger.info(
873
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
874
+ )
875
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
876
+
877
+ for removing_checkpoint in removing_checkpoints:
878
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
879
+ shutil.rmtree(removing_checkpoint)
880
+
881
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
882
+ accelerator.save_state(save_path)
883
+ logger.info(f"Saved state to {save_path}")
884
+
885
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
886
+ progress_bar.set_postfix(**logs)
887
+
888
+ if global_step >= args.max_train_steps:
889
+ break
890
+
891
+ if accelerator.is_main_process:
892
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
893
+ if args.use_ema:
894
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
895
+ ema_prior.store(prior.parameters())
896
+ ema_prior.copy_to(prior.parameters())
897
+ log_validation(
898
+ image_encoder,
899
+ image_processor,
900
+ text_encoder,
901
+ tokenizer,
902
+ prior,
903
+ args,
904
+ accelerator,
905
+ weight_dtype,
906
+ global_step,
907
+ )
908
+ if args.use_ema:
909
+ # Switch back to the original UNet parameters.
910
+ ema_prior.restore(prior.parameters())
911
+
912
+ # Create the pipeline using the trained modules and save it.
913
+ accelerator.wait_for_everyone()
914
+ if accelerator.is_main_process:
915
+ prior = accelerator.unwrap_model(prior)
916
+ if args.use_ema:
917
+ ema_prior.copy_to(prior.parameters())
918
+
919
+ pipeline = AutoPipelineForText2Image.from_pretrained(
920
+ args.pretrained_decoder_model_name_or_path,
921
+ prior_image_encoder=image_encoder,
922
+ prior_text_encoder=text_encoder,
923
+ prior_prior=prior,
924
+ )
925
+ pipeline.prior_pipe.save_pretrained(args.output_dir)
926
+
927
+ # Run a final round of inference.
928
+ images = []
929
+ if args.validation_prompts is not None:
930
+ logger.info("Running inference for collecting generated images...")
931
+ pipeline = pipeline.to(accelerator.device)
932
+ pipeline.torch_dtype = weight_dtype
933
+ pipeline.set_progress_bar_config(disable=True)
934
+
935
+ if args.seed is None:
936
+ generator = None
937
+ else:
938
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
939
+
940
+ for i in range(len(args.validation_prompts)):
941
+ with torch.autocast("cuda"):
942
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
943
+ images.append(image)
944
+
945
+ if args.push_to_hub:
946
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
947
+ upload_folder(
948
+ repo_id=repo_id,
949
+ folder_path=args.output_dir,
950
+ commit_message="End of training",
951
+ ignore_patterns=["step_*", "epoch_*"],
952
+ )
953
+
954
+ accelerator.end_training()
955
+
956
+
957
+ if __name__ == "__main__":
958
+ main()
diffusers/examples/research_projects/anytext/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AnyTextPipeline
2
+
3
+ Project page: https://aigcdesigngroup.github.io/homepage_anytext
4
+
5
+ "AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
6
+
7
+ > **Note:** Each text line that needs to be generated should be enclosed in double quotes.
8
+
9
+ For any usage questions, please refer to the [paper](https://huggingface.co/papers/2311.03054).
10
+
11
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
12
+
13
+ ```py
14
+ # This example requires the `anytext_controlnet.py` file:
15
+ # !git clone --depth 1 https://github.com/huggingface/diffusers.git
16
+ # %cd diffusers/examples/research_projects/anytext
17
+ # Let's choose a font file shared by an HF staff:
18
+ # !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
19
+
20
+ import torch
21
+ from diffusers import DiffusionPipeline
22
+ from anytext_controlnet import AnyTextControlNetModel
23
+ from diffusers.utils import load_image
24
+
25
+
26
+ anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
27
+ variant="fp16",)
28
+ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
29
+ controlnet=anytext_controlnet, torch_dtype=torch.float16,
30
+ trust_remote_code=False, # One needs to give permission to run this pipeline's code
31
+ ).to("cuda")
32
+
33
+ # generate image
34
+ prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
35
+ draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
36
+ # There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
37
+ image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
38
+ ).images[0]
39
+ image
40
+ ```
diffusers/examples/research_projects/anytext/anytext.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/examples/research_projects/anytext/anytext_controlnet.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
16
+ # Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
17
+ # Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
18
+ #
19
+ # Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
20
+
21
+
22
+ from typing import Any, Dict, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from diffusers.configuration_utils import register_to_config
28
+ from diffusers.models.controlnets.controlnet import (
29
+ ControlNetModel,
30
+ ControlNetOutput,
31
+ )
32
+ from diffusers.utils import logging
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class AnyTextControlNetConditioningEmbedding(nn.Module):
39
+ """
40
+ Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
41
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
42
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
43
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
44
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
45
+ model) to encode image-space conditions ... into feature maps ..."
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ conditioning_embedding_channels: int,
51
+ glyph_channels=1,
52
+ position_channels=1,
53
+ ):
54
+ super().__init__()
55
+
56
+ self.glyph_block = nn.Sequential(
57
+ nn.Conv2d(glyph_channels, 8, 3, padding=1),
58
+ nn.SiLU(),
59
+ nn.Conv2d(8, 8, 3, padding=1),
60
+ nn.SiLU(),
61
+ nn.Conv2d(8, 16, 3, padding=1, stride=2),
62
+ nn.SiLU(),
63
+ nn.Conv2d(16, 16, 3, padding=1),
64
+ nn.SiLU(),
65
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
66
+ nn.SiLU(),
67
+ nn.Conv2d(32, 32, 3, padding=1),
68
+ nn.SiLU(),
69
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
70
+ nn.SiLU(),
71
+ nn.Conv2d(96, 96, 3, padding=1),
72
+ nn.SiLU(),
73
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
74
+ nn.SiLU(),
75
+ )
76
+
77
+ self.position_block = nn.Sequential(
78
+ nn.Conv2d(position_channels, 8, 3, padding=1),
79
+ nn.SiLU(),
80
+ nn.Conv2d(8, 8, 3, padding=1),
81
+ nn.SiLU(),
82
+ nn.Conv2d(8, 16, 3, padding=1, stride=2),
83
+ nn.SiLU(),
84
+ nn.Conv2d(16, 16, 3, padding=1),
85
+ nn.SiLU(),
86
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
87
+ nn.SiLU(),
88
+ nn.Conv2d(32, 32, 3, padding=1),
89
+ nn.SiLU(),
90
+ nn.Conv2d(32, 64, 3, padding=1, stride=2),
91
+ nn.SiLU(),
92
+ )
93
+
94
+ self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
95
+
96
+ def forward(self, glyphs, positions, text_info):
97
+ glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
98
+ position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
99
+ guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
100
+
101
+ return guided_hint
102
+
103
+
104
+ class AnyTextControlNetModel(ControlNetModel):
105
+ """
106
+ A AnyTextControlNetModel model.
107
+
108
+ Args:
109
+ in_channels (`int`, defaults to 4):
110
+ The number of channels in the input sample.
111
+ flip_sin_to_cos (`bool`, defaults to `True`):
112
+ Whether to flip the sin to cos in the time embedding.
113
+ freq_shift (`int`, defaults to 0):
114
+ The frequency shift to apply to the time embedding.
115
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
116
+ The tuple of downsample blocks to use.
117
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
118
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
119
+ The tuple of output channels for each block.
120
+ layers_per_block (`int`, defaults to 2):
121
+ The number of layers per block.
122
+ downsample_padding (`int`, defaults to 1):
123
+ The padding to use for the downsampling convolution.
124
+ mid_block_scale_factor (`float`, defaults to 1):
125
+ The scale factor to use for the mid block.
126
+ act_fn (`str`, defaults to "silu"):
127
+ The activation function to use.
128
+ norm_num_groups (`int`, *optional*, defaults to 32):
129
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
130
+ in post-processing.
131
+ norm_eps (`float`, defaults to 1e-5):
132
+ The epsilon to use for the normalization.
133
+ cross_attention_dim (`int`, defaults to 1280):
134
+ The dimension of the cross attention features.
135
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
136
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
137
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
138
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
139
+ encoder_hid_dim (`int`, *optional*, defaults to None):
140
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
141
+ dimension to `cross_attention_dim`.
142
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
143
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
144
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
145
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
146
+ The dimension of the attention heads.
147
+ use_linear_projection (`bool`, defaults to `False`):
148
+ class_embed_type (`str`, *optional*, defaults to `None`):
149
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
150
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
151
+ addition_embed_type (`str`, *optional*, defaults to `None`):
152
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
153
+ "text". "text" will use the `TextTimeEmbedding` layer.
154
+ num_class_embeds (`int`, *optional*, defaults to 0):
155
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
156
+ class conditioning with `class_embed_type` equal to `None`.
157
+ upcast_attention (`bool`, defaults to `False`):
158
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
159
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
160
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
161
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
162
+ `class_embed_type="projection"`.
163
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
164
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
165
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
166
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
167
+ global_pool_conditions (`bool`, defaults to `False`):
168
+ TODO(Patrick) - unused parameter.
169
+ addition_embed_type_num_heads (`int`, defaults to 64):
170
+ The number of heads to use for the `TextTimeEmbedding` layer.
171
+ """
172
+
173
+ _supports_gradient_checkpointing = True
174
+
175
+ @register_to_config
176
+ def __init__(
177
+ self,
178
+ in_channels: int = 4,
179
+ conditioning_channels: int = 1,
180
+ flip_sin_to_cos: bool = True,
181
+ freq_shift: int = 0,
182
+ down_block_types: Tuple[str, ...] = (
183
+ "CrossAttnDownBlock2D",
184
+ "CrossAttnDownBlock2D",
185
+ "CrossAttnDownBlock2D",
186
+ "DownBlock2D",
187
+ ),
188
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
189
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
190
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
191
+ layers_per_block: int = 2,
192
+ downsample_padding: int = 1,
193
+ mid_block_scale_factor: float = 1,
194
+ act_fn: str = "silu",
195
+ norm_num_groups: Optional[int] = 32,
196
+ norm_eps: float = 1e-5,
197
+ cross_attention_dim: int = 1280,
198
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
199
+ encoder_hid_dim: Optional[int] = None,
200
+ encoder_hid_dim_type: Optional[str] = None,
201
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
202
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
203
+ use_linear_projection: bool = False,
204
+ class_embed_type: Optional[str] = None,
205
+ addition_embed_type: Optional[str] = None,
206
+ addition_time_embed_dim: Optional[int] = None,
207
+ num_class_embeds: Optional[int] = None,
208
+ upcast_attention: bool = False,
209
+ resnet_time_scale_shift: str = "default",
210
+ projection_class_embeddings_input_dim: Optional[int] = None,
211
+ controlnet_conditioning_channel_order: str = "rgb",
212
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
213
+ global_pool_conditions: bool = False,
214
+ addition_embed_type_num_heads: int = 64,
215
+ ):
216
+ super().__init__(
217
+ in_channels,
218
+ conditioning_channels,
219
+ flip_sin_to_cos,
220
+ freq_shift,
221
+ down_block_types,
222
+ mid_block_type,
223
+ only_cross_attention,
224
+ block_out_channels,
225
+ layers_per_block,
226
+ downsample_padding,
227
+ mid_block_scale_factor,
228
+ act_fn,
229
+ norm_num_groups,
230
+ norm_eps,
231
+ cross_attention_dim,
232
+ transformer_layers_per_block,
233
+ encoder_hid_dim,
234
+ encoder_hid_dim_type,
235
+ attention_head_dim,
236
+ num_attention_heads,
237
+ use_linear_projection,
238
+ class_embed_type,
239
+ addition_embed_type,
240
+ addition_time_embed_dim,
241
+ num_class_embeds,
242
+ upcast_attention,
243
+ resnet_time_scale_shift,
244
+ projection_class_embeddings_input_dim,
245
+ controlnet_conditioning_channel_order,
246
+ conditioning_embedding_out_channels,
247
+ global_pool_conditions,
248
+ addition_embed_type_num_heads,
249
+ )
250
+
251
+ # control net conditioning embedding
252
+ self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
253
+ conditioning_embedding_channels=block_out_channels[0],
254
+ glyph_channels=conditioning_channels,
255
+ position_channels=conditioning_channels,
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ sample: torch.Tensor,
261
+ timestep: Union[torch.Tensor, float, int],
262
+ encoder_hidden_states: torch.Tensor,
263
+ controlnet_cond: torch.Tensor,
264
+ conditioning_scale: float = 1.0,
265
+ class_labels: Optional[torch.Tensor] = None,
266
+ timestep_cond: Optional[torch.Tensor] = None,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
269
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
270
+ guess_mode: bool = False,
271
+ return_dict: bool = True,
272
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
273
+ """
274
+ The [`~PromptDiffusionControlNetModel`] forward method.
275
+
276
+ Args:
277
+ sample (`torch.Tensor`):
278
+ The noisy input tensor.
279
+ timestep (`Union[torch.Tensor, float, int]`):
280
+ The number of timesteps to denoise an input.
281
+ encoder_hidden_states (`torch.Tensor`):
282
+ The encoder hidden states.
283
+ #controlnet_cond (`torch.Tensor`):
284
+ # The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
285
+ conditioning_scale (`float`, defaults to `1.0`):
286
+ The scale factor for ControlNet outputs.
287
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
288
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
289
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
290
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
291
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
292
+ embeddings.
293
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
294
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
295
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
296
+ negative values to the attention scores corresponding to "discard" tokens.
297
+ added_cond_kwargs (`dict`):
298
+ Additional conditions for the Stable Diffusion XL UNet.
299
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
300
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
301
+ guess_mode (`bool`, defaults to `False`):
302
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
303
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
304
+ return_dict (`bool`, defaults to `True`):
305
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
306
+
307
+ Returns:
308
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
309
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
310
+ returned where the first element is the sample tensor.
311
+ """
312
+ # check channel order
313
+ channel_order = self.config.controlnet_conditioning_channel_order
314
+
315
+ if channel_order == "rgb":
316
+ # in rgb order by default
317
+ ...
318
+ # elif channel_order == "bgr":
319
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
320
+ else:
321
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
322
+
323
+ # prepare attention_mask
324
+ if attention_mask is not None:
325
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
326
+ attention_mask = attention_mask.unsqueeze(1)
327
+
328
+ # 1. time
329
+ timesteps = timestep
330
+ if not torch.is_tensor(timesteps):
331
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
332
+ # This would be a good case for the `match` statement (Python 3.10+)
333
+ is_mps = sample.device.type == "mps"
334
+ if isinstance(timestep, float):
335
+ dtype = torch.float32 if is_mps else torch.float64
336
+ else:
337
+ dtype = torch.int32 if is_mps else torch.int64
338
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
339
+ elif len(timesteps.shape) == 0:
340
+ timesteps = timesteps[None].to(sample.device)
341
+
342
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
343
+ timesteps = timesteps.expand(sample.shape[0])
344
+
345
+ t_emb = self.time_proj(timesteps)
346
+
347
+ # timesteps does not contain any weights and will always return f32 tensors
348
+ # but time_embedding might actually be running in fp16. so we need to cast here.
349
+ # there might be better ways to encapsulate this.
350
+ t_emb = t_emb.to(dtype=sample.dtype)
351
+
352
+ emb = self.time_embedding(t_emb, timestep_cond)
353
+ aug_emb = None
354
+
355
+ if self.class_embedding is not None:
356
+ if class_labels is None:
357
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
358
+
359
+ if self.config.class_embed_type == "timestep":
360
+ class_labels = self.time_proj(class_labels)
361
+
362
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
363
+ emb = emb + class_emb
364
+
365
+ if self.config.addition_embed_type is not None:
366
+ if self.config.addition_embed_type == "text":
367
+ aug_emb = self.add_embedding(encoder_hidden_states)
368
+
369
+ elif self.config.addition_embed_type == "text_time":
370
+ if "text_embeds" not in added_cond_kwargs:
371
+ raise ValueError(
372
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
373
+ )
374
+ text_embeds = added_cond_kwargs.get("text_embeds")
375
+ if "time_ids" not in added_cond_kwargs:
376
+ raise ValueError(
377
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
378
+ )
379
+ time_ids = added_cond_kwargs.get("time_ids")
380
+ time_embeds = self.add_time_proj(time_ids.flatten())
381
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
382
+
383
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
384
+ add_embeds = add_embeds.to(emb.dtype)
385
+ aug_emb = self.add_embedding(add_embeds)
386
+
387
+ emb = emb + aug_emb if aug_emb is not None else emb
388
+
389
+ # 2. pre-process
390
+ sample = self.conv_in(sample)
391
+
392
+ controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
393
+ sample = sample + controlnet_cond
394
+
395
+ # 3. down
396
+ down_block_res_samples = (sample,)
397
+ for downsample_block in self.down_blocks:
398
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
399
+ sample, res_samples = downsample_block(
400
+ hidden_states=sample,
401
+ temb=emb,
402
+ encoder_hidden_states=encoder_hidden_states,
403
+ attention_mask=attention_mask,
404
+ cross_attention_kwargs=cross_attention_kwargs,
405
+ )
406
+ else:
407
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
408
+
409
+ down_block_res_samples += res_samples
410
+
411
+ # 4. mid
412
+ if self.mid_block is not None:
413
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
414
+ sample = self.mid_block(
415
+ sample,
416
+ emb,
417
+ encoder_hidden_states=encoder_hidden_states,
418
+ attention_mask=attention_mask,
419
+ cross_attention_kwargs=cross_attention_kwargs,
420
+ )
421
+ else:
422
+ sample = self.mid_block(sample, emb)
423
+
424
+ # 5. Control net blocks
425
+ controlnet_down_block_res_samples = ()
426
+
427
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
428
+ down_block_res_sample = controlnet_block(down_block_res_sample)
429
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
430
+
431
+ down_block_res_samples = controlnet_down_block_res_samples
432
+
433
+ mid_block_res_sample = self.controlnet_mid_block(sample)
434
+
435
+ # 6. scaling
436
+ if guess_mode and not self.config.global_pool_conditions:
437
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
438
+ scales = scales * conditioning_scale
439
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
440
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
441
+ else:
442
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
443
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
444
+
445
+ if self.config.global_pool_conditions:
446
+ down_block_res_samples = [
447
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
448
+ ]
449
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
450
+
451
+ if not return_dict:
452
+ return (down_block_res_samples, mid_block_res_sample)
453
+
454
+ return ControlNetOutput(
455
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
456
+ )
457
+
458
+
459
+ # Copied from diffusers.models.controlnet.zero_module
460
+ def zero_module(module):
461
+ for p in module.parameters():
462
+ nn.init.zeros_(p)
463
+ return module
diffusers/examples/research_projects/anytext/ocr_recog/RNN.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .RecSVTR import Block
5
+
6
+
7
+ class Swish(nn.Module):
8
+ def __int__(self):
9
+ super(Swish, self).__int__()
10
+
11
+ def forward(self, x):
12
+ return x * torch.sigmoid(x)
13
+
14
+
15
+ class Im2Im(nn.Module):
16
+ def __init__(self, in_channels, **kwargs):
17
+ super().__init__()
18
+ self.out_channels = in_channels
19
+
20
+ def forward(self, x):
21
+ return x
22
+
23
+
24
+ class Im2Seq(nn.Module):
25
+ def __init__(self, in_channels, **kwargs):
26
+ super().__init__()
27
+ self.out_channels = in_channels
28
+
29
+ def forward(self, x):
30
+ B, C, H, W = x.shape
31
+ # assert H == 1
32
+ x = x.reshape(B, C, H * W)
33
+ x = x.permute((0, 2, 1))
34
+ return x
35
+
36
+
37
+ class EncoderWithRNN(nn.Module):
38
+ def __init__(self, in_channels, **kwargs):
39
+ super(EncoderWithRNN, self).__init__()
40
+ hidden_size = kwargs.get("hidden_size", 256)
41
+ self.out_channels = hidden_size * 2
42
+ self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
43
+
44
+ def forward(self, x):
45
+ self.lstm.flatten_parameters()
46
+ x, _ = self.lstm(x)
47
+ return x
48
+
49
+
50
+ class SequenceEncoder(nn.Module):
51
+ def __init__(self, in_channels, encoder_type="rnn", **kwargs):
52
+ super(SequenceEncoder, self).__init__()
53
+ self.encoder_reshape = Im2Seq(in_channels)
54
+ self.out_channels = self.encoder_reshape.out_channels
55
+ self.encoder_type = encoder_type
56
+ if encoder_type == "reshape":
57
+ self.only_reshape = True
58
+ else:
59
+ support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
60
+ assert encoder_type in support_encoder_dict, "{} must in {}".format(
61
+ encoder_type, support_encoder_dict.keys()
62
+ )
63
+
64
+ self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
65
+ self.out_channels = self.encoder.out_channels
66
+ self.only_reshape = False
67
+
68
+ def forward(self, x):
69
+ if self.encoder_type != "svtr":
70
+ x = self.encoder_reshape(x)
71
+ if not self.only_reshape:
72
+ x = self.encoder(x)
73
+ return x
74
+ else:
75
+ x = self.encoder(x)
76
+ x = self.encoder_reshape(x)
77
+ return x
78
+
79
+
80
+ class ConvBNLayer(nn.Module):
81
+ def __init__(
82
+ self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
83
+ ):
84
+ super().__init__()
85
+ self.conv = nn.Conv2d(
86
+ in_channels=in_channels,
87
+ out_channels=out_channels,
88
+ kernel_size=kernel_size,
89
+ stride=stride,
90
+ padding=padding,
91
+ groups=groups,
92
+ # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
93
+ bias=bias_attr,
94
+ )
95
+ self.norm = nn.BatchNorm2d(out_channels)
96
+ self.act = Swish()
97
+
98
+ def forward(self, inputs):
99
+ out = self.conv(inputs)
100
+ out = self.norm(out)
101
+ out = self.act(out)
102
+ return out
103
+
104
+
105
+ class EncoderWithSVTR(nn.Module):
106
+ def __init__(
107
+ self,
108
+ in_channels,
109
+ dims=64, # XS
110
+ depth=2,
111
+ hidden_dims=120,
112
+ use_guide=False,
113
+ num_heads=8,
114
+ qkv_bias=True,
115
+ mlp_ratio=2.0,
116
+ drop_rate=0.1,
117
+ attn_drop_rate=0.1,
118
+ drop_path=0.0,
119
+ qk_scale=None,
120
+ ):
121
+ super(EncoderWithSVTR, self).__init__()
122
+ self.depth = depth
123
+ self.use_guide = use_guide
124
+ self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
125
+ self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
126
+
127
+ self.svtr_block = nn.ModuleList(
128
+ [
129
+ Block(
130
+ dim=hidden_dims,
131
+ num_heads=num_heads,
132
+ mixer="Global",
133
+ HW=None,
134
+ mlp_ratio=mlp_ratio,
135
+ qkv_bias=qkv_bias,
136
+ qk_scale=qk_scale,
137
+ drop=drop_rate,
138
+ act_layer="swish",
139
+ attn_drop=attn_drop_rate,
140
+ drop_path=drop_path,
141
+ norm_layer="nn.LayerNorm",
142
+ epsilon=1e-05,
143
+ prenorm=False,
144
+ )
145
+ for i in range(depth)
146
+ ]
147
+ )
148
+ self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
149
+ self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
150
+ # last conv-nxn, the input is concat of input tensor and conv3 output tensor
151
+ self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
152
+
153
+ self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
154
+ self.out_channels = dims
155
+ self.apply(self._init_weights)
156
+
157
+ def _init_weights(self, m):
158
+ # weight initialization
159
+ if isinstance(m, nn.Conv2d):
160
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
161
+ if m.bias is not None:
162
+ nn.init.zeros_(m.bias)
163
+ elif isinstance(m, nn.BatchNorm2d):
164
+ nn.init.ones_(m.weight)
165
+ nn.init.zeros_(m.bias)
166
+ elif isinstance(m, nn.Linear):
167
+ nn.init.normal_(m.weight, 0, 0.01)
168
+ if m.bias is not None:
169
+ nn.init.zeros_(m.bias)
170
+ elif isinstance(m, nn.ConvTranspose2d):
171
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
172
+ if m.bias is not None:
173
+ nn.init.zeros_(m.bias)
174
+ elif isinstance(m, nn.LayerNorm):
175
+ nn.init.ones_(m.weight)
176
+ nn.init.zeros_(m.bias)
177
+
178
+ def forward(self, x):
179
+ # for use guide
180
+ if self.use_guide:
181
+ z = x.clone()
182
+ z.stop_gradient = True
183
+ else:
184
+ z = x
185
+ # for short cut
186
+ h = z
187
+ # reduce dim
188
+ z = self.conv1(z)
189
+ z = self.conv2(z)
190
+ # SVTR global block
191
+ B, C, H, W = z.shape
192
+ z = z.flatten(2).permute(0, 2, 1)
193
+
194
+ for blk in self.svtr_block:
195
+ z = blk(z)
196
+
197
+ z = self.norm(z)
198
+ # last stage
199
+ z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
200
+ z = self.conv3(z)
201
+ z = torch.cat((h, z), dim=1)
202
+ z = self.conv1x1(self.conv4(z))
203
+
204
+ return z
205
+
206
+
207
+ if __name__ == "__main__":
208
+ svtrRNN = EncoderWithSVTR(56)
209
+ print(svtrRNN)
diffusers/examples/research_projects/anytext/ocr_recog/RecCTCHead.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class CTCHead(nn.Module):
5
+ def __init__(
6
+ self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
7
+ ):
8
+ super(CTCHead, self).__init__()
9
+ if mid_channels is None:
10
+ self.fc = nn.Linear(
11
+ in_channels,
12
+ out_channels,
13
+ bias=True,
14
+ )
15
+ else:
16
+ self.fc1 = nn.Linear(
17
+ in_channels,
18
+ mid_channels,
19
+ bias=True,
20
+ )
21
+ self.fc2 = nn.Linear(
22
+ mid_channels,
23
+ out_channels,
24
+ bias=True,
25
+ )
26
+
27
+ self.out_channels = out_channels
28
+ self.mid_channels = mid_channels
29
+ self.return_feats = return_feats
30
+
31
+ def forward(self, x, labels=None):
32
+ if self.mid_channels is None:
33
+ predicts = self.fc(x)
34
+ else:
35
+ x = self.fc1(x)
36
+ predicts = self.fc2(x)
37
+
38
+ if self.return_feats:
39
+ result = {}
40
+ result["ctc"] = predicts
41
+ result["ctc_neck"] = x
42
+ else:
43
+ result = predicts
44
+
45
+ return result
diffusers/examples/research_projects/anytext/ocr_recog/RecModel.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from .RecCTCHead import CTCHead
4
+ from .RecMv1_enhance import MobileNetV1Enhance
5
+ from .RNN import Im2Im, Im2Seq, SequenceEncoder
6
+
7
+
8
+ backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
9
+ neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
10
+ head_dict = {"CTCHead": CTCHead}
11
+
12
+
13
+ class RecModel(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ assert "in_channels" in config, "in_channels must in model config"
17
+ backbone_type = config["backbone"].pop("type")
18
+ assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
19
+ self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
20
+
21
+ neck_type = config["neck"].pop("type")
22
+ assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
23
+ self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
24
+
25
+ head_type = config["head"].pop("type")
26
+ assert head_type in head_dict, f"head.type must in {head_dict}"
27
+ self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
28
+
29
+ self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
30
+
31
+ def load_3rd_state_dict(self, _3rd_name, _state):
32
+ self.backbone.load_3rd_state_dict(_3rd_name, _state)
33
+ self.neck.load_3rd_state_dict(_3rd_name, _state)
34
+ self.head.load_3rd_state_dict(_3rd_name, _state)
35
+
36
+ def forward(self, x):
37
+ import torch
38
+
39
+ x = x.to(torch.float32)
40
+ x = self.backbone(x)
41
+ x = self.neck(x)
42
+ x = self.head(x)
43
+ return x
44
+
45
+ def encode(self, x):
46
+ x = self.backbone(x)
47
+ x = self.neck(x)
48
+ x = self.head.ctc_encoder(x)
49
+ return x
diffusers/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .common import Activation
6
+
7
+
8
+ class ConvBNLayer(nn.Module):
9
+ def __init__(
10
+ self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
11
+ ):
12
+ super(ConvBNLayer, self).__init__()
13
+ self.act = act
14
+ self._conv = nn.Conv2d(
15
+ in_channels=num_channels,
16
+ out_channels=num_filters,
17
+ kernel_size=filter_size,
18
+ stride=stride,
19
+ padding=padding,
20
+ groups=num_groups,
21
+ bias=False,
22
+ )
23
+
24
+ self._batch_norm = nn.BatchNorm2d(
25
+ num_filters,
26
+ )
27
+ if self.act is not None:
28
+ self._act = Activation(act_type=act, inplace=True)
29
+
30
+ def forward(self, inputs):
31
+ y = self._conv(inputs)
32
+ y = self._batch_norm(y)
33
+ if self.act is not None:
34
+ y = self._act(y)
35
+ return y
36
+
37
+
38
+ class DepthwiseSeparable(nn.Module):
39
+ def __init__(
40
+ self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
41
+ ):
42
+ super(DepthwiseSeparable, self).__init__()
43
+ self.use_se = use_se
44
+ self._depthwise_conv = ConvBNLayer(
45
+ num_channels=num_channels,
46
+ num_filters=int(num_filters1 * scale),
47
+ filter_size=dw_size,
48
+ stride=stride,
49
+ padding=padding,
50
+ num_groups=int(num_groups * scale),
51
+ )
52
+ if use_se:
53
+ self._se = SEModule(int(num_filters1 * scale))
54
+ self._pointwise_conv = ConvBNLayer(
55
+ num_channels=int(num_filters1 * scale),
56
+ filter_size=1,
57
+ num_filters=int(num_filters2 * scale),
58
+ stride=1,
59
+ padding=0,
60
+ )
61
+
62
+ def forward(self, inputs):
63
+ y = self._depthwise_conv(inputs)
64
+ if self.use_se:
65
+ y = self._se(y)
66
+ y = self._pointwise_conv(y)
67
+ return y
68
+
69
+
70
+ class MobileNetV1Enhance(nn.Module):
71
+ def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
72
+ super().__init__()
73
+ self.scale = scale
74
+ self.block_list = []
75
+
76
+ self.conv1 = ConvBNLayer(
77
+ num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
78
+ )
79
+
80
+ conv2_1 = DepthwiseSeparable(
81
+ num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
82
+ )
83
+ self.block_list.append(conv2_1)
84
+
85
+ conv2_2 = DepthwiseSeparable(
86
+ num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
87
+ )
88
+ self.block_list.append(conv2_2)
89
+
90
+ conv3_1 = DepthwiseSeparable(
91
+ num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
92
+ )
93
+ self.block_list.append(conv3_1)
94
+
95
+ conv3_2 = DepthwiseSeparable(
96
+ num_channels=int(128 * scale),
97
+ num_filters1=128,
98
+ num_filters2=256,
99
+ num_groups=128,
100
+ stride=(2, 1),
101
+ scale=scale,
102
+ )
103
+ self.block_list.append(conv3_2)
104
+
105
+ conv4_1 = DepthwiseSeparable(
106
+ num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
107
+ )
108
+ self.block_list.append(conv4_1)
109
+
110
+ conv4_2 = DepthwiseSeparable(
111
+ num_channels=int(256 * scale),
112
+ num_filters1=256,
113
+ num_filters2=512,
114
+ num_groups=256,
115
+ stride=(2, 1),
116
+ scale=scale,
117
+ )
118
+ self.block_list.append(conv4_2)
119
+
120
+ for _ in range(5):
121
+ conv5 = DepthwiseSeparable(
122
+ num_channels=int(512 * scale),
123
+ num_filters1=512,
124
+ num_filters2=512,
125
+ num_groups=512,
126
+ stride=1,
127
+ dw_size=5,
128
+ padding=2,
129
+ scale=scale,
130
+ use_se=False,
131
+ )
132
+ self.block_list.append(conv5)
133
+
134
+ conv5_6 = DepthwiseSeparable(
135
+ num_channels=int(512 * scale),
136
+ num_filters1=512,
137
+ num_filters2=1024,
138
+ num_groups=512,
139
+ stride=(2, 1),
140
+ dw_size=5,
141
+ padding=2,
142
+ scale=scale,
143
+ use_se=True,
144
+ )
145
+ self.block_list.append(conv5_6)
146
+
147
+ conv6 = DepthwiseSeparable(
148
+ num_channels=int(1024 * scale),
149
+ num_filters1=1024,
150
+ num_filters2=1024,
151
+ num_groups=1024,
152
+ stride=last_conv_stride,
153
+ dw_size=5,
154
+ padding=2,
155
+ use_se=True,
156
+ scale=scale,
157
+ )
158
+ self.block_list.append(conv6)
159
+
160
+ self.block_list = nn.Sequential(*self.block_list)
161
+ if last_pool_type == "avg":
162
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
163
+ else:
164
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
165
+ self.out_channels = int(1024 * scale)
166
+
167
+ def forward(self, inputs):
168
+ y = self.conv1(inputs)
169
+ y = self.block_list(y)
170
+ y = self.pool(y)
171
+ return y
172
+
173
+
174
+ def hardsigmoid(x):
175
+ return F.relu6(x + 3.0, inplace=True) / 6.0
176
+
177
+
178
+ class SEModule(nn.Module):
179
+ def __init__(self, channel, reduction=4):
180
+ super(SEModule, self).__init__()
181
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
182
+ self.conv1 = nn.Conv2d(
183
+ in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
184
+ )
185
+ self.conv2 = nn.Conv2d(
186
+ in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
187
+ )
188
+
189
+ def forward(self, inputs):
190
+ outputs = self.avg_pool(inputs)
191
+ outputs = self.conv1(outputs)
192
+ outputs = F.relu(outputs)
193
+ outputs = self.conv2(outputs)
194
+ outputs = hardsigmoid(outputs)
195
+ x = torch.mul(inputs, outputs)
196
+
197
+ return x
diffusers/examples/research_projects/anytext/ocr_recog/RecSVTR.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional
5
+ from torch.nn.init import ones_, trunc_normal_, zeros_
6
+
7
+
8
+ def drop_path(x, drop_prob=0.0, training=False):
9
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
10
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
11
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
12
+ """
13
+ if drop_prob == 0.0 or not training:
14
+ return x
15
+ keep_prob = torch.tensor(1 - drop_prob)
16
+ shape = (x.size()[0],) + (1,) * (x.ndim - 1)
17
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
18
+ random_tensor = torch.floor(random_tensor) # binarize
19
+ output = x.divide(keep_prob) * random_tensor
20
+ return output
21
+
22
+
23
+ class Swish(nn.Module):
24
+ def __int__(self):
25
+ super(Swish, self).__int__()
26
+
27
+ def forward(self, x):
28
+ return x * torch.sigmoid(x)
29
+
30
+
31
+ class ConvBNLayer(nn.Module):
32
+ def __init__(
33
+ self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
34
+ ):
35
+ super().__init__()
36
+ self.conv = nn.Conv2d(
37
+ in_channels=in_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=kernel_size,
40
+ stride=stride,
41
+ padding=padding,
42
+ groups=groups,
43
+ # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
44
+ bias=bias_attr,
45
+ )
46
+ self.norm = nn.BatchNorm2d(out_channels)
47
+ self.act = act()
48
+
49
+ def forward(self, inputs):
50
+ out = self.conv(inputs)
51
+ out = self.norm(out)
52
+ out = self.act(out)
53
+ return out
54
+
55
+
56
+ class DropPath(nn.Module):
57
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
58
+
59
+ def __init__(self, drop_prob=None):
60
+ super(DropPath, self).__init__()
61
+ self.drop_prob = drop_prob
62
+
63
+ def forward(self, x):
64
+ return drop_path(x, self.drop_prob, self.training)
65
+
66
+
67
+ class Identity(nn.Module):
68
+ def __init__(self):
69
+ super(Identity, self).__init__()
70
+
71
+ def forward(self, input):
72
+ return input
73
+
74
+
75
+ class Mlp(nn.Module):
76
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
77
+ super().__init__()
78
+ out_features = out_features or in_features
79
+ hidden_features = hidden_features or in_features
80
+ self.fc1 = nn.Linear(in_features, hidden_features)
81
+ if isinstance(act_layer, str):
82
+ self.act = Swish()
83
+ else:
84
+ self.act = act_layer()
85
+ self.fc2 = nn.Linear(hidden_features, out_features)
86
+ self.drop = nn.Dropout(drop)
87
+
88
+ def forward(self, x):
89
+ x = self.fc1(x)
90
+ x = self.act(x)
91
+ x = self.drop(x)
92
+ x = self.fc2(x)
93
+ x = self.drop(x)
94
+ return x
95
+
96
+
97
+ class ConvMixer(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dim,
101
+ num_heads=8,
102
+ HW=(8, 25),
103
+ local_k=(3, 3),
104
+ ):
105
+ super().__init__()
106
+ self.HW = HW
107
+ self.dim = dim
108
+ self.local_mixer = nn.Conv2d(
109
+ dim,
110
+ dim,
111
+ local_k,
112
+ 1,
113
+ (local_k[0] // 2, local_k[1] // 2),
114
+ groups=num_heads,
115
+ # weight_attr=ParamAttr(initializer=KaimingNormal())
116
+ )
117
+
118
+ def forward(self, x):
119
+ h = self.HW[0]
120
+ w = self.HW[1]
121
+ x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
122
+ x = self.local_mixer(x)
123
+ x = x.flatten(2).transpose([0, 2, 1])
124
+ return x
125
+
126
+
127
+ class Attention(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim,
131
+ num_heads=8,
132
+ mixer="Global",
133
+ HW=(8, 25),
134
+ local_k=(7, 11),
135
+ qkv_bias=False,
136
+ qk_scale=None,
137
+ attn_drop=0.0,
138
+ proj_drop=0.0,
139
+ ):
140
+ super().__init__()
141
+ self.num_heads = num_heads
142
+ head_dim = dim // num_heads
143
+ self.scale = qk_scale or head_dim**-0.5
144
+
145
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
146
+ self.attn_drop = nn.Dropout(attn_drop)
147
+ self.proj = nn.Linear(dim, dim)
148
+ self.proj_drop = nn.Dropout(proj_drop)
149
+ self.HW = HW
150
+ if HW is not None:
151
+ H = HW[0]
152
+ W = HW[1]
153
+ self.N = H * W
154
+ self.C = dim
155
+ if mixer == "Local" and HW is not None:
156
+ hk = local_k[0]
157
+ wk = local_k[1]
158
+ mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
159
+ for h in range(0, H):
160
+ for w in range(0, W):
161
+ mask[h * W + w, h : h + hk, w : w + wk] = 0.0
162
+ mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
163
+ mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
164
+ mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
165
+ self.mask = mask[None, None, :]
166
+ # self.mask = mask.unsqueeze([0, 1])
167
+ self.mixer = mixer
168
+
169
+ def forward(self, x):
170
+ if self.HW is not None:
171
+ N = self.N
172
+ C = self.C
173
+ else:
174
+ _, N, C = x.shape
175
+ qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
176
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
177
+
178
+ attn = q.matmul(k.permute((0, 1, 3, 2)))
179
+ if self.mixer == "Local":
180
+ attn += self.mask
181
+ attn = functional.softmax(attn, dim=-1)
182
+ attn = self.attn_drop(attn)
183
+
184
+ x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
185
+ x = self.proj(x)
186
+ x = self.proj_drop(x)
187
+ return x
188
+
189
+
190
+ class Block(nn.Module):
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ mixer="Global",
196
+ local_mixer=(7, 11),
197
+ HW=(8, 25),
198
+ mlp_ratio=4.0,
199
+ qkv_bias=False,
200
+ qk_scale=None,
201
+ drop=0.0,
202
+ attn_drop=0.0,
203
+ drop_path=0.0,
204
+ act_layer=nn.GELU,
205
+ norm_layer="nn.LayerNorm",
206
+ epsilon=1e-6,
207
+ prenorm=True,
208
+ ):
209
+ super().__init__()
210
+ if isinstance(norm_layer, str):
211
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
212
+ else:
213
+ self.norm1 = norm_layer(dim)
214
+ if mixer == "Global" or mixer == "Local":
215
+ self.mixer = Attention(
216
+ dim,
217
+ num_heads=num_heads,
218
+ mixer=mixer,
219
+ HW=HW,
220
+ local_k=local_mixer,
221
+ qkv_bias=qkv_bias,
222
+ qk_scale=qk_scale,
223
+ attn_drop=attn_drop,
224
+ proj_drop=drop,
225
+ )
226
+ elif mixer == "Conv":
227
+ self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
228
+ else:
229
+ raise TypeError("The mixer must be one of [Global, Local, Conv]")
230
+
231
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
232
+ if isinstance(norm_layer, str):
233
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
234
+ else:
235
+ self.norm2 = norm_layer(dim)
236
+ mlp_hidden_dim = int(dim * mlp_ratio)
237
+ self.mlp_ratio = mlp_ratio
238
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
239
+ self.prenorm = prenorm
240
+
241
+ def forward(self, x):
242
+ if self.prenorm:
243
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
244
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
245
+ else:
246
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
247
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
248
+ return x
249
+
250
+
251
+ class PatchEmbed(nn.Module):
252
+ """Image to Patch Embedding"""
253
+
254
+ def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
255
+ super().__init__()
256
+ num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
257
+ self.img_size = img_size
258
+ self.num_patches = num_patches
259
+ self.embed_dim = embed_dim
260
+ self.norm = None
261
+ if sub_num == 2:
262
+ self.proj = nn.Sequential(
263
+ ConvBNLayer(
264
+ in_channels=in_channels,
265
+ out_channels=embed_dim // 2,
266
+ kernel_size=3,
267
+ stride=2,
268
+ padding=1,
269
+ act=nn.GELU,
270
+ bias_attr=False,
271
+ ),
272
+ ConvBNLayer(
273
+ in_channels=embed_dim // 2,
274
+ out_channels=embed_dim,
275
+ kernel_size=3,
276
+ stride=2,
277
+ padding=1,
278
+ act=nn.GELU,
279
+ bias_attr=False,
280
+ ),
281
+ )
282
+ if sub_num == 3:
283
+ self.proj = nn.Sequential(
284
+ ConvBNLayer(
285
+ in_channels=in_channels,
286
+ out_channels=embed_dim // 4,
287
+ kernel_size=3,
288
+ stride=2,
289
+ padding=1,
290
+ act=nn.GELU,
291
+ bias_attr=False,
292
+ ),
293
+ ConvBNLayer(
294
+ in_channels=embed_dim // 4,
295
+ out_channels=embed_dim // 2,
296
+ kernel_size=3,
297
+ stride=2,
298
+ padding=1,
299
+ act=nn.GELU,
300
+ bias_attr=False,
301
+ ),
302
+ ConvBNLayer(
303
+ in_channels=embed_dim // 2,
304
+ out_channels=embed_dim,
305
+ kernel_size=3,
306
+ stride=2,
307
+ padding=1,
308
+ act=nn.GELU,
309
+ bias_attr=False,
310
+ ),
311
+ )
312
+
313
+ def forward(self, x):
314
+ B, C, H, W = x.shape
315
+ assert H == self.img_size[0] and W == self.img_size[1], (
316
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
317
+ )
318
+ x = self.proj(x).flatten(2).permute(0, 2, 1)
319
+ return x
320
+
321
+
322
+ class SubSample(nn.Module):
323
+ def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
324
+ super().__init__()
325
+ self.types = types
326
+ if types == "Pool":
327
+ self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
328
+ self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
329
+ self.proj = nn.Linear(in_channels, out_channels)
330
+ else:
331
+ self.conv = nn.Conv2d(
332
+ in_channels,
333
+ out_channels,
334
+ kernel_size=3,
335
+ stride=stride,
336
+ padding=1,
337
+ # weight_attr=ParamAttr(initializer=KaimingNormal())
338
+ )
339
+ self.norm = eval(sub_norm)(out_channels)
340
+ if act is not None:
341
+ self.act = act()
342
+ else:
343
+ self.act = None
344
+
345
+ def forward(self, x):
346
+ if self.types == "Pool":
347
+ x1 = self.avgpool(x)
348
+ x2 = self.maxpool(x)
349
+ x = (x1 + x2) * 0.5
350
+ out = self.proj(x.flatten(2).permute((0, 2, 1)))
351
+ else:
352
+ x = self.conv(x)
353
+ out = x.flatten(2).permute((0, 2, 1))
354
+ out = self.norm(out)
355
+ if self.act is not None:
356
+ out = self.act(out)
357
+
358
+ return out
359
+
360
+
361
+ class SVTRNet(nn.Module):
362
+ def __init__(
363
+ self,
364
+ img_size=[48, 100],
365
+ in_channels=3,
366
+ embed_dim=[64, 128, 256],
367
+ depth=[3, 6, 3],
368
+ num_heads=[2, 4, 8],
369
+ mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
370
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
371
+ patch_merging="Conv", # Conv, Pool, None
372
+ mlp_ratio=4,
373
+ qkv_bias=True,
374
+ qk_scale=None,
375
+ drop_rate=0.0,
376
+ last_drop=0.1,
377
+ attn_drop_rate=0.0,
378
+ drop_path_rate=0.1,
379
+ norm_layer="nn.LayerNorm",
380
+ sub_norm="nn.LayerNorm",
381
+ epsilon=1e-6,
382
+ out_channels=192,
383
+ out_char_num=25,
384
+ block_unit="Block",
385
+ act="nn.GELU",
386
+ last_stage=True,
387
+ sub_num=2,
388
+ prenorm=True,
389
+ use_lenhead=False,
390
+ **kwargs,
391
+ ):
392
+ super().__init__()
393
+ self.img_size = img_size
394
+ self.embed_dim = embed_dim
395
+ self.out_channels = out_channels
396
+ self.prenorm = prenorm
397
+ patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
398
+ self.patch_embed = PatchEmbed(
399
+ img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
400
+ )
401
+ num_patches = self.patch_embed.num_patches
402
+ self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
403
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
404
+ # self.pos_embed = self.create_parameter(
405
+ # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
406
+
407
+ # self.add_parameter("pos_embed", self.pos_embed)
408
+
409
+ self.pos_drop = nn.Dropout(p=drop_rate)
410
+ Block_unit = eval(block_unit)
411
+
412
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
413
+ self.blocks1 = nn.ModuleList(
414
+ [
415
+ Block_unit(
416
+ dim=embed_dim[0],
417
+ num_heads=num_heads[0],
418
+ mixer=mixer[0 : depth[0]][i],
419
+ HW=self.HW,
420
+ local_mixer=local_mixer[0],
421
+ mlp_ratio=mlp_ratio,
422
+ qkv_bias=qkv_bias,
423
+ qk_scale=qk_scale,
424
+ drop=drop_rate,
425
+ act_layer=eval(act),
426
+ attn_drop=attn_drop_rate,
427
+ drop_path=dpr[0 : depth[0]][i],
428
+ norm_layer=norm_layer,
429
+ epsilon=epsilon,
430
+ prenorm=prenorm,
431
+ )
432
+ for i in range(depth[0])
433
+ ]
434
+ )
435
+ if patch_merging is not None:
436
+ self.sub_sample1 = SubSample(
437
+ embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
438
+ )
439
+ HW = [self.HW[0] // 2, self.HW[1]]
440
+ else:
441
+ HW = self.HW
442
+ self.patch_merging = patch_merging
443
+ self.blocks2 = nn.ModuleList(
444
+ [
445
+ Block_unit(
446
+ dim=embed_dim[1],
447
+ num_heads=num_heads[1],
448
+ mixer=mixer[depth[0] : depth[0] + depth[1]][i],
449
+ HW=HW,
450
+ local_mixer=local_mixer[1],
451
+ mlp_ratio=mlp_ratio,
452
+ qkv_bias=qkv_bias,
453
+ qk_scale=qk_scale,
454
+ drop=drop_rate,
455
+ act_layer=eval(act),
456
+ attn_drop=attn_drop_rate,
457
+ drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
458
+ norm_layer=norm_layer,
459
+ epsilon=epsilon,
460
+ prenorm=prenorm,
461
+ )
462
+ for i in range(depth[1])
463
+ ]
464
+ )
465
+ if patch_merging is not None:
466
+ self.sub_sample2 = SubSample(
467
+ embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
468
+ )
469
+ HW = [self.HW[0] // 4, self.HW[1]]
470
+ else:
471
+ HW = self.HW
472
+ self.blocks3 = nn.ModuleList(
473
+ [
474
+ Block_unit(
475
+ dim=embed_dim[2],
476
+ num_heads=num_heads[2],
477
+ mixer=mixer[depth[0] + depth[1] :][i],
478
+ HW=HW,
479
+ local_mixer=local_mixer[2],
480
+ mlp_ratio=mlp_ratio,
481
+ qkv_bias=qkv_bias,
482
+ qk_scale=qk_scale,
483
+ drop=drop_rate,
484
+ act_layer=eval(act),
485
+ attn_drop=attn_drop_rate,
486
+ drop_path=dpr[depth[0] + depth[1] :][i],
487
+ norm_layer=norm_layer,
488
+ epsilon=epsilon,
489
+ prenorm=prenorm,
490
+ )
491
+ for i in range(depth[2])
492
+ ]
493
+ )
494
+ self.last_stage = last_stage
495
+ if last_stage:
496
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
497
+ self.last_conv = nn.Conv2d(
498
+ in_channels=embed_dim[2],
499
+ out_channels=self.out_channels,
500
+ kernel_size=1,
501
+ stride=1,
502
+ padding=0,
503
+ bias=False,
504
+ )
505
+ self.hardswish = nn.Hardswish()
506
+ self.dropout = nn.Dropout(p=last_drop)
507
+ if not prenorm:
508
+ self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
509
+ self.use_lenhead = use_lenhead
510
+ if use_lenhead:
511
+ self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
512
+ self.hardswish_len = nn.Hardswish()
513
+ self.dropout_len = nn.Dropout(p=last_drop)
514
+
515
+ trunc_normal_(self.pos_embed, std=0.02)
516
+ self.apply(self._init_weights)
517
+
518
+ def _init_weights(self, m):
519
+ if isinstance(m, nn.Linear):
520
+ trunc_normal_(m.weight, std=0.02)
521
+ if isinstance(m, nn.Linear) and m.bias is not None:
522
+ zeros_(m.bias)
523
+ elif isinstance(m, nn.LayerNorm):
524
+ zeros_(m.bias)
525
+ ones_(m.weight)
526
+
527
+ def forward_features(self, x):
528
+ x = self.patch_embed(x)
529
+ x = x + self.pos_embed
530
+ x = self.pos_drop(x)
531
+ for blk in self.blocks1:
532
+ x = blk(x)
533
+ if self.patch_merging is not None:
534
+ x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
535
+ for blk in self.blocks2:
536
+ x = blk(x)
537
+ if self.patch_merging is not None:
538
+ x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
539
+ for blk in self.blocks3:
540
+ x = blk(x)
541
+ if not self.prenorm:
542
+ x = self.norm(x)
543
+ return x
544
+
545
+ def forward(self, x):
546
+ x = self.forward_features(x)
547
+ if self.use_lenhead:
548
+ len_x = self.len_conv(x.mean(1))
549
+ len_x = self.dropout_len(self.hardswish_len(len_x))
550
+ if self.last_stage:
551
+ if self.patch_merging is not None:
552
+ h = self.HW[0] // 4
553
+ else:
554
+ h = self.HW[0]
555
+ x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
556
+ x = self.last_conv(x)
557
+ x = self.hardswish(x)
558
+ x = self.dropout(x)
559
+ if self.use_lenhead:
560
+ return x, len_x
561
+ return x
562
+
563
+
564
+ if __name__ == "__main__":
565
+ a = torch.rand(1, 3, 48, 100)
566
+ svtr = SVTRNet()
567
+
568
+ out = svtr(a)
569
+ print(svtr)
570
+ print(out.size())
diffusers/examples/research_projects/anytext/ocr_recog/common.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Hswish(nn.Module):
7
+ def __init__(self, inplace=True):
8
+ super(Hswish, self).__init__()
9
+ self.inplace = inplace
10
+
11
+ def forward(self, x):
12
+ return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
13
+
14
+
15
+ # out = max(0, min(1, slop*x+offset))
16
+ # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
17
+ class Hsigmoid(nn.Module):
18
+ def __init__(self, inplace=True):
19
+ super(Hsigmoid, self).__init__()
20
+ self.inplace = inplace
21
+
22
+ def forward(self, x):
23
+ # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
24
+ # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
25
+ return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
26
+
27
+
28
+ class GELU(nn.Module):
29
+ def __init__(self, inplace=True):
30
+ super(GELU, self).__init__()
31
+ self.inplace = inplace
32
+
33
+ def forward(self, x):
34
+ return torch.nn.functional.gelu(x)
35
+
36
+
37
+ class Swish(nn.Module):
38
+ def __init__(self, inplace=True):
39
+ super(Swish, self).__init__()
40
+ self.inplace = inplace
41
+
42
+ def forward(self, x):
43
+ if self.inplace:
44
+ x.mul_(torch.sigmoid(x))
45
+ return x
46
+ else:
47
+ return x * torch.sigmoid(x)
48
+
49
+
50
+ class Activation(nn.Module):
51
+ def __init__(self, act_type, inplace=True):
52
+ super(Activation, self).__init__()
53
+ act_type = act_type.lower()
54
+ if act_type == "relu":
55
+ self.act = nn.ReLU(inplace=inplace)
56
+ elif act_type == "relu6":
57
+ self.act = nn.ReLU6(inplace=inplace)
58
+ elif act_type == "sigmoid":
59
+ raise NotImplementedError
60
+ elif act_type == "hard_sigmoid":
61
+ self.act = Hsigmoid(inplace)
62
+ elif act_type == "hard_swish":
63
+ self.act = Hswish(inplace=inplace)
64
+ elif act_type == "leakyrelu":
65
+ self.act = nn.LeakyReLU(inplace=inplace)
66
+ elif act_type == "gelu":
67
+ self.act = GELU(inplace=inplace)
68
+ elif act_type == "swish":
69
+ self.act = Swish(inplace=inplace)
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ def forward(self, inputs):
74
+ return self.act(inputs)
diffusers/examples/research_projects/anytext/ocr_recog/en_dict.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0
2
+ 1
3
+ 2
4
+ 3
5
+ 4
6
+ 5
7
+ 6
8
+ 7
9
+ 8
10
+ 9
11
+ :
12
+ ;
13
+ <
14
+ =
15
+ >
16
+ ?
17
+ @
18
+ A
19
+ B
20
+ C
21
+ D
22
+ E
23
+ F
24
+ G
25
+ H
26
+ I
27
+ J
28
+ K
29
+ L
30
+ M
31
+ N
32
+ O
33
+ P
34
+ Q
35
+ R
36
+ S
37
+ T
38
+ U
39
+ V
40
+ W
41
+ X
42
+ Y
43
+ Z
44
+ [
45
+ \
46
+ ]
47
+ ^
48
+ _
49
+ `
50
+ a
51
+ b
52
+ c
53
+ d
54
+ e
55
+ f
56
+ g
57
+ h
58
+ i
59
+ j
60
+ k
61
+ l
62
+ m
63
+ n
64
+ o
65
+ p
66
+ q
67
+ r
68
+ s
69
+ t
70
+ u
71
+ v
72
+ w
73
+ x
74
+ y
75
+ z
76
+ {
77
+ |
78
+ }
79
+ ~
80
+ !
81
+ "
82
+ #
83
+ $
84
+ %
85
+ &
86
+ '
87
+ (
88
+ )
89
+ *
90
+ +
91
+ ,
92
+ -
93
+ .
94
+ /
95
+
diffusers/examples/research_projects/consistency_training/README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Consistency Training
2
+
3
+ `train_cm_ct_unconditional.py` trains a consistency model (CM) from scratch following the consistency training (CT) algorithm introduced in [Consistency Models](https://huggingface.co/papers/2303.01469) and refined in [Improved Techniques for Training Consistency Models](https://huggingface.co/papers/2310.14189). Both unconditional and class-conditional training are supported.
4
+
5
+ A usage example is as follows:
6
+
7
+ ```bash
8
+ accelerate launch examples/research_projects/consistency_training/train_cm_ct_unconditional.py \
9
+ --dataset_name="cifar10" \
10
+ --dataset_image_column_name="img" \
11
+ --output_dir="/path/to/output/dir" \
12
+ --mixed_precision=fp16 \
13
+ --resolution=32 \
14
+ --max_train_steps=1000 --max_train_samples=10000 \
15
+ --dataloader_num_workers=8 \
16
+ --noise_precond_type="cm" --input_precond_type="cm" \
17
+ --train_batch_size=4 \
18
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
19
+ --use_8bit_adam \
20
+ --use_ema \
21
+ --validation_steps=100 --eval_batch_size=4 \
22
+ --checkpointing_steps=100 --checkpoints_total_limit=10 \
23
+ --class_conditional --num_classes=10 \
24
+ ```
diffusers/examples/research_projects/consistency_training/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py ADDED
@@ -0,0 +1,1438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ """Script to train a consistency model from scratch via (improved) consistency training."""
16
+
17
+ import argparse
18
+ import gc
19
+ import logging
20
+ import math
21
+ import os
22
+ import shutil
23
+ from datetime import timedelta
24
+ from pathlib import Path
25
+
26
+ import accelerate
27
+ import datasets
28
+ import numpy as np
29
+ import torch
30
+ from accelerate import Accelerator, InitProcessGroupKwargs
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import ProjectConfiguration, set_seed
33
+ from datasets import load_dataset
34
+ from huggingface_hub import create_repo, upload_folder
35
+ from packaging import version
36
+ from torchvision import transforms
37
+ from tqdm.auto import tqdm
38
+
39
+ import diffusers
40
+ from diffusers import (
41
+ CMStochasticIterativeScheduler,
42
+ ConsistencyModelPipeline,
43
+ UNet2DModel,
44
+ )
45
+ from diffusers.optimization import get_scheduler
46
+ from diffusers.training_utils import EMAModel, resolve_interpolation_mode
47
+ from diffusers.utils import is_tensorboard_available, is_wandb_available
48
+ from diffusers.utils.import_utils import is_xformers_available
49
+ from diffusers.utils.torch_utils import is_compiled_module
50
+
51
+
52
+ if is_wandb_available():
53
+ import wandb
54
+
55
+
56
+ logger = get_logger(__name__, log_level="INFO")
57
+
58
+
59
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
60
+ """
61
+ Extract values from a 1-D numpy array for a batch of indices.
62
+
63
+ :param arr: the 1-D numpy array.
64
+ :param timesteps: a tensor of indices into the array to extract.
65
+ :param broadcast_shape: a larger shape of K dimensions with the batch
66
+ dimension equal to the length of timesteps.
67
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
68
+ """
69
+ if not isinstance(arr, torch.Tensor):
70
+ arr = torch.from_numpy(arr)
71
+ res = arr[timesteps].float().to(timesteps.device)
72
+ while len(res.shape) < len(broadcast_shape):
73
+ res = res[..., None]
74
+ return res.expand(broadcast_shape)
75
+
76
+
77
+ def append_dims(x, target_dims):
78
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
79
+ dims_to_append = target_dims - x.ndim
80
+ if dims_to_append < 0:
81
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
82
+ return x[(...,) + (None,) * dims_to_append]
83
+
84
+
85
+ def extract_into_tensor(a, t, x_shape):
86
+ b, *_ = t.shape
87
+ out = a.gather(-1, t)
88
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
89
+
90
+
91
+ def get_discretization_steps(global_step: int, max_train_steps: int, s_0: int = 10, s_1: int = 1280, constant=False):
92
+ """
93
+ Calculates the current discretization steps at global step k using the discretization curriculum N(k).
94
+ """
95
+ if constant:
96
+ return s_0 + 1
97
+
98
+ k_prime = math.floor(max_train_steps / (math.log2(math.floor(s_1 / s_0)) + 1))
99
+ num_discretization_steps = min(s_0 * 2 ** math.floor(global_step / k_prime), s_1) + 1
100
+
101
+ return num_discretization_steps
102
+
103
+
104
+ def get_skip_steps(global_step, initial_skip: int = 1):
105
+ # Currently only support constant skip curriculum.
106
+ return initial_skip
107
+
108
+
109
+ def get_karras_sigmas(
110
+ num_discretization_steps: int,
111
+ sigma_min: float = 0.002,
112
+ sigma_max: float = 80.0,
113
+ rho: float = 7.0,
114
+ dtype=torch.float32,
115
+ ):
116
+ """
117
+ Calculates the Karras sigmas timestep discretization of [sigma_min, sigma_max].
118
+ """
119
+ ramp = np.linspace(0, 1, num_discretization_steps)
120
+ min_inv_rho = sigma_min ** (1 / rho)
121
+ max_inv_rho = sigma_max ** (1 / rho)
122
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
123
+ # Make sure sigmas are in increasing rather than decreasing order (see section 2 of the iCT paper)
124
+ sigmas = sigmas[::-1].copy()
125
+ sigmas = torch.from_numpy(sigmas).to(dtype=dtype)
126
+ return sigmas
127
+
128
+
129
+ def get_discretized_lognormal_weights(noise_levels: torch.Tensor, p_mean: float = -1.1, p_std: float = 2.0):
130
+ """
131
+ Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal"
132
+ " distribution used in the iCT paper (given in Equation 10).
133
+ """
134
+ upper_prob = torch.special.erf((torch.log(noise_levels[1:]) - p_mean) / (math.sqrt(2) * p_std))
135
+ lower_prob = torch.special.erf((torch.log(noise_levels[:-1]) - p_mean) / (math.sqrt(2) * p_std))
136
+ weights = upper_prob - lower_prob
137
+ return weights
138
+
139
+
140
+ def get_loss_weighting_schedule(noise_levels: torch.Tensor):
141
+ """
142
+ Calculates the loss weighting schedule lambda given a set of noise levels.
143
+ """
144
+ return 1.0 / (noise_levels[1:] - noise_levels[:-1])
145
+
146
+
147
+ def add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
148
+ # Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples
149
+ sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype)
150
+ while len(sigmas.shape) < len(original_samples.shape):
151
+ sigmas = sigmas.unsqueeze(-1)
152
+
153
+ noisy_samples = original_samples + noise * sigmas
154
+
155
+ return noisy_samples
156
+
157
+
158
+ def get_noise_preconditioning(sigmas, noise_precond_type: str = "cm"):
159
+ """
160
+ Calculates the noise preconditioning function c_noise, which is used to transform the raw Karras sigmas into the
161
+ timestep input for the U-Net.
162
+ """
163
+ if noise_precond_type == "none":
164
+ return sigmas
165
+ elif noise_precond_type == "edm":
166
+ return 0.25 * torch.log(sigmas)
167
+ elif noise_precond_type == "cm":
168
+ return 1000 * 0.25 * torch.log(sigmas + 1e-44)
169
+ else:
170
+ raise ValueError(
171
+ f"Noise preconditioning type {noise_precond_type} is not current supported. Currently supported noise"
172
+ f" preconditioning types are `none` (which uses the sigmas as is), `edm`, and `cm`."
173
+ )
174
+
175
+
176
+ def get_input_preconditioning(sigmas, sigma_data=0.5, input_precond_type: str = "cm"):
177
+ """
178
+ Calculates the input preconditioning factor c_in, which is used to scale the U-Net image input.
179
+ """
180
+ if input_precond_type == "none":
181
+ return 1
182
+ elif input_precond_type == "cm":
183
+ return 1.0 / (sigmas**2 + sigma_data**2)
184
+ else:
185
+ raise ValueError(
186
+ f"Input preconditioning type {input_precond_type} is not current supported. Currently supported input"
187
+ f" preconditioning types are `none` (which uses a scaling factor of 1.0) and `cm`."
188
+ )
189
+
190
+
191
+ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=1.0):
192
+ scaled_timestep = timestep_scaling * timestep
193
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
194
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
195
+ return c_skip, c_out
196
+
197
+
198
+ def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name="teacher"):
199
+ logger.info("Running validation... ")
200
+
201
+ unet = accelerator.unwrap_model(unet)
202
+ pipeline = ConsistencyModelPipeline(
203
+ unet=unet,
204
+ scheduler=scheduler,
205
+ )
206
+ pipeline = pipeline.to(device=accelerator.device)
207
+ pipeline.set_progress_bar_config(disable=True)
208
+
209
+ if args.enable_xformers_memory_efficient_attention:
210
+ pipeline.enable_xformers_memory_efficient_attention()
211
+
212
+ if args.seed is None:
213
+ generator = None
214
+ else:
215
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
216
+
217
+ class_labels = [None]
218
+ if args.class_conditional:
219
+ if args.num_classes is not None:
220
+ class_labels = list(range(args.num_classes))
221
+ else:
222
+ logger.warning(
223
+ "The model is class-conditional but the number of classes is not set. The generated images will be"
224
+ " unconditional rather than class-conditional."
225
+ )
226
+
227
+ image_logs = []
228
+
229
+ for class_label in class_labels:
230
+ images = []
231
+ with torch.autocast("cuda"):
232
+ images = pipeline(
233
+ num_inference_steps=1,
234
+ batch_size=args.eval_batch_size,
235
+ class_labels=[class_label] * args.eval_batch_size,
236
+ generator=generator,
237
+ ).images
238
+ log = {"images": images}
239
+ if args.class_conditional and class_label is not None:
240
+ log["class_label"] = str(class_label)
241
+ else:
242
+ log["class_label"] = "images"
243
+ image_logs.append(log)
244
+
245
+ for tracker in accelerator.trackers:
246
+ if tracker.name == "tensorboard":
247
+ for log in image_logs:
248
+ images = log["images"]
249
+ class_label = log["class_label"]
250
+ formatted_images = []
251
+ for image in images:
252
+ formatted_images.append(np.asarray(image))
253
+
254
+ formatted_images = np.stack(formatted_images)
255
+
256
+ tracker.writer.add_images(class_label, formatted_images, step, dataformats="NHWC")
257
+ elif tracker.name == "wandb":
258
+ formatted_images = []
259
+
260
+ for log in image_logs:
261
+ images = log["images"]
262
+ class_label = log["class_label"]
263
+ for image in images:
264
+ image = wandb.Image(image, caption=class_label)
265
+ formatted_images.append(image)
266
+
267
+ tracker.log({f"validation/{name}": formatted_images})
268
+ else:
269
+ logger.warning(f"image logging not implemented for {tracker.name}")
270
+
271
+ del pipeline
272
+ gc.collect()
273
+ torch.cuda.empty_cache()
274
+
275
+ return image_logs
276
+
277
+
278
+ def parse_args():
279
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
280
+ # ------------Model Arguments-----------
281
+ parser.add_argument(
282
+ "--model_config_name_or_path",
283
+ type=str,
284
+ default=None,
285
+ help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
286
+ )
287
+ parser.add_argument(
288
+ "--pretrained_model_name_or_path",
289
+ type=str,
290
+ default=None,
291
+ help=(
292
+ "If initializing the weights from a pretrained model, the path to the pretrained model or model identifier"
293
+ " from huggingface.co/models."
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--revision",
298
+ type=str,
299
+ default=None,
300
+ required=False,
301
+ help="Revision of pretrained model identifier from huggingface.co/models.",
302
+ )
303
+ parser.add_argument(
304
+ "--variant",
305
+ type=str,
306
+ default=None,
307
+ help=(
308
+ "Variant of the model files of the pretrained model identifier from huggingface.co/models, e.g. `fp16`,"
309
+ " `non_ema`, etc.",
310
+ ),
311
+ )
312
+ # ------------Dataset Arguments-----------
313
+ parser.add_argument(
314
+ "--train_data_dir",
315
+ type=str,
316
+ default=None,
317
+ help=(
318
+ "A folder containing the training data. Folder contents must follow the structure described in"
319
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
320
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
321
+ ),
322
+ )
323
+ parser.add_argument(
324
+ "--dataset_name",
325
+ type=str,
326
+ default=None,
327
+ help=(
328
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
329
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
330
+ " or to a folder containing files that HF Datasets can understand."
331
+ ),
332
+ )
333
+ parser.add_argument(
334
+ "--dataset_config_name",
335
+ type=str,
336
+ default=None,
337
+ help="The config of the Dataset, leave as None if there's only one config.",
338
+ )
339
+ parser.add_argument(
340
+ "--dataset_image_column_name",
341
+ type=str,
342
+ default="image",
343
+ help="The name of the image column in the dataset to use for training.",
344
+ )
345
+ parser.add_argument(
346
+ "--dataset_class_label_column_name",
347
+ type=str,
348
+ default="label",
349
+ help="If doing class-conditional training, the name of the class label column in the dataset to use.",
350
+ )
351
+ # ------------Image Processing Arguments-----------
352
+ parser.add_argument(
353
+ "--resolution",
354
+ type=int,
355
+ default=64,
356
+ help=(
357
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
358
+ " resolution"
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--interpolation_type",
363
+ type=str,
364
+ default="bilinear",
365
+ help=(
366
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
367
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
368
+ ),
369
+ )
370
+ parser.add_argument(
371
+ "--center_crop",
372
+ default=False,
373
+ action="store_true",
374
+ help=(
375
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
376
+ " cropped. The images will be resized to the resolution first before cropping."
377
+ ),
378
+ )
379
+ parser.add_argument(
380
+ "--random_flip",
381
+ default=False,
382
+ action="store_true",
383
+ help="whether to randomly flip images horizontally",
384
+ )
385
+ parser.add_argument(
386
+ "--class_conditional",
387
+ action="store_true",
388
+ help=(
389
+ "Whether to train a class-conditional model. If set, the class labels will be taken from the `label`"
390
+ " column of the provided dataset."
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--num_classes",
395
+ type=int,
396
+ default=None,
397
+ help="The number of classes in the training data, if training a class-conditional model.",
398
+ )
399
+ parser.add_argument(
400
+ "--class_embed_type",
401
+ type=str,
402
+ default=None,
403
+ help=(
404
+ "The class embedding type to use. Choose from `None`, `identity`, and `timestep`. If `class_conditional`"
405
+ " and `num_classes` and set, but `class_embed_type` is `None`, a embedding matrix will be used."
406
+ ),
407
+ )
408
+ # ------------Dataloader Arguments-----------
409
+ parser.add_argument(
410
+ "--dataloader_num_workers",
411
+ type=int,
412
+ default=0,
413
+ help=(
414
+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
415
+ " process."
416
+ ),
417
+ )
418
+ # ------------Training Arguments-----------
419
+ # ----General Training Arguments----
420
+ parser.add_argument(
421
+ "--output_dir",
422
+ type=str,
423
+ default="ddpm-model-64",
424
+ help="The output directory where the model predictions and checkpoints will be written.",
425
+ )
426
+ parser.add_argument("--overwrite_output_dir", action="store_true")
427
+ parser.add_argument(
428
+ "--cache_dir",
429
+ type=str,
430
+ default=None,
431
+ help="The directory where the downloaded models and datasets will be stored.",
432
+ )
433
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
434
+ # ----Batch Size and Training Length----
435
+ parser.add_argument(
436
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
437
+ )
438
+ parser.add_argument("--num_train_epochs", type=int, default=100)
439
+ parser.add_argument(
440
+ "--max_train_steps",
441
+ type=int,
442
+ default=None,
443
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
444
+ )
445
+ parser.add_argument(
446
+ "--max_train_samples",
447
+ type=int,
448
+ default=None,
449
+ help=(
450
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
451
+ "value if set."
452
+ ),
453
+ )
454
+ # ----Learning Rate----
455
+ parser.add_argument(
456
+ "--learning_rate",
457
+ type=float,
458
+ default=1e-4,
459
+ help="Initial learning rate (after the potential warmup period) to use.",
460
+ )
461
+ parser.add_argument(
462
+ "--scale_lr",
463
+ action="store_true",
464
+ default=False,
465
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
466
+ )
467
+ parser.add_argument(
468
+ "--lr_scheduler",
469
+ type=str,
470
+ default="cosine",
471
+ help=(
472
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
473
+ ' "constant", "constant_with_warmup"]'
474
+ ),
475
+ )
476
+ parser.add_argument(
477
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
478
+ )
479
+ # ----Optimizer (Adam) Arguments----
480
+ parser.add_argument(
481
+ "--optimizer_type",
482
+ type=str,
483
+ default="adamw",
484
+ help=(
485
+ "The optimizer algorithm to use for training. Choose between `radam` and `adamw`. The iCT paper uses"
486
+ " RAdam."
487
+ ),
488
+ )
489
+ parser.add_argument(
490
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
491
+ )
492
+ parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
493
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
494
+ parser.add_argument(
495
+ "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
496
+ )
497
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
498
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
499
+ # ----Consistency Training (CT) Specific Arguments----
500
+ parser.add_argument(
501
+ "--prediction_type",
502
+ type=str,
503
+ default="sample",
504
+ choices=["sample"],
505
+ help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
506
+ )
507
+ parser.add_argument("--ddpm_num_steps", type=int, default=1000)
508
+ parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
509
+ parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
510
+ parser.add_argument(
511
+ "--sigma_min",
512
+ type=float,
513
+ default=0.002,
514
+ help=(
515
+ "The lower boundary for the timestep discretization, which should be set to a small positive value close"
516
+ " to zero to avoid numerical issues when solving the PF-ODE backwards in time."
517
+ ),
518
+ )
519
+ parser.add_argument(
520
+ "--sigma_max",
521
+ type=float,
522
+ default=80.0,
523
+ help=(
524
+ "The upper boundary for the timestep discretization, which also determines the variance of the Gaussian"
525
+ " prior."
526
+ ),
527
+ )
528
+ parser.add_argument(
529
+ "--rho",
530
+ type=float,
531
+ default=7.0,
532
+ help="The rho parameter for the Karras sigmas timestep dicretization.",
533
+ )
534
+ parser.add_argument(
535
+ "--huber_c",
536
+ type=float,
537
+ default=None,
538
+ help=(
539
+ "The Pseudo-Huber loss parameter c. If not set, this will default to the value recommended in the Improved"
540
+ " Consistency Training (iCT) paper of 0.00054 * sqrt(d), where d is the data dimensionality."
541
+ ),
542
+ )
543
+ parser.add_argument(
544
+ "--discretization_s_0",
545
+ type=int,
546
+ default=10,
547
+ help=(
548
+ "The s_0 parameter in the discretization curriculum N(k). This controls the number of training steps after"
549
+ " which the number of discretization steps N will be doubled."
550
+ ),
551
+ )
552
+ parser.add_argument(
553
+ "--discretization_s_1",
554
+ type=int,
555
+ default=1280,
556
+ help=(
557
+ "The s_1 parameter in the discretization curriculum N(k). This controls the upper limit to the number of"
558
+ " discretization steps used. Increasing this value will reduce the bias at the cost of higher variance."
559
+ ),
560
+ )
561
+ parser.add_argument(
562
+ "--constant_discretization_steps",
563
+ action="store_true",
564
+ help=(
565
+ "Whether to set the discretization curriculum N(k) to be the constant value `discretization_s_0 + 1`. This"
566
+ " is useful for testing when `max_number_steps` is small, when `k_prime` would otherwise be 0, causing"
567
+ " a divide-by-zero error."
568
+ ),
569
+ )
570
+ parser.add_argument(
571
+ "--p_mean",
572
+ type=float,
573
+ default=-1.1,
574
+ help=(
575
+ "The mean parameter P_mean for the (discretized) lognormal noise schedule, which controls the probability"
576
+ " of sampling a (discrete) noise level sigma_i."
577
+ ),
578
+ )
579
+ parser.add_argument(
580
+ "--p_std",
581
+ type=float,
582
+ default=2.0,
583
+ help=(
584
+ "The standard deviation parameter P_std for the (discretized) noise schedule, which controls the"
585
+ " probability of sampling a (discrete) noise level sigma_i."
586
+ ),
587
+ )
588
+ parser.add_argument(
589
+ "--noise_precond_type",
590
+ type=str,
591
+ default="cm",
592
+ help=(
593
+ "The noise preconditioning function to use for transforming the raw Karras sigmas into the timestep"
594
+ " argument of the U-Net. Choose between `none` (the identity function), `edm`, and `cm`."
595
+ ),
596
+ )
597
+ parser.add_argument(
598
+ "--input_precond_type",
599
+ type=str,
600
+ default="cm",
601
+ help=(
602
+ "The input preconditioning function to use for scaling the image input of the U-Net. Choose between `none`"
603
+ " (a scaling factor of 1) and `cm`."
604
+ ),
605
+ )
606
+ parser.add_argument(
607
+ "--skip_steps",
608
+ type=int,
609
+ default=1,
610
+ help=(
611
+ "The gap in indices between the student and teacher noise levels. In the iCT paper this is always set to"
612
+ " 1, but theoretically this could be greater than 1 and/or altered according to a curriculum throughout"
613
+ " training, much like the number of discretization steps is."
614
+ ),
615
+ )
616
+ parser.add_argument(
617
+ "--cast_teacher",
618
+ action="store_true",
619
+ help="Whether to cast the teacher U-Net model to `weight_dtype` or leave it in full precision.",
620
+ )
621
+ # ----Exponential Moving Average (EMA) Arguments----
622
+ parser.add_argument(
623
+ "--use_ema",
624
+ action="store_true",
625
+ help="Whether to use Exponential Moving Average for the final model weights.",
626
+ )
627
+ parser.add_argument(
628
+ "--ema_min_decay",
629
+ type=float,
630
+ default=None,
631
+ help=(
632
+ "The minimum decay magnitude for EMA. If not set, this will default to the value of `ema_max_decay`,"
633
+ " resulting in a constant EMA decay rate."
634
+ ),
635
+ )
636
+ parser.add_argument(
637
+ "--ema_max_decay",
638
+ type=float,
639
+ default=0.99993,
640
+ help=(
641
+ "The maximum decay magnitude for EMA. Setting `ema_min_decay` equal to this value will result in a"
642
+ " constant decay rate."
643
+ ),
644
+ )
645
+ parser.add_argument(
646
+ "--use_ema_warmup",
647
+ action="store_true",
648
+ help="Whether to use EMA warmup.",
649
+ )
650
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
651
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
652
+ # ----Training Optimization Arguments----
653
+ parser.add_argument(
654
+ "--mixed_precision",
655
+ type=str,
656
+ default="no",
657
+ choices=["no", "fp16", "bf16"],
658
+ help=(
659
+ "Whether to use mixed precision. Choose"
660
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
661
+ "and an Nvidia Ampere GPU."
662
+ ),
663
+ )
664
+ parser.add_argument(
665
+ "--allow_tf32",
666
+ action="store_true",
667
+ help=(
668
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
669
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
670
+ ),
671
+ )
672
+ parser.add_argument(
673
+ "--gradient_checkpointing",
674
+ action="store_true",
675
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
676
+ )
677
+ parser.add_argument(
678
+ "--gradient_accumulation_steps",
679
+ type=int,
680
+ default=1,
681
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
682
+ )
683
+ parser.add_argument(
684
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
685
+ )
686
+ # ----Distributed Training Arguments----
687
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
688
+ # ------------Validation Arguments-----------
689
+ parser.add_argument(
690
+ "--validation_steps",
691
+ type=int,
692
+ default=200,
693
+ help="Run validation every X steps.",
694
+ )
695
+ parser.add_argument(
696
+ "--eval_batch_size",
697
+ type=int,
698
+ default=16,
699
+ help=(
700
+ "The number of images to generate for evaluation. Note that if `class_conditional` and `num_classes` is"
701
+ " set the effective number of images generated per evaluation step is `eval_batch_size * num_classes`."
702
+ ),
703
+ )
704
+ parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
705
+ # ------------Validation Arguments-----------
706
+ parser.add_argument(
707
+ "--checkpointing_steps",
708
+ type=int,
709
+ default=500,
710
+ help=(
711
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
712
+ " training using `--resume_from_checkpoint`."
713
+ ),
714
+ )
715
+ parser.add_argument(
716
+ "--checkpoints_total_limit",
717
+ type=int,
718
+ default=None,
719
+ help=("Max number of checkpoints to store."),
720
+ )
721
+ parser.add_argument(
722
+ "--resume_from_checkpoint",
723
+ type=str,
724
+ default=None,
725
+ help=(
726
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
727
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
728
+ ),
729
+ )
730
+ parser.add_argument(
731
+ "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
732
+ )
733
+ # ------------Logging Arguments-----------
734
+ parser.add_argument(
735
+ "--report_to",
736
+ type=str,
737
+ default="tensorboard",
738
+ help=(
739
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
740
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
741
+ ),
742
+ )
743
+ parser.add_argument(
744
+ "--logging_dir",
745
+ type=str,
746
+ default="logs",
747
+ help=(
748
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
749
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
750
+ ),
751
+ )
752
+ # ------------HuggingFace Hub Arguments-----------
753
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
754
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
755
+ parser.add_argument(
756
+ "--hub_model_id",
757
+ type=str,
758
+ default=None,
759
+ help="The name of the repository to keep in sync with the local `output_dir`.",
760
+ )
761
+ parser.add_argument(
762
+ "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
763
+ )
764
+ # ------------Accelerate Arguments-----------
765
+ parser.add_argument(
766
+ "--tracker_project_name",
767
+ type=str,
768
+ default="consistency-training",
769
+ help=(
770
+ "The `project_name` argument passed to Accelerator.init_trackers for"
771
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
772
+ ),
773
+ )
774
+
775
+ args = parser.parse_args()
776
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
777
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
778
+ args.local_rank = env_local_rank
779
+
780
+ if args.dataset_name is None and args.train_data_dir is None:
781
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
782
+
783
+ return args
784
+
785
+
786
+ def main(args):
787
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
788
+
789
+ if args.report_to == "wandb" and args.hub_token is not None:
790
+ raise ValueError(
791
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
792
+ " Please use `huggingface-cli login` to authenticate with the Hub."
793
+ )
794
+
795
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
796
+
797
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) # a big number for high resolution or big dataset
798
+ accelerator = Accelerator(
799
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
800
+ mixed_precision=args.mixed_precision,
801
+ log_with=args.report_to,
802
+ project_config=accelerator_project_config,
803
+ kwargs_handlers=[kwargs],
804
+ )
805
+
806
+ if args.report_to == "tensorboard":
807
+ if not is_tensorboard_available():
808
+ raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
809
+
810
+ elif args.report_to == "wandb":
811
+ if not is_wandb_available():
812
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
813
+
814
+ # Make one log on every process with the configuration for debugging.
815
+ logging.basicConfig(
816
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
817
+ datefmt="%m/%d/%Y %H:%M:%S",
818
+ level=logging.INFO,
819
+ )
820
+ logger.info(accelerator.state, main_process_only=False)
821
+ if accelerator.is_local_main_process:
822
+ datasets.utils.logging.set_verbosity_warning()
823
+ diffusers.utils.logging.set_verbosity_info()
824
+ else:
825
+ datasets.utils.logging.set_verbosity_error()
826
+ diffusers.utils.logging.set_verbosity_error()
827
+
828
+ # If passed along, set the training seed now.
829
+ if args.seed is not None:
830
+ set_seed(args.seed)
831
+
832
+ # Handle the repository creation
833
+ if accelerator.is_main_process:
834
+ if args.output_dir is not None:
835
+ os.makedirs(args.output_dir, exist_ok=True)
836
+
837
+ if args.push_to_hub:
838
+ repo_id = create_repo(
839
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
840
+ ).repo_id
841
+
842
+ # 1. Initialize the noise scheduler.
843
+ initial_discretization_steps = get_discretization_steps(
844
+ 0,
845
+ args.max_train_steps,
846
+ s_0=args.discretization_s_0,
847
+ s_1=args.discretization_s_1,
848
+ constant=args.constant_discretization_steps,
849
+ )
850
+ noise_scheduler = CMStochasticIterativeScheduler(
851
+ num_train_timesteps=initial_discretization_steps,
852
+ sigma_min=args.sigma_min,
853
+ sigma_max=args.sigma_max,
854
+ rho=args.rho,
855
+ )
856
+
857
+ # 2. Initialize the student U-Net model.
858
+ if args.pretrained_model_name_or_path is not None:
859
+ logger.info(f"Loading pretrained U-Net weights from {args.pretrained_model_name_or_path}... ")
860
+ unet = UNet2DModel.from_pretrained(
861
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
862
+ )
863
+ elif args.model_config_name_or_path is None:
864
+ # TODO: use default architectures from iCT paper
865
+ if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
866
+ logger.warning(
867
+ f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
868
+ f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
869
+ )
870
+ args.num_classes = None
871
+ args.class_embed_type = None
872
+ elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
873
+ logger.warning(
874
+ "`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
875
+ "`class_conditional` will be overridden to `False`."
876
+ )
877
+ args.class_conditional = False
878
+ unet = UNet2DModel(
879
+ sample_size=args.resolution,
880
+ in_channels=3,
881
+ out_channels=3,
882
+ layers_per_block=2,
883
+ block_out_channels=(128, 128, 256, 256, 512, 512),
884
+ down_block_types=(
885
+ "DownBlock2D",
886
+ "DownBlock2D",
887
+ "DownBlock2D",
888
+ "DownBlock2D",
889
+ "AttnDownBlock2D",
890
+ "DownBlock2D",
891
+ ),
892
+ up_block_types=(
893
+ "UpBlock2D",
894
+ "AttnUpBlock2D",
895
+ "UpBlock2D",
896
+ "UpBlock2D",
897
+ "UpBlock2D",
898
+ "UpBlock2D",
899
+ ),
900
+ class_embed_type=args.class_embed_type,
901
+ num_class_embeds=args.num_classes,
902
+ )
903
+ else:
904
+ config = UNet2DModel.load_config(args.model_config_name_or_path)
905
+ unet = UNet2DModel.from_config(config)
906
+ unet.train()
907
+
908
+ # Create EMA for the student U-Net model.
909
+ if args.use_ema:
910
+ if args.ema_min_decay is None:
911
+ args.ema_min_decay = args.ema_max_decay
912
+ ema_unet = EMAModel(
913
+ unet.parameters(),
914
+ decay=args.ema_max_decay,
915
+ min_decay=args.ema_min_decay,
916
+ use_ema_warmup=args.use_ema_warmup,
917
+ inv_gamma=args.ema_inv_gamma,
918
+ power=args.ema_power,
919
+ model_cls=UNet2DModel,
920
+ model_config=unet.config,
921
+ )
922
+
923
+ # 3. Initialize the teacher U-Net model from the student U-Net model.
924
+ # Note that following the improved Consistency Training paper, the teacher U-Net is not updated via EMA (e.g. the
925
+ # EMA decay rate is 0.)
926
+ teacher_unet = UNet2DModel.from_config(unet.config)
927
+ teacher_unet.load_state_dict(unet.state_dict())
928
+ teacher_unet.train()
929
+ teacher_unet.requires_grad_(False)
930
+
931
+ # 4. Handle mixed precision and device placement
932
+ weight_dtype = torch.float32
933
+ if accelerator.mixed_precision == "fp16":
934
+ weight_dtype = torch.float16
935
+ args.mixed_precision = accelerator.mixed_precision
936
+ elif accelerator.mixed_precision == "bf16":
937
+ weight_dtype = torch.bfloat16
938
+ args.mixed_precision = accelerator.mixed_precision
939
+
940
+ # Cast teacher_unet to weight_dtype if cast_teacher is set.
941
+ if args.cast_teacher:
942
+ teacher_dtype = weight_dtype
943
+ else:
944
+ teacher_dtype = torch.float32
945
+
946
+ teacher_unet.to(accelerator.device)
947
+ if args.use_ema:
948
+ ema_unet.to(accelerator.device)
949
+
950
+ # 5. Handle saving and loading of checkpoints.
951
+ # `accelerate` 0.16.0 will have better support for customized saving
952
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
953
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
954
+ def save_model_hook(models, weights, output_dir):
955
+ if accelerator.is_main_process:
956
+ teacher_unet.save_pretrained(os.path.join(output_dir, "unet_teacher"))
957
+ if args.use_ema:
958
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
959
+
960
+ for i, model in enumerate(models):
961
+ model.save_pretrained(os.path.join(output_dir, "unet"))
962
+
963
+ # make sure to pop weight so that corresponding model is not saved again
964
+ weights.pop()
965
+
966
+ def load_model_hook(models, input_dir):
967
+ load_model = UNet2DModel.from_pretrained(os.path.join(input_dir, "unet_teacher"))
968
+ teacher_unet.load_state_dict(load_model.state_dict())
969
+ teacher_unet.to(accelerator.device)
970
+ del load_model
971
+
972
+ if args.use_ema:
973
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
974
+ ema_unet.load_state_dict(load_model.state_dict())
975
+ ema_unet.to(accelerator.device)
976
+ del load_model
977
+
978
+ for i in range(len(models)):
979
+ # pop models so that they are not loaded again
980
+ model = models.pop()
981
+
982
+ # load diffusers style into model
983
+ load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
984
+ model.register_to_config(**load_model.config)
985
+
986
+ model.load_state_dict(load_model.state_dict())
987
+ del load_model
988
+
989
+ accelerator.register_save_state_pre_hook(save_model_hook)
990
+ accelerator.register_load_state_pre_hook(load_model_hook)
991
+
992
+ # 6. Enable optimizations
993
+ if args.enable_xformers_memory_efficient_attention:
994
+ if is_xformers_available():
995
+ import xformers
996
+
997
+ xformers_version = version.parse(xformers.__version__)
998
+ if xformers_version == version.parse("0.0.16"):
999
+ logger.warning(
1000
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1001
+ )
1002
+ unet.enable_xformers_memory_efficient_attention()
1003
+ teacher_unet.enable_xformers_memory_efficient_attention()
1004
+ if args.use_ema:
1005
+ ema_unet.enable_xformers_memory_efficient_attention()
1006
+ else:
1007
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1008
+
1009
+ # Enable TF32 for faster training on Ampere GPUs,
1010
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1011
+ if args.allow_tf32:
1012
+ torch.backends.cuda.matmul.allow_tf32 = True
1013
+
1014
+ if args.gradient_checkpointing:
1015
+ unet.enable_gradient_checkpointing()
1016
+
1017
+ if args.optimizer_type == "radam":
1018
+ optimizer_class = torch.optim.RAdam
1019
+ elif args.optimizer_type == "adamw":
1020
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model for 16GB GPUs
1021
+ if args.use_8bit_adam:
1022
+ try:
1023
+ import bitsandbytes as bnb
1024
+ except ImportError:
1025
+ raise ImportError(
1026
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1027
+ )
1028
+
1029
+ optimizer_class = bnb.optim.AdamW8bit
1030
+ else:
1031
+ optimizer_class = torch.optim.AdamW
1032
+ else:
1033
+ raise ValueError(
1034
+ f"Optimizer type {args.optimizer_type} is not supported. Currently supported optimizer types are `radam`"
1035
+ f" and `adamw`."
1036
+ )
1037
+
1038
+ # 7. Initialize the optimizer
1039
+ optimizer = optimizer_class(
1040
+ unet.parameters(),
1041
+ lr=args.learning_rate,
1042
+ betas=(args.adam_beta1, args.adam_beta2),
1043
+ weight_decay=args.adam_weight_decay,
1044
+ eps=args.adam_epsilon,
1045
+ )
1046
+
1047
+ # 8. Dataset creation and data preprocessing
1048
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
1049
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
1050
+
1051
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
1052
+ # download the dataset.
1053
+ if args.dataset_name is not None:
1054
+ dataset = load_dataset(
1055
+ args.dataset_name,
1056
+ args.dataset_config_name,
1057
+ cache_dir=args.cache_dir,
1058
+ split="train",
1059
+ )
1060
+ else:
1061
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
1062
+ # See more about loading custom images at
1063
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
1064
+
1065
+ # Preprocessing the datasets and DataLoaders creation.
1066
+ interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
1067
+ augmentations = transforms.Compose(
1068
+ [
1069
+ transforms.Resize(args.resolution, interpolation=interpolation_mode),
1070
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
1071
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
1072
+ transforms.ToTensor(),
1073
+ transforms.Normalize([0.5], [0.5]),
1074
+ ]
1075
+ )
1076
+
1077
+ def transform_images(examples):
1078
+ images = [augmentations(image.convert("RGB")) for image in examples[args.dataset_image_column_name]]
1079
+ batch_dict = {"images": images}
1080
+ if args.class_conditional:
1081
+ batch_dict["class_labels"] = examples[args.dataset_class_label_column_name]
1082
+ return batch_dict
1083
+
1084
+ logger.info(f"Dataset size: {len(dataset)}")
1085
+
1086
+ dataset.set_transform(transform_images)
1087
+ train_dataloader = torch.utils.data.DataLoader(
1088
+ dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
1089
+ )
1090
+
1091
+ # 9. Initialize the learning rate scheduler
1092
+ # Scheduler and math around the number of training steps.
1093
+ overrode_max_train_steps = False
1094
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1095
+ if args.max_train_steps is None:
1096
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1097
+ overrode_max_train_steps = True
1098
+
1099
+ lr_scheduler = get_scheduler(
1100
+ args.lr_scheduler,
1101
+ optimizer=optimizer,
1102
+ num_warmup_steps=args.lr_warmup_steps,
1103
+ num_training_steps=args.max_train_steps,
1104
+ )
1105
+
1106
+ # 10. Prepare for training
1107
+ # Prepare everything with our `accelerator`.
1108
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1109
+ unet, optimizer, train_dataloader, lr_scheduler
1110
+ )
1111
+
1112
+ def recalculate_num_discretization_step_values(discretization_steps, skip_steps):
1113
+ """
1114
+ Recalculates all quantities depending on the number of discretization steps N.
1115
+ """
1116
+ noise_scheduler = CMStochasticIterativeScheduler(
1117
+ num_train_timesteps=discretization_steps,
1118
+ sigma_min=args.sigma_min,
1119
+ sigma_max=args.sigma_max,
1120
+ rho=args.rho,
1121
+ )
1122
+ current_timesteps = get_karras_sigmas(discretization_steps, args.sigma_min, args.sigma_max, args.rho)
1123
+ valid_teacher_timesteps_plus_one = current_timesteps[: len(current_timesteps) - skip_steps + 1]
1124
+ # timestep_weights are the unnormalized probabilities of sampling the timestep/noise level at each index
1125
+ timestep_weights = get_discretized_lognormal_weights(
1126
+ valid_teacher_timesteps_plus_one, p_mean=args.p_mean, p_std=args.p_std
1127
+ )
1128
+ # timestep_loss_weights is the timestep-dependent loss weighting schedule lambda(sigma_i)
1129
+ timestep_loss_weights = get_loss_weighting_schedule(valid_teacher_timesteps_plus_one)
1130
+
1131
+ current_timesteps = current_timesteps.to(accelerator.device)
1132
+ timestep_weights = timestep_weights.to(accelerator.device)
1133
+ timestep_loss_weights = timestep_loss_weights.to(accelerator.device)
1134
+
1135
+ return noise_scheduler, current_timesteps, timestep_weights, timestep_loss_weights
1136
+
1137
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1138
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1139
+ if overrode_max_train_steps:
1140
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1141
+ # Afterwards we recalculate our number of training epochs
1142
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1143
+
1144
+ # We need to initialize the trackers we use, and also store our configuration.
1145
+ # The trackers initializes automatically on the main process.
1146
+ if accelerator.is_main_process:
1147
+ tracker_config = dict(vars(args))
1148
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1149
+
1150
+ # Function for unwrapping if torch.compile() was used in accelerate.
1151
+ def unwrap_model(model):
1152
+ model = accelerator.unwrap_model(model)
1153
+ model = model._orig_mod if is_compiled_module(model) else model
1154
+ return model
1155
+
1156
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1157
+
1158
+ logger.info("***** Running training *****")
1159
+ logger.info(f" Num examples = {len(dataset)}")
1160
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1161
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1162
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1163
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1164
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1165
+
1166
+ global_step = 0
1167
+ first_epoch = 0
1168
+
1169
+ # Potentially load in the weights and states from a previous save
1170
+ if args.resume_from_checkpoint:
1171
+ if args.resume_from_checkpoint != "latest":
1172
+ path = os.path.basename(args.resume_from_checkpoint)
1173
+ else:
1174
+ # Get the most recent checkpoint
1175
+ dirs = os.listdir(args.output_dir)
1176
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1177
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1178
+ path = dirs[-1] if len(dirs) > 0 else None
1179
+
1180
+ if path is None:
1181
+ accelerator.print(
1182
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1183
+ )
1184
+ args.resume_from_checkpoint = None
1185
+ initial_global_step = 0
1186
+ else:
1187
+ accelerator.print(f"Resuming from checkpoint {path}")
1188
+ accelerator.load_state(os.path.join(args.output_dir, path))
1189
+ global_step = int(path.split("-")[1])
1190
+
1191
+ initial_global_step = global_step
1192
+ first_epoch = global_step // num_update_steps_per_epoch
1193
+ else:
1194
+ initial_global_step = 0
1195
+
1196
+ # Resolve the c parameter for the Pseudo-Huber loss
1197
+ if args.huber_c is None:
1198
+ args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
1199
+
1200
+ # Get current number of discretization steps N according to our discretization curriculum
1201
+ current_discretization_steps = get_discretization_steps(
1202
+ initial_global_step,
1203
+ args.max_train_steps,
1204
+ s_0=args.discretization_s_0,
1205
+ s_1=args.discretization_s_1,
1206
+ constant=args.constant_discretization_steps,
1207
+ )
1208
+ current_skip_steps = get_skip_steps(initial_global_step, initial_skip=args.skip_steps)
1209
+ if current_skip_steps >= current_discretization_steps:
1210
+ raise ValueError(
1211
+ f"The current skip steps is {current_skip_steps}, but should be smaller than the current number of"
1212
+ f" discretization steps {current_discretization_steps}"
1213
+ )
1214
+ # Recalculate all quantities depending on the number of discretization steps N
1215
+ (
1216
+ noise_scheduler,
1217
+ current_timesteps,
1218
+ timestep_weights,
1219
+ timestep_loss_weights,
1220
+ ) = recalculate_num_discretization_step_values(current_discretization_steps, current_skip_steps)
1221
+
1222
+ progress_bar = tqdm(
1223
+ range(0, args.max_train_steps),
1224
+ initial=initial_global_step,
1225
+ desc="Steps",
1226
+ # Only show the progress bar once on each machine.
1227
+ disable=not accelerator.is_local_main_process,
1228
+ )
1229
+
1230
+ # 11. Train!
1231
+ for epoch in range(first_epoch, args.num_train_epochs):
1232
+ unet.train()
1233
+ for step, batch in enumerate(train_dataloader):
1234
+ # 1. Get batch of images from dataloader (sample x ~ p_data(x))
1235
+ clean_images = batch["images"].to(weight_dtype)
1236
+ if args.class_conditional:
1237
+ class_labels = batch["class_labels"]
1238
+ else:
1239
+ class_labels = None
1240
+ bsz = clean_images.shape[0]
1241
+
1242
+ # 2. Sample a random timestep for each image according to the noise schedule.
1243
+ # Sample random indices i ~ p(i), where p(i) is the dicretized lognormal distribution in the iCT paper
1244
+ # NOTE: timestep_indices should be in the range [0, len(current_timesteps) - k - 1] inclusive
1245
+ timestep_indices = torch.multinomial(timestep_weights, bsz, replacement=True).long()
1246
+ teacher_timesteps = current_timesteps[timestep_indices]
1247
+ student_timesteps = current_timesteps[timestep_indices + current_skip_steps]
1248
+
1249
+ # 3. Sample noise and add it to the clean images for both teacher and student unets
1250
+ # Sample noise z ~ N(0, I) that we'll add to the images
1251
+ noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
1252
+ # Add noise to the clean images according to the noise magnitude at each timestep
1253
+ # (this is the forward diffusion process)
1254
+ teacher_noisy_images = add_noise(clean_images, noise, teacher_timesteps)
1255
+ student_noisy_images = add_noise(clean_images, noise, student_timesteps)
1256
+
1257
+ # 4. Calculate preconditioning and scalings for boundary conditions for the consistency model.
1258
+ teacher_rescaled_timesteps = get_noise_preconditioning(teacher_timesteps, args.noise_precond_type)
1259
+ student_rescaled_timesteps = get_noise_preconditioning(student_timesteps, args.noise_precond_type)
1260
+
1261
+ c_in_teacher = get_input_preconditioning(teacher_timesteps, input_precond_type=args.input_precond_type)
1262
+ c_in_student = get_input_preconditioning(student_timesteps, input_precond_type=args.input_precond_type)
1263
+
1264
+ c_skip_teacher, c_out_teacher = scalings_for_boundary_conditions(teacher_timesteps)
1265
+ c_skip_student, c_out_student = scalings_for_boundary_conditions(student_timesteps)
1266
+
1267
+ c_skip_teacher, c_out_teacher, c_in_teacher = [
1268
+ append_dims(x, clean_images.ndim) for x in [c_skip_teacher, c_out_teacher, c_in_teacher]
1269
+ ]
1270
+ c_skip_student, c_out_student, c_in_student = [
1271
+ append_dims(x, clean_images.ndim) for x in [c_skip_student, c_out_student, c_in_student]
1272
+ ]
1273
+
1274
+ with accelerator.accumulate(unet):
1275
+ # 5. Get the student unet denoising prediction on the student timesteps
1276
+ # Get rng state now to ensure that dropout is synced between the student and teacher models.
1277
+ dropout_state = torch.get_rng_state()
1278
+ student_model_output = unet(
1279
+ c_in_student * student_noisy_images, student_rescaled_timesteps, class_labels=class_labels
1280
+ ).sample
1281
+ # NOTE: currently only support prediction_type == sample, so no need to convert model_output
1282
+ student_denoise_output = c_skip_student * student_noisy_images + c_out_student * student_model_output
1283
+
1284
+ # 6. Get the teacher unet denoising prediction on the teacher timesteps
1285
+ with torch.no_grad(), torch.autocast("cuda", dtype=teacher_dtype):
1286
+ torch.set_rng_state(dropout_state)
1287
+ teacher_model_output = teacher_unet(
1288
+ c_in_teacher * teacher_noisy_images, teacher_rescaled_timesteps, class_labels=class_labels
1289
+ ).sample
1290
+ # NOTE: currently only support prediction_type == sample, so no need to convert model_output
1291
+ teacher_denoise_output = (
1292
+ c_skip_teacher * teacher_noisy_images + c_out_teacher * teacher_model_output
1293
+ )
1294
+
1295
+ # 7. Calculate the weighted Pseudo-Huber loss
1296
+ if args.prediction_type == "sample":
1297
+ # Note that the loss weights should be those at the (teacher) timestep indices.
1298
+ lambda_t = _extract_into_tensor(
1299
+ timestep_loss_weights, timestep_indices, (bsz,) + (1,) * (clean_images.ndim - 1)
1300
+ )
1301
+ loss = lambda_t * (
1302
+ torch.sqrt(
1303
+ (student_denoise_output.float() - teacher_denoise_output.float()) ** 2 + args.huber_c**2
1304
+ )
1305
+ - args.huber_c
1306
+ )
1307
+ loss = loss.mean()
1308
+ else:
1309
+ raise ValueError(
1310
+ f"Unsupported prediction type: {args.prediction_type}. Currently, only `sample` is supported."
1311
+ )
1312
+
1313
+ # 8. Backpropagate on the consistency training loss
1314
+ accelerator.backward(loss)
1315
+ if accelerator.sync_gradients:
1316
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
1317
+ optimizer.step()
1318
+ lr_scheduler.step()
1319
+ optimizer.zero_grad()
1320
+
1321
+ # Checks if the accelerator has performed an optimization step behind the scenes
1322
+ if accelerator.sync_gradients:
1323
+ # 9. Update teacher_unet and ema_unet parameters using unet's parameters.
1324
+ teacher_unet.load_state_dict(unet.state_dict())
1325
+ if args.use_ema:
1326
+ ema_unet.step(unet.parameters())
1327
+ progress_bar.update(1)
1328
+ global_step += 1
1329
+
1330
+ if accelerator.is_main_process:
1331
+ # 10. Recalculate quantities depending on the global step, if necessary.
1332
+ new_discretization_steps = get_discretization_steps(
1333
+ global_step,
1334
+ args.max_train_steps,
1335
+ s_0=args.discretization_s_0,
1336
+ s_1=args.discretization_s_1,
1337
+ constant=args.constant_discretization_steps,
1338
+ )
1339
+ current_skip_steps = get_skip_steps(global_step, initial_skip=args.skip_steps)
1340
+ if current_skip_steps >= new_discretization_steps:
1341
+ raise ValueError(
1342
+ f"The current skip steps is {current_skip_steps}, but should be smaller than the current"
1343
+ f" number of discretization steps {new_discretization_steps}."
1344
+ )
1345
+ if new_discretization_steps != current_discretization_steps:
1346
+ (
1347
+ noise_scheduler,
1348
+ current_timesteps,
1349
+ timestep_weights,
1350
+ timestep_loss_weights,
1351
+ ) = recalculate_num_discretization_step_values(new_discretization_steps, current_skip_steps)
1352
+ current_discretization_steps = new_discretization_steps
1353
+
1354
+ if global_step % args.checkpointing_steps == 0:
1355
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1356
+ if args.checkpoints_total_limit is not None:
1357
+ checkpoints = os.listdir(args.output_dir)
1358
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1359
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1360
+
1361
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1362
+ if len(checkpoints) >= args.checkpoints_total_limit:
1363
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1364
+ removing_checkpoints = checkpoints[0:num_to_remove]
1365
+
1366
+ logger.info(
1367
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1368
+ )
1369
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1370
+
1371
+ for removing_checkpoint in removing_checkpoints:
1372
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1373
+ shutil.rmtree(removing_checkpoint)
1374
+
1375
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1376
+ accelerator.save_state(save_path)
1377
+ logger.info(f"Saved state to {save_path}")
1378
+
1379
+ if global_step % args.validation_steps == 0:
1380
+ # NOTE: since we do not use EMA for the teacher model, the teacher parameters and student
1381
+ # parameters are the same at this point in time
1382
+ log_validation(unet, noise_scheduler, args, accelerator, weight_dtype, global_step, "teacher")
1383
+ # teacher_unet.to(dtype=teacher_dtype)
1384
+
1385
+ if args.use_ema:
1386
+ # Store the student unet weights and load the EMA weights.
1387
+ ema_unet.store(unet.parameters())
1388
+ ema_unet.copy_to(unet.parameters())
1389
+
1390
+ log_validation(
1391
+ unet,
1392
+ noise_scheduler,
1393
+ args,
1394
+ accelerator,
1395
+ weight_dtype,
1396
+ global_step,
1397
+ "ema_student",
1398
+ )
1399
+
1400
+ # Restore student unet weights
1401
+ ema_unet.restore(unet.parameters())
1402
+
1403
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
1404
+ if args.use_ema:
1405
+ logs["ema_decay"] = ema_unet.cur_decay_value
1406
+ progress_bar.set_postfix(**logs)
1407
+ accelerator.log(logs, step=global_step)
1408
+
1409
+ if global_step >= args.max_train_steps:
1410
+ break
1411
+ # progress_bar.close()
1412
+
1413
+ accelerator.wait_for_everyone()
1414
+ if accelerator.is_main_process:
1415
+ unet = unwrap_model(unet)
1416
+ pipeline = ConsistencyModelPipeline(unet=unet, scheduler=noise_scheduler)
1417
+ pipeline.save_pretrained(args.output_dir)
1418
+
1419
+ # If using EMA, save EMA weights as well.
1420
+ if args.use_ema:
1421
+ ema_unet.copy_to(unet.parameters())
1422
+
1423
+ unet.save_pretrained(os.path.join(args.output_dir, "ema_unet"))
1424
+
1425
+ if args.push_to_hub:
1426
+ upload_folder(
1427
+ repo_id=repo_id,
1428
+ folder_path=args.output_dir,
1429
+ commit_message="End of training",
1430
+ ignore_patterns=["step_*", "epoch_*"],
1431
+ )
1432
+
1433
+ accelerator.end_training()
1434
+
1435
+
1436
+ if __name__ == "__main__":
1437
+ args = parse_args()
1438
+ main(args)
diffusers/examples/research_projects/diffusion_dpo/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusion Model Alignment Using Direct Preference Optimization
2
+
3
+ This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://huggingface.co/papers/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik.
4
+
5
+ We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below:
6
+
7
+ * [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)
8
+ * [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
9
+
10
+ > 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues.
11
+
12
+ ## SD training command
13
+
14
+ ```bash
15
+ accelerate launch train_diffusion_dpo.py \
16
+ --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \
17
+ --output_dir="diffusion-dpo" \
18
+ --mixed_precision="fp16" \
19
+ --dataset_name=kashif/pickascore \
20
+ --resolution=512 \
21
+ --train_batch_size=16 \
22
+ --gradient_accumulation_steps=2 \
23
+ --gradient_checkpointing \
24
+ --use_8bit_adam \
25
+ --rank=8 \
26
+ --learning_rate=1e-5 \
27
+ --report_to="wandb" \
28
+ --lr_scheduler="constant" \
29
+ --lr_warmup_steps=0 \
30
+ --max_train_steps=10000 \
31
+ --checkpointing_steps=2000 \
32
+ --run_validation --validation_steps=200 \
33
+ --seed="0" \
34
+ --report_to="wandb" \
35
+ --push_to_hub
36
+ ```
37
+
38
+ ## SDXL training command
39
+
40
+ ```bash
41
+ accelerate launch train_diffusion_dpo_sdxl.py \
42
+ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
43
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
44
+ --output_dir="diffusion-sdxl-dpo" \
45
+ --mixed_precision="fp16" \
46
+ --dataset_name=kashif/pickascore \
47
+ --train_batch_size=8 \
48
+ --gradient_accumulation_steps=2 \
49
+ --gradient_checkpointing \
50
+ --use_8bit_adam \
51
+ --rank=8 \
52
+ --learning_rate=1e-5 \
53
+ --report_to="wandb" \
54
+ --lr_scheduler="constant" \
55
+ --lr_warmup_steps=0 \
56
+ --max_train_steps=2000 \
57
+ --checkpointing_steps=500 \
58
+ --run_validation --validation_steps=50 \
59
+ --seed="0" \
60
+ --report_to="wandb" \
61
+ --push_to_hub
62
+ ```
63
+
64
+ ## SDXL Turbo training command
65
+
66
+ ```bash
67
+ accelerate launch train_diffusion_dpo_sdxl.py \
68
+ --pretrained_model_name_or_path=stabilityai/sdxl-turbo \
69
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
70
+ --output_dir="diffusion-sdxl-turbo-dpo" \
71
+ --mixed_precision="fp16" \
72
+ --dataset_name=kashif/pickascore \
73
+ --train_batch_size=8 \
74
+ --gradient_accumulation_steps=2 \
75
+ --gradient_checkpointing \
76
+ --use_8bit_adam \
77
+ --rank=8 \
78
+ --learning_rate=1e-5 \
79
+ --report_to="wandb" \
80
+ --lr_scheduler="constant" \
81
+ --lr_warmup_steps=0 \
82
+ --max_train_steps=2000 \
83
+ --checkpointing_steps=500 \
84
+ --run_validation --validation_steps=50 \
85
+ --seed="0" \
86
+ --report_to="wandb" \
87
+ --is_turbo --resolution 512 \
88
+ --push_to_hub
89
+ ```
90
+
91
+
92
+ ## Acknowledgements
93
+
94
+ This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
diffusers/examples/research_projects/diffusion_dpo/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
7
+ peft
8
+ wandb
diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 bram-w, The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import io
19
+ import logging
20
+ import math
21
+ import os
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ import wandb
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from packaging import version
37
+ from peft import LoraConfig
38
+ from peft.utils import get_peft_model_state_dict
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ DDPMScheduler,
48
+ DiffusionPipeline,
49
+ DPMSolverMultistepScheduler,
50
+ UNet2DConditionModel,
51
+ )
52
+ from diffusers.loaders import StableDiffusionLoraLoaderMixin
53
+ from diffusers.optimization import get_scheduler
54
+ from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
55
+ from diffusers.utils.import_utils import is_xformers_available
56
+
57
+
58
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
+ check_min_version("0.25.0.dev0")
60
+
61
+ logger = get_logger(__name__)
62
+
63
+
64
+ VALIDATION_PROMPTS = [
65
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
66
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
67
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
68
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
69
+ ]
70
+
71
+
72
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
73
+ text_encoder_config = PretrainedConfig.from_pretrained(
74
+ pretrained_model_name_or_path,
75
+ subfolder="text_encoder",
76
+ revision=revision,
77
+ )
78
+ model_class = text_encoder_config.architectures[0]
79
+
80
+ if model_class == "CLIPTextModel":
81
+ from transformers import CLIPTextModel
82
+
83
+ return CLIPTextModel
84
+ else:
85
+ raise ValueError(f"{model_class} is not supported.")
86
+
87
+
88
+ def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
89
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
90
+
91
+ # create pipeline
92
+ pipeline = DiffusionPipeline.from_pretrained(
93
+ args.pretrained_model_name_or_path,
94
+ revision=args.revision,
95
+ variant=args.variant,
96
+ torch_dtype=weight_dtype,
97
+ )
98
+ if not is_final_validation:
99
+ pipeline.unet = accelerator.unwrap_model(unet)
100
+ else:
101
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
102
+
103
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
104
+ pipeline = pipeline.to(accelerator.device)
105
+ pipeline.set_progress_bar_config(disable=True)
106
+
107
+ # run inference
108
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
109
+ images = []
110
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
111
+
112
+ for prompt in VALIDATION_PROMPTS:
113
+ with context:
114
+ image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
115
+ images.append(image)
116
+
117
+ tracker_key = "test" if is_final_validation else "validation"
118
+ for tracker in accelerator.trackers:
119
+ if tracker.name == "tensorboard":
120
+ np_images = np.stack([np.asarray(img) for img in images])
121
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
122
+ if tracker.name == "wandb":
123
+ tracker.log(
124
+ {
125
+ tracker_key: [
126
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
127
+ ]
128
+ }
129
+ )
130
+
131
+ # Also log images without the LoRA params for comparison.
132
+ if is_final_validation:
133
+ pipeline.disable_lora()
134
+ no_lora_images = [
135
+ pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
136
+ ]
137
+
138
+ for tracker in accelerator.trackers:
139
+ if tracker.name == "tensorboard":
140
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
141
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
142
+ if tracker.name == "wandb":
143
+ tracker.log(
144
+ {
145
+ "test_without_lora": [
146
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
147
+ for i, image in enumerate(no_lora_images)
148
+ ]
149
+ }
150
+ )
151
+
152
+
153
+ def parse_args(input_args=None):
154
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
155
+ parser.add_argument(
156
+ "--pretrained_model_name_or_path",
157
+ type=str,
158
+ default=None,
159
+ required=True,
160
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
161
+ )
162
+ parser.add_argument(
163
+ "--revision",
164
+ type=str,
165
+ default=None,
166
+ required=False,
167
+ help="Revision of pretrained model identifier from huggingface.co/models.",
168
+ )
169
+ parser.add_argument(
170
+ "--dataset_name",
171
+ type=str,
172
+ default=None,
173
+ help=(
174
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
175
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
176
+ " or to a folder containing files that 🤗 Datasets can understand."
177
+ ),
178
+ )
179
+ parser.add_argument(
180
+ "--dataset_split_name",
181
+ type=str,
182
+ default="validation",
183
+ help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
184
+ )
185
+ parser.add_argument(
186
+ "--variant",
187
+ type=str,
188
+ default=None,
189
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
190
+ )
191
+ parser.add_argument(
192
+ "--run_validation",
193
+ default=False,
194
+ action="store_true",
195
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
196
+ )
197
+ parser.add_argument(
198
+ "--validation_steps",
199
+ type=int,
200
+ default=200,
201
+ help="Run validation every X steps.",
202
+ )
203
+ parser.add_argument(
204
+ "--max_train_samples",
205
+ type=int,
206
+ default=None,
207
+ help=(
208
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
209
+ "value if set."
210
+ ),
211
+ )
212
+ parser.add_argument(
213
+ "--output_dir",
214
+ type=str,
215
+ default="diffusion-dpo-lora",
216
+ help="The output directory where the model predictions and checkpoints will be written.",
217
+ )
218
+ parser.add_argument(
219
+ "--cache_dir",
220
+ type=str,
221
+ default=None,
222
+ help="The directory where the downloaded models and datasets will be stored.",
223
+ )
224
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
225
+ parser.add_argument(
226
+ "--resolution",
227
+ type=int,
228
+ default=512,
229
+ help=(
230
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
231
+ " resolution"
232
+ ),
233
+ )
234
+ parser.add_argument(
235
+ "--vae_encode_batch_size",
236
+ type=int,
237
+ default=8,
238
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
239
+ )
240
+ parser.add_argument(
241
+ "--no_hflip",
242
+ action="store_true",
243
+ help="whether to randomly flip images horizontally",
244
+ )
245
+ parser.add_argument(
246
+ "--random_crop",
247
+ default=False,
248
+ action="store_true",
249
+ help=(
250
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
251
+ ),
252
+ )
253
+ parser.add_argument(
254
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
255
+ )
256
+ parser.add_argument("--num_train_epochs", type=int, default=1)
257
+ parser.add_argument(
258
+ "--max_train_steps",
259
+ type=int,
260
+ default=None,
261
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
262
+ )
263
+ parser.add_argument(
264
+ "--checkpointing_steps",
265
+ type=int,
266
+ default=500,
267
+ help=(
268
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
269
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
270
+ " training using `--resume_from_checkpoint`."
271
+ ),
272
+ )
273
+ parser.add_argument(
274
+ "--checkpoints_total_limit",
275
+ type=int,
276
+ default=None,
277
+ help=("Max number of checkpoints to store."),
278
+ )
279
+ parser.add_argument(
280
+ "--resume_from_checkpoint",
281
+ type=str,
282
+ default=None,
283
+ help=(
284
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
285
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--gradient_accumulation_steps",
290
+ type=int,
291
+ default=1,
292
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
293
+ )
294
+ parser.add_argument(
295
+ "--gradient_checkpointing",
296
+ action="store_true",
297
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
298
+ )
299
+ parser.add_argument(
300
+ "--beta_dpo",
301
+ type=int,
302
+ default=2500,
303
+ help="DPO KL Divergence penalty.",
304
+ )
305
+ parser.add_argument(
306
+ "--loss_type",
307
+ type=str,
308
+ default="sigmoid",
309
+ help="DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'",
310
+ )
311
+ parser.add_argument(
312
+ "--learning_rate",
313
+ type=float,
314
+ default=5e-4,
315
+ help="Initial learning rate (after the potential warmup period) to use.",
316
+ )
317
+ parser.add_argument(
318
+ "--scale_lr",
319
+ action="store_true",
320
+ default=False,
321
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
322
+ )
323
+ parser.add_argument(
324
+ "--lr_scheduler",
325
+ type=str,
326
+ default="constant",
327
+ help=(
328
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
329
+ ' "constant", "constant_with_warmup"]'
330
+ ),
331
+ )
332
+ parser.add_argument(
333
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
334
+ )
335
+ parser.add_argument(
336
+ "--lr_num_cycles",
337
+ type=int,
338
+ default=1,
339
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
340
+ )
341
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
342
+ parser.add_argument(
343
+ "--dataloader_num_workers",
344
+ type=int,
345
+ default=0,
346
+ help=(
347
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
352
+ )
353
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
354
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
355
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
356
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
357
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
358
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
359
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
360
+ parser.add_argument(
361
+ "--hub_model_id",
362
+ type=str,
363
+ default=None,
364
+ help="The name of the repository to keep in sync with the local `output_dir`.",
365
+ )
366
+ parser.add_argument(
367
+ "--logging_dir",
368
+ type=str,
369
+ default="logs",
370
+ help=(
371
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
372
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
373
+ ),
374
+ )
375
+ parser.add_argument(
376
+ "--allow_tf32",
377
+ action="store_true",
378
+ help=(
379
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
380
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
381
+ ),
382
+ )
383
+ parser.add_argument(
384
+ "--report_to",
385
+ type=str,
386
+ default="tensorboard",
387
+ help=(
388
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
389
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
390
+ ),
391
+ )
392
+ parser.add_argument(
393
+ "--mixed_precision",
394
+ type=str,
395
+ default=None,
396
+ choices=["no", "fp16", "bf16"],
397
+ help=(
398
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
399
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
400
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
401
+ ),
402
+ )
403
+ parser.add_argument(
404
+ "--prior_generation_precision",
405
+ type=str,
406
+ default=None,
407
+ choices=["no", "fp32", "fp16", "bf16"],
408
+ help=(
409
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
410
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
411
+ ),
412
+ )
413
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
414
+ parser.add_argument(
415
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
416
+ )
417
+ parser.add_argument(
418
+ "--rank",
419
+ type=int,
420
+ default=4,
421
+ help=("The dimension of the LoRA update matrices."),
422
+ )
423
+ parser.add_argument(
424
+ "--tracker_name",
425
+ type=str,
426
+ default="diffusion-dpo-lora",
427
+ help=("The name of the tracker to report results to."),
428
+ )
429
+
430
+ if input_args is not None:
431
+ args = parser.parse_args(input_args)
432
+ else:
433
+ args = parser.parse_args()
434
+
435
+ if args.dataset_name is None:
436
+ raise ValueError("Must provide a `dataset_name`.")
437
+
438
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
439
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
440
+ args.local_rank = env_local_rank
441
+
442
+ return args
443
+
444
+
445
+ def tokenize_captions(tokenizer, examples):
446
+ max_length = tokenizer.model_max_length
447
+ captions = []
448
+ for caption in examples["caption"]:
449
+ captions.append(caption)
450
+
451
+ text_inputs = tokenizer(
452
+ captions, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt"
453
+ )
454
+
455
+ return text_inputs.input_ids
456
+
457
+
458
+ @torch.no_grad()
459
+ def encode_prompt(text_encoder, input_ids):
460
+ text_input_ids = input_ids.to(text_encoder.device)
461
+ attention_mask = None
462
+
463
+ prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)
464
+ prompt_embeds = prompt_embeds[0]
465
+
466
+ return prompt_embeds
467
+
468
+
469
+ def main(args):
470
+ if args.report_to == "wandb" and args.hub_token is not None:
471
+ raise ValueError(
472
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
473
+ " Please use `huggingface-cli login` to authenticate with the Hub."
474
+ )
475
+
476
+ logging_dir = Path(args.output_dir, args.logging_dir)
477
+
478
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
479
+
480
+ accelerator = Accelerator(
481
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
482
+ mixed_precision=args.mixed_precision,
483
+ log_with=args.report_to,
484
+ project_config=accelerator_project_config,
485
+ )
486
+
487
+ # Disable AMP for MPS.
488
+ if torch.backends.mps.is_available():
489
+ accelerator.native_amp = False
490
+
491
+ # Make one log on every process with the configuration for debugging.
492
+ logging.basicConfig(
493
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
494
+ datefmt="%m/%d/%Y %H:%M:%S",
495
+ level=logging.INFO,
496
+ )
497
+ logger.info(accelerator.state, main_process_only=False)
498
+ if accelerator.is_local_main_process:
499
+ transformers.utils.logging.set_verbosity_warning()
500
+ diffusers.utils.logging.set_verbosity_info()
501
+ else:
502
+ transformers.utils.logging.set_verbosity_error()
503
+ diffusers.utils.logging.set_verbosity_error()
504
+
505
+ # If passed along, set the training seed now.
506
+ if args.seed is not None:
507
+ set_seed(args.seed)
508
+
509
+ # Handle the repository creation
510
+ if accelerator.is_main_process:
511
+ if args.output_dir is not None:
512
+ os.makedirs(args.output_dir, exist_ok=True)
513
+
514
+ if args.push_to_hub:
515
+ repo_id = create_repo(
516
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
517
+ ).repo_id
518
+
519
+ # Load the tokenizer
520
+ tokenizer = AutoTokenizer.from_pretrained(
521
+ args.pretrained_model_name_or_path,
522
+ subfolder="tokenizer",
523
+ revision=args.revision,
524
+ use_fast=False,
525
+ )
526
+
527
+ # import correct text encoder class
528
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
529
+
530
+ # Load scheduler and models
531
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
532
+ text_encoder = text_encoder_cls.from_pretrained(
533
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
534
+ )
535
+ vae = AutoencoderKL.from_pretrained(
536
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
537
+ )
538
+
539
+ unet = UNet2DConditionModel.from_pretrained(
540
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
541
+ )
542
+
543
+ vae.requires_grad_(False)
544
+ text_encoder.requires_grad_(False)
545
+ unet.requires_grad_(False)
546
+
547
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
548
+ # as these weights are only used for inference, keeping weights in full precision is not required.
549
+ weight_dtype = torch.float32
550
+ if accelerator.mixed_precision == "fp16":
551
+ weight_dtype = torch.float16
552
+ elif accelerator.mixed_precision == "bf16":
553
+ weight_dtype = torch.bfloat16
554
+
555
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
556
+ unet.to(accelerator.device, dtype=weight_dtype)
557
+ vae.to(accelerator.device, dtype=weight_dtype)
558
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
559
+
560
+ # Set up LoRA.
561
+ unet_lora_config = LoraConfig(
562
+ r=args.rank,
563
+ lora_alpha=args.rank,
564
+ init_lora_weights="gaussian",
565
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
566
+ )
567
+ # Add adapter and make sure the trainable params are in float32.
568
+ unet.add_adapter(unet_lora_config)
569
+ if args.mixed_precision == "fp16":
570
+ for param in unet.parameters():
571
+ # only upcast trainable parameters (LoRA) into fp32
572
+ if param.requires_grad:
573
+ param.data = param.to(torch.float32)
574
+
575
+ if args.enable_xformers_memory_efficient_attention:
576
+ if is_xformers_available():
577
+ import xformers
578
+
579
+ xformers_version = version.parse(xformers.__version__)
580
+ if xformers_version == version.parse("0.0.16"):
581
+ logger.warning(
582
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
583
+ )
584
+ unet.enable_xformers_memory_efficient_attention()
585
+ else:
586
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
587
+
588
+ if args.gradient_checkpointing:
589
+ unet.enable_gradient_checkpointing()
590
+
591
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
592
+ def save_model_hook(models, weights, output_dir):
593
+ if accelerator.is_main_process:
594
+ # there are only two options here. Either are just the unet attn processor layers
595
+ # or there are the unet and text encoder atten layers
596
+ unet_lora_layers_to_save = None
597
+
598
+ for model in models:
599
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
600
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
601
+ else:
602
+ raise ValueError(f"unexpected save model: {model.__class__}")
603
+
604
+ # make sure to pop weight so that corresponding model is not saved again
605
+ weights.pop()
606
+
607
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
608
+ output_dir,
609
+ unet_lora_layers=unet_lora_layers_to_save,
610
+ text_encoder_lora_layers=None,
611
+ )
612
+
613
+ def load_model_hook(models, input_dir):
614
+ unet_ = None
615
+
616
+ while len(models) > 0:
617
+ model = models.pop()
618
+
619
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
620
+ unet_ = model
621
+ else:
622
+ raise ValueError(f"unexpected save model: {model.__class__}")
623
+
624
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
625
+ StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
626
+
627
+ accelerator.register_save_state_pre_hook(save_model_hook)
628
+ accelerator.register_load_state_pre_hook(load_model_hook)
629
+
630
+ # Enable TF32 for faster training on Ampere GPUs,
631
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
632
+ if args.allow_tf32:
633
+ torch.backends.cuda.matmul.allow_tf32 = True
634
+
635
+ if args.scale_lr:
636
+ args.learning_rate = (
637
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
638
+ )
639
+
640
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
641
+ if args.use_8bit_adam:
642
+ try:
643
+ import bitsandbytes as bnb
644
+ except ImportError:
645
+ raise ImportError(
646
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
647
+ )
648
+
649
+ optimizer_class = bnb.optim.AdamW8bit
650
+ else:
651
+ optimizer_class = torch.optim.AdamW
652
+
653
+ # Optimizer creation
654
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
655
+ optimizer = optimizer_class(
656
+ params_to_optimize,
657
+ lr=args.learning_rate,
658
+ betas=(args.adam_beta1, args.adam_beta2),
659
+ weight_decay=args.adam_weight_decay,
660
+ eps=args.adam_epsilon,
661
+ )
662
+
663
+ # Dataset and DataLoaders creation:
664
+ train_dataset = load_dataset(
665
+ args.dataset_name,
666
+ cache_dir=args.cache_dir,
667
+ split=args.dataset_split_name,
668
+ )
669
+
670
+ train_transforms = transforms.Compose(
671
+ [
672
+ transforms.Resize(int(args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
673
+ transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution),
674
+ transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(),
675
+ transforms.ToTensor(),
676
+ transforms.Normalize([0.5], [0.5]),
677
+ ]
678
+ )
679
+
680
+ def preprocess_train(examples):
681
+ all_pixel_values = []
682
+ for col_name in ["jpg_0", "jpg_1"]:
683
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
684
+ pixel_values = [train_transforms(image) for image in images]
685
+ all_pixel_values.append(pixel_values)
686
+
687
+ # Double on channel dim, jpg_y then jpg_w
688
+ im_tup_iterator = zip(*all_pixel_values)
689
+ combined_pixel_values = []
690
+ for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
691
+ if label_0 == 0:
692
+ im_tup = im_tup[::-1]
693
+ combined_im = torch.cat(im_tup, dim=0) # no batch dim
694
+ combined_pixel_values.append(combined_im)
695
+ examples["pixel_values"] = combined_pixel_values
696
+
697
+ examples["input_ids"] = tokenize_captions(tokenizer, examples)
698
+ return examples
699
+
700
+ with accelerator.main_process_first():
701
+ if args.max_train_samples is not None:
702
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
703
+ # Set the training transforms
704
+ train_dataset = train_dataset.with_transform(preprocess_train)
705
+
706
+ def collate_fn(examples):
707
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
708
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
709
+ final_dict = {"pixel_values": pixel_values}
710
+ final_dict["input_ids"] = torch.stack([example["input_ids"] for example in examples])
711
+ return final_dict
712
+
713
+ train_dataloader = torch.utils.data.DataLoader(
714
+ train_dataset,
715
+ batch_size=args.train_batch_size,
716
+ shuffle=True,
717
+ collate_fn=collate_fn,
718
+ num_workers=args.dataloader_num_workers,
719
+ )
720
+
721
+ # Scheduler and math around the number of training steps.
722
+ overrode_max_train_steps = False
723
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
724
+ if args.max_train_steps is None:
725
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
726
+ overrode_max_train_steps = True
727
+
728
+ lr_scheduler = get_scheduler(
729
+ args.lr_scheduler,
730
+ optimizer=optimizer,
731
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
732
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
733
+ num_cycles=args.lr_num_cycles,
734
+ power=args.lr_power,
735
+ )
736
+
737
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
738
+ unet, optimizer, train_dataloader, lr_scheduler
739
+ )
740
+
741
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
742
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
743
+ if overrode_max_train_steps:
744
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
745
+ # Afterwards we recalculate our number of training epochs
746
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
747
+
748
+ # We need to initialize the trackers we use, and also store our configuration.
749
+ # The trackers initializes automatically on the main process.
750
+ if accelerator.is_main_process:
751
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
752
+
753
+ # Train!
754
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
755
+
756
+ logger.info("***** Running training *****")
757
+ logger.info(f" Num examples = {len(train_dataset)}")
758
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
759
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
760
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
761
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
762
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
763
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
764
+ global_step = 0
765
+ first_epoch = 0
766
+
767
+ # Potentially load in the weights and states from a previous save
768
+ if args.resume_from_checkpoint:
769
+ if args.resume_from_checkpoint != "latest":
770
+ path = os.path.basename(args.resume_from_checkpoint)
771
+ else:
772
+ # Get the mos recent checkpoint
773
+ dirs = os.listdir(args.output_dir)
774
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
775
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
776
+ path = dirs[-1] if len(dirs) > 0 else None
777
+
778
+ if path is None:
779
+ accelerator.print(
780
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
781
+ )
782
+ args.resume_from_checkpoint = None
783
+ initial_global_step = 0
784
+ else:
785
+ accelerator.print(f"Resuming from checkpoint {path}")
786
+ accelerator.load_state(os.path.join(args.output_dir, path))
787
+ global_step = int(path.split("-")[1])
788
+
789
+ initial_global_step = global_step
790
+ first_epoch = global_step // num_update_steps_per_epoch
791
+ else:
792
+ initial_global_step = 0
793
+
794
+ progress_bar = tqdm(
795
+ range(0, args.max_train_steps),
796
+ initial=initial_global_step,
797
+ desc="Steps",
798
+ # Only show the progress bar once on each machine.
799
+ disable=not accelerator.is_local_main_process,
800
+ )
801
+
802
+ unet.train()
803
+ for epoch in range(first_epoch, args.num_train_epochs):
804
+ for step, batch in enumerate(train_dataloader):
805
+ with accelerator.accumulate(unet):
806
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
807
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
808
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
809
+
810
+ latents = []
811
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
812
+ latents.append(
813
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
814
+ )
815
+ latents = torch.cat(latents, dim=0)
816
+ latents = latents * vae.config.scaling_factor
817
+
818
+ # Sample noise that we'll add to the latents
819
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
820
+
821
+ # Sample a random timestep for each image
822
+ bsz = latents.shape[0] // 2
823
+ timesteps = torch.randint(
824
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
825
+ ).repeat(2)
826
+
827
+ # Add noise to the model input according to the noise magnitude at each timestep
828
+ # (this is the forward diffusion process)
829
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
830
+
831
+ # Get the text embedding for conditioning
832
+ encoder_hidden_states = encode_prompt(text_encoder, batch["input_ids"]).repeat(2, 1, 1)
833
+
834
+ # Predict the noise residual
835
+ model_pred = unet(
836
+ noisy_model_input,
837
+ timesteps,
838
+ encoder_hidden_states,
839
+ ).sample
840
+
841
+ # Get the target for loss depending on the prediction type
842
+ if noise_scheduler.config.prediction_type == "epsilon":
843
+ target = noise
844
+ elif noise_scheduler.config.prediction_type == "v_prediction":
845
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
846
+ else:
847
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
848
+
849
+ # Compute losses.
850
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
851
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
852
+ model_losses_w, model_losses_l = model_losses.chunk(2)
853
+
854
+ # For logging
855
+ raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
856
+ model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
857
+
858
+ # Reference model predictions.
859
+ accelerator.unwrap_model(unet).disable_adapters()
860
+ with torch.no_grad():
861
+ ref_preds = unet(
862
+ noisy_model_input,
863
+ timesteps,
864
+ encoder_hidden_states,
865
+ ).sample.detach()
866
+ ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
867
+ ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
868
+
869
+ ref_losses_w, ref_losses_l = ref_loss.chunk(2)
870
+ ref_diff = ref_losses_w - ref_losses_l
871
+ raw_ref_loss = ref_loss.mean()
872
+
873
+ # Re-enable adapters.
874
+ accelerator.unwrap_model(unet).enable_adapters()
875
+
876
+ # Final loss.
877
+ logits = ref_diff - model_diff
878
+ if args.loss_type == "sigmoid":
879
+ loss = -1 * F.logsigmoid(args.beta_dpo * logits).mean()
880
+ elif args.loss_type == "hinge":
881
+ loss = torch.relu(1 - args.beta_dpo * logits).mean()
882
+ elif args.loss_type == "ipo":
883
+ losses = (logits - 1 / (2 * args.beta)) ** 2
884
+ loss = losses.mean()
885
+ else:
886
+ raise ValueError(f"Unknown loss type {args.loss_type}")
887
+
888
+ implicit_acc = (logits > 0).sum().float() / logits.size(0)
889
+ implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)
890
+
891
+ accelerator.backward(loss)
892
+ if accelerator.sync_gradients:
893
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
894
+ optimizer.step()
895
+ lr_scheduler.step()
896
+ optimizer.zero_grad()
897
+
898
+ # Checks if the accelerator has performed an optimization step behind the scenes
899
+ if accelerator.sync_gradients:
900
+ progress_bar.update(1)
901
+ global_step += 1
902
+
903
+ if accelerator.is_main_process:
904
+ if global_step % args.checkpointing_steps == 0:
905
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
906
+ if args.checkpoints_total_limit is not None:
907
+ checkpoints = os.listdir(args.output_dir)
908
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
909
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
910
+
911
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
912
+ if len(checkpoints) >= args.checkpoints_total_limit:
913
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
914
+ removing_checkpoints = checkpoints[0:num_to_remove]
915
+
916
+ logger.info(
917
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
918
+ )
919
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
920
+
921
+ for removing_checkpoint in removing_checkpoints:
922
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
923
+ shutil.rmtree(removing_checkpoint)
924
+
925
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
926
+ accelerator.save_state(save_path)
927
+ logger.info(f"Saved state to {save_path}")
928
+
929
+ if args.run_validation and global_step % args.validation_steps == 0:
930
+ log_validation(
931
+ args, unet=unet, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
932
+ )
933
+
934
+ logs = {
935
+ "loss": loss.detach().item(),
936
+ "raw_model_loss": raw_model_loss.detach().item(),
937
+ "ref_loss": raw_ref_loss.detach().item(),
938
+ "implicit_acc": implicit_acc.detach().item(),
939
+ "lr": lr_scheduler.get_last_lr()[0],
940
+ }
941
+ progress_bar.set_postfix(**logs)
942
+ accelerator.log(logs, step=global_step)
943
+
944
+ if global_step >= args.max_train_steps:
945
+ break
946
+
947
+ # Save the lora layers
948
+ accelerator.wait_for_everyone()
949
+ if accelerator.is_main_process:
950
+ unet = accelerator.unwrap_model(unet)
951
+ unet = unet.to(torch.float32)
952
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
953
+
954
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
955
+ save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=None
956
+ )
957
+
958
+ # Final validation?
959
+ if args.run_validation:
960
+ log_validation(
961
+ args,
962
+ unet=None,
963
+ accelerator=accelerator,
964
+ weight_dtype=weight_dtype,
965
+ epoch=epoch,
966
+ is_final_validation=True,
967
+ )
968
+
969
+ if args.push_to_hub:
970
+ upload_folder(
971
+ repo_id=repo_id,
972
+ folder_path=args.output_dir,
973
+ commit_message="End of training",
974
+ ignore_patterns=["step_*", "epoch_*"],
975
+ )
976
+
977
+ accelerator.end_training()
978
+
979
+
980
+ if __name__ == "__main__":
981
+ args = parse_args()
982
+ main(args)
diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 bram-w, The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import io
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ import wandb
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from peft import LoraConfig
39
+ from peft.utils import get_peft_model_state_dict
40
+ from PIL import Image
41
+ from torchvision import transforms
42
+ from torchvision.transforms.functional import crop
43
+ from tqdm.auto import tqdm
44
+ from transformers import AutoTokenizer, PretrainedConfig
45
+
46
+ import diffusers
47
+ from diffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ DiffusionPipeline,
51
+ UNet2DConditionModel,
52
+ )
53
+ from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
54
+ from diffusers.optimization import get_scheduler
55
+ from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
56
+ from diffusers.utils.import_utils import is_xformers_available
57
+
58
+
59
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
+ check_min_version("0.25.0.dev0")
61
+
62
+ logger = get_logger(__name__)
63
+
64
+
65
+ VALIDATION_PROMPTS = [
66
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
67
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
68
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
69
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
70
+ ]
71
+
72
+
73
+ def import_model_class_from_model_name_or_path(
74
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
75
+ ):
76
+ text_encoder_config = PretrainedConfig.from_pretrained(
77
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
78
+ )
79
+ model_class = text_encoder_config.architectures[0]
80
+
81
+ if model_class == "CLIPTextModel":
82
+ from transformers import CLIPTextModel
83
+
84
+ return CLIPTextModel
85
+ elif model_class == "CLIPTextModelWithProjection":
86
+ from transformers import CLIPTextModelWithProjection
87
+
88
+ return CLIPTextModelWithProjection
89
+ else:
90
+ raise ValueError(f"{model_class} is not supported.")
91
+
92
+
93
+ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
94
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
95
+
96
+ if is_final_validation:
97
+ if args.mixed_precision == "fp16":
98
+ vae.to(weight_dtype)
99
+
100
+ # create pipeline
101
+ pipeline = DiffusionPipeline.from_pretrained(
102
+ args.pretrained_model_name_or_path,
103
+ vae=vae,
104
+ revision=args.revision,
105
+ variant=args.variant,
106
+ torch_dtype=weight_dtype,
107
+ )
108
+ if not is_final_validation:
109
+ pipeline.unet = accelerator.unwrap_model(unet)
110
+ else:
111
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
112
+
113
+ pipeline = pipeline.to(accelerator.device)
114
+ pipeline.set_progress_bar_config(disable=True)
115
+
116
+ # run inference
117
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
118
+ images = []
119
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
120
+
121
+ guidance_scale = 5.0
122
+ num_inference_steps = 25
123
+ if args.is_turbo:
124
+ guidance_scale = 0.0
125
+ num_inference_steps = 4
126
+ for prompt in VALIDATION_PROMPTS:
127
+ with context:
128
+ image = pipeline(
129
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
130
+ ).images[0]
131
+ images.append(image)
132
+
133
+ tracker_key = "test" if is_final_validation else "validation"
134
+ for tracker in accelerator.trackers:
135
+ if tracker.name == "tensorboard":
136
+ np_images = np.stack([np.asarray(img) for img in images])
137
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
138
+ if tracker.name == "wandb":
139
+ tracker.log(
140
+ {
141
+ tracker_key: [
142
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
143
+ ]
144
+ }
145
+ )
146
+
147
+ # Also log images without the LoRA params for comparison.
148
+ if is_final_validation:
149
+ pipeline.disable_lora()
150
+ no_lora_images = [
151
+ pipeline(
152
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
153
+ ).images[0]
154
+ for prompt in VALIDATION_PROMPTS
155
+ ]
156
+
157
+ for tracker in accelerator.trackers:
158
+ if tracker.name == "tensorboard":
159
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
160
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
161
+ if tracker.name == "wandb":
162
+ tracker.log(
163
+ {
164
+ "test_without_lora": [
165
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
166
+ for i, image in enumerate(no_lora_images)
167
+ ]
168
+ }
169
+ )
170
+
171
+
172
+ def parse_args(input_args=None):
173
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
174
+ parser.add_argument(
175
+ "--pretrained_model_name_or_path",
176
+ type=str,
177
+ default=None,
178
+ required=True,
179
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
180
+ )
181
+ parser.add_argument(
182
+ "--pretrained_vae_model_name_or_path",
183
+ type=str,
184
+ default=None,
185
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
186
+ )
187
+ parser.add_argument(
188
+ "--revision",
189
+ type=str,
190
+ default=None,
191
+ required=False,
192
+ help="Revision of pretrained model identifier from huggingface.co/models.",
193
+ )
194
+ parser.add_argument(
195
+ "--dataset_name",
196
+ type=str,
197
+ default=None,
198
+ help=(
199
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
200
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
201
+ " or to a folder containing files that 🤗 Datasets can understand."
202
+ ),
203
+ )
204
+ parser.add_argument(
205
+ "--dataset_split_name",
206
+ type=str,
207
+ default="validation",
208
+ help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
209
+ )
210
+ parser.add_argument(
211
+ "--variant",
212
+ type=str,
213
+ default=None,
214
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
215
+ )
216
+ parser.add_argument(
217
+ "--run_validation",
218
+ default=False,
219
+ action="store_true",
220
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
221
+ )
222
+ parser.add_argument(
223
+ "--validation_steps",
224
+ type=int,
225
+ default=200,
226
+ help="Run validation every X steps.",
227
+ )
228
+ parser.add_argument(
229
+ "--max_train_samples",
230
+ type=int,
231
+ default=None,
232
+ help=(
233
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
234
+ "value if set."
235
+ ),
236
+ )
237
+ parser.add_argument(
238
+ "--output_dir",
239
+ type=str,
240
+ default="diffusion-dpo-lora",
241
+ help="The output directory where the model predictions and checkpoints will be written.",
242
+ )
243
+ parser.add_argument(
244
+ "--cache_dir",
245
+ type=str,
246
+ default=None,
247
+ help="The directory where the downloaded models and datasets will be stored.",
248
+ )
249
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
250
+ parser.add_argument(
251
+ "--resolution",
252
+ type=int,
253
+ default=1024,
254
+ help=(
255
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
256
+ " resolution"
257
+ ),
258
+ )
259
+ parser.add_argument(
260
+ "--vae_encode_batch_size",
261
+ type=int,
262
+ default=8,
263
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
264
+ )
265
+ parser.add_argument(
266
+ "--no_hflip",
267
+ action="store_true",
268
+ help="whether to randomly flip images horizontally",
269
+ )
270
+ parser.add_argument(
271
+ "--random_crop",
272
+ default=False,
273
+ action="store_true",
274
+ help=(
275
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
276
+ ),
277
+ )
278
+ parser.add_argument(
279
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
280
+ )
281
+ parser.add_argument("--num_train_epochs", type=int, default=1)
282
+ parser.add_argument(
283
+ "--max_train_steps",
284
+ type=int,
285
+ default=None,
286
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
287
+ )
288
+ parser.add_argument(
289
+ "--checkpointing_steps",
290
+ type=int,
291
+ default=500,
292
+ help=(
293
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
294
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
295
+ " training using `--resume_from_checkpoint`."
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--checkpoints_total_limit",
300
+ type=int,
301
+ default=None,
302
+ help=("Max number of checkpoints to store."),
303
+ )
304
+ parser.add_argument(
305
+ "--resume_from_checkpoint",
306
+ type=str,
307
+ default=None,
308
+ help=(
309
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
310
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
311
+ ),
312
+ )
313
+ parser.add_argument(
314
+ "--gradient_accumulation_steps",
315
+ type=int,
316
+ default=1,
317
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
318
+ )
319
+ parser.add_argument(
320
+ "--gradient_checkpointing",
321
+ action="store_true",
322
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
323
+ )
324
+ parser.add_argument(
325
+ "--beta_dpo",
326
+ type=int,
327
+ default=5000,
328
+ help="DPO KL Divergence penalty.",
329
+ )
330
+ parser.add_argument(
331
+ "--learning_rate",
332
+ type=float,
333
+ default=5e-4,
334
+ help="Initial learning rate (after the potential warmup period) to use.",
335
+ )
336
+ parser.add_argument(
337
+ "--scale_lr",
338
+ action="store_true",
339
+ default=False,
340
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
341
+ )
342
+ parser.add_argument(
343
+ "--lr_scheduler",
344
+ type=str,
345
+ default="constant",
346
+ help=(
347
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
348
+ ' "constant", "constant_with_warmup"]'
349
+ ),
350
+ )
351
+ parser.add_argument(
352
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
353
+ )
354
+ parser.add_argument(
355
+ "--lr_num_cycles",
356
+ type=int,
357
+ default=1,
358
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
359
+ )
360
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
361
+ parser.add_argument(
362
+ "--dataloader_num_workers",
363
+ type=int,
364
+ default=0,
365
+ help=(
366
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
367
+ ),
368
+ )
369
+ parser.add_argument(
370
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
371
+ )
372
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
373
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
374
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
375
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
376
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
377
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
378
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
379
+ parser.add_argument(
380
+ "--hub_model_id",
381
+ type=str,
382
+ default=None,
383
+ help="The name of the repository to keep in sync with the local `output_dir`.",
384
+ )
385
+ parser.add_argument(
386
+ "--logging_dir",
387
+ type=str,
388
+ default="logs",
389
+ help=(
390
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
391
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
392
+ ),
393
+ )
394
+ parser.add_argument(
395
+ "--allow_tf32",
396
+ action="store_true",
397
+ help=(
398
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
399
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
400
+ ),
401
+ )
402
+ parser.add_argument(
403
+ "--report_to",
404
+ type=str,
405
+ default="tensorboard",
406
+ help=(
407
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
408
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--mixed_precision",
413
+ type=str,
414
+ default=None,
415
+ choices=["no", "fp16", "bf16"],
416
+ help=(
417
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
418
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
419
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
420
+ ),
421
+ )
422
+ parser.add_argument(
423
+ "--prior_generation_precision",
424
+ type=str,
425
+ default=None,
426
+ choices=["no", "fp32", "fp16", "bf16"],
427
+ help=(
428
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
429
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
430
+ ),
431
+ )
432
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
433
+ parser.add_argument(
434
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
435
+ )
436
+ parser.add_argument(
437
+ "--is_turbo",
438
+ action="store_true",
439
+ help=("Use if tuning SDXL Turbo instead of SDXL"),
440
+ )
441
+ parser.add_argument(
442
+ "--rank",
443
+ type=int,
444
+ default=4,
445
+ help=("The dimension of the LoRA update matrices."),
446
+ )
447
+ parser.add_argument(
448
+ "--tracker_name",
449
+ type=str,
450
+ default="diffusion-dpo-lora-sdxl",
451
+ help=("The name of the tracker to report results to."),
452
+ )
453
+
454
+ if input_args is not None:
455
+ args = parser.parse_args(input_args)
456
+ else:
457
+ args = parser.parse_args()
458
+
459
+ if args.dataset_name is None:
460
+ raise ValueError("Must provide a `dataset_name`.")
461
+
462
+ if args.is_turbo:
463
+ assert "turbo" in args.pretrained_model_name_or_path
464
+
465
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
466
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
467
+ args.local_rank = env_local_rank
468
+
469
+ return args
470
+
471
+
472
+ def tokenize_captions(tokenizers, examples):
473
+ captions = []
474
+ for caption in examples["caption"]:
475
+ captions.append(caption)
476
+
477
+ tokens_one = tokenizers[0](
478
+ captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt"
479
+ ).input_ids
480
+ tokens_two = tokenizers[1](
481
+ captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt"
482
+ ).input_ids
483
+
484
+ return tokens_one, tokens_two
485
+
486
+
487
+ @torch.no_grad()
488
+ def encode_prompt(text_encoders, text_input_ids_list):
489
+ prompt_embeds_list = []
490
+
491
+ for i, text_encoder in enumerate(text_encoders):
492
+ text_input_ids = text_input_ids_list[i]
493
+
494
+ prompt_embeds = text_encoder(
495
+ text_input_ids.to(text_encoder.device),
496
+ output_hidden_states=True,
497
+ )
498
+
499
+ # We are only ALWAYS interested in the pooled output of the final text encoder
500
+ pooled_prompt_embeds = prompt_embeds[0]
501
+ prompt_embeds = prompt_embeds.hidden_states[-2]
502
+ bs_embed, seq_len, _ = prompt_embeds.shape
503
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
504
+ prompt_embeds_list.append(prompt_embeds)
505
+
506
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
507
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
508
+ return prompt_embeds, pooled_prompt_embeds
509
+
510
+
511
+ def main(args):
512
+ if args.report_to == "wandb" and args.hub_token is not None:
513
+ raise ValueError(
514
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
515
+ " Please use `huggingface-cli login` to authenticate with the Hub."
516
+ )
517
+
518
+ logging_dir = Path(args.output_dir, args.logging_dir)
519
+
520
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
521
+
522
+ accelerator = Accelerator(
523
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
524
+ mixed_precision=args.mixed_precision,
525
+ log_with=args.report_to,
526
+ project_config=accelerator_project_config,
527
+ )
528
+
529
+ # Disable AMP for MPS.
530
+ if torch.backends.mps.is_available():
531
+ accelerator.native_amp = False
532
+
533
+ # Make one log on every process with the configuration for debugging.
534
+ logging.basicConfig(
535
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
536
+ datefmt="%m/%d/%Y %H:%M:%S",
537
+ level=logging.INFO,
538
+ )
539
+ logger.info(accelerator.state, main_process_only=False)
540
+ if accelerator.is_local_main_process:
541
+ transformers.utils.logging.set_verbosity_warning()
542
+ diffusers.utils.logging.set_verbosity_info()
543
+ else:
544
+ transformers.utils.logging.set_verbosity_error()
545
+ diffusers.utils.logging.set_verbosity_error()
546
+
547
+ # If passed along, set the training seed now.
548
+ if args.seed is not None:
549
+ set_seed(args.seed)
550
+
551
+ # Handle the repository creation
552
+ if accelerator.is_main_process:
553
+ if args.output_dir is not None:
554
+ os.makedirs(args.output_dir, exist_ok=True)
555
+
556
+ if args.push_to_hub:
557
+ repo_id = create_repo(
558
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
559
+ ).repo_id
560
+
561
+ # Load the tokenizers
562
+ tokenizer_one = AutoTokenizer.from_pretrained(
563
+ args.pretrained_model_name_or_path,
564
+ subfolder="tokenizer",
565
+ revision=args.revision,
566
+ use_fast=False,
567
+ )
568
+ tokenizer_two = AutoTokenizer.from_pretrained(
569
+ args.pretrained_model_name_or_path,
570
+ subfolder="tokenizer_2",
571
+ revision=args.revision,
572
+ use_fast=False,
573
+ )
574
+
575
+ # import correct text encoder classes
576
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
577
+ args.pretrained_model_name_or_path, args.revision
578
+ )
579
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
580
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
581
+ )
582
+
583
+ # Load scheduler and models
584
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
585
+
586
+ def enforce_zero_terminal_snr(scheduler):
587
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
588
+ # Original implementation https://huggingface.co/papers/2305.08891
589
+ # Turbo needs zero terminal SNR
590
+ # Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
591
+ # Convert betas to alphas_bar_sqrt
592
+ alphas = 1 - scheduler.betas
593
+ alphas_bar = alphas.cumprod(0)
594
+ alphas_bar_sqrt = alphas_bar.sqrt()
595
+
596
+ # Store old values.
597
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
598
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
599
+ # Shift so last timestep is zero.
600
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
601
+ # Scale so first timestep is back to old value.
602
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
603
+
604
+ alphas_bar = alphas_bar_sqrt**2
605
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
606
+ alphas = torch.cat([alphas_bar[0:1], alphas])
607
+
608
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
609
+ scheduler.alphas_cumprod = alphas_cumprod
610
+ return
611
+
612
+ if args.is_turbo:
613
+ enforce_zero_terminal_snr(noise_scheduler)
614
+
615
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
616
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
617
+ )
618
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
619
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
620
+ )
621
+ vae_path = (
622
+ args.pretrained_model_name_or_path
623
+ if args.pretrained_vae_model_name_or_path is None
624
+ else args.pretrained_vae_model_name_or_path
625
+ )
626
+ vae = AutoencoderKL.from_pretrained(
627
+ vae_path,
628
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
629
+ revision=args.revision,
630
+ variant=args.variant,
631
+ )
632
+ unet = UNet2DConditionModel.from_pretrained(
633
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
634
+ )
635
+
636
+ # We only train the additional adapter LoRA layers
637
+ vae.requires_grad_(False)
638
+ text_encoder_one.requires_grad_(False)
639
+ text_encoder_two.requires_grad_(False)
640
+ unet.requires_grad_(False)
641
+
642
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
643
+ # as these weights are only used for inference, keeping weights in full precision is not required.
644
+ weight_dtype = torch.float32
645
+ if accelerator.mixed_precision == "fp16":
646
+ weight_dtype = torch.float16
647
+ elif accelerator.mixed_precision == "bf16":
648
+ weight_dtype = torch.bfloat16
649
+
650
+ # Move unet and text_encoders to device and cast to weight_dtype
651
+ unet.to(accelerator.device, dtype=weight_dtype)
652
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
653
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
654
+
655
+ # The VAE is always in float32 to avoid NaN losses.
656
+ vae.to(accelerator.device, dtype=torch.float32)
657
+
658
+ # Set up LoRA.
659
+ unet_lora_config = LoraConfig(
660
+ r=args.rank,
661
+ lora_alpha=args.rank,
662
+ init_lora_weights="gaussian",
663
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
664
+ )
665
+ # Add adapter and make sure the trainable params are in float32.
666
+ unet.add_adapter(unet_lora_config)
667
+ if args.mixed_precision == "fp16":
668
+ for param in unet.parameters():
669
+ # only upcast trainable parameters (LoRA) into fp32
670
+ if param.requires_grad:
671
+ param.data = param.to(torch.float32)
672
+
673
+ if args.enable_xformers_memory_efficient_attention:
674
+ if is_xformers_available():
675
+ import xformers
676
+
677
+ xformers_version = version.parse(xformers.__version__)
678
+ if xformers_version == version.parse("0.0.16"):
679
+ logger.warning(
680
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
681
+ )
682
+ unet.enable_xformers_memory_efficient_attention()
683
+ else:
684
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
685
+
686
+ if args.gradient_checkpointing:
687
+ unet.enable_gradient_checkpointing()
688
+
689
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
690
+ def save_model_hook(models, weights, output_dir):
691
+ if accelerator.is_main_process:
692
+ # there are only two options here. Either are just the unet attn processor layers
693
+ # or there are the unet and text encoder atten layers
694
+ unet_lora_layers_to_save = None
695
+
696
+ for model in models:
697
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
698
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
699
+ else:
700
+ raise ValueError(f"unexpected save model: {model.__class__}")
701
+
702
+ # make sure to pop weight so that corresponding model is not saved again
703
+ weights.pop()
704
+
705
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save)
706
+
707
+ def load_model_hook(models, input_dir):
708
+ unet_ = None
709
+
710
+ while len(models) > 0:
711
+ model = models.pop()
712
+
713
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
714
+ unet_ = model
715
+ else:
716
+ raise ValueError(f"unexpected save model: {model.__class__}")
717
+
718
+ lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
719
+ StableDiffusionXLLoraLoaderMixin.load_lora_into_unet(
720
+ lora_state_dict, network_alphas=network_alphas, unet=unet_
721
+ )
722
+
723
+ accelerator.register_save_state_pre_hook(save_model_hook)
724
+ accelerator.register_load_state_pre_hook(load_model_hook)
725
+
726
+ # Enable TF32 for faster training on Ampere GPUs,
727
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
728
+ if args.allow_tf32:
729
+ torch.backends.cuda.matmul.allow_tf32 = True
730
+
731
+ if args.scale_lr:
732
+ args.learning_rate = (
733
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
734
+ )
735
+
736
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
737
+ if args.use_8bit_adam:
738
+ try:
739
+ import bitsandbytes as bnb
740
+ except ImportError:
741
+ raise ImportError(
742
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
743
+ )
744
+
745
+ optimizer_class = bnb.optim.AdamW8bit
746
+ else:
747
+ optimizer_class = torch.optim.AdamW
748
+
749
+ # Optimizer creation
750
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
751
+ optimizer = optimizer_class(
752
+ params_to_optimize,
753
+ lr=args.learning_rate,
754
+ betas=(args.adam_beta1, args.adam_beta2),
755
+ weight_decay=args.adam_weight_decay,
756
+ eps=args.adam_epsilon,
757
+ )
758
+
759
+ # Dataset and DataLoaders creation:
760
+ train_dataset = load_dataset(
761
+ args.dataset_name,
762
+ cache_dir=args.cache_dir,
763
+ split=args.dataset_split_name,
764
+ )
765
+
766
+ # Preprocessing the datasets.
767
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
768
+ train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
769
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
770
+ to_tensor = transforms.ToTensor()
771
+ normalize = transforms.Normalize([0.5], [0.5])
772
+
773
+ def preprocess_train(examples):
774
+ all_pixel_values = []
775
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]]
776
+ original_sizes = [(image.height, image.width) for image in images]
777
+ crop_top_lefts = []
778
+
779
+ for col_name in ["jpg_0", "jpg_1"]:
780
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
781
+ if col_name == "jpg_1":
782
+ # Need to bring down the image to the same resolution.
783
+ # This seems like the simplest reasonable approach.
784
+ # "::-1" because PIL resize takes (width, height).
785
+ images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]
786
+ pixel_values = [to_tensor(image) for image in images]
787
+ all_pixel_values.append(pixel_values)
788
+
789
+ # Double on channel dim, jpg_y then jpg_w
790
+ im_tup_iterator = zip(*all_pixel_values)
791
+ combined_pixel_values = []
792
+ for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
793
+ if label_0 == 0:
794
+ im_tup = im_tup[::-1]
795
+
796
+ combined_im = torch.cat(im_tup, dim=0) # no batch dim
797
+
798
+ # Resize.
799
+ combined_im = train_resize(combined_im)
800
+
801
+ # Flipping.
802
+ if not args.no_hflip and random.random() < 0.5:
803
+ combined_im = train_flip(combined_im)
804
+
805
+ # Cropping.
806
+ if not args.random_crop:
807
+ y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
808
+ x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
809
+ combined_im = train_crop(combined_im)
810
+ else:
811
+ y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
812
+ combined_im = crop(combined_im, y1, x1, h, w)
813
+
814
+ crop_top_left = (y1, x1)
815
+ crop_top_lefts.append(crop_top_left)
816
+ combined_im = normalize(combined_im)
817
+ combined_pixel_values.append(combined_im)
818
+
819
+ examples["pixel_values"] = combined_pixel_values
820
+ examples["original_sizes"] = original_sizes
821
+ examples["crop_top_lefts"] = crop_top_lefts
822
+ tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)
823
+ examples["input_ids_one"] = tokens_one
824
+ examples["input_ids_two"] = tokens_two
825
+ return examples
826
+
827
+ with accelerator.main_process_first():
828
+ if args.max_train_samples is not None:
829
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
830
+ # Set the training transforms
831
+ train_dataset = train_dataset.with_transform(preprocess_train)
832
+
833
+ def collate_fn(examples):
834
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
835
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
836
+ original_sizes = [example["original_sizes"] for example in examples]
837
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
838
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
839
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
840
+
841
+ return {
842
+ "pixel_values": pixel_values,
843
+ "input_ids_one": input_ids_one,
844
+ "input_ids_two": input_ids_two,
845
+ "original_sizes": original_sizes,
846
+ "crop_top_lefts": crop_top_lefts,
847
+ }
848
+
849
+ train_dataloader = torch.utils.data.DataLoader(
850
+ train_dataset,
851
+ batch_size=args.train_batch_size,
852
+ shuffle=True,
853
+ collate_fn=collate_fn,
854
+ num_workers=args.dataloader_num_workers,
855
+ )
856
+
857
+ # Scheduler and math around the number of training steps.
858
+ overrode_max_train_steps = False
859
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
860
+ if args.max_train_steps is None:
861
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
862
+ overrode_max_train_steps = True
863
+
864
+ lr_scheduler = get_scheduler(
865
+ args.lr_scheduler,
866
+ optimizer=optimizer,
867
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
868
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
869
+ num_cycles=args.lr_num_cycles,
870
+ power=args.lr_power,
871
+ )
872
+
873
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
874
+ unet, optimizer, train_dataloader, lr_scheduler
875
+ )
876
+
877
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
878
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
879
+ if overrode_max_train_steps:
880
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
881
+ # Afterwards we recalculate our number of training epochs
882
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
883
+
884
+ # We need to initialize the trackers we use, and also store our configuration.
885
+ # The trackers initializes automatically on the main process.
886
+ if accelerator.is_main_process:
887
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
888
+
889
+ # Train!
890
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
891
+
892
+ logger.info("***** Running training *****")
893
+ logger.info(f" Num examples = {len(train_dataset)}")
894
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
895
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
896
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
897
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
898
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
899
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
900
+ global_step = 0
901
+ first_epoch = 0
902
+
903
+ # Potentially load in the weights and states from a previous save
904
+ if args.resume_from_checkpoint:
905
+ if args.resume_from_checkpoint != "latest":
906
+ path = os.path.basename(args.resume_from_checkpoint)
907
+ else:
908
+ # Get the mos recent checkpoint
909
+ dirs = os.listdir(args.output_dir)
910
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
911
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
912
+ path = dirs[-1] if len(dirs) > 0 else None
913
+
914
+ if path is None:
915
+ accelerator.print(
916
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
917
+ )
918
+ args.resume_from_checkpoint = None
919
+ initial_global_step = 0
920
+ else:
921
+ accelerator.print(f"Resuming from checkpoint {path}")
922
+ accelerator.load_state(os.path.join(args.output_dir, path))
923
+ global_step = int(path.split("-")[1])
924
+
925
+ initial_global_step = global_step
926
+ first_epoch = global_step // num_update_steps_per_epoch
927
+ else:
928
+ initial_global_step = 0
929
+
930
+ progress_bar = tqdm(
931
+ range(0, args.max_train_steps),
932
+ initial=initial_global_step,
933
+ desc="Steps",
934
+ # Only show the progress bar once on each machine.
935
+ disable=not accelerator.is_local_main_process,
936
+ )
937
+
938
+ unet.train()
939
+ for epoch in range(first_epoch, args.num_train_epochs):
940
+ for step, batch in enumerate(train_dataloader):
941
+ with accelerator.accumulate(unet):
942
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
943
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
944
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
945
+
946
+ latents = []
947
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
948
+ latents.append(
949
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
950
+ )
951
+ latents = torch.cat(latents, dim=0)
952
+ latents = latents * vae.config.scaling_factor
953
+ if args.pretrained_vae_model_name_or_path is None:
954
+ latents = latents.to(weight_dtype)
955
+
956
+ # Sample noise that we'll add to the latents
957
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
958
+
959
+ # Sample a random timestep for each image
960
+ bsz = latents.shape[0] // 2
961
+ timesteps = torch.randint(
962
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
963
+ ).repeat(2)
964
+ if args.is_turbo:
965
+ # Learn a 4 timestep schedule
966
+ timesteps_0_to_3 = timesteps % 4
967
+ timesteps = 250 * timesteps_0_to_3 + 249
968
+
969
+ # Add noise to the model input according to the noise magnitude at each timestep
970
+ # (this is the forward diffusion process)
971
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
972
+
973
+ # time ids
974
+ def compute_time_ids(original_size, crops_coords_top_left):
975
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
976
+ target_size = (args.resolution, args.resolution)
977
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
978
+ add_time_ids = torch.tensor([add_time_ids])
979
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
980
+ return add_time_ids
981
+
982
+ add_time_ids = torch.cat(
983
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
984
+ ).repeat(2, 1)
985
+
986
+ # Get the text embedding for conditioning
987
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
988
+ [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
989
+ )
990
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
991
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
992
+
993
+ # Predict the noise residual
994
+ model_pred = unet(
995
+ noisy_model_input,
996
+ timesteps,
997
+ prompt_embeds,
998
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
999
+ ).sample
1000
+
1001
+ # Get the target for loss depending on the prediction type
1002
+ if noise_scheduler.config.prediction_type == "epsilon":
1003
+ target = noise
1004
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1005
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1006
+ else:
1007
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1008
+
1009
+ # Compute losses.
1010
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1011
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
1012
+ model_losses_w, model_losses_l = model_losses.chunk(2)
1013
+
1014
+ # For logging
1015
+ raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
1016
+ model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
1017
+
1018
+ # Reference model predictions.
1019
+ accelerator.unwrap_model(unet).disable_adapters()
1020
+ with torch.no_grad():
1021
+ ref_preds = unet(
1022
+ noisy_model_input,
1023
+ timesteps,
1024
+ prompt_embeds,
1025
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
1026
+ ).sample
1027
+ ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
1028
+ ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
1029
+
1030
+ ref_losses_w, ref_losses_l = ref_loss.chunk(2)
1031
+ ref_diff = ref_losses_w - ref_losses_l
1032
+ raw_ref_loss = ref_loss.mean()
1033
+
1034
+ # Re-enable adapters.
1035
+ accelerator.unwrap_model(unet).enable_adapters()
1036
+
1037
+ # Final loss.
1038
+ scale_term = -0.5 * args.beta_dpo
1039
+ inside_term = scale_term * (model_diff - ref_diff)
1040
+ loss = -1 * F.logsigmoid(inside_term).mean()
1041
+
1042
+ implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
1043
+ implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
1044
+
1045
+ accelerator.backward(loss)
1046
+ if accelerator.sync_gradients:
1047
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1048
+ optimizer.step()
1049
+ lr_scheduler.step()
1050
+ optimizer.zero_grad()
1051
+
1052
+ # Checks if the accelerator has performed an optimization step behind the scenes
1053
+ if accelerator.sync_gradients:
1054
+ progress_bar.update(1)
1055
+ global_step += 1
1056
+
1057
+ if accelerator.is_main_process:
1058
+ if global_step % args.checkpointing_steps == 0:
1059
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1060
+ if args.checkpoints_total_limit is not None:
1061
+ checkpoints = os.listdir(args.output_dir)
1062
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1063
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1064
+
1065
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1066
+ if len(checkpoints) >= args.checkpoints_total_limit:
1067
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1068
+ removing_checkpoints = checkpoints[0:num_to_remove]
1069
+
1070
+ logger.info(
1071
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1072
+ )
1073
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1074
+
1075
+ for removing_checkpoint in removing_checkpoints:
1076
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1077
+ shutil.rmtree(removing_checkpoint)
1078
+
1079
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1080
+ accelerator.save_state(save_path)
1081
+ logger.info(f"Saved state to {save_path}")
1082
+
1083
+ if args.run_validation and global_step % args.validation_steps == 0:
1084
+ log_validation(
1085
+ args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
1086
+ )
1087
+
1088
+ logs = {
1089
+ "loss": loss.detach().item(),
1090
+ "raw_model_loss": raw_model_loss.detach().item(),
1091
+ "ref_loss": raw_ref_loss.detach().item(),
1092
+ "implicit_acc": implicit_acc.detach().item(),
1093
+ "lr": lr_scheduler.get_last_lr()[0],
1094
+ }
1095
+ progress_bar.set_postfix(**logs)
1096
+ accelerator.log(logs, step=global_step)
1097
+
1098
+ if global_step >= args.max_train_steps:
1099
+ break
1100
+
1101
+ # Save the lora layers
1102
+ accelerator.wait_for_everyone()
1103
+ if accelerator.is_main_process:
1104
+ unet = accelerator.unwrap_model(unet)
1105
+ unet = unet.to(torch.float32)
1106
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
1107
+
1108
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
1109
+ save_directory=args.output_dir,
1110
+ unet_lora_layers=unet_lora_state_dict,
1111
+ text_encoder_lora_layers=None,
1112
+ text_encoder_2_lora_layers=None,
1113
+ )
1114
+
1115
+ # Final validation?
1116
+ if args.run_validation:
1117
+ log_validation(
1118
+ args,
1119
+ unet=None,
1120
+ vae=vae,
1121
+ accelerator=accelerator,
1122
+ weight_dtype=weight_dtype,
1123
+ epoch=epoch,
1124
+ is_final_validation=True,
1125
+ )
1126
+
1127
+ if args.push_to_hub:
1128
+ upload_folder(
1129
+ repo_id=repo_id,
1130
+ folder_path=args.output_dir,
1131
+ commit_message="End of training",
1132
+ ignore_patterns=["step_*", "epoch_*"],
1133
+ )
1134
+
1135
+ accelerator.end_training()
1136
+
1137
+
1138
+ if __name__ == "__main__":
1139
+ args = parse_args()
1140
+ main(args)
diffusers/examples/research_projects/diffusion_orpo/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too!
diffusers/examples/research_projects/diffusion_orpo/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets
2
+ accelerate
3
+ transformers
4
+ torchvision
5
+ wandb
6
+ peft
7
+ webdataset
diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py ADDED
@@ -0,0 +1,1092 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import io
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ import wandb
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from peft import LoraConfig, set_peft_model_state_dict
39
+ from peft.utils import get_peft_model_state_dict
40
+ from PIL import Image
41
+ from torchvision import transforms
42
+ from torchvision.transforms.functional import crop
43
+ from tqdm.auto import tqdm
44
+ from transformers import AutoTokenizer, PretrainedConfig
45
+
46
+ import diffusers
47
+ from diffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ DiffusionPipeline,
51
+ UNet2DConditionModel,
52
+ )
53
+ from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
54
+ from diffusers.optimization import get_scheduler
55
+ from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
56
+ from diffusers.utils.import_utils import is_xformers_available
57
+
58
+
59
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
+ check_min_version("0.25.0.dev0")
61
+
62
+ logger = get_logger(__name__)
63
+
64
+
65
+ VALIDATION_PROMPTS = [
66
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
67
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
68
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
69
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
70
+ ]
71
+
72
+
73
+ def import_model_class_from_model_name_or_path(
74
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
75
+ ):
76
+ text_encoder_config = PretrainedConfig.from_pretrained(
77
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
78
+ )
79
+ model_class = text_encoder_config.architectures[0]
80
+
81
+ if model_class == "CLIPTextModel":
82
+ from transformers import CLIPTextModel
83
+
84
+ return CLIPTextModel
85
+ elif model_class == "CLIPTextModelWithProjection":
86
+ from transformers import CLIPTextModelWithProjection
87
+
88
+ return CLIPTextModelWithProjection
89
+ else:
90
+ raise ValueError(f"{model_class} is not supported.")
91
+
92
+
93
+ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
94
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
95
+
96
+ if is_final_validation:
97
+ if args.mixed_precision == "fp16":
98
+ vae.to(weight_dtype)
99
+
100
+ # create pipeline
101
+ pipeline = DiffusionPipeline.from_pretrained(
102
+ args.pretrained_model_name_or_path,
103
+ vae=vae,
104
+ revision=args.revision,
105
+ variant=args.variant,
106
+ torch_dtype=weight_dtype,
107
+ )
108
+ if not is_final_validation:
109
+ pipeline.unet = accelerator.unwrap_model(unet)
110
+ else:
111
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
112
+
113
+ pipeline = pipeline.to(accelerator.device)
114
+ pipeline.set_progress_bar_config(disable=True)
115
+
116
+ # run inference
117
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
118
+ images = []
119
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
120
+
121
+ guidance_scale = 5.0
122
+ num_inference_steps = 25
123
+ for prompt in VALIDATION_PROMPTS:
124
+ with context:
125
+ image = pipeline(
126
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
127
+ ).images[0]
128
+ images.append(image)
129
+
130
+ tracker_key = "test" if is_final_validation else "validation"
131
+ for tracker in accelerator.trackers:
132
+ if tracker.name == "tensorboard":
133
+ np_images = np.stack([np.asarray(img) for img in images])
134
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
135
+ if tracker.name == "wandb":
136
+ tracker.log(
137
+ {
138
+ tracker_key: [
139
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
140
+ ]
141
+ }
142
+ )
143
+
144
+ # Also log images without the LoRA params for comparison.
145
+ if is_final_validation:
146
+ pipeline.disable_lora()
147
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
148
+ no_lora_images = [
149
+ pipeline(
150
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
151
+ ).images[0]
152
+ for prompt in VALIDATION_PROMPTS
153
+ ]
154
+
155
+ for tracker in accelerator.trackers:
156
+ if tracker.name == "tensorboard":
157
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
158
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
159
+ if tracker.name == "wandb":
160
+ tracker.log(
161
+ {
162
+ "test_without_lora": [
163
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
164
+ for i, image in enumerate(no_lora_images)
165
+ ]
166
+ }
167
+ )
168
+
169
+
170
+ def parse_args(input_args=None):
171
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
172
+ parser.add_argument(
173
+ "--pretrained_model_name_or_path",
174
+ type=str,
175
+ default=None,
176
+ required=True,
177
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
178
+ )
179
+ parser.add_argument(
180
+ "--pretrained_vae_model_name_or_path",
181
+ type=str,
182
+ default=None,
183
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
184
+ )
185
+ parser.add_argument(
186
+ "--revision",
187
+ type=str,
188
+ default=None,
189
+ required=False,
190
+ help="Revision of pretrained model identifier from huggingface.co/models.",
191
+ )
192
+ parser.add_argument(
193
+ "--dataset_name",
194
+ type=str,
195
+ default=None,
196
+ help=(
197
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
198
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
199
+ " or to a folder containing files that 🤗 Datasets can understand."
200
+ ),
201
+ )
202
+ parser.add_argument(
203
+ "--dataset_split_name",
204
+ type=str,
205
+ default="validation",
206
+ help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
207
+ )
208
+ parser.add_argument(
209
+ "--variant",
210
+ type=str,
211
+ default=None,
212
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
213
+ )
214
+ parser.add_argument(
215
+ "--run_validation",
216
+ default=False,
217
+ action="store_true",
218
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
219
+ )
220
+ parser.add_argument(
221
+ "--validation_steps",
222
+ type=int,
223
+ default=200,
224
+ help="Run validation every X steps.",
225
+ )
226
+ parser.add_argument(
227
+ "--max_train_samples",
228
+ type=int,
229
+ default=None,
230
+ help=(
231
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
232
+ "value if set."
233
+ ),
234
+ )
235
+ parser.add_argument(
236
+ "--output_dir",
237
+ type=str,
238
+ default="diffusion-orpo-lora-sdxl",
239
+ help="The output directory where the model predictions and checkpoints will be written.",
240
+ )
241
+ parser.add_argument(
242
+ "--cache_dir",
243
+ type=str,
244
+ default=None,
245
+ help="The directory where the downloaded models and datasets will be stored.",
246
+ )
247
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
248
+ parser.add_argument(
249
+ "--resolution",
250
+ type=int,
251
+ default=1024,
252
+ help=(
253
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
254
+ " resolution"
255
+ ),
256
+ )
257
+ parser.add_argument(
258
+ "--vae_encode_batch_size",
259
+ type=int,
260
+ default=8,
261
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
262
+ )
263
+ parser.add_argument(
264
+ "--no_hflip",
265
+ action="store_true",
266
+ help="whether to randomly flip images horizontally",
267
+ )
268
+ parser.add_argument(
269
+ "--random_crop",
270
+ default=False,
271
+ action="store_true",
272
+ help=(
273
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
274
+ ),
275
+ )
276
+ parser.add_argument(
277
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
278
+ )
279
+ parser.add_argument("--num_train_epochs", type=int, default=1)
280
+ parser.add_argument(
281
+ "--max_train_steps",
282
+ type=int,
283
+ default=None,
284
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
285
+ )
286
+ parser.add_argument(
287
+ "--checkpointing_steps",
288
+ type=int,
289
+ default=500,
290
+ help=(
291
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
292
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
293
+ " training using `--resume_from_checkpoint`."
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--checkpoints_total_limit",
298
+ type=int,
299
+ default=None,
300
+ help=("Max number of checkpoints to store."),
301
+ )
302
+ parser.add_argument(
303
+ "--resume_from_checkpoint",
304
+ type=str,
305
+ default=None,
306
+ help=(
307
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
308
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
309
+ ),
310
+ )
311
+ parser.add_argument(
312
+ "--gradient_accumulation_steps",
313
+ type=int,
314
+ default=1,
315
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
316
+ )
317
+ parser.add_argument(
318
+ "--gradient_checkpointing",
319
+ action="store_true",
320
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
321
+ )
322
+ parser.add_argument(
323
+ "--beta_orpo",
324
+ type=float,
325
+ default=0.1,
326
+ help="ORPO contribution factor.",
327
+ )
328
+ parser.add_argument(
329
+ "--learning_rate",
330
+ type=float,
331
+ default=5e-4,
332
+ help="Initial learning rate (after the potential warmup period) to use.",
333
+ )
334
+ parser.add_argument(
335
+ "--scale_lr",
336
+ action="store_true",
337
+ default=False,
338
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
339
+ )
340
+ parser.add_argument(
341
+ "--lr_scheduler",
342
+ type=str,
343
+ default="constant",
344
+ help=(
345
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
346
+ ' "constant", "constant_with_warmup"]'
347
+ ),
348
+ )
349
+ parser.add_argument(
350
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
351
+ )
352
+ parser.add_argument(
353
+ "--lr_num_cycles",
354
+ type=int,
355
+ default=1,
356
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
357
+ )
358
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
359
+ parser.add_argument(
360
+ "--dataloader_num_workers",
361
+ type=int,
362
+ default=0,
363
+ help=(
364
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
365
+ ),
366
+ )
367
+ parser.add_argument(
368
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
369
+ )
370
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
371
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
372
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
373
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
374
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
375
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
376
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
377
+ parser.add_argument(
378
+ "--hub_model_id",
379
+ type=str,
380
+ default=None,
381
+ help="The name of the repository to keep in sync with the local `output_dir`.",
382
+ )
383
+ parser.add_argument(
384
+ "--logging_dir",
385
+ type=str,
386
+ default="logs",
387
+ help=(
388
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
389
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
390
+ ),
391
+ )
392
+ parser.add_argument(
393
+ "--allow_tf32",
394
+ action="store_true",
395
+ help=(
396
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
397
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
398
+ ),
399
+ )
400
+ parser.add_argument(
401
+ "--report_to",
402
+ type=str,
403
+ default="tensorboard",
404
+ help=(
405
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
406
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
407
+ ),
408
+ )
409
+ parser.add_argument(
410
+ "--mixed_precision",
411
+ type=str,
412
+ default=None,
413
+ choices=["no", "fp16", "bf16"],
414
+ help=(
415
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
416
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
417
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
418
+ ),
419
+ )
420
+ parser.add_argument(
421
+ "--prior_generation_precision",
422
+ type=str,
423
+ default=None,
424
+ choices=["no", "fp32", "fp16", "bf16"],
425
+ help=(
426
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
427
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
428
+ ),
429
+ )
430
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
431
+ parser.add_argument(
432
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
433
+ )
434
+ parser.add_argument(
435
+ "--rank",
436
+ type=int,
437
+ default=4,
438
+ help=("The dimension of the LoRA update matrices."),
439
+ )
440
+ parser.add_argument(
441
+ "--tracker_name",
442
+ type=str,
443
+ default="diffusion-orpo-lora-sdxl",
444
+ help=("The name of the tracker to report results to."),
445
+ )
446
+
447
+ if input_args is not None:
448
+ args = parser.parse_args(input_args)
449
+ else:
450
+ args = parser.parse_args()
451
+
452
+ if args.dataset_name is None:
453
+ raise ValueError("Must provide a `dataset_name`.")
454
+
455
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
456
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
457
+ args.local_rank = env_local_rank
458
+
459
+ return args
460
+
461
+
462
+ def tokenize_captions(tokenizers, examples):
463
+ captions = []
464
+ for caption in examples["caption"]:
465
+ captions.append(caption)
466
+
467
+ tokens_one = tokenizers[0](
468
+ captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt"
469
+ ).input_ids
470
+ tokens_two = tokenizers[1](
471
+ captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt"
472
+ ).input_ids
473
+
474
+ return tokens_one, tokens_two
475
+
476
+
477
+ @torch.no_grad()
478
+ def encode_prompt(text_encoders, text_input_ids_list):
479
+ prompt_embeds_list = []
480
+
481
+ for i, text_encoder in enumerate(text_encoders):
482
+ text_input_ids = text_input_ids_list[i]
483
+
484
+ prompt_embeds = text_encoder(
485
+ text_input_ids.to(text_encoder.device),
486
+ output_hidden_states=True,
487
+ )
488
+
489
+ # We are only ALWAYS interested in the pooled output of the final text encoder
490
+ pooled_prompt_embeds = prompt_embeds[0]
491
+ prompt_embeds = prompt_embeds.hidden_states[-2]
492
+ bs_embed, seq_len, _ = prompt_embeds.shape
493
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
494
+ prompt_embeds_list.append(prompt_embeds)
495
+
496
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
497
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
498
+ return prompt_embeds, pooled_prompt_embeds
499
+
500
+
501
+ def main(args):
502
+ if args.report_to == "wandb" and args.hub_token is not None:
503
+ raise ValueError(
504
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
505
+ " Please use `huggingface-cli login` to authenticate with the Hub."
506
+ )
507
+
508
+ logging_dir = Path(args.output_dir, args.logging_dir)
509
+
510
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
511
+
512
+ accelerator = Accelerator(
513
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
514
+ mixed_precision=args.mixed_precision,
515
+ log_with=args.report_to,
516
+ project_config=accelerator_project_config,
517
+ )
518
+
519
+ # Disable AMP for MPS.
520
+ if torch.backends.mps.is_available():
521
+ accelerator.native_amp = False
522
+
523
+ # Make one log on every process with the configuration for debugging.
524
+ logging.basicConfig(
525
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
526
+ datefmt="%m/%d/%Y %H:%M:%S",
527
+ level=logging.INFO,
528
+ )
529
+ logger.info(accelerator.state, main_process_only=False)
530
+ if accelerator.is_local_main_process:
531
+ transformers.utils.logging.set_verbosity_warning()
532
+ diffusers.utils.logging.set_verbosity_info()
533
+ else:
534
+ transformers.utils.logging.set_verbosity_error()
535
+ diffusers.utils.logging.set_verbosity_error()
536
+
537
+ # If passed along, set the training seed now.
538
+ if args.seed is not None:
539
+ set_seed(args.seed)
540
+
541
+ # Handle the repository creation
542
+ if accelerator.is_main_process:
543
+ if args.output_dir is not None:
544
+ os.makedirs(args.output_dir, exist_ok=True)
545
+
546
+ if args.push_to_hub:
547
+ repo_id = create_repo(
548
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
549
+ ).repo_id
550
+
551
+ # Load the tokenizers
552
+ tokenizer_one = AutoTokenizer.from_pretrained(
553
+ args.pretrained_model_name_or_path,
554
+ subfolder="tokenizer",
555
+ revision=args.revision,
556
+ use_fast=False,
557
+ )
558
+ tokenizer_two = AutoTokenizer.from_pretrained(
559
+ args.pretrained_model_name_or_path,
560
+ subfolder="tokenizer_2",
561
+ revision=args.revision,
562
+ use_fast=False,
563
+ )
564
+
565
+ # import correct text encoder classes
566
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
567
+ args.pretrained_model_name_or_path, args.revision
568
+ )
569
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
570
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
571
+ )
572
+
573
+ # Load scheduler and models
574
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
575
+
576
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
577
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
578
+ )
579
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
580
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
581
+ )
582
+ vae_path = (
583
+ args.pretrained_model_name_or_path
584
+ if args.pretrained_vae_model_name_or_path is None
585
+ else args.pretrained_vae_model_name_or_path
586
+ )
587
+ vae = AutoencoderKL.from_pretrained(
588
+ vae_path,
589
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
590
+ revision=args.revision,
591
+ variant=args.variant,
592
+ )
593
+ unet = UNet2DConditionModel.from_pretrained(
594
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
595
+ )
596
+
597
+ # We only train the additional adapter LoRA layers
598
+ vae.requires_grad_(False)
599
+ text_encoder_one.requires_grad_(False)
600
+ text_encoder_two.requires_grad_(False)
601
+ unet.requires_grad_(False)
602
+
603
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
604
+ # as these weights are only used for inference, keeping weights in full precision is not required.
605
+ weight_dtype = torch.float32
606
+ if accelerator.mixed_precision == "fp16":
607
+ weight_dtype = torch.float16
608
+ elif accelerator.mixed_precision == "bf16":
609
+ weight_dtype = torch.bfloat16
610
+
611
+ # Move unet and text_encoders to device and cast to weight_dtype
612
+ unet.to(accelerator.device, dtype=weight_dtype)
613
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
614
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
615
+
616
+ # The VAE is always in float32 to avoid NaN losses.
617
+ vae.to(accelerator.device, dtype=torch.float32)
618
+
619
+ # Set up LoRA.
620
+ unet_lora_config = LoraConfig(
621
+ r=args.rank,
622
+ lora_alpha=args.rank,
623
+ init_lora_weights="gaussian",
624
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
625
+ )
626
+ # Add adapter and make sure the trainable params are in float32.
627
+ unet.add_adapter(unet_lora_config)
628
+ if args.mixed_precision == "fp16":
629
+ for param in unet.parameters():
630
+ # only upcast trainable parameters (LoRA) into fp32
631
+ if param.requires_grad:
632
+ param.data = param.to(torch.float32)
633
+
634
+ if args.enable_xformers_memory_efficient_attention:
635
+ if is_xformers_available():
636
+ import xformers
637
+
638
+ xformers_version = version.parse(xformers.__version__)
639
+ if xformers_version == version.parse("0.0.16"):
640
+ logger.warning(
641
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
642
+ )
643
+ unet.enable_xformers_memory_efficient_attention()
644
+ else:
645
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
646
+
647
+ if args.gradient_checkpointing:
648
+ unet.enable_gradient_checkpointing()
649
+
650
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
651
+ def save_model_hook(models, weights, output_dir):
652
+ if accelerator.is_main_process:
653
+ # there are only two options here. Either are just the unet attn processor layers
654
+ # or there are the unet and text encoder atten layers
655
+ unet_lora_layers_to_save = None
656
+
657
+ for model in models:
658
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
659
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
660
+ else:
661
+ raise ValueError(f"unexpected save model: {model.__class__}")
662
+
663
+ # make sure to pop weight so that corresponding model is not saved again
664
+ weights.pop()
665
+
666
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
667
+ output_dir,
668
+ unet_lora_layers=unet_lora_layers_to_save,
669
+ text_encoder_lora_layers=None,
670
+ text_encoder_2_lora_layers=None,
671
+ )
672
+
673
+ def load_model_hook(models, input_dir):
674
+ unet_ = None
675
+
676
+ while len(models) > 0:
677
+ model = models.pop()
678
+
679
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
680
+ unet_ = model
681
+ else:
682
+ raise ValueError(f"unexpected save model: {model.__class__}")
683
+
684
+ lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
685
+
686
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
687
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
688
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
689
+ if incompatible_keys is not None:
690
+ # check only for unexpected keys
691
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
692
+ if unexpected_keys:
693
+ logger.warning(
694
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
695
+ f" {unexpected_keys}. "
696
+ )
697
+
698
+ accelerator.register_save_state_pre_hook(save_model_hook)
699
+ accelerator.register_load_state_pre_hook(load_model_hook)
700
+
701
+ # Enable TF32 for faster training on Ampere GPUs,
702
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
703
+ if args.allow_tf32:
704
+ torch.backends.cuda.matmul.allow_tf32 = True
705
+
706
+ if args.scale_lr:
707
+ args.learning_rate = (
708
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
709
+ )
710
+
711
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
712
+ if args.use_8bit_adam:
713
+ try:
714
+ import bitsandbytes as bnb
715
+ except ImportError:
716
+ raise ImportError(
717
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
718
+ )
719
+
720
+ optimizer_class = bnb.optim.AdamW8bit
721
+ else:
722
+ optimizer_class = torch.optim.AdamW
723
+
724
+ # Optimizer creation
725
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
726
+ optimizer = optimizer_class(
727
+ params_to_optimize,
728
+ lr=args.learning_rate,
729
+ betas=(args.adam_beta1, args.adam_beta2),
730
+ weight_decay=args.adam_weight_decay,
731
+ eps=args.adam_epsilon,
732
+ )
733
+
734
+ # Dataset and DataLoaders creation:
735
+ train_dataset = load_dataset(
736
+ args.dataset_name,
737
+ cache_dir=args.cache_dir,
738
+ split=args.dataset_split_name,
739
+ )
740
+
741
+ # Preprocessing the datasets.
742
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
743
+ train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
744
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
745
+ to_tensor = transforms.ToTensor()
746
+ normalize = transforms.Normalize([0.5], [0.5])
747
+
748
+ def preprocess_train(examples):
749
+ all_pixel_values = []
750
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]]
751
+ original_sizes = [(image.height, image.width) for image in images]
752
+ crop_top_lefts = []
753
+
754
+ for col_name in ["jpg_0", "jpg_1"]:
755
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
756
+ if col_name == "jpg_1":
757
+ # Need to bring down the image to the same resolution.
758
+ # This seems like the simplest reasonable approach.
759
+ # "::-1" because PIL resize takes (width, height).
760
+ images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]
761
+ pixel_values = [to_tensor(image) for image in images]
762
+ all_pixel_values.append(pixel_values)
763
+
764
+ # Double on channel dim, jpg_y then jpg_w
765
+ im_tup_iterator = zip(*all_pixel_values)
766
+ combined_pixel_values = []
767
+ for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
768
+ # We randomize selection and rejection.
769
+ if label_0 == 0.5:
770
+ if random.random() < 0.5:
771
+ label_0 = 0
772
+ else:
773
+ label_0 = 1
774
+
775
+ if label_0 == 0:
776
+ im_tup = im_tup[::-1]
777
+
778
+ combined_im = torch.cat(im_tup, dim=0) # no batch dim
779
+
780
+ # Resize.
781
+ combined_im = train_resize(combined_im)
782
+
783
+ # Flipping.
784
+ if not args.no_hflip and random.random() < 0.5:
785
+ combined_im = train_flip(combined_im)
786
+
787
+ # Cropping.
788
+ if not args.random_crop:
789
+ y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
790
+ x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
791
+ combined_im = train_crop(combined_im)
792
+ else:
793
+ y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
794
+ combined_im = crop(combined_im, y1, x1, h, w)
795
+
796
+ crop_top_left = (y1, x1)
797
+ crop_top_lefts.append(crop_top_left)
798
+ combined_im = normalize(combined_im)
799
+ combined_pixel_values.append(combined_im)
800
+
801
+ examples["pixel_values"] = combined_pixel_values
802
+ examples["original_sizes"] = original_sizes
803
+ examples["crop_top_lefts"] = crop_top_lefts
804
+ tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)
805
+ examples["input_ids_one"] = tokens_one
806
+ examples["input_ids_two"] = tokens_two
807
+ return examples
808
+
809
+ with accelerator.main_process_first():
810
+ if args.max_train_samples is not None:
811
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
812
+ # Set the training transforms
813
+ train_dataset = train_dataset.with_transform(preprocess_train)
814
+
815
+ def collate_fn(examples):
816
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
817
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
818
+ original_sizes = [example["original_sizes"] for example in examples]
819
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
820
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
821
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
822
+
823
+ return {
824
+ "pixel_values": pixel_values,
825
+ "input_ids_one": input_ids_one,
826
+ "input_ids_two": input_ids_two,
827
+ "original_sizes": original_sizes,
828
+ "crop_top_lefts": crop_top_lefts,
829
+ }
830
+
831
+ train_dataloader = torch.utils.data.DataLoader(
832
+ train_dataset,
833
+ batch_size=args.train_batch_size,
834
+ shuffle=True,
835
+ collate_fn=collate_fn,
836
+ num_workers=args.dataloader_num_workers,
837
+ )
838
+
839
+ # Scheduler and math around the number of training steps.
840
+ overrode_max_train_steps = False
841
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
842
+ if args.max_train_steps is None:
843
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
844
+ overrode_max_train_steps = True
845
+
846
+ lr_scheduler = get_scheduler(
847
+ args.lr_scheduler,
848
+ optimizer=optimizer,
849
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
850
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
851
+ num_cycles=args.lr_num_cycles,
852
+ power=args.lr_power,
853
+ )
854
+
855
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
856
+ unet, optimizer, train_dataloader, lr_scheduler
857
+ )
858
+
859
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
860
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
861
+ if overrode_max_train_steps:
862
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
863
+ # Afterwards we recalculate our number of training epochs
864
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
865
+
866
+ # We need to initialize the trackers we use, and also store our configuration.
867
+ # The trackers initializes automatically on the main process.
868
+ if accelerator.is_main_process:
869
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
870
+
871
+ # Train!
872
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
873
+
874
+ logger.info("***** Running training *****")
875
+ logger.info(f" Num examples = {len(train_dataset)}")
876
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
877
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
878
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
879
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
880
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
881
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
882
+ global_step = 0
883
+ first_epoch = 0
884
+
885
+ # Potentially load in the weights and states from a previous save
886
+ if args.resume_from_checkpoint:
887
+ if args.resume_from_checkpoint != "latest":
888
+ path = os.path.basename(args.resume_from_checkpoint)
889
+ else:
890
+ # Get the mos recent checkpoint
891
+ dirs = os.listdir(args.output_dir)
892
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
893
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
894
+ path = dirs[-1] if len(dirs) > 0 else None
895
+
896
+ if path is None:
897
+ accelerator.print(
898
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
899
+ )
900
+ args.resume_from_checkpoint = None
901
+ initial_global_step = 0
902
+ else:
903
+ accelerator.print(f"Resuming from checkpoint {path}")
904
+ accelerator.load_state(os.path.join(args.output_dir, path))
905
+ global_step = int(path.split("-")[1])
906
+
907
+ initial_global_step = global_step
908
+ first_epoch = global_step // num_update_steps_per_epoch
909
+ else:
910
+ initial_global_step = 0
911
+
912
+ progress_bar = tqdm(
913
+ range(0, args.max_train_steps),
914
+ initial=initial_global_step,
915
+ desc="Steps",
916
+ # Only show the progress bar once on each machine.
917
+ disable=not accelerator.is_local_main_process,
918
+ )
919
+
920
+ unet.train()
921
+ for epoch in range(first_epoch, args.num_train_epochs):
922
+ for step, batch in enumerate(train_dataloader):
923
+ with accelerator.accumulate(unet):
924
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
925
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
926
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
927
+
928
+ latents = []
929
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
930
+ latents.append(
931
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
932
+ )
933
+ latents = torch.cat(latents, dim=0)
934
+ latents = latents * vae.config.scaling_factor
935
+ if args.pretrained_vae_model_name_or_path is None:
936
+ latents = latents.to(weight_dtype)
937
+
938
+ # Sample noise that we'll add to the latents
939
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
940
+
941
+ # Sample a random timestep for each image
942
+ bsz = latents.shape[0] // 2
943
+ timesteps = torch.randint(
944
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
945
+ ).repeat(2)
946
+
947
+ # Add noise to the model input according to the noise magnitude at each timestep
948
+ # (this is the forward diffusion process)
949
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
950
+
951
+ # time ids
952
+ def compute_time_ids(original_size, crops_coords_top_left):
953
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
954
+ target_size = (args.resolution, args.resolution)
955
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
956
+ add_time_ids = torch.tensor([add_time_ids])
957
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
958
+ return add_time_ids
959
+
960
+ add_time_ids = torch.cat(
961
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
962
+ ).repeat(2, 1)
963
+
964
+ # Get the text embedding for conditioning
965
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
966
+ [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
967
+ )
968
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
969
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
970
+
971
+ # Predict the noise residual
972
+ model_pred = unet(
973
+ noisy_model_input,
974
+ timesteps,
975
+ prompt_embeds,
976
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
977
+ ).sample
978
+
979
+ # Get the target for loss depending on the prediction type
980
+ if noise_scheduler.config.prediction_type == "epsilon":
981
+ target = noise
982
+ elif noise_scheduler.config.prediction_type == "v_prediction":
983
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
984
+ else:
985
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
986
+
987
+ # ODDS ratio loss.
988
+ # In the diffusion formulation, we're assuming that the MSE loss
989
+ # approximates the logp.
990
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
991
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
992
+ model_losses_w, model_losses_l = model_losses.chunk(2)
993
+ log_odds = model_losses_w - model_losses_l
994
+
995
+ # Ratio loss.
996
+ ratio = F.logsigmoid(log_odds)
997
+ ratio_losses = args.beta_orpo * ratio
998
+
999
+ # Full ORPO loss
1000
+ loss = model_losses_w.mean() - ratio_losses.mean()
1001
+
1002
+ # Backprop.
1003
+ accelerator.backward(loss)
1004
+ if accelerator.sync_gradients:
1005
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1006
+ optimizer.step()
1007
+ lr_scheduler.step()
1008
+ optimizer.zero_grad()
1009
+
1010
+ # Checks if the accelerator has performed an optimization step behind the scenes
1011
+ if accelerator.sync_gradients:
1012
+ progress_bar.update(1)
1013
+ global_step += 1
1014
+
1015
+ if accelerator.is_main_process:
1016
+ if global_step % args.checkpointing_steps == 0:
1017
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1018
+ if args.checkpoints_total_limit is not None:
1019
+ checkpoints = os.listdir(args.output_dir)
1020
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1021
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1022
+
1023
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1024
+ if len(checkpoints) >= args.checkpoints_total_limit:
1025
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1026
+ removing_checkpoints = checkpoints[0:num_to_remove]
1027
+
1028
+ logger.info(
1029
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1030
+ )
1031
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1032
+
1033
+ for removing_checkpoint in removing_checkpoints:
1034
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1035
+ shutil.rmtree(removing_checkpoint)
1036
+
1037
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1038
+ accelerator.save_state(save_path)
1039
+ logger.info(f"Saved state to {save_path}")
1040
+
1041
+ if args.run_validation and global_step % args.validation_steps == 0:
1042
+ log_validation(
1043
+ args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
1044
+ )
1045
+
1046
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1047
+ progress_bar.set_postfix(**logs)
1048
+ accelerator.log(logs, step=global_step)
1049
+
1050
+ if global_step >= args.max_train_steps:
1051
+ break
1052
+
1053
+ # Save the lora layers
1054
+ accelerator.wait_for_everyone()
1055
+ if accelerator.is_main_process:
1056
+ unet = accelerator.unwrap_model(unet)
1057
+ unet = unet.to(torch.float32)
1058
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
1059
+
1060
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
1061
+ save_directory=args.output_dir,
1062
+ unet_lora_layers=unet_lora_state_dict,
1063
+ text_encoder_lora_layers=None,
1064
+ text_encoder_2_lora_layers=None,
1065
+ )
1066
+
1067
+ # Final validation?
1068
+ if args.run_validation:
1069
+ log_validation(
1070
+ args,
1071
+ unet=None,
1072
+ vae=vae,
1073
+ accelerator=accelerator,
1074
+ weight_dtype=weight_dtype,
1075
+ epoch=epoch,
1076
+ is_final_validation=True,
1077
+ )
1078
+
1079
+ if args.push_to_hub:
1080
+ upload_folder(
1081
+ repo_id=repo_id,
1082
+ folder_path=args.output_dir,
1083
+ commit_message="End of training",
1084
+ ignore_patterns=["step_*", "epoch_*"],
1085
+ )
1086
+
1087
+ accelerator.end_training()
1088
+
1089
+
1090
+ if __name__ == "__main__":
1091
+ args = parse_args()
1092
+ main(args)
diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py ADDED
@@ -0,0 +1,1095 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ import wandb
31
+ import webdataset as wds
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from packaging import version
37
+ from peft import LoraConfig, set_peft_model_state_dict
38
+ from peft.utils import get_peft_model_state_dict
39
+ from torchvision import transforms
40
+ from torchvision.transforms.functional import crop
41
+ from tqdm.auto import tqdm
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ DDPMScheduler,
48
+ DiffusionPipeline,
49
+ UNet2DConditionModel,
50
+ )
51
+ from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
52
+ from diffusers.optimization import get_scheduler
53
+ from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
54
+ from diffusers.utils.import_utils import is_xformers_available
55
+
56
+
57
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
+ check_min_version("0.25.0.dev0")
59
+
60
+ logger = get_logger(__name__)
61
+
62
+
63
+ VALIDATION_PROMPTS = [
64
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
65
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
66
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
67
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
68
+ ]
69
+
70
+
71
+ def import_model_class_from_model_name_or_path(
72
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
73
+ ):
74
+ text_encoder_config = PretrainedConfig.from_pretrained(
75
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
76
+ )
77
+ model_class = text_encoder_config.architectures[0]
78
+
79
+ if model_class == "CLIPTextModel":
80
+ from transformers import CLIPTextModel
81
+
82
+ return CLIPTextModel
83
+ elif model_class == "CLIPTextModelWithProjection":
84
+ from transformers import CLIPTextModelWithProjection
85
+
86
+ return CLIPTextModelWithProjection
87
+ else:
88
+ raise ValueError(f"{model_class} is not supported.")
89
+
90
+
91
+ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
92
+ logger.info(f"Running validation... \n Generating images with prompts:\n {VALIDATION_PROMPTS}.")
93
+
94
+ if is_final_validation:
95
+ if args.mixed_precision == "fp16":
96
+ vae.to(weight_dtype)
97
+
98
+ # create pipeline
99
+ pipeline = DiffusionPipeline.from_pretrained(
100
+ args.pretrained_model_name_or_path,
101
+ vae=vae,
102
+ revision=args.revision,
103
+ variant=args.variant,
104
+ torch_dtype=weight_dtype,
105
+ )
106
+ if not is_final_validation:
107
+ pipeline.unet = accelerator.unwrap_model(unet)
108
+ else:
109
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
110
+
111
+ pipeline = pipeline.to(accelerator.device)
112
+ pipeline.set_progress_bar_config(disable=True)
113
+
114
+ # run inference
115
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
116
+ images = []
117
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
118
+
119
+ guidance_scale = 5.0
120
+ num_inference_steps = 25
121
+ for prompt in VALIDATION_PROMPTS:
122
+ with context:
123
+ image = pipeline(
124
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
125
+ ).images[0]
126
+ images.append(image)
127
+
128
+ tracker_key = "test" if is_final_validation else "validation"
129
+ for tracker in accelerator.trackers:
130
+ if tracker.name == "tensorboard":
131
+ np_images = np.stack([np.asarray(img) for img in images])
132
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
133
+ if tracker.name == "wandb":
134
+ tracker.log(
135
+ {
136
+ tracker_key: [
137
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
138
+ ]
139
+ }
140
+ )
141
+
142
+ # Also log images without the LoRA params for comparison.
143
+ if is_final_validation:
144
+ pipeline.disable_lora()
145
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
146
+ no_lora_images = [
147
+ pipeline(
148
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
149
+ ).images[0]
150
+ for prompt in VALIDATION_PROMPTS
151
+ ]
152
+
153
+ for tracker in accelerator.trackers:
154
+ if tracker.name == "tensorboard":
155
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
156
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
157
+ if tracker.name == "wandb":
158
+ tracker.log(
159
+ {
160
+ "test_without_lora": [
161
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
162
+ for i, image in enumerate(no_lora_images)
163
+ ]
164
+ }
165
+ )
166
+
167
+
168
+ def parse_args(input_args=None):
169
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
170
+ parser.add_argument(
171
+ "--pretrained_model_name_or_path",
172
+ type=str,
173
+ default=None,
174
+ required=True,
175
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
176
+ )
177
+ parser.add_argument(
178
+ "--pretrained_vae_model_name_or_path",
179
+ type=str,
180
+ default=None,
181
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
182
+ )
183
+ parser.add_argument(
184
+ "--revision",
185
+ type=str,
186
+ default=None,
187
+ required=False,
188
+ help="Revision of pretrained model identifier from huggingface.co/models.",
189
+ )
190
+ parser.add_argument(
191
+ "--dataset_path",
192
+ type=str,
193
+ default="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -",
194
+ )
195
+ parser.add_argument(
196
+ "--num_train_examples",
197
+ type=int,
198
+ default=1001352,
199
+ )
200
+ parser.add_argument(
201
+ "--variant",
202
+ type=str,
203
+ default=None,
204
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
205
+ )
206
+ parser.add_argument(
207
+ "--run_validation",
208
+ default=False,
209
+ action="store_true",
210
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
211
+ )
212
+ parser.add_argument(
213
+ "--validation_steps",
214
+ type=int,
215
+ default=200,
216
+ help="Run validation every X steps.",
217
+ )
218
+ parser.add_argument(
219
+ "--output_dir",
220
+ type=str,
221
+ default="diffusion-orpo-lora-sdxl",
222
+ help="The output directory where the model predictions and checkpoints will be written.",
223
+ )
224
+ parser.add_argument(
225
+ "--cache_dir",
226
+ type=str,
227
+ default=None,
228
+ help="The directory where the downloaded models and datasets will be stored.",
229
+ )
230
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
231
+ parser.add_argument(
232
+ "--resolution",
233
+ type=int,
234
+ default=1024,
235
+ help=(
236
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
237
+ " resolution"
238
+ ),
239
+ )
240
+ parser.add_argument(
241
+ "--vae_encode_batch_size",
242
+ type=int,
243
+ default=8,
244
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
245
+ )
246
+ parser.add_argument(
247
+ "--no_hflip",
248
+ action="store_true",
249
+ help="whether to randomly flip images horizontally",
250
+ )
251
+ parser.add_argument(
252
+ "--random_crop",
253
+ default=False,
254
+ action="store_true",
255
+ help=(
256
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
257
+ ),
258
+ )
259
+ parser.add_argument("--global_batch_size", type=int, default=64, help="Total batch size.")
260
+ parser.add_argument(
261
+ "--per_gpu_batch_size", type=int, default=8, help="Number of samples in a batch for a single GPU."
262
+ )
263
+ parser.add_argument("--num_train_epochs", type=int, default=1)
264
+ parser.add_argument(
265
+ "--max_train_steps",
266
+ type=int,
267
+ default=None,
268
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
269
+ )
270
+ parser.add_argument(
271
+ "--checkpointing_steps",
272
+ type=int,
273
+ default=500,
274
+ help=(
275
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
276
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
277
+ " training using `--resume_from_checkpoint`."
278
+ ),
279
+ )
280
+ parser.add_argument(
281
+ "--checkpoints_total_limit",
282
+ type=int,
283
+ default=None,
284
+ help=("Max number of checkpoints to store."),
285
+ )
286
+ parser.add_argument(
287
+ "--resume_from_checkpoint",
288
+ type=str,
289
+ default=None,
290
+ help=(
291
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
292
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
293
+ ),
294
+ )
295
+ parser.add_argument(
296
+ "--gradient_accumulation_steps",
297
+ type=int,
298
+ default=1,
299
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
300
+ )
301
+ parser.add_argument(
302
+ "--gradient_checkpointing",
303
+ action="store_true",
304
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
305
+ )
306
+ parser.add_argument(
307
+ "--beta_orpo",
308
+ type=float,
309
+ default=0.1,
310
+ help="ORPO contribution factor.",
311
+ )
312
+ parser.add_argument(
313
+ "--learning_rate",
314
+ type=float,
315
+ default=5e-4,
316
+ help="Initial learning rate (after the potential warmup period) to use.",
317
+ )
318
+ parser.add_argument(
319
+ "--scale_lr",
320
+ action="store_true",
321
+ default=False,
322
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
323
+ )
324
+ parser.add_argument(
325
+ "--lr_scheduler",
326
+ type=str,
327
+ default="constant",
328
+ help=(
329
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
330
+ ' "constant", "constant_with_warmup"]'
331
+ ),
332
+ )
333
+ parser.add_argument(
334
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
335
+ )
336
+ parser.add_argument(
337
+ "--lr_num_cycles",
338
+ type=int,
339
+ default=1,
340
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
341
+ )
342
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
343
+ parser.add_argument(
344
+ "--dataloader_num_workers",
345
+ type=int,
346
+ default=0,
347
+ help=(
348
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
349
+ ),
350
+ )
351
+ parser.add_argument(
352
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
353
+ )
354
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
355
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
356
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
357
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
358
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
359
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
360
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
361
+ parser.add_argument(
362
+ "--hub_model_id",
363
+ type=str,
364
+ default=None,
365
+ help="The name of the repository to keep in sync with the local `output_dir`.",
366
+ )
367
+ parser.add_argument(
368
+ "--logging_dir",
369
+ type=str,
370
+ default="logs",
371
+ help=(
372
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
373
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
374
+ ),
375
+ )
376
+ parser.add_argument(
377
+ "--allow_tf32",
378
+ action="store_true",
379
+ help=(
380
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
381
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
382
+ ),
383
+ )
384
+ parser.add_argument(
385
+ "--report_to",
386
+ type=str,
387
+ default="tensorboard",
388
+ help=(
389
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
390
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--mixed_precision",
395
+ type=str,
396
+ default=None,
397
+ choices=["no", "fp16", "bf16"],
398
+ help=(
399
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
400
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
401
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
402
+ ),
403
+ )
404
+ parser.add_argument(
405
+ "--prior_generation_precision",
406
+ type=str,
407
+ default=None,
408
+ choices=["no", "fp32", "fp16", "bf16"],
409
+ help=(
410
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
411
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
412
+ ),
413
+ )
414
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
415
+ parser.add_argument(
416
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
417
+ )
418
+ parser.add_argument(
419
+ "--rank",
420
+ type=int,
421
+ default=4,
422
+ help=("The dimension of the LoRA update matrices."),
423
+ )
424
+ parser.add_argument(
425
+ "--tracker_name",
426
+ type=str,
427
+ default="diffusion-orpo-lora-sdxl",
428
+ help=("The name of the tracker to report results to."),
429
+ )
430
+
431
+ if input_args is not None:
432
+ args = parser.parse_args(input_args)
433
+ else:
434
+ args = parser.parse_args()
435
+
436
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
437
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
438
+ args.local_rank = env_local_rank
439
+
440
+ return args
441
+
442
+
443
+ def tokenize_captions(tokenizers, sample):
444
+ tokens_one = tokenizers[0](
445
+ sample["original_prompt"],
446
+ truncation=True,
447
+ padding="max_length",
448
+ max_length=tokenizers[0].model_max_length,
449
+ return_tensors="pt",
450
+ ).input_ids
451
+ tokens_two = tokenizers[1](
452
+ sample["original_prompt"],
453
+ truncation=True,
454
+ padding="max_length",
455
+ max_length=tokenizers[1].model_max_length,
456
+ return_tensors="pt",
457
+ ).input_ids
458
+
459
+ return tokens_one, tokens_two
460
+
461
+
462
+ @torch.no_grad()
463
+ def encode_prompt(text_encoders, text_input_ids_list):
464
+ prompt_embeds_list = []
465
+
466
+ for i, text_encoder in enumerate(text_encoders):
467
+ text_input_ids = text_input_ids_list[i]
468
+
469
+ prompt_embeds = text_encoder(
470
+ text_input_ids.to(text_encoder.device),
471
+ output_hidden_states=True,
472
+ )
473
+
474
+ # We are only ALWAYS interested in the pooled output of the final text encoder
475
+ pooled_prompt_embeds = prompt_embeds[0]
476
+ prompt_embeds = prompt_embeds.hidden_states[-2]
477
+ bs_embed, seq_len, _ = prompt_embeds.shape
478
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
479
+ prompt_embeds_list.append(prompt_embeds)
480
+
481
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
482
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
483
+ return prompt_embeds, pooled_prompt_embeds
484
+
485
+
486
+ def get_dataset(args):
487
+ dataset = (
488
+ wds.WebDataset(args.dataset_path, resampled=True, handler=wds.warn_and_continue)
489
+ .shuffle(690, handler=wds.warn_and_continue)
490
+ .decode("pil", handler=wds.warn_and_continue)
491
+ .rename(
492
+ original_prompt="original_prompt.txt",
493
+ jpg_0="jpg_0.jpg",
494
+ jpg_1="jpg_1.jpg",
495
+ label_0="label_0.txt",
496
+ label_1="label_1.txt",
497
+ handler=wds.warn_and_continue,
498
+ )
499
+ )
500
+ return dataset
501
+
502
+
503
+ def get_loader(args, tokenizer_one, tokenizer_two):
504
+ # 1,001,352
505
+ num_batches = math.ceil(args.num_train_examples / args.global_batch_size)
506
+ num_worker_batches = math.ceil(
507
+ args.num_train_examples / (args.global_batch_size * args.dataloader_num_workers)
508
+ ) # per dataloader worker
509
+ num_batches = num_worker_batches * args.dataloader_num_workers
510
+ num_samples = num_batches * args.global_batch_size
511
+
512
+ dataset = get_dataset(args)
513
+
514
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
515
+ train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
516
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
517
+ to_tensor = transforms.ToTensor()
518
+ normalize = transforms.Normalize([0.5], [0.5])
519
+
520
+ def preprocess_images(sample):
521
+ jpg_0_image = sample["jpg_0"]
522
+ original_size = (jpg_0_image.height, jpg_0_image.width)
523
+ crop_top_left = []
524
+
525
+ jpg_1_image = sample["jpg_1"]
526
+ # Need to bring down the image to the same resolution.
527
+ # This seems like the simplest reasonable approach.
528
+ # "::-1" because PIL resize takes (width, height).
529
+ jpg_1_image = jpg_1_image.resize(original_size[::-1])
530
+
531
+ # We randomize selection and rejection.
532
+ label_0 = sample["label_0"]
533
+ if sample["label_0"] == 0.5:
534
+ if random.random() < 0.5:
535
+ label_0 = 0
536
+ else:
537
+ label_0 = 1
538
+
539
+ # Double on channel dim, jpg_y then jpg_w
540
+ if label_0 == 0:
541
+ pixel_values = torch.cat([to_tensor(image) for image in [jpg_1_image, jpg_0_image]])
542
+ else:
543
+ pixel_values = torch.cat([to_tensor(image) for image in [jpg_0_image, jpg_1_image]])
544
+
545
+ # Resize.
546
+ combined_im = train_resize(pixel_values)
547
+
548
+ # Flipping.
549
+ if not args.no_hflip and random.random() < 0.5:
550
+ combined_im = train_flip(combined_im)
551
+
552
+ # Cropping.
553
+ if not args.random_crop:
554
+ y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
555
+ x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
556
+ combined_im = train_crop(combined_im)
557
+ else:
558
+ y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
559
+ combined_im = crop(combined_im, y1, x1, h, w)
560
+
561
+ crop_top_left = (y1, x1)
562
+ combined_im = normalize(combined_im)
563
+ tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], sample)
564
+
565
+ return {
566
+ "pixel_values": combined_im,
567
+ "original_size": original_size,
568
+ "crop_top_left": crop_top_left,
569
+ "tokens_one": tokens_one,
570
+ "tokens_two": tokens_two,
571
+ }
572
+
573
+ def collate_fn(samples):
574
+ pixel_values = torch.stack([sample["pixel_values"] for sample in samples])
575
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
576
+
577
+ original_sizes = [example["original_size"] for example in samples]
578
+ crop_top_lefts = [example["crop_top_left"] for example in samples]
579
+ input_ids_one = torch.stack([example["tokens_one"] for example in samples])
580
+ input_ids_two = torch.stack([example["tokens_two"] for example in samples])
581
+
582
+ return {
583
+ "pixel_values": pixel_values,
584
+ "input_ids_one": input_ids_one,
585
+ "input_ids_two": input_ids_two,
586
+ "original_sizes": original_sizes,
587
+ "crop_top_lefts": crop_top_lefts,
588
+ }
589
+
590
+ dataset = dataset.map(preprocess_images, handler=wds.warn_and_continue)
591
+ dataset = dataset.batched(args.per_gpu_batch_size, partial=False, collation_fn=collate_fn)
592
+ dataset = dataset.with_epoch(num_worker_batches)
593
+
594
+ dataloader = wds.WebLoader(
595
+ dataset,
596
+ batch_size=None,
597
+ shuffle=False,
598
+ num_workers=args.dataloader_num_workers,
599
+ pin_memory=True,
600
+ persistent_workers=True,
601
+ )
602
+ # add meta-data to dataloader instance for convenience
603
+ dataloader.num_batches = num_batches
604
+ dataloader.num_samples = num_samples
605
+ return dataloader
606
+
607
+
608
+ def main(args):
609
+ if args.report_to == "wandb" and args.hub_token is not None:
610
+ raise ValueError(
611
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
612
+ " Please use `huggingface-cli login` to authenticate with the Hub."
613
+ )
614
+
615
+ logging_dir = Path(args.output_dir, args.logging_dir)
616
+
617
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
618
+
619
+ accelerator = Accelerator(
620
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
621
+ mixed_precision=args.mixed_precision,
622
+ log_with=args.report_to,
623
+ project_config=accelerator_project_config,
624
+ )
625
+
626
+ # Disable AMP for MPS.
627
+ if torch.backends.mps.is_available():
628
+ accelerator.native_amp = False
629
+
630
+ # Make one log on every process with the configuration for debugging.
631
+ logging.basicConfig(
632
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
633
+ datefmt="%m/%d/%Y %H:%M:%S",
634
+ level=logging.INFO,
635
+ )
636
+ logger.info(accelerator.state, main_process_only=False)
637
+ if accelerator.is_local_main_process:
638
+ transformers.utils.logging.set_verbosity_warning()
639
+ diffusers.utils.logging.set_verbosity_info()
640
+ else:
641
+ transformers.utils.logging.set_verbosity_error()
642
+ diffusers.utils.logging.set_verbosity_error()
643
+
644
+ # If passed along, set the training seed now.
645
+ if args.seed is not None:
646
+ set_seed(args.seed)
647
+
648
+ # Handle the repository creation
649
+ if accelerator.is_main_process:
650
+ if args.output_dir is not None:
651
+ os.makedirs(args.output_dir, exist_ok=True)
652
+
653
+ if args.push_to_hub:
654
+ repo_id = create_repo(
655
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
656
+ ).repo_id
657
+
658
+ # Load the tokenizers
659
+ tokenizer_one = AutoTokenizer.from_pretrained(
660
+ args.pretrained_model_name_or_path,
661
+ subfolder="tokenizer",
662
+ revision=args.revision,
663
+ use_fast=False,
664
+ )
665
+ tokenizer_two = AutoTokenizer.from_pretrained(
666
+ args.pretrained_model_name_or_path,
667
+ subfolder="tokenizer_2",
668
+ revision=args.revision,
669
+ use_fast=False,
670
+ )
671
+
672
+ # import correct text encoder classes
673
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
674
+ args.pretrained_model_name_or_path, args.revision
675
+ )
676
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
677
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
678
+ )
679
+
680
+ # Load scheduler and models
681
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
682
+
683
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
684
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
685
+ )
686
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
687
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
688
+ )
689
+ vae_path = (
690
+ args.pretrained_model_name_or_path
691
+ if args.pretrained_vae_model_name_or_path is None
692
+ else args.pretrained_vae_model_name_or_path
693
+ )
694
+ vae = AutoencoderKL.from_pretrained(
695
+ vae_path,
696
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
697
+ revision=args.revision,
698
+ variant=args.variant,
699
+ )
700
+ unet = UNet2DConditionModel.from_pretrained(
701
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
702
+ )
703
+
704
+ # We only train the additional adapter LoRA layers
705
+ vae.requires_grad_(False)
706
+ text_encoder_one.requires_grad_(False)
707
+ text_encoder_two.requires_grad_(False)
708
+ unet.requires_grad_(False)
709
+
710
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
711
+ # as these weights are only used for inference, keeping weights in full precision is not required.
712
+ weight_dtype = torch.float32
713
+ if accelerator.mixed_precision == "fp16":
714
+ weight_dtype = torch.float16
715
+ elif accelerator.mixed_precision == "bf16":
716
+ weight_dtype = torch.bfloat16
717
+
718
+ # Move unet and text_encoders to device and cast to weight_dtype
719
+ unet.to(accelerator.device, dtype=weight_dtype)
720
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
721
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
722
+
723
+ # The VAE is always in float32 to avoid NaN losses.
724
+ vae.to(accelerator.device, dtype=torch.float32)
725
+
726
+ # Set up LoRA.
727
+ unet_lora_config = LoraConfig(
728
+ r=args.rank,
729
+ lora_alpha=args.rank,
730
+ init_lora_weights="gaussian",
731
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
732
+ )
733
+ # Add adapter and make sure the trainable params are in float32.
734
+ unet.add_adapter(unet_lora_config)
735
+ if args.mixed_precision == "fp16":
736
+ for param in unet.parameters():
737
+ # only upcast trainable parameters (LoRA) into fp32
738
+ if param.requires_grad:
739
+ param.data = param.to(torch.float32)
740
+
741
+ if args.enable_xformers_memory_efficient_attention:
742
+ if is_xformers_available():
743
+ import xformers
744
+
745
+ xformers_version = version.parse(xformers.__version__)
746
+ if xformers_version == version.parse("0.0.16"):
747
+ logger.warning(
748
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
749
+ )
750
+ unet.enable_xformers_memory_efficient_attention()
751
+ else:
752
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
753
+
754
+ if args.gradient_checkpointing:
755
+ unet.enable_gradient_checkpointing()
756
+
757
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
758
+ def save_model_hook(models, weights, output_dir):
759
+ if accelerator.is_main_process:
760
+ # there are only two options here. Either are just the unet attn processor layers
761
+ # or there are the unet and text encoder atten layers
762
+ unet_lora_layers_to_save = None
763
+
764
+ for model in models:
765
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
766
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
767
+ else:
768
+ raise ValueError(f"unexpected save model: {model.__class__}")
769
+
770
+ # make sure to pop weight so that corresponding model is not saved again
771
+ weights.pop()
772
+
773
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
774
+ output_dir,
775
+ unet_lora_layers=unet_lora_layers_to_save,
776
+ text_encoder_lora_layers=None,
777
+ text_encoder_2_lora_layers=None,
778
+ )
779
+
780
+ def load_model_hook(models, input_dir):
781
+ unet_ = None
782
+
783
+ while len(models) > 0:
784
+ model = models.pop()
785
+
786
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
787
+ unet_ = model
788
+ else:
789
+ raise ValueError(f"unexpected save model: {model.__class__}")
790
+
791
+ lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
792
+
793
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
794
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
795
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
796
+ if incompatible_keys is not None:
797
+ # check only for unexpected keys
798
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
799
+ if unexpected_keys:
800
+ logger.warning(
801
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
802
+ f" {unexpected_keys}. "
803
+ )
804
+
805
+ accelerator.register_save_state_pre_hook(save_model_hook)
806
+ accelerator.register_load_state_pre_hook(load_model_hook)
807
+
808
+ # Enable TF32 for faster training on Ampere GPUs,
809
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
810
+ if args.allow_tf32:
811
+ torch.backends.cuda.matmul.allow_tf32 = True
812
+
813
+ if args.scale_lr:
814
+ args.learning_rate = (
815
+ args.learning_rate * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
816
+ )
817
+
818
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
819
+ if args.use_8bit_adam:
820
+ try:
821
+ import bitsandbytes as bnb
822
+ except ImportError:
823
+ raise ImportError(
824
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
825
+ )
826
+
827
+ optimizer_class = bnb.optim.AdamW8bit
828
+ else:
829
+ optimizer_class = torch.optim.AdamW
830
+
831
+ # Optimizer creation
832
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
833
+ optimizer = optimizer_class(
834
+ params_to_optimize,
835
+ lr=args.learning_rate,
836
+ betas=(args.adam_beta1, args.adam_beta2),
837
+ weight_decay=args.adam_weight_decay,
838
+ eps=args.adam_epsilon,
839
+ )
840
+
841
+ # Dataset and DataLoaders creation:
842
+ args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
843
+ train_dataloader = get_loader(args, tokenizer_one=tokenizer_one, tokenizer_two=tokenizer_two)
844
+
845
+ # Scheduler and math around the number of training steps.
846
+ overrode_max_train_steps = False
847
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
848
+ if args.max_train_steps is None:
849
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
850
+ overrode_max_train_steps = True
851
+
852
+ lr_scheduler = get_scheduler(
853
+ args.lr_scheduler,
854
+ optimizer=optimizer,
855
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
856
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
857
+ num_cycles=args.lr_num_cycles,
858
+ power=args.lr_power,
859
+ )
860
+
861
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
862
+
863
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
864
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
865
+ if overrode_max_train_steps:
866
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
867
+ # Afterwards we recalculate our number of training epochs
868
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
869
+
870
+ # We need to initialize the trackers we use, and also store our configuration.
871
+ # The trackers initializes automatically on the main process.
872
+ if accelerator.is_main_process:
873
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
874
+
875
+ # Train!
876
+ total_batch_size = args.per_gpu_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
877
+
878
+ logger.info("***** Running training *****")
879
+ logger.info(f" Num examples = {train_dataloader.num_samples}")
880
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
881
+ logger.info(f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
882
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
883
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
884
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
885
+ global_step = 0
886
+ first_epoch = 0
887
+
888
+ # Potentially load in the weights and states from a previous save
889
+ if args.resume_from_checkpoint:
890
+ if args.resume_from_checkpoint != "latest":
891
+ path = os.path.basename(args.resume_from_checkpoint)
892
+ else:
893
+ # Get the mos recent checkpoint
894
+ dirs = os.listdir(args.output_dir)
895
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
896
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
897
+ path = dirs[-1] if len(dirs) > 0 else None
898
+
899
+ if path is None:
900
+ accelerator.print(
901
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
902
+ )
903
+ args.resume_from_checkpoint = None
904
+ initial_global_step = 0
905
+ else:
906
+ accelerator.print(f"Resuming from checkpoint {path}")
907
+ accelerator.load_state(os.path.join(args.output_dir, path))
908
+ global_step = int(path.split("-")[1])
909
+
910
+ initial_global_step = global_step
911
+ first_epoch = global_step // num_update_steps_per_epoch
912
+ else:
913
+ initial_global_step = 0
914
+
915
+ progress_bar = tqdm(
916
+ range(0, args.max_train_steps),
917
+ initial=initial_global_step,
918
+ desc="Steps",
919
+ # Only show the progress bar once on each machine.
920
+ disable=not accelerator.is_local_main_process,
921
+ )
922
+
923
+ unet.train()
924
+ for epoch in range(first_epoch, args.num_train_epochs):
925
+ for step, batch in enumerate(train_dataloader):
926
+ with accelerator.accumulate(unet):
927
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
928
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype, device=accelerator.device, non_blocking=True)
929
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
930
+
931
+ latents = []
932
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
933
+ latents.append(
934
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
935
+ )
936
+ latents = torch.cat(latents, dim=0)
937
+ latents = latents * vae.config.scaling_factor
938
+ if args.pretrained_vae_model_name_or_path is None:
939
+ latents = latents.to(weight_dtype)
940
+
941
+ # Sample noise that we'll add to the latents
942
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
943
+
944
+ # Sample a random timestep for each image
945
+ bsz = latents.shape[0] // 2
946
+ timesteps = torch.randint(
947
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
948
+ ).repeat(2)
949
+
950
+ # Add noise to the model input according to the noise magnitude at each timestep
951
+ # (this is the forward diffusion process)
952
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
953
+
954
+ # time ids
955
+ def compute_time_ids(original_size, crops_coords_top_left):
956
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
957
+ target_size = (args.resolution, args.resolution)
958
+ add_time_ids = list(tuple(original_size) + tuple(crops_coords_top_left) + target_size)
959
+ add_time_ids = torch.tensor([add_time_ids])
960
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
961
+ return add_time_ids
962
+
963
+ add_time_ids = torch.cat(
964
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
965
+ ).repeat(2, 1)
966
+
967
+ # Get the text embedding for conditioning
968
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
969
+ [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
970
+ )
971
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
972
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
973
+
974
+ # Predict the noise residual
975
+ model_pred = unet(
976
+ noisy_model_input,
977
+ timesteps,
978
+ prompt_embeds,
979
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
980
+ ).sample
981
+
982
+ # Get the target for loss depending on the prediction type
983
+ if noise_scheduler.config.prediction_type == "epsilon":
984
+ target = noise
985
+ elif noise_scheduler.config.prediction_type == "v_prediction":
986
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
987
+ else:
988
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
989
+
990
+ # ODDS ratio loss.
991
+ # In the diffusion formulation, we're assuming that the MSE loss
992
+ # approximates the logp.
993
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
994
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
995
+ model_losses_w, model_losses_l = model_losses.chunk(2)
996
+ log_odds = model_losses_w - model_losses_l
997
+
998
+ # Ratio loss.
999
+ ratio = F.logsigmoid(log_odds)
1000
+ ratio_losses = args.beta_orpo * ratio
1001
+
1002
+ # Full ORPO loss
1003
+ loss = model_losses_w.mean() - ratio_losses.mean()
1004
+
1005
+ # Backprop.
1006
+ accelerator.backward(loss)
1007
+ if accelerator.sync_gradients:
1008
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1009
+ optimizer.step()
1010
+ lr_scheduler.step()
1011
+ optimizer.zero_grad()
1012
+
1013
+ # Checks if the accelerator has performed an optimization step behind the scenes
1014
+ if accelerator.sync_gradients:
1015
+ progress_bar.update(1)
1016
+ global_step += 1
1017
+
1018
+ if accelerator.is_main_process:
1019
+ if global_step % args.checkpointing_steps == 0:
1020
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1021
+ if args.checkpoints_total_limit is not None:
1022
+ checkpoints = os.listdir(args.output_dir)
1023
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1024
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1025
+
1026
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1027
+ if len(checkpoints) >= args.checkpoints_total_limit:
1028
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1029
+ removing_checkpoints = checkpoints[0:num_to_remove]
1030
+
1031
+ logger.info(
1032
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1033
+ )
1034
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1035
+
1036
+ for removing_checkpoint in removing_checkpoints:
1037
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1038
+ shutil.rmtree(removing_checkpoint)
1039
+
1040
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1041
+ accelerator.save_state(save_path)
1042
+ logger.info(f"Saved state to {save_path}")
1043
+
1044
+ if args.run_validation and global_step % args.validation_steps == 0:
1045
+ log_validation(
1046
+ args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
1047
+ )
1048
+
1049
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1050
+ progress_bar.set_postfix(**logs)
1051
+ accelerator.log(logs, step=global_step)
1052
+
1053
+ if global_step >= args.max_train_steps:
1054
+ break
1055
+
1056
+ # Save the lora layers
1057
+ accelerator.wait_for_everyone()
1058
+ if accelerator.is_main_process:
1059
+ unet = accelerator.unwrap_model(unet)
1060
+ unet = unet.to(torch.float32)
1061
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
1062
+
1063
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
1064
+ save_directory=args.output_dir,
1065
+ unet_lora_layers=unet_lora_state_dict,
1066
+ text_encoder_lora_layers=None,
1067
+ text_encoder_2_lora_layers=None,
1068
+ )
1069
+
1070
+ # Final validation?
1071
+ if args.run_validation:
1072
+ log_validation(
1073
+ args,
1074
+ unet=None,
1075
+ vae=vae,
1076
+ accelerator=accelerator,
1077
+ weight_dtype=weight_dtype,
1078
+ epoch=epoch,
1079
+ is_final_validation=True,
1080
+ )
1081
+
1082
+ if args.push_to_hub:
1083
+ upload_folder(
1084
+ repo_id=repo_id,
1085
+ folder_path=args.output_dir,
1086
+ commit_message="End of training",
1087
+ ignore_patterns=["step_*", "epoch_*"],
1088
+ )
1089
+
1090
+ accelerator.end_training()
1091
+
1092
+
1093
+ if __name__ == "__main__":
1094
+ args = parse_args()
1095
+ main(args)
diffusers/examples/research_projects/dreambooth_inpaint/README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dreambooth for the inpainting model
2
+
3
+ This script was added by @thedarkzeno .
4
+
5
+ Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
6
+
7
+ ```bash
8
+ export MODEL_NAME="runwayml/stable-diffusion-inpainting"
9
+ export INSTANCE_DIR="path-to-instance-images"
10
+ export OUTPUT_DIR="path-to-save-model"
11
+
12
+ accelerate launch train_dreambooth_inpaint.py \
13
+ --pretrained_model_name_or_path=$MODEL_NAME \
14
+ --instance_data_dir=$INSTANCE_DIR \
15
+ --output_dir=$OUTPUT_DIR \
16
+ --instance_prompt="a photo of sks dog" \
17
+ --resolution=512 \
18
+ --train_batch_size=1 \
19
+ --gradient_accumulation_steps=1 \
20
+ --learning_rate=5e-6 \
21
+ --lr_scheduler="constant" \
22
+ --lr_warmup_steps=0 \
23
+ --max_train_steps=400
24
+ ```
25
+
26
+ ### Training with prior-preservation loss
27
+
28
+ Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
29
+ According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
30
+
31
+ ```bash
32
+ export MODEL_NAME="runwayml/stable-diffusion-inpainting"
33
+ export INSTANCE_DIR="path-to-instance-images"
34
+ export CLASS_DIR="path-to-class-images"
35
+ export OUTPUT_DIR="path-to-save-model"
36
+
37
+ accelerate launch train_dreambooth_inpaint.py \
38
+ --pretrained_model_name_or_path=$MODEL_NAME \
39
+ --instance_data_dir=$INSTANCE_DIR \
40
+ --class_data_dir=$CLASS_DIR \
41
+ --output_dir=$OUTPUT_DIR \
42
+ --with_prior_preservation --prior_loss_weight=1.0 \
43
+ --instance_prompt="a photo of sks dog" \
44
+ --class_prompt="a photo of dog" \
45
+ --resolution=512 \
46
+ --train_batch_size=1 \
47
+ --gradient_accumulation_steps=1 \
48
+ --learning_rate=5e-6 \
49
+ --lr_scheduler="constant" \
50
+ --lr_warmup_steps=0 \
51
+ --num_class_images=200 \
52
+ --max_train_steps=800
53
+ ```
54
+
55
+
56
+ ### Training with gradient checkpointing and 8-bit optimizer:
57
+
58
+ With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
59
+
60
+ To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
61
+
62
+ ```bash
63
+ export MODEL_NAME="runwayml/stable-diffusion-inpainting"
64
+ export INSTANCE_DIR="path-to-instance-images"
65
+ export CLASS_DIR="path-to-class-images"
66
+ export OUTPUT_DIR="path-to-save-model"
67
+
68
+ accelerate launch train_dreambooth_inpaint.py \
69
+ --pretrained_model_name_or_path=$MODEL_NAME \
70
+ --instance_data_dir=$INSTANCE_DIR \
71
+ --class_data_dir=$CLASS_DIR \
72
+ --output_dir=$OUTPUT_DIR \
73
+ --with_prior_preservation --prior_loss_weight=1.0 \
74
+ --instance_prompt="a photo of sks dog" \
75
+ --class_prompt="a photo of dog" \
76
+ --resolution=512 \
77
+ --train_batch_size=1 \
78
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
79
+ --use_8bit_adam \
80
+ --learning_rate=5e-6 \
81
+ --lr_scheduler="constant" \
82
+ --lr_warmup_steps=0 \
83
+ --num_class_images=200 \
84
+ --max_train_steps=800
85
+ ```
86
+
87
+ ### Fine-tune text encoder with the UNet.
88
+
89
+ The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
90
+ Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
91
+
92
+ ___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
93
+
94
+ ```bash
95
+ export MODEL_NAME="runwayml/stable-diffusion-inpainting"
96
+ export INSTANCE_DIR="path-to-instance-images"
97
+ export CLASS_DIR="path-to-class-images"
98
+ export OUTPUT_DIR="path-to-save-model"
99
+
100
+ accelerate launch train_dreambooth_inpaint.py \
101
+ --pretrained_model_name_or_path=$MODEL_NAME \
102
+ --train_text_encoder \
103
+ --instance_data_dir=$INSTANCE_DIR \
104
+ --class_data_dir=$CLASS_DIR \
105
+ --output_dir=$OUTPUT_DIR \
106
+ --with_prior_preservation --prior_loss_weight=1.0 \
107
+ --instance_prompt="a photo of sks dog" \
108
+ --class_prompt="a photo of dog" \
109
+ --resolution=512 \
110
+ --train_batch_size=1 \
111
+ --use_8bit_adam \
112
+ --gradient_checkpointing \
113
+ --learning_rate=2e-6 \
114
+ --lr_scheduler="constant" \
115
+ --lr_warmup_steps=0 \
116
+ --num_class_images=200 \
117
+ --max_train_steps=800
118
+ ```
diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.9.0
2
+ accelerate>=0.16.0
3
+ torchvision
4
+ transformers>=4.21.0
5
+ ftfy
6
+ tensorboard
7
+ Jinja2
diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from accelerate import Accelerator
13
+ from accelerate.logging import get_logger
14
+ from accelerate.utils import ProjectConfiguration, set_seed
15
+ from huggingface_hub import create_repo, upload_folder
16
+ from huggingface_hub.utils import insecure_hashlib
17
+ from PIL import Image, ImageDraw
18
+ from torch.utils.data import Dataset
19
+ from torchvision import transforms
20
+ from tqdm.auto import tqdm
21
+ from transformers import CLIPTextModel, CLIPTokenizer
22
+
23
+ from diffusers import (
24
+ AutoencoderKL,
25
+ DDPMScheduler,
26
+ StableDiffusionInpaintPipeline,
27
+ StableDiffusionPipeline,
28
+ UNet2DConditionModel,
29
+ )
30
+ from diffusers.optimization import get_scheduler
31
+ from diffusers.utils import check_min_version
32
+
33
+
34
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
35
+ check_min_version("0.13.0.dev0")
36
+
37
+ logger = get_logger(__name__)
38
+
39
+
40
+ def prepare_mask_and_masked_image(image, mask):
41
+ image = np.array(image.convert("RGB"))
42
+ image = image[None].transpose(0, 3, 1, 2)
43
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
44
+
45
+ mask = np.array(mask.convert("L"))
46
+ mask = mask.astype(np.float32) / 255.0
47
+ mask = mask[None, None]
48
+ mask[mask < 0.5] = 0
49
+ mask[mask >= 0.5] = 1
50
+ mask = torch.from_numpy(mask)
51
+
52
+ masked_image = image * (mask < 0.5)
53
+
54
+ return mask, masked_image
55
+
56
+
57
+ # generate random masks
58
+ def random_mask(im_shape, ratio=1, mask_full_image=False):
59
+ mask = Image.new("L", im_shape, 0)
60
+ draw = ImageDraw.Draw(mask)
61
+ size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
62
+ # use this to always mask the whole image
63
+ if mask_full_image:
64
+ size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
65
+ limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
66
+ center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
67
+ draw_type = random.randint(0, 1)
68
+ if draw_type == 0 or mask_full_image:
69
+ draw.rectangle(
70
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
71
+ fill=255,
72
+ )
73
+ else:
74
+ draw.ellipse(
75
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
76
+ fill=255,
77
+ )
78
+
79
+ return mask
80
+
81
+
82
+ def parse_args():
83
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
84
+ parser.add_argument(
85
+ "--pretrained_model_name_or_path",
86
+ type=str,
87
+ default=None,
88
+ required=True,
89
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
90
+ )
91
+ parser.add_argument(
92
+ "--tokenizer_name",
93
+ type=str,
94
+ default=None,
95
+ help="Pretrained tokenizer name or path if not the same as model_name",
96
+ )
97
+ parser.add_argument(
98
+ "--instance_data_dir",
99
+ type=str,
100
+ default=None,
101
+ required=True,
102
+ help="A folder containing the training data of instance images.",
103
+ )
104
+ parser.add_argument(
105
+ "--class_data_dir",
106
+ type=str,
107
+ default=None,
108
+ required=False,
109
+ help="A folder containing the training data of class images.",
110
+ )
111
+ parser.add_argument(
112
+ "--instance_prompt",
113
+ type=str,
114
+ default=None,
115
+ help="The prompt with identifier specifying the instance",
116
+ )
117
+ parser.add_argument(
118
+ "--class_prompt",
119
+ type=str,
120
+ default=None,
121
+ help="The prompt to specify images in the same class as provided instance images.",
122
+ )
123
+ parser.add_argument(
124
+ "--with_prior_preservation",
125
+ default=False,
126
+ action="store_true",
127
+ help="Flag to add prior preservation loss.",
128
+ )
129
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
130
+ parser.add_argument(
131
+ "--num_class_images",
132
+ type=int,
133
+ default=100,
134
+ help=(
135
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
136
+ " sampled with class_prompt."
137
+ ),
138
+ )
139
+ parser.add_argument(
140
+ "--output_dir",
141
+ type=str,
142
+ default="text-inversion-model",
143
+ help="The output directory where the model predictions and checkpoints will be written.",
144
+ )
145
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
146
+ parser.add_argument(
147
+ "--resolution",
148
+ type=int,
149
+ default=512,
150
+ help=(
151
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
152
+ " resolution"
153
+ ),
154
+ )
155
+ parser.add_argument(
156
+ "--center_crop",
157
+ default=False,
158
+ action="store_true",
159
+ help=(
160
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
161
+ " cropped. The images will be resized to the resolution first before cropping."
162
+ ),
163
+ )
164
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
165
+ parser.add_argument(
166
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
167
+ )
168
+ parser.add_argument(
169
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
170
+ )
171
+ parser.add_argument("--num_train_epochs", type=int, default=1)
172
+ parser.add_argument(
173
+ "--max_train_steps",
174
+ type=int,
175
+ default=None,
176
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
177
+ )
178
+ parser.add_argument(
179
+ "--gradient_accumulation_steps",
180
+ type=int,
181
+ default=1,
182
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
183
+ )
184
+ parser.add_argument(
185
+ "--gradient_checkpointing",
186
+ action="store_true",
187
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
188
+ )
189
+ parser.add_argument(
190
+ "--learning_rate",
191
+ type=float,
192
+ default=5e-6,
193
+ help="Initial learning rate (after the potential warmup period) to use.",
194
+ )
195
+ parser.add_argument(
196
+ "--scale_lr",
197
+ action="store_true",
198
+ default=False,
199
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
200
+ )
201
+ parser.add_argument(
202
+ "--lr_scheduler",
203
+ type=str,
204
+ default="constant",
205
+ help=(
206
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
207
+ ' "constant", "constant_with_warmup"]'
208
+ ),
209
+ )
210
+ parser.add_argument(
211
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
212
+ )
213
+ parser.add_argument(
214
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
215
+ )
216
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
217
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
218
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
219
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
220
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
221
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
222
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
223
+ parser.add_argument(
224
+ "--hub_model_id",
225
+ type=str,
226
+ default=None,
227
+ help="The name of the repository to keep in sync with the local `output_dir`.",
228
+ )
229
+ parser.add_argument(
230
+ "--logging_dir",
231
+ type=str,
232
+ default="logs",
233
+ help=(
234
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
235
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
236
+ ),
237
+ )
238
+ parser.add_argument(
239
+ "--mixed_precision",
240
+ type=str,
241
+ default="no",
242
+ choices=["no", "fp16", "bf16"],
243
+ help=(
244
+ "Whether to use mixed precision. Choose"
245
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
246
+ "and an Nvidia Ampere GPU."
247
+ ),
248
+ )
249
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
250
+ parser.add_argument(
251
+ "--checkpointing_steps",
252
+ type=int,
253
+ default=500,
254
+ help=(
255
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
256
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
257
+ " using `--resume_from_checkpoint`."
258
+ ),
259
+ )
260
+ parser.add_argument(
261
+ "--checkpoints_total_limit",
262
+ type=int,
263
+ default=None,
264
+ help=(
265
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
266
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
267
+ " for more docs"
268
+ ),
269
+ )
270
+ parser.add_argument(
271
+ "--resume_from_checkpoint",
272
+ type=str,
273
+ default=None,
274
+ help=(
275
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
276
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
277
+ ),
278
+ )
279
+
280
+ args = parser.parse_args()
281
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
282
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
283
+ args.local_rank = env_local_rank
284
+
285
+ if args.instance_data_dir is None:
286
+ raise ValueError("You must specify a train data directory.")
287
+
288
+ if args.with_prior_preservation:
289
+ if args.class_data_dir is None:
290
+ raise ValueError("You must specify a data directory for class images.")
291
+ if args.class_prompt is None:
292
+ raise ValueError("You must specify prompt for class images.")
293
+
294
+ return args
295
+
296
+
297
+ class DreamBoothDataset(Dataset):
298
+ """
299
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
300
+ It pre-processes the images and the tokenizes prompts.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ instance_data_root,
306
+ instance_prompt,
307
+ tokenizer,
308
+ class_data_root=None,
309
+ class_prompt=None,
310
+ size=512,
311
+ center_crop=False,
312
+ ):
313
+ self.size = size
314
+ self.center_crop = center_crop
315
+ self.tokenizer = tokenizer
316
+
317
+ self.instance_data_root = Path(instance_data_root)
318
+ if not self.instance_data_root.exists():
319
+ raise ValueError("Instance images root doesn't exists.")
320
+
321
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
322
+ self.num_instance_images = len(self.instance_images_path)
323
+ self.instance_prompt = instance_prompt
324
+ self._length = self.num_instance_images
325
+
326
+ if class_data_root is not None:
327
+ self.class_data_root = Path(class_data_root)
328
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
329
+ self.class_images_path = list(self.class_data_root.iterdir())
330
+ self.num_class_images = len(self.class_images_path)
331
+ self._length = max(self.num_class_images, self.num_instance_images)
332
+ self.class_prompt = class_prompt
333
+ else:
334
+ self.class_data_root = None
335
+
336
+ self.image_transforms_resize_and_crop = transforms.Compose(
337
+ [
338
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
339
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
340
+ ]
341
+ )
342
+
343
+ self.image_transforms = transforms.Compose(
344
+ [
345
+ transforms.ToTensor(),
346
+ transforms.Normalize([0.5], [0.5]),
347
+ ]
348
+ )
349
+
350
+ def __len__(self):
351
+ return self._length
352
+
353
+ def __getitem__(self, index):
354
+ example = {}
355
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
356
+ if not instance_image.mode == "RGB":
357
+ instance_image = instance_image.convert("RGB")
358
+ instance_image = self.image_transforms_resize_and_crop(instance_image)
359
+
360
+ example["PIL_images"] = instance_image
361
+ example["instance_images"] = self.image_transforms(instance_image)
362
+
363
+ example["instance_prompt_ids"] = self.tokenizer(
364
+ self.instance_prompt,
365
+ padding="do_not_pad",
366
+ truncation=True,
367
+ max_length=self.tokenizer.model_max_length,
368
+ ).input_ids
369
+
370
+ if self.class_data_root:
371
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
372
+ if not class_image.mode == "RGB":
373
+ class_image = class_image.convert("RGB")
374
+ class_image = self.image_transforms_resize_and_crop(class_image)
375
+ example["class_images"] = self.image_transforms(class_image)
376
+ example["class_PIL_images"] = class_image
377
+ example["class_prompt_ids"] = self.tokenizer(
378
+ self.class_prompt,
379
+ padding="do_not_pad",
380
+ truncation=True,
381
+ max_length=self.tokenizer.model_max_length,
382
+ ).input_ids
383
+
384
+ return example
385
+
386
+
387
+ class PromptDataset(Dataset):
388
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
389
+
390
+ def __init__(self, prompt, num_samples):
391
+ self.prompt = prompt
392
+ self.num_samples = num_samples
393
+
394
+ def __len__(self):
395
+ return self.num_samples
396
+
397
+ def __getitem__(self, index):
398
+ example = {}
399
+ example["prompt"] = self.prompt
400
+ example["index"] = index
401
+ return example
402
+
403
+
404
+ def main():
405
+ args = parse_args()
406
+ logging_dir = Path(args.output_dir, args.logging_dir)
407
+
408
+ project_config = ProjectConfiguration(
409
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
410
+ )
411
+
412
+ accelerator = Accelerator(
413
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
414
+ mixed_precision=args.mixed_precision,
415
+ log_with="tensorboard",
416
+ project_config=project_config,
417
+ )
418
+
419
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
420
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
421
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
422
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
423
+ raise ValueError(
424
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
425
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
426
+ )
427
+
428
+ if args.seed is not None:
429
+ set_seed(args.seed)
430
+
431
+ if args.with_prior_preservation:
432
+ class_images_dir = Path(args.class_data_dir)
433
+ if not class_images_dir.exists():
434
+ class_images_dir.mkdir(parents=True)
435
+ cur_class_images = len(list(class_images_dir.iterdir()))
436
+
437
+ if cur_class_images < args.num_class_images:
438
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
439
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
440
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
441
+ )
442
+ pipeline.set_progress_bar_config(disable=True)
443
+
444
+ num_new_images = args.num_class_images - cur_class_images
445
+ logger.info(f"Number of class images to sample: {num_new_images}.")
446
+
447
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
448
+ sample_dataloader = torch.utils.data.DataLoader(
449
+ sample_dataset, batch_size=args.sample_batch_size, num_workers=1
450
+ )
451
+
452
+ sample_dataloader = accelerator.prepare(sample_dataloader)
453
+ pipeline.to(accelerator.device)
454
+ transform_to_pil = transforms.ToPILImage()
455
+ for example in tqdm(
456
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
457
+ ):
458
+ bsz = len(example["prompt"])
459
+ fake_images = torch.rand((3, args.resolution, args.resolution))
460
+ transform_to_pil = transforms.ToPILImage()
461
+ fake_pil_images = transform_to_pil(fake_images)
462
+
463
+ fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
464
+
465
+ images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
466
+
467
+ for i, image in enumerate(images):
468
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
469
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
470
+ image.save(image_filename)
471
+
472
+ del pipeline
473
+ if torch.cuda.is_available():
474
+ torch.cuda.empty_cache()
475
+
476
+ # Handle the repository creation
477
+ if accelerator.is_main_process:
478
+ if args.output_dir is not None:
479
+ os.makedirs(args.output_dir, exist_ok=True)
480
+
481
+ if args.push_to_hub:
482
+ repo_id = create_repo(
483
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
484
+ ).repo_id
485
+
486
+ # Load the tokenizer
487
+ if args.tokenizer_name:
488
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
489
+ elif args.pretrained_model_name_or_path:
490
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
491
+
492
+ # Load models and create wrapper for stable diffusion
493
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
494
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
495
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
496
+
497
+ vae.requires_grad_(False)
498
+ if not args.train_text_encoder:
499
+ text_encoder.requires_grad_(False)
500
+
501
+ if args.gradient_checkpointing:
502
+ unet.enable_gradient_checkpointing()
503
+ if args.train_text_encoder:
504
+ text_encoder.gradient_checkpointing_enable()
505
+
506
+ if args.scale_lr:
507
+ args.learning_rate = (
508
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
509
+ )
510
+
511
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
512
+ if args.use_8bit_adam:
513
+ try:
514
+ import bitsandbytes as bnb
515
+ except ImportError:
516
+ raise ImportError(
517
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
518
+ )
519
+
520
+ optimizer_class = bnb.optim.AdamW8bit
521
+ else:
522
+ optimizer_class = torch.optim.AdamW
523
+
524
+ params_to_optimize = (
525
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
526
+ )
527
+ optimizer = optimizer_class(
528
+ params_to_optimize,
529
+ lr=args.learning_rate,
530
+ betas=(args.adam_beta1, args.adam_beta2),
531
+ weight_decay=args.adam_weight_decay,
532
+ eps=args.adam_epsilon,
533
+ )
534
+
535
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
536
+
537
+ train_dataset = DreamBoothDataset(
538
+ instance_data_root=args.instance_data_dir,
539
+ instance_prompt=args.instance_prompt,
540
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
541
+ class_prompt=args.class_prompt,
542
+ tokenizer=tokenizer,
543
+ size=args.resolution,
544
+ center_crop=args.center_crop,
545
+ )
546
+
547
+ def collate_fn(examples):
548
+ input_ids = [example["instance_prompt_ids"] for example in examples]
549
+ pixel_values = [example["instance_images"] for example in examples]
550
+
551
+ # Concat class and instance examples for prior preservation.
552
+ # We do this to avoid doing two forward passes.
553
+ if args.with_prior_preservation:
554
+ input_ids += [example["class_prompt_ids"] for example in examples]
555
+ pixel_values += [example["class_images"] for example in examples]
556
+ pior_pil = [example["class_PIL_images"] for example in examples]
557
+
558
+ masks = []
559
+ masked_images = []
560
+ for example in examples:
561
+ pil_image = example["PIL_images"]
562
+ # generate a random mask
563
+ mask = random_mask(pil_image.size, 1, False)
564
+ # prepare mask and masked image
565
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
566
+
567
+ masks.append(mask)
568
+ masked_images.append(masked_image)
569
+
570
+ if args.with_prior_preservation:
571
+ for pil_image in pior_pil:
572
+ # generate a random mask
573
+ mask = random_mask(pil_image.size, 1, False)
574
+ # prepare mask and masked image
575
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
576
+
577
+ masks.append(mask)
578
+ masked_images.append(masked_image)
579
+
580
+ pixel_values = torch.stack(pixel_values)
581
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
582
+
583
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
584
+ masks = torch.stack(masks)
585
+ masked_images = torch.stack(masked_images)
586
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
587
+ return batch
588
+
589
+ train_dataloader = torch.utils.data.DataLoader(
590
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
591
+ )
592
+
593
+ # Scheduler and math around the number of training steps.
594
+ overrode_max_train_steps = False
595
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
596
+ if args.max_train_steps is None:
597
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
598
+ overrode_max_train_steps = True
599
+
600
+ lr_scheduler = get_scheduler(
601
+ args.lr_scheduler,
602
+ optimizer=optimizer,
603
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
604
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
605
+ )
606
+
607
+ if args.train_text_encoder:
608
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
609
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
610
+ )
611
+ else:
612
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
613
+ unet, optimizer, train_dataloader, lr_scheduler
614
+ )
615
+ accelerator.register_for_checkpointing(lr_scheduler)
616
+
617
+ weight_dtype = torch.float32
618
+ if args.mixed_precision == "fp16":
619
+ weight_dtype = torch.float16
620
+ elif args.mixed_precision == "bf16":
621
+ weight_dtype = torch.bfloat16
622
+
623
+ # Move text_encode and vae to gpu.
624
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
625
+ # as these models are only used for inference, keeping weights in full precision is not required.
626
+ vae.to(accelerator.device, dtype=weight_dtype)
627
+ if not args.train_text_encoder:
628
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
629
+
630
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
631
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
632
+ if overrode_max_train_steps:
633
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
634
+ # Afterwards we recalculate our number of training epochs
635
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
636
+
637
+ # We need to initialize the trackers we use, and also store our configuration.
638
+ # The trackers initializes automatically on the main process.
639
+ if accelerator.is_main_process:
640
+ accelerator.init_trackers("dreambooth", config=vars(args))
641
+
642
+ # Train!
643
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
644
+
645
+ logger.info("***** Running training *****")
646
+ logger.info(f" Num examples = {len(train_dataset)}")
647
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
648
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
649
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
650
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
651
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
652
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
653
+ global_step = 0
654
+ first_epoch = 0
655
+
656
+ if args.resume_from_checkpoint:
657
+ if args.resume_from_checkpoint != "latest":
658
+ path = os.path.basename(args.resume_from_checkpoint)
659
+ else:
660
+ # Get the most recent checkpoint
661
+ dirs = os.listdir(args.output_dir)
662
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
663
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
664
+ path = dirs[-1] if len(dirs) > 0 else None
665
+
666
+ if path is None:
667
+ accelerator.print(
668
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
669
+ )
670
+ args.resume_from_checkpoint = None
671
+ else:
672
+ accelerator.print(f"Resuming from checkpoint {path}")
673
+ accelerator.load_state(os.path.join(args.output_dir, path))
674
+ global_step = int(path.split("-")[1])
675
+
676
+ resume_global_step = global_step * args.gradient_accumulation_steps
677
+ first_epoch = global_step // num_update_steps_per_epoch
678
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
679
+
680
+ # Only show the progress bar once on each machine.
681
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
682
+ progress_bar.set_description("Steps")
683
+
684
+ for epoch in range(first_epoch, args.num_train_epochs):
685
+ unet.train()
686
+ for step, batch in enumerate(train_dataloader):
687
+ # Skip steps until we reach the resumed step
688
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
689
+ if step % args.gradient_accumulation_steps == 0:
690
+ progress_bar.update(1)
691
+ continue
692
+
693
+ with accelerator.accumulate(unet):
694
+ # Convert images to latent space
695
+
696
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
697
+ latents = latents * vae.config.scaling_factor
698
+
699
+ # Convert masked images to latent space
700
+ masked_latents = vae.encode(
701
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
702
+ ).latent_dist.sample()
703
+ masked_latents = masked_latents * vae.config.scaling_factor
704
+
705
+ masks = batch["masks"]
706
+ # resize the mask to latents shape as we concatenate the mask to the latents
707
+ mask = torch.stack(
708
+ [
709
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
710
+ for mask in masks
711
+ ]
712
+ )
713
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
714
+
715
+ # Sample noise that we'll add to the latents
716
+ noise = torch.randn_like(latents)
717
+ bsz = latents.shape[0]
718
+ # Sample a random timestep for each image
719
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
720
+ timesteps = timesteps.long()
721
+
722
+ # Add noise to the latents according to the noise magnitude at each timestep
723
+ # (this is the forward diffusion process)
724
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
725
+
726
+ # concatenate the noised latents with the mask and the masked latents
727
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
728
+
729
+ # Get the text embedding for conditioning
730
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
731
+
732
+ # Predict the noise residual
733
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
734
+
735
+ # Get the target for loss depending on the prediction type
736
+ if noise_scheduler.config.prediction_type == "epsilon":
737
+ target = noise
738
+ elif noise_scheduler.config.prediction_type == "v_prediction":
739
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
740
+ else:
741
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
742
+
743
+ if args.with_prior_preservation:
744
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
745
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
746
+ target, target_prior = torch.chunk(target, 2, dim=0)
747
+
748
+ # Compute instance loss
749
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
750
+
751
+ # Compute prior loss
752
+ prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
753
+
754
+ # Add the prior loss to the instance loss.
755
+ loss = loss + args.prior_loss_weight * prior_loss
756
+ else:
757
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
758
+
759
+ accelerator.backward(loss)
760
+ if accelerator.sync_gradients:
761
+ params_to_clip = (
762
+ itertools.chain(unet.parameters(), text_encoder.parameters())
763
+ if args.train_text_encoder
764
+ else unet.parameters()
765
+ )
766
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
767
+ optimizer.step()
768
+ lr_scheduler.step()
769
+ optimizer.zero_grad()
770
+
771
+ # Checks if the accelerator has performed an optimization step behind the scenes
772
+ if accelerator.sync_gradients:
773
+ progress_bar.update(1)
774
+ global_step += 1
775
+
776
+ if global_step % args.checkpointing_steps == 0:
777
+ if accelerator.is_main_process:
778
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
779
+ accelerator.save_state(save_path)
780
+ logger.info(f"Saved state to {save_path}")
781
+
782
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
783
+ progress_bar.set_postfix(**logs)
784
+ accelerator.log(logs, step=global_step)
785
+
786
+ if global_step >= args.max_train_steps:
787
+ break
788
+
789
+ accelerator.wait_for_everyone()
790
+
791
+ # Create the pipeline using using the trained modules and save it.
792
+ if accelerator.is_main_process:
793
+ pipeline = StableDiffusionPipeline.from_pretrained(
794
+ args.pretrained_model_name_or_path,
795
+ unet=accelerator.unwrap_model(unet),
796
+ text_encoder=accelerator.unwrap_model(text_encoder),
797
+ )
798
+ pipeline.save_pretrained(args.output_dir)
799
+
800
+ if args.push_to_hub:
801
+ upload_folder(
802
+ repo_id=repo_id,
803
+ folder_path=args.output_dir,
804
+ commit_message="End of training",
805
+ ignore_patterns=["step_*", "epoch_*"],
806
+ )
807
+
808
+ accelerator.end_training()
809
+
810
+
811
+ if __name__ == "__main__":
812
+ main()
diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from accelerate import Accelerator
12
+ from accelerate.logging import get_logger
13
+ from accelerate.utils import ProjectConfiguration, set_seed
14
+ from huggingface_hub import create_repo, upload_folder
15
+ from huggingface_hub.utils import insecure_hashlib
16
+ from PIL import Image, ImageDraw
17
+ from torch.utils.data import Dataset
18
+ from torchvision import transforms
19
+ from tqdm.auto import tqdm
20
+ from transformers import CLIPTextModel, CLIPTokenizer
21
+
22
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
23
+ from diffusers.loaders import AttnProcsLayers
24
+ from diffusers.models.attention_processor import LoRAAttnProcessor
25
+ from diffusers.optimization import get_scheduler
26
+ from diffusers.utils import check_min_version
27
+ from diffusers.utils.import_utils import is_xformers_available
28
+
29
+
30
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
31
+ check_min_version("0.13.0.dev0")
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ def prepare_mask_and_masked_image(image, mask):
37
+ image = np.array(image.convert("RGB"))
38
+ image = image[None].transpose(0, 3, 1, 2)
39
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
40
+
41
+ mask = np.array(mask.convert("L"))
42
+ mask = mask.astype(np.float32) / 255.0
43
+ mask = mask[None, None]
44
+ mask[mask < 0.5] = 0
45
+ mask[mask >= 0.5] = 1
46
+ mask = torch.from_numpy(mask)
47
+
48
+ masked_image = image * (mask < 0.5)
49
+
50
+ return mask, masked_image
51
+
52
+
53
+ # generate random masks
54
+ def random_mask(im_shape, ratio=1, mask_full_image=False):
55
+ mask = Image.new("L", im_shape, 0)
56
+ draw = ImageDraw.Draw(mask)
57
+ size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
58
+ # use this to always mask the whole image
59
+ if mask_full_image:
60
+ size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
61
+ limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
62
+ center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
63
+ draw_type = random.randint(0, 1)
64
+ if draw_type == 0 or mask_full_image:
65
+ draw.rectangle(
66
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
67
+ fill=255,
68
+ )
69
+ else:
70
+ draw.ellipse(
71
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
72
+ fill=255,
73
+ )
74
+
75
+ return mask
76
+
77
+
78
+ def parse_args():
79
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
80
+ parser.add_argument(
81
+ "--pretrained_model_name_or_path",
82
+ type=str,
83
+ default=None,
84
+ required=True,
85
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
86
+ )
87
+ parser.add_argument(
88
+ "--tokenizer_name",
89
+ type=str,
90
+ default=None,
91
+ help="Pretrained tokenizer name or path if not the same as model_name",
92
+ )
93
+ parser.add_argument(
94
+ "--instance_data_dir",
95
+ type=str,
96
+ default=None,
97
+ required=True,
98
+ help="A folder containing the training data of instance images.",
99
+ )
100
+ parser.add_argument(
101
+ "--class_data_dir",
102
+ type=str,
103
+ default=None,
104
+ required=False,
105
+ help="A folder containing the training data of class images.",
106
+ )
107
+ parser.add_argument(
108
+ "--instance_prompt",
109
+ type=str,
110
+ default=None,
111
+ help="The prompt with identifier specifying the instance",
112
+ )
113
+ parser.add_argument(
114
+ "--class_prompt",
115
+ type=str,
116
+ default=None,
117
+ help="The prompt to specify images in the same class as provided instance images.",
118
+ )
119
+ parser.add_argument(
120
+ "--with_prior_preservation",
121
+ default=False,
122
+ action="store_true",
123
+ help="Flag to add prior preservation loss.",
124
+ )
125
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
126
+ parser.add_argument(
127
+ "--num_class_images",
128
+ type=int,
129
+ default=100,
130
+ help=(
131
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
132
+ " sampled with class_prompt."
133
+ ),
134
+ )
135
+ parser.add_argument(
136
+ "--output_dir",
137
+ type=str,
138
+ default="dreambooth-inpaint-model",
139
+ help="The output directory where the model predictions and checkpoints will be written.",
140
+ )
141
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
142
+ parser.add_argument(
143
+ "--resolution",
144
+ type=int,
145
+ default=512,
146
+ help=(
147
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
148
+ " resolution"
149
+ ),
150
+ )
151
+ parser.add_argument(
152
+ "--center_crop",
153
+ default=False,
154
+ action="store_true",
155
+ help=(
156
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
157
+ " cropped. The images will be resized to the resolution first before cropping."
158
+ ),
159
+ )
160
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
161
+ parser.add_argument(
162
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
163
+ )
164
+ parser.add_argument(
165
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
166
+ )
167
+ parser.add_argument("--num_train_epochs", type=int, default=1)
168
+ parser.add_argument(
169
+ "--max_train_steps",
170
+ type=int,
171
+ default=None,
172
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
173
+ )
174
+ parser.add_argument(
175
+ "--gradient_accumulation_steps",
176
+ type=int,
177
+ default=1,
178
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
179
+ )
180
+ parser.add_argument(
181
+ "--gradient_checkpointing",
182
+ action="store_true",
183
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
184
+ )
185
+ parser.add_argument(
186
+ "--learning_rate",
187
+ type=float,
188
+ default=5e-6,
189
+ help="Initial learning rate (after the potential warmup period) to use.",
190
+ )
191
+ parser.add_argument(
192
+ "--scale_lr",
193
+ action="store_true",
194
+ default=False,
195
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
196
+ )
197
+ parser.add_argument(
198
+ "--lr_scheduler",
199
+ type=str,
200
+ default="constant",
201
+ help=(
202
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
203
+ ' "constant", "constant_with_warmup"]'
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
208
+ )
209
+ parser.add_argument(
210
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
211
+ )
212
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
213
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
214
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
215
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
216
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
217
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
218
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
219
+ parser.add_argument(
220
+ "--hub_model_id",
221
+ type=str,
222
+ default=None,
223
+ help="The name of the repository to keep in sync with the local `output_dir`.",
224
+ )
225
+ parser.add_argument(
226
+ "--logging_dir",
227
+ type=str,
228
+ default="logs",
229
+ help=(
230
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
231
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
232
+ ),
233
+ )
234
+ parser.add_argument(
235
+ "--mixed_precision",
236
+ type=str,
237
+ default="no",
238
+ choices=["no", "fp16", "bf16"],
239
+ help=(
240
+ "Whether to use mixed precision. Choose"
241
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
242
+ "and an Nvidia Ampere GPU."
243
+ ),
244
+ )
245
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
246
+ parser.add_argument(
247
+ "--checkpointing_steps",
248
+ type=int,
249
+ default=500,
250
+ help=(
251
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
252
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
253
+ " using `--resume_from_checkpoint`."
254
+ ),
255
+ )
256
+ parser.add_argument(
257
+ "--checkpoints_total_limit",
258
+ type=int,
259
+ default=None,
260
+ help=(
261
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
262
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
263
+ " for more docs"
264
+ ),
265
+ )
266
+ parser.add_argument(
267
+ "--resume_from_checkpoint",
268
+ type=str,
269
+ default=None,
270
+ help=(
271
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
272
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
273
+ ),
274
+ )
275
+ parser.add_argument(
276
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
277
+ )
278
+
279
+ args = parser.parse_args()
280
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
281
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
282
+ args.local_rank = env_local_rank
283
+
284
+ if args.instance_data_dir is None:
285
+ raise ValueError("You must specify a train data directory.")
286
+
287
+ if args.with_prior_preservation:
288
+ if args.class_data_dir is None:
289
+ raise ValueError("You must specify a data directory for class images.")
290
+ if args.class_prompt is None:
291
+ raise ValueError("You must specify prompt for class images.")
292
+
293
+ return args
294
+
295
+
296
+ class DreamBoothDataset(Dataset):
297
+ """
298
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
299
+ It pre-processes the images and the tokenizes prompts.
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ instance_data_root,
305
+ instance_prompt,
306
+ tokenizer,
307
+ class_data_root=None,
308
+ class_prompt=None,
309
+ size=512,
310
+ center_crop=False,
311
+ ):
312
+ self.size = size
313
+ self.center_crop = center_crop
314
+ self.tokenizer = tokenizer
315
+
316
+ self.instance_data_root = Path(instance_data_root)
317
+ if not self.instance_data_root.exists():
318
+ raise ValueError("Instance images root doesn't exists.")
319
+
320
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
321
+ self.num_instance_images = len(self.instance_images_path)
322
+ self.instance_prompt = instance_prompt
323
+ self._length = self.num_instance_images
324
+
325
+ if class_data_root is not None:
326
+ self.class_data_root = Path(class_data_root)
327
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
328
+ self.class_images_path = list(self.class_data_root.iterdir())
329
+ self.num_class_images = len(self.class_images_path)
330
+ self._length = max(self.num_class_images, self.num_instance_images)
331
+ self.class_prompt = class_prompt
332
+ else:
333
+ self.class_data_root = None
334
+
335
+ self.image_transforms_resize_and_crop = transforms.Compose(
336
+ [
337
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
338
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
339
+ ]
340
+ )
341
+
342
+ self.image_transforms = transforms.Compose(
343
+ [
344
+ transforms.ToTensor(),
345
+ transforms.Normalize([0.5], [0.5]),
346
+ ]
347
+ )
348
+
349
+ def __len__(self):
350
+ return self._length
351
+
352
+ def __getitem__(self, index):
353
+ example = {}
354
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
355
+ if not instance_image.mode == "RGB":
356
+ instance_image = instance_image.convert("RGB")
357
+ instance_image = self.image_transforms_resize_and_crop(instance_image)
358
+
359
+ example["PIL_images"] = instance_image
360
+ example["instance_images"] = self.image_transforms(instance_image)
361
+
362
+ example["instance_prompt_ids"] = self.tokenizer(
363
+ self.instance_prompt,
364
+ padding="do_not_pad",
365
+ truncation=True,
366
+ max_length=self.tokenizer.model_max_length,
367
+ ).input_ids
368
+
369
+ if self.class_data_root:
370
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
371
+ if not class_image.mode == "RGB":
372
+ class_image = class_image.convert("RGB")
373
+ class_image = self.image_transforms_resize_and_crop(class_image)
374
+ example["class_images"] = self.image_transforms(class_image)
375
+ example["class_PIL_images"] = class_image
376
+ example["class_prompt_ids"] = self.tokenizer(
377
+ self.class_prompt,
378
+ padding="do_not_pad",
379
+ truncation=True,
380
+ max_length=self.tokenizer.model_max_length,
381
+ ).input_ids
382
+
383
+ return example
384
+
385
+
386
+ class PromptDataset(Dataset):
387
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
388
+
389
+ def __init__(self, prompt, num_samples):
390
+ self.prompt = prompt
391
+ self.num_samples = num_samples
392
+
393
+ def __len__(self):
394
+ return self.num_samples
395
+
396
+ def __getitem__(self, index):
397
+ example = {}
398
+ example["prompt"] = self.prompt
399
+ example["index"] = index
400
+ return example
401
+
402
+
403
+ def main():
404
+ args = parse_args()
405
+ logging_dir = Path(args.output_dir, args.logging_dir)
406
+
407
+ accelerator_project_config = ProjectConfiguration(
408
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
409
+ )
410
+
411
+ accelerator = Accelerator(
412
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
413
+ mixed_precision=args.mixed_precision,
414
+ log_with="tensorboard",
415
+ project_config=accelerator_project_config,
416
+ )
417
+
418
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
419
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
420
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
421
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
422
+ raise ValueError(
423
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
424
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
425
+ )
426
+
427
+ if args.seed is not None:
428
+ set_seed(args.seed)
429
+
430
+ if args.with_prior_preservation:
431
+ class_images_dir = Path(args.class_data_dir)
432
+ if not class_images_dir.exists():
433
+ class_images_dir.mkdir(parents=True)
434
+ cur_class_images = len(list(class_images_dir.iterdir()))
435
+
436
+ if cur_class_images < args.num_class_images:
437
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
438
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
439
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
440
+ )
441
+ pipeline.set_progress_bar_config(disable=True)
442
+
443
+ num_new_images = args.num_class_images - cur_class_images
444
+ logger.info(f"Number of class images to sample: {num_new_images}.")
445
+
446
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
447
+ sample_dataloader = torch.utils.data.DataLoader(
448
+ sample_dataset, batch_size=args.sample_batch_size, num_workers=1
449
+ )
450
+
451
+ sample_dataloader = accelerator.prepare(sample_dataloader)
452
+ pipeline.to(accelerator.device)
453
+ transform_to_pil = transforms.ToPILImage()
454
+ for example in tqdm(
455
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
456
+ ):
457
+ bsz = len(example["prompt"])
458
+ fake_images = torch.rand((3, args.resolution, args.resolution))
459
+ transform_to_pil = transforms.ToPILImage()
460
+ fake_pil_images = transform_to_pil(fake_images)
461
+
462
+ fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
463
+
464
+ images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
465
+
466
+ for i, image in enumerate(images):
467
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
468
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
469
+ image.save(image_filename)
470
+
471
+ del pipeline
472
+ if torch.cuda.is_available():
473
+ torch.cuda.empty_cache()
474
+
475
+ # Handle the repository creation
476
+ if accelerator.is_main_process:
477
+ if args.output_dir is not None:
478
+ os.makedirs(args.output_dir, exist_ok=True)
479
+
480
+ if args.push_to_hub:
481
+ repo_id = create_repo(
482
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
483
+ ).repo_id
484
+
485
+ # Load the tokenizer
486
+ if args.tokenizer_name:
487
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
488
+ elif args.pretrained_model_name_or_path:
489
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
490
+
491
+ # Load models and create wrapper for stable diffusion
492
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
493
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
494
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
495
+
496
+ # We only train the additional adapter LoRA layers
497
+ vae.requires_grad_(False)
498
+ text_encoder.requires_grad_(False)
499
+ unet.requires_grad_(False)
500
+
501
+ weight_dtype = torch.float32
502
+ if args.mixed_precision == "fp16":
503
+ weight_dtype = torch.float16
504
+ elif args.mixed_precision == "bf16":
505
+ weight_dtype = torch.bfloat16
506
+
507
+ # Move text_encode and vae to gpu.
508
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
509
+ # as these models are only used for inference, keeping weights in full precision is not required.
510
+ unet.to(accelerator.device, dtype=weight_dtype)
511
+ vae.to(accelerator.device, dtype=weight_dtype)
512
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
513
+
514
+ if args.enable_xformers_memory_efficient_attention:
515
+ if is_xformers_available():
516
+ unet.enable_xformers_memory_efficient_attention()
517
+ else:
518
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
519
+
520
+ # now we will add new LoRA weights to the attention layers
521
+ # It's important to realize here how many attention weights will be added and of which sizes
522
+ # The sizes of the attention layers consist only of two different variables:
523
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
524
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
525
+
526
+ # Let's first see how many attention processors we will have to set.
527
+ # For Stable Diffusion, it should be equal to:
528
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
529
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
530
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
531
+ # => 32 layers
532
+
533
+ # Set correct lora layers
534
+ lora_attn_procs = {}
535
+ for name in unet.attn_processors.keys():
536
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
537
+ if name.startswith("mid_block"):
538
+ hidden_size = unet.config.block_out_channels[-1]
539
+ elif name.startswith("up_blocks"):
540
+ block_id = int(name[len("up_blocks.")])
541
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
542
+ elif name.startswith("down_blocks"):
543
+ block_id = int(name[len("down_blocks.")])
544
+ hidden_size = unet.config.block_out_channels[block_id]
545
+
546
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
547
+
548
+ unet.set_attn_processor(lora_attn_procs)
549
+ lora_layers = AttnProcsLayers(unet.attn_processors)
550
+
551
+ accelerator.register_for_checkpointing(lora_layers)
552
+
553
+ if args.scale_lr:
554
+ args.learning_rate = (
555
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
556
+ )
557
+
558
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
559
+ if args.use_8bit_adam:
560
+ try:
561
+ import bitsandbytes as bnb
562
+ except ImportError:
563
+ raise ImportError(
564
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
565
+ )
566
+
567
+ optimizer_class = bnb.optim.AdamW8bit
568
+ else:
569
+ optimizer_class = torch.optim.AdamW
570
+
571
+ optimizer = optimizer_class(
572
+ lora_layers.parameters(),
573
+ lr=args.learning_rate,
574
+ betas=(args.adam_beta1, args.adam_beta2),
575
+ weight_decay=args.adam_weight_decay,
576
+ eps=args.adam_epsilon,
577
+ )
578
+
579
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
580
+
581
+ train_dataset = DreamBoothDataset(
582
+ instance_data_root=args.instance_data_dir,
583
+ instance_prompt=args.instance_prompt,
584
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
585
+ class_prompt=args.class_prompt,
586
+ tokenizer=tokenizer,
587
+ size=args.resolution,
588
+ center_crop=args.center_crop,
589
+ )
590
+
591
+ def collate_fn(examples):
592
+ input_ids = [example["instance_prompt_ids"] for example in examples]
593
+ pixel_values = [example["instance_images"] for example in examples]
594
+
595
+ # Concat class and instance examples for prior preservation.
596
+ # We do this to avoid doing two forward passes.
597
+ if args.with_prior_preservation:
598
+ input_ids += [example["class_prompt_ids"] for example in examples]
599
+ pixel_values += [example["class_images"] for example in examples]
600
+ pior_pil = [example["class_PIL_images"] for example in examples]
601
+
602
+ masks = []
603
+ masked_images = []
604
+ for example in examples:
605
+ pil_image = example["PIL_images"]
606
+ # generate a random mask
607
+ mask = random_mask(pil_image.size, 1, False)
608
+ # prepare mask and masked image
609
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
610
+
611
+ masks.append(mask)
612
+ masked_images.append(masked_image)
613
+
614
+ if args.with_prior_preservation:
615
+ for pil_image in pior_pil:
616
+ # generate a random mask
617
+ mask = random_mask(pil_image.size, 1, False)
618
+ # prepare mask and masked image
619
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
620
+
621
+ masks.append(mask)
622
+ masked_images.append(masked_image)
623
+
624
+ pixel_values = torch.stack(pixel_values)
625
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
626
+
627
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
628
+ masks = torch.stack(masks)
629
+ masked_images = torch.stack(masked_images)
630
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
631
+ return batch
632
+
633
+ train_dataloader = torch.utils.data.DataLoader(
634
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
635
+ )
636
+
637
+ # Scheduler and math around the number of training steps.
638
+ overrode_max_train_steps = False
639
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
640
+ if args.max_train_steps is None:
641
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
642
+ overrode_max_train_steps = True
643
+
644
+ lr_scheduler = get_scheduler(
645
+ args.lr_scheduler,
646
+ optimizer=optimizer,
647
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
648
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
649
+ )
650
+
651
+ # Prepare everything with our `accelerator`.
652
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
653
+ lora_layers, optimizer, train_dataloader, lr_scheduler
654
+ )
655
+ # accelerator.register_for_checkpointing(lr_scheduler)
656
+
657
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
658
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
659
+ if overrode_max_train_steps:
660
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
661
+ # Afterwards we recalculate our number of training epochs
662
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
663
+
664
+ # We need to initialize the trackers we use, and also store our configuration.
665
+ # The trackers initializes automatically on the main process.
666
+ if accelerator.is_main_process:
667
+ accelerator.init_trackers("dreambooth-inpaint-lora", config=vars(args))
668
+
669
+ # Train!
670
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
671
+
672
+ logger.info("***** Running training *****")
673
+ logger.info(f" Num examples = {len(train_dataset)}")
674
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
675
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
676
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
677
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
678
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
679
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
680
+ global_step = 0
681
+ first_epoch = 0
682
+
683
+ if args.resume_from_checkpoint:
684
+ if args.resume_from_checkpoint != "latest":
685
+ path = os.path.basename(args.resume_from_checkpoint)
686
+ else:
687
+ # Get the most recent checkpoint
688
+ dirs = os.listdir(args.output_dir)
689
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
690
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
691
+ path = dirs[-1] if len(dirs) > 0 else None
692
+
693
+ if path is None:
694
+ accelerator.print(
695
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
696
+ )
697
+ args.resume_from_checkpoint = None
698
+ else:
699
+ accelerator.print(f"Resuming from checkpoint {path}")
700
+ accelerator.load_state(os.path.join(args.output_dir, path))
701
+ global_step = int(path.split("-")[1])
702
+
703
+ resume_global_step = global_step * args.gradient_accumulation_steps
704
+ first_epoch = global_step // num_update_steps_per_epoch
705
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
706
+
707
+ # Only show the progress bar once on each machine.
708
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
709
+ progress_bar.set_description("Steps")
710
+
711
+ for epoch in range(first_epoch, args.num_train_epochs):
712
+ unet.train()
713
+ for step, batch in enumerate(train_dataloader):
714
+ # Skip steps until we reach the resumed step
715
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
716
+ if step % args.gradient_accumulation_steps == 0:
717
+ progress_bar.update(1)
718
+ continue
719
+
720
+ with accelerator.accumulate(unet):
721
+ # Convert images to latent space
722
+
723
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
724
+ latents = latents * vae.config.scaling_factor
725
+
726
+ # Convert masked images to latent space
727
+ masked_latents = vae.encode(
728
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
729
+ ).latent_dist.sample()
730
+ masked_latents = masked_latents * vae.config.scaling_factor
731
+
732
+ masks = batch["masks"]
733
+ # resize the mask to latents shape as we concatenate the mask to the latents
734
+ mask = torch.stack(
735
+ [
736
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
737
+ for mask in masks
738
+ ]
739
+ ).to(dtype=weight_dtype)
740
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
741
+
742
+ # Sample noise that we'll add to the latents
743
+ noise = torch.randn_like(latents)
744
+ bsz = latents.shape[0]
745
+ # Sample a random timestep for each image
746
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
747
+ timesteps = timesteps.long()
748
+
749
+ # Add noise to the latents according to the noise magnitude at each timestep
750
+ # (this is the forward diffusion process)
751
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
752
+
753
+ # concatenate the noised latents with the mask and the masked latents
754
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
755
+
756
+ # Get the text embedding for conditioning
757
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
758
+
759
+ # Predict the noise residual
760
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
761
+
762
+ # Get the target for loss depending on the prediction type
763
+ if noise_scheduler.config.prediction_type == "epsilon":
764
+ target = noise
765
+ elif noise_scheduler.config.prediction_type == "v_prediction":
766
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
767
+ else:
768
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
769
+
770
+ if args.with_prior_preservation:
771
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
772
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
773
+ target, target_prior = torch.chunk(target, 2, dim=0)
774
+
775
+ # Compute instance loss
776
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
777
+
778
+ # Compute prior loss
779
+ prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
780
+
781
+ # Add the prior loss to the instance loss.
782
+ loss = loss + args.prior_loss_weight * prior_loss
783
+ else:
784
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
785
+
786
+ accelerator.backward(loss)
787
+ if accelerator.sync_gradients:
788
+ params_to_clip = lora_layers.parameters()
789
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
790
+ optimizer.step()
791
+ lr_scheduler.step()
792
+ optimizer.zero_grad()
793
+
794
+ # Checks if the accelerator has performed an optimization step behind the scenes
795
+ if accelerator.sync_gradients:
796
+ progress_bar.update(1)
797
+ global_step += 1
798
+
799
+ if global_step % args.checkpointing_steps == 0:
800
+ if accelerator.is_main_process:
801
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
802
+ accelerator.save_state(save_path)
803
+ logger.info(f"Saved state to {save_path}")
804
+
805
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
806
+ progress_bar.set_postfix(**logs)
807
+ accelerator.log(logs, step=global_step)
808
+
809
+ if global_step >= args.max_train_steps:
810
+ break
811
+
812
+ accelerator.wait_for_everyone()
813
+
814
+ # Save the lora layers
815
+ if accelerator.is_main_process:
816
+ unet = unet.to(torch.float32)
817
+ unet.save_attn_procs(args.output_dir)
818
+
819
+ if args.push_to_hub:
820
+ upload_folder(
821
+ repo_id=repo_id,
822
+ folder_path=args.output_dir,
823
+ commit_message="End of training",
824
+ ignore_patterns=["step_*", "epoch_*"],
825
+ )
826
+
827
+ accelerator.end_training()
828
+
829
+
830
+ if __name__ == "__main__":
831
+ main()
diffusers/examples/research_projects/flux_lora_quantization/README.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## LoRA fine-tuning Flux.1 Dev with quantization
2
+
3
+ > [!NOTE]
4
+ > This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
5
+
6
+ This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
7
+
8
+ * We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
9
+ * Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print.
10
+ * `train_dreambooth_lora_flux_miniature.py` takes care of training:
11
+ * Since we already precomputed the text embeddings, we don't load the text encoders.
12
+ * We load the VAE and use it to precompute the image latents and we then delete it.
13
+ * Load the Flux transformer, quantize it with the [NF4 datatype](https://huggingface.co/papers/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
14
+ * Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
15
+ * Train!
16
+
17
+ To run training in a memory-optimized manner, we additionally use:
18
+
19
+ * 8Bit Adam
20
+ * Gradient checkpointing
21
+
22
+ We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
23
+
24
+ ## Training
25
+
26
+ Ensure you have installed the required libraries:
27
+
28
+ ```bash
29
+ pip install -U transformers accelerate bitsandbytes peft datasets
30
+ pip install git+https://github.com/huggingface/diffusers -U
31
+ ```
32
+
33
+ Now, compute the text embeddings:
34
+
35
+ ```bash
36
+ python compute_embeddings.py
37
+ ```
38
+
39
+ It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
40
+
41
+ ```bash
42
+ huggingface-cli
43
+ ```
44
+
45
+ Then launch:
46
+
47
+ ```bash
48
+ accelerate launch --config_file=accelerate.yaml \
49
+ train_dreambooth_lora_flux_miniature.py \
50
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
51
+ --data_df_path="embeddings.parquet" \
52
+ --output_dir="yarn_art_lora_flux_nf4" \
53
+ --mixed_precision="fp16" \
54
+ --use_8bit_adam \
55
+ --weighting_scheme="none" \
56
+ --resolution=1024 \
57
+ --train_batch_size=1 \
58
+ --repeats=1 \
59
+ --learning_rate=1e-4 \
60
+ --guidance_scale=1 \
61
+ --report_to="wandb" \
62
+ --gradient_accumulation_steps=4 \
63
+ --gradient_checkpointing \
64
+ --lr_scheduler="constant" \
65
+ --lr_warmup_steps=0 \
66
+ --cache_latents \
67
+ --rank=4 \
68
+ --max_train_steps=700 \
69
+ --seed="0"
70
+ ```
71
+
72
+ We can directly pass a quantized checkpoint path, too:
73
+
74
+ ```diff
75
+ + --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
76
+ ```
77
+
78
+ Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
79
+
80
+ We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
81
+
82
+ ```bash
83
+ pip install -Uq deepspeed
84
+ ```
85
+
86
+ And then launch:
87
+
88
+ ```bash
89
+ accelerate launch --config_file=ds2.yaml \
90
+ train_dreambooth_lora_flux_miniature.py \
91
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
92
+ --data_df_path="embeddings.parquet" \
93
+ --output_dir="yarn_art_lora_flux_nf4" \
94
+ --mixed_precision="no" \
95
+ --use_8bit_adam \
96
+ --weighting_scheme="none" \
97
+ --resolution=1024 \
98
+ --train_batch_size=1 \
99
+ --repeats=1 \
100
+ --learning_rate=1e-4 \
101
+ --guidance_scale=1 \
102
+ --report_to="wandb" \
103
+ --gradient_accumulation_steps=4 \
104
+ --gradient_checkpointing \
105
+ --lr_scheduler="constant" \
106
+ --lr_warmup_steps=0 \
107
+ --cache_latents \
108
+ --rank=4 \
109
+ --max_train_steps=700 \
110
+ --seed="0"
111
+ ```
112
+
113
+ ## Inference
114
+
115
+ When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
116
+
117
+ 1. First, load the original model and merge the LoRA params into it:
118
+
119
+ ```py
120
+ from diffusers import FluxPipeline
121
+ import torch
122
+
123
+ ckpt_id = "black-forest-labs/FLUX.1-dev"
124
+ pipeline = FluxPipeline.from_pretrained(
125
+ ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
126
+ )
127
+ pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
128
+ pipeline.fuse_lora()
129
+ pipeline.unload_lora_weights()
130
+
131
+ pipeline.transformer.save_pretrained("fused_transformer")
132
+ ```
133
+
134
+ 2. Quantize the model and run inference
135
+
136
+ ```py
137
+ from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
138
+ import torch
139
+
140
+ ckpt_id = "black-forest-labs/FLUX.1-dev"
141
+ bnb_4bit_compute_dtype = torch.float16
142
+ nf4_config = BitsAndBytesConfig(
143
+ load_in_4bit=True,
144
+ bnb_4bit_quant_type="nf4",
145
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
146
+ )
147
+ transformer = FluxTransformer2DModel.from_pretrained(
148
+ "fused_transformer",
149
+ quantization_config=nf4_config,
150
+ torch_dtype=bnb_4bit_compute_dtype,
151
+ )
152
+ pipeline = AutoPipelineForText2Image.from_pretrained(
153
+ ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
154
+ )
155
+ pipeline.enable_model_cpu_offload()
156
+
157
+ image = pipeline(
158
+ "a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
159
+ ).images[0]
160
+ image.save("yarn_merged.png")
161
+ ```
162
+
163
+ | Dequantize, merge, quantize | Merging directly into quantized model |
164
+ |-------|-------|
165
+ | ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) |
166
+
167
+ As we can notice the first column result follows the style more closely.
diffusers/examples/research_projects/flux_lora_quantization/accelerate.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: NO
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: true
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 1
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
diffusers/examples/research_projects/flux_lora_quantization/compute_embeddings.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ import pandas as pd
20
+ import torch
21
+ from datasets import load_dataset
22
+ from huggingface_hub.utils import insecure_hashlib
23
+ from tqdm.auto import tqdm
24
+ from transformers import T5EncoderModel
25
+
26
+ from diffusers import FluxPipeline
27
+
28
+
29
+ MAX_SEQ_LENGTH = 77
30
+ OUTPUT_PATH = "embeddings.parquet"
31
+
32
+
33
+ def generate_image_hash(image):
34
+ return insecure_hashlib.sha256(image.tobytes()).hexdigest()
35
+
36
+
37
+ def load_flux_dev_pipeline():
38
+ id = "black-forest-labs/FLUX.1-dev"
39
+ text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
40
+ pipeline = FluxPipeline.from_pretrained(
41
+ id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
42
+ )
43
+ return pipeline
44
+
45
+
46
+ @torch.no_grad()
47
+ def compute_embeddings(pipeline, prompts, max_sequence_length):
48
+ all_prompt_embeds = []
49
+ all_pooled_prompt_embeds = []
50
+ all_text_ids = []
51
+ for prompt in tqdm(prompts, desc="Encoding prompts."):
52
+ (
53
+ prompt_embeds,
54
+ pooled_prompt_embeds,
55
+ text_ids,
56
+ ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
57
+ all_prompt_embeds.append(prompt_embeds)
58
+ all_pooled_prompt_embeds.append(pooled_prompt_embeds)
59
+ all_text_ids.append(text_ids)
60
+
61
+ max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
62
+ print(f"Max memory allocated: {max_memory:.3f} GB")
63
+ return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
64
+
65
+
66
+ def run(args):
67
+ dataset = load_dataset("Norod78/Yarn-art-style", split="train")
68
+ image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
69
+ all_prompts = list(image_prompts.values())
70
+ print(f"{len(all_prompts)=}")
71
+
72
+ pipeline = load_flux_dev_pipeline()
73
+ all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
74
+ pipeline, all_prompts, args.max_sequence_length
75
+ )
76
+
77
+ data = []
78
+ for i, (image_hash, _) in enumerate(image_prompts.items()):
79
+ data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
80
+ print(f"{len(data)=}")
81
+
82
+ # Create a DataFrame
83
+ embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
84
+ df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
85
+ print(f"{len(df)=}")
86
+
87
+ # Convert embedding lists to arrays (for proper storage in parquet)
88
+ for col in embedding_cols:
89
+ df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
90
+
91
+ # Save the dataframe to a parquet file
92
+ df.to_parquet(args.output_path)
93
+ print(f"Data successfully serialized to {args.output_path}")
94
+
95
+
96
+ if __name__ == "__main__":
97
+ parser = argparse.ArgumentParser()
98
+ parser.add_argument(
99
+ "--max_sequence_length",
100
+ type=int,
101
+ default=MAX_SEQ_LENGTH,
102
+ help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
103
+ )
104
+ parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
105
+ args = parser.parse_args()
106
+
107
+ run(args)
diffusers/examples/research_projects/flux_lora_quantization/ds2.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ gradient_clipping: 1.0
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: false
9
+ zero_stage: 2
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ enable_cpu_affinity: false
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ mixed_precision: 'no'
16
+ num_machines: 1
17
+ num_processes: 1
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_env: []
21
+ tpu_use_cluster: false
22
+ tpu_use_sudo: false
23
+ use_cpu: false
diffusers/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py ADDED
@@ -0,0 +1,1200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import torch
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from accelerate import Accelerator, DistributedType
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
33
+ from datasets import load_dataset
34
+ from huggingface_hub import create_repo, upload_folder
35
+ from huggingface_hub.utils import insecure_hashlib
36
+ from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
37
+ from peft.utils import get_peft_model_state_dict
38
+ from PIL.ImageOps import exif_transpose
39
+ from torch.utils.data import Dataset
40
+ from torchvision import transforms
41
+ from torchvision.transforms.functional import crop
42
+ from tqdm.auto import tqdm
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ BitsAndBytesConfig,
48
+ FlowMatchEulerDiscreteScheduler,
49
+ FluxPipeline,
50
+ FluxTransformer2DModel,
51
+ )
52
+ from diffusers.optimization import get_scheduler
53
+ from diffusers.training_utils import (
54
+ cast_training_params,
55
+ compute_density_for_timestep_sampling,
56
+ compute_loss_weighting_for_sd3,
57
+ free_memory,
58
+ )
59
+ from diffusers.utils import (
60
+ check_min_version,
61
+ convert_unet_state_dict_to_peft,
62
+ is_wandb_available,
63
+ )
64
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
65
+ from diffusers.utils.torch_utils import is_compiled_module
66
+
67
+
68
+ if is_wandb_available():
69
+ pass
70
+
71
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
72
+ check_min_version("0.31.0.dev0")
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ def save_model_card(
78
+ repo_id: str,
79
+ base_model: str = None,
80
+ instance_prompt=None,
81
+ repo_folder=None,
82
+ quantization_config=None,
83
+ ):
84
+ widget_dict = []
85
+
86
+ model_description = f"""
87
+ # Flux DreamBooth LoRA - {repo_id}
88
+
89
+ <Gallery />
90
+
91
+ ## Model description
92
+
93
+ These are {repo_id} DreamBooth LoRA weights for {base_model}.
94
+
95
+ The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
96
+
97
+ Was LoRA for the text encoder enabled? False.
98
+
99
+ Quantization config:
100
+
101
+ ```yaml
102
+ {quantization_config}
103
+ ```
104
+
105
+ ## Trigger words
106
+
107
+ You should use `{instance_prompt}` to trigger the image generation.
108
+
109
+ ## Download model
110
+
111
+ [Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
112
+
113
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
114
+
115
+ ## Usage
116
+
117
+ TODO
118
+
119
+ ## License
120
+
121
+ Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
122
+ """
123
+ model_card = load_or_create_model_card(
124
+ repo_id_or_path=repo_id,
125
+ from_training=True,
126
+ license="other",
127
+ base_model=base_model,
128
+ prompt=instance_prompt,
129
+ model_description=model_description,
130
+ widget=widget_dict,
131
+ )
132
+ tags = [
133
+ "text-to-image",
134
+ "diffusers-training",
135
+ "diffusers",
136
+ "lora",
137
+ "flux",
138
+ "flux-diffusers",
139
+ "template:sd-lora",
140
+ ]
141
+
142
+ model_card = populate_model_card(model_card, tags=tags)
143
+ model_card.save(os.path.join(repo_folder, "README.md"))
144
+
145
+
146
+ def parse_args(input_args=None):
147
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
148
+ parser.add_argument(
149
+ "--pretrained_model_name_or_path",
150
+ type=str,
151
+ default=None,
152
+ required=True,
153
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
154
+ )
155
+ parser.add_argument(
156
+ "--quantized_model_path",
157
+ type=str,
158
+ default=None,
159
+ help="Path to the quantized model.",
160
+ )
161
+ parser.add_argument(
162
+ "--revision",
163
+ type=str,
164
+ default=None,
165
+ required=False,
166
+ help="Revision of pretrained model identifier from huggingface.co/models.",
167
+ )
168
+ parser.add_argument(
169
+ "--variant",
170
+ type=str,
171
+ default=None,
172
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
173
+ )
174
+ parser.add_argument(
175
+ "--data_df_path",
176
+ type=str,
177
+ default=None,
178
+ help=("Path to the parquet file serialized with compute_embeddings.py."),
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--cache_dir",
183
+ type=str,
184
+ default=None,
185
+ help="The directory where the downloaded models and datasets will be stored.",
186
+ )
187
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
188
+
189
+ parser.add_argument(
190
+ "--max_sequence_length",
191
+ type=int,
192
+ default=77,
193
+ help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.",
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--rank",
198
+ type=int,
199
+ default=4,
200
+ help=("The dimension of the LoRA update matrices."),
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--output_dir",
205
+ type=str,
206
+ default="flux-dreambooth-lora-nf4",
207
+ help="The output directory where the model predictions and checkpoints will be written.",
208
+ )
209
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
210
+ parser.add_argument(
211
+ "--resolution",
212
+ type=int,
213
+ default=512,
214
+ help=(
215
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
216
+ " resolution"
217
+ ),
218
+ )
219
+ parser.add_argument(
220
+ "--center_crop",
221
+ default=False,
222
+ action="store_true",
223
+ help=(
224
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
225
+ " cropped. The images will be resized to the resolution first before cropping."
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--random_flip",
230
+ action="store_true",
231
+ help="whether to randomly flip images horizontally",
232
+ )
233
+ parser.add_argument(
234
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
235
+ )
236
+ parser.add_argument(
237
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
238
+ )
239
+ parser.add_argument("--num_train_epochs", type=int, default=1)
240
+ parser.add_argument(
241
+ "--max_train_steps",
242
+ type=int,
243
+ default=None,
244
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
245
+ )
246
+ parser.add_argument(
247
+ "--checkpointing_steps",
248
+ type=int,
249
+ default=500,
250
+ help=(
251
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
252
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
253
+ " training using `--resume_from_checkpoint`."
254
+ ),
255
+ )
256
+ parser.add_argument(
257
+ "--checkpoints_total_limit",
258
+ type=int,
259
+ default=None,
260
+ help=("Max number of checkpoints to store."),
261
+ )
262
+ parser.add_argument(
263
+ "--resume_from_checkpoint",
264
+ type=str,
265
+ default=None,
266
+ help=(
267
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
268
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
269
+ ),
270
+ )
271
+ parser.add_argument(
272
+ "--gradient_accumulation_steps",
273
+ type=int,
274
+ default=1,
275
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
276
+ )
277
+ parser.add_argument(
278
+ "--gradient_checkpointing",
279
+ action="store_true",
280
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
281
+ )
282
+ parser.add_argument(
283
+ "--learning_rate",
284
+ type=float,
285
+ default=1e-4,
286
+ help="Initial learning rate (after the potential warmup period) to use.",
287
+ )
288
+
289
+ parser.add_argument(
290
+ "--guidance_scale",
291
+ type=float,
292
+ default=3.5,
293
+ help="the FLUX.1 dev variant is a guidance distilled model",
294
+ )
295
+
296
+ parser.add_argument(
297
+ "--scale_lr",
298
+ action="store_true",
299
+ default=False,
300
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
301
+ )
302
+ parser.add_argument(
303
+ "--lr_scheduler",
304
+ type=str,
305
+ default="constant",
306
+ help=(
307
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
308
+ ' "constant", "constant_with_warmup"]'
309
+ ),
310
+ )
311
+ parser.add_argument(
312
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
313
+ )
314
+ parser.add_argument(
315
+ "--lr_num_cycles",
316
+ type=int,
317
+ default=1,
318
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
319
+ )
320
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
321
+ parser.add_argument(
322
+ "--dataloader_num_workers",
323
+ type=int,
324
+ default=0,
325
+ help=(
326
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
327
+ ),
328
+ )
329
+ parser.add_argument(
330
+ "--weighting_scheme",
331
+ type=str,
332
+ default="none",
333
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
334
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
335
+ )
336
+ parser.add_argument(
337
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
338
+ )
339
+ parser.add_argument(
340
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
341
+ )
342
+ parser.add_argument(
343
+ "--mode_scale",
344
+ type=float,
345
+ default=1.29,
346
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
347
+ )
348
+ parser.add_argument(
349
+ "--optimizer",
350
+ type=str,
351
+ default="AdamW",
352
+ choices=["AdamW", "Prodigy", "AdEMAMix"],
353
+ )
354
+
355
+ parser.add_argument(
356
+ "--use_8bit_adam",
357
+ action="store_true",
358
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
359
+ )
360
+ parser.add_argument(
361
+ "--use_8bit_ademamix",
362
+ action="store_true",
363
+ help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
364
+ )
365
+
366
+ parser.add_argument(
367
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
368
+ )
369
+ parser.add_argument(
370
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
371
+ )
372
+ parser.add_argument(
373
+ "--prodigy_beta3",
374
+ type=float,
375
+ default=None,
376
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
377
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
378
+ )
379
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
380
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
381
+
382
+ parser.add_argument(
383
+ "--adam_epsilon",
384
+ type=float,
385
+ default=1e-08,
386
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
387
+ )
388
+
389
+ parser.add_argument(
390
+ "--prodigy_use_bias_correction",
391
+ type=bool,
392
+ default=True,
393
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
394
+ )
395
+ parser.add_argument(
396
+ "--prodigy_safeguard_warmup",
397
+ type=bool,
398
+ default=True,
399
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
400
+ "Ignored if optimizer is adamW",
401
+ )
402
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
403
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
404
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
405
+ parser.add_argument(
406
+ "--hub_model_id",
407
+ type=str,
408
+ default=None,
409
+ help="The name of the repository to keep in sync with the local `output_dir`.",
410
+ )
411
+ parser.add_argument(
412
+ "--logging_dir",
413
+ type=str,
414
+ default="logs",
415
+ help=(
416
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
417
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
418
+ ),
419
+ )
420
+
421
+ parser.add_argument(
422
+ "--cache_latents",
423
+ action="store_true",
424
+ default=False,
425
+ help="Cache the VAE latents",
426
+ )
427
+ parser.add_argument(
428
+ "--report_to",
429
+ type=str,
430
+ default="tensorboard",
431
+ help=(
432
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
433
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
434
+ ),
435
+ )
436
+ parser.add_argument(
437
+ "--mixed_precision",
438
+ type=str,
439
+ default=None,
440
+ choices=["no", "fp16", "bf16"],
441
+ help=(
442
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
443
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
444
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
445
+ ),
446
+ )
447
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
448
+
449
+ if input_args is not None:
450
+ args = parser.parse_args(input_args)
451
+ else:
452
+ args = parser.parse_args()
453
+
454
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
455
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
456
+ args.local_rank = env_local_rank
457
+
458
+ return args
459
+
460
+
461
+ class DreamBoothDataset(Dataset):
462
+ def __init__(
463
+ self,
464
+ data_df_path,
465
+ dataset_name,
466
+ size=1024,
467
+ max_sequence_length=77,
468
+ center_crop=False,
469
+ ):
470
+ # Logistics
471
+ self.size = size
472
+ self.center_crop = center_crop
473
+ self.max_sequence_length = max_sequence_length
474
+
475
+ self.data_df_path = Path(data_df_path)
476
+ if not self.data_df_path.exists():
477
+ raise ValueError("`data_df_path` doesn't exists.")
478
+
479
+ # Load images.
480
+ dataset = load_dataset(dataset_name, split="train")
481
+ instance_images = [sample["image"] for sample in dataset]
482
+ image_hashes = [self.generate_image_hash(image) for image in instance_images]
483
+ self.instance_images = instance_images
484
+ self.image_hashes = image_hashes
485
+
486
+ # Image transformations
487
+ self.pixel_values = self.apply_image_transformations(
488
+ instance_images=instance_images, size=size, center_crop=center_crop
489
+ )
490
+
491
+ # Map hashes to embeddings.
492
+ self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)
493
+
494
+ self.num_instance_images = len(instance_images)
495
+ self._length = self.num_instance_images
496
+
497
+ def __len__(self):
498
+ return self._length
499
+
500
+ def __getitem__(self, index):
501
+ example = {}
502
+ instance_image = self.pixel_values[index % self.num_instance_images]
503
+ image_hash = self.image_hashes[index % self.num_instance_images]
504
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash]
505
+ example["instance_images"] = instance_image
506
+ example["prompt_embeds"] = prompt_embeds
507
+ example["pooled_prompt_embeds"] = pooled_prompt_embeds
508
+ example["text_ids"] = text_ids
509
+ return example
510
+
511
+ def apply_image_transformations(self, instance_images, size, center_crop):
512
+ pixel_values = []
513
+
514
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
515
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
516
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
517
+ train_transforms = transforms.Compose(
518
+ [
519
+ transforms.ToTensor(),
520
+ transforms.Normalize([0.5], [0.5]),
521
+ ]
522
+ )
523
+ for image in instance_images:
524
+ image = exif_transpose(image)
525
+ if not image.mode == "RGB":
526
+ image = image.convert("RGB")
527
+ image = train_resize(image)
528
+ if args.random_flip and random.random() < 0.5:
529
+ # flip
530
+ image = train_flip(image)
531
+ if args.center_crop:
532
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
533
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
534
+ image = train_crop(image)
535
+ else:
536
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
537
+ image = crop(image, y1, x1, h, w)
538
+ image = train_transforms(image)
539
+ pixel_values.append(image)
540
+
541
+ return pixel_values
542
+
543
+ def convert_to_torch_tensor(self, embeddings: list):
544
+ prompt_embeds = embeddings[0]
545
+ pooled_prompt_embeds = embeddings[1]
546
+ text_ids = embeddings[2]
547
+ prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, 4096)
548
+ pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(768)
549
+ text_ids = np.array(text_ids).reshape(77, 3)
550
+ return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids)
551
+
552
+ def map_image_hash_embedding(self, data_df_path):
553
+ hashes_df = pd.read_parquet(data_df_path)
554
+ data_dict = {}
555
+ for i, row in hashes_df.iterrows():
556
+ embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"], row["text_ids"]]
557
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(embeddings=embeddings)
558
+ data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds, text_ids)})
559
+ return data_dict
560
+
561
+ def generate_image_hash(self, image):
562
+ return insecure_hashlib.sha256(image.tobytes()).hexdigest()
563
+
564
+
565
+ def collate_fn(examples):
566
+ pixel_values = [example["instance_images"] for example in examples]
567
+ prompt_embeds = [example["prompt_embeds"] for example in examples]
568
+ pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
569
+ text_ids = [example["text_ids"] for example in examples]
570
+
571
+ pixel_values = torch.stack(pixel_values)
572
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
573
+ prompt_embeds = torch.stack(prompt_embeds)
574
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
575
+ text_ids = torch.stack(text_ids)[0] # just 2D tensor
576
+
577
+ batch = {
578
+ "pixel_values": pixel_values,
579
+ "prompt_embeds": prompt_embeds,
580
+ "pooled_prompt_embeds": pooled_prompt_embeds,
581
+ "text_ids": text_ids,
582
+ }
583
+ return batch
584
+
585
+
586
+ def main(args):
587
+ if args.report_to == "wandb" and args.hub_token is not None:
588
+ raise ValueError(
589
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
590
+ " Please use `huggingface-cli login` to authenticate with the Hub."
591
+ )
592
+
593
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
594
+ # due to pytorch#99272, MPS does not yet support bfloat16.
595
+ raise ValueError(
596
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
597
+ )
598
+
599
+ logging_dir = Path(args.output_dir, args.logging_dir)
600
+
601
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
602
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
603
+ accelerator = Accelerator(
604
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
605
+ mixed_precision=args.mixed_precision,
606
+ log_with=args.report_to,
607
+ project_config=accelerator_project_config,
608
+ kwargs_handlers=[kwargs],
609
+ )
610
+
611
+ # Disable AMP for MPS.
612
+ if torch.backends.mps.is_available():
613
+ accelerator.native_amp = False
614
+
615
+ if args.report_to == "wandb":
616
+ if not is_wandb_available():
617
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
618
+
619
+ # Make one log on every process with the configuration for debugging.
620
+ logging.basicConfig(
621
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
622
+ datefmt="%m/%d/%Y %H:%M:%S",
623
+ level=logging.INFO,
624
+ )
625
+ logger.info(accelerator.state, main_process_only=False)
626
+ if accelerator.is_local_main_process:
627
+ transformers.utils.logging.set_verbosity_warning()
628
+ diffusers.utils.logging.set_verbosity_info()
629
+ else:
630
+ transformers.utils.logging.set_verbosity_error()
631
+ diffusers.utils.logging.set_verbosity_error()
632
+
633
+ # If passed along, set the training seed now.
634
+ if args.seed is not None:
635
+ set_seed(args.seed)
636
+
637
+ # Handle the repository creation
638
+ if accelerator.is_main_process:
639
+ if args.output_dir is not None:
640
+ os.makedirs(args.output_dir, exist_ok=True)
641
+
642
+ if args.push_to_hub:
643
+ repo_id = create_repo(
644
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
645
+ exist_ok=True,
646
+ ).repo_id
647
+
648
+ # Load scheduler and models
649
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
650
+ args.pretrained_model_name_or_path, subfolder="scheduler"
651
+ )
652
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
653
+ vae = AutoencoderKL.from_pretrained(
654
+ args.pretrained_model_name_or_path,
655
+ subfolder="vae",
656
+ revision=args.revision,
657
+ variant=args.variant,
658
+ )
659
+ bnb_4bit_compute_dtype = torch.float32
660
+ if args.mixed_precision == "fp16":
661
+ bnb_4bit_compute_dtype = torch.float16
662
+ elif args.mixed_precision == "bf16":
663
+ bnb_4bit_compute_dtype = torch.bfloat16
664
+ if args.quantized_model_path is not None:
665
+ transformer = FluxTransformer2DModel.from_pretrained(
666
+ args.quantized_model_path,
667
+ subfolder="transformer",
668
+ revision=args.revision,
669
+ variant=args.variant,
670
+ torch_dtype=bnb_4bit_compute_dtype,
671
+ )
672
+ else:
673
+ nf4_config = BitsAndBytesConfig(
674
+ load_in_4bit=True,
675
+ bnb_4bit_quant_type="nf4",
676
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
677
+ )
678
+ transformer = FluxTransformer2DModel.from_pretrained(
679
+ args.pretrained_model_name_or_path,
680
+ subfolder="transformer",
681
+ revision=args.revision,
682
+ variant=args.variant,
683
+ quantization_config=nf4_config,
684
+ torch_dtype=bnb_4bit_compute_dtype,
685
+ )
686
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
687
+
688
+ # We only train the additional adapter LoRA layers
689
+ transformer.requires_grad_(False)
690
+ vae.requires_grad_(False)
691
+
692
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
693
+ # as these weights are only used for inference, keeping weights in full precision is not required.
694
+ weight_dtype = torch.float32
695
+ if accelerator.mixed_precision == "fp16":
696
+ weight_dtype = torch.float16
697
+ elif accelerator.mixed_precision == "bf16":
698
+ weight_dtype = torch.bfloat16
699
+
700
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
701
+ # due to pytorch#99272, MPS does not yet support bfloat16.
702
+ raise ValueError(
703
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
704
+ )
705
+
706
+ vae.to(accelerator.device, dtype=weight_dtype)
707
+ if args.gradient_checkpointing:
708
+ transformer.enable_gradient_checkpointing()
709
+
710
+ # now we will add new LoRA weights to the attention layers
711
+ transformer_lora_config = LoraConfig(
712
+ r=args.rank,
713
+ lora_alpha=args.rank,
714
+ init_lora_weights="gaussian",
715
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
716
+ )
717
+ transformer.add_adapter(transformer_lora_config)
718
+
719
+ def unwrap_model(model):
720
+ model = accelerator.unwrap_model(model)
721
+ model = model._orig_mod if is_compiled_module(model) else model
722
+ return model
723
+
724
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
725
+ def save_model_hook(models, weights, output_dir):
726
+ if accelerator.is_main_process:
727
+ transformer_lora_layers_to_save = None
728
+
729
+ for model in models:
730
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
731
+ model = unwrap_model(model)
732
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
733
+ else:
734
+ raise ValueError(f"unexpected save model: {model.__class__}")
735
+
736
+ # make sure to pop weight so that corresponding model is not saved again
737
+ if weights:
738
+ weights.pop()
739
+
740
+ FluxPipeline.save_lora_weights(
741
+ output_dir,
742
+ transformer_lora_layers=transformer_lora_layers_to_save,
743
+ text_encoder_lora_layers=None,
744
+ )
745
+
746
+ def load_model_hook(models, input_dir):
747
+ transformer_ = None
748
+
749
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
750
+ while len(models) > 0:
751
+ model = models.pop()
752
+
753
+ if isinstance(model, type(unwrap_model(transformer))):
754
+ transformer_ = model
755
+ else:
756
+ raise ValueError(f"unexpected save model: {model.__class__}")
757
+ else:
758
+ if args.quantized_model_path is not None:
759
+ transformer_ = FluxTransformer2DModel.from_pretrained(
760
+ args.quantized_model_path,
761
+ subfolder="transformer",
762
+ revision=args.revision,
763
+ variant=args.variant,
764
+ torch_dtype=bnb_4bit_compute_dtype,
765
+ )
766
+ else:
767
+ nf4_config = BitsAndBytesConfig(
768
+ load_in_4bit=True,
769
+ bnb_4bit_quant_type="nf4",
770
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
771
+ )
772
+ transformer_ = FluxTransformer2DModel.from_pretrained(
773
+ args.pretrained_model_name_or_path,
774
+ subfolder="transformer",
775
+ revision=args.revision,
776
+ variant=args.variant,
777
+ quantization_config=nf4_config,
778
+ torch_dtype=bnb_4bit_compute_dtype,
779
+ )
780
+ transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False)
781
+ transformer_.add_adapter(transformer_lora_config)
782
+
783
+ lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
784
+
785
+ transformer_state_dict = {
786
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
787
+ }
788
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
789
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
790
+ if incompatible_keys is not None:
791
+ # check only for unexpected keys
792
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
793
+ if unexpected_keys:
794
+ logger.warning(
795
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
796
+ f" {unexpected_keys}. "
797
+ )
798
+
799
+ # Make sure the trainable params are in float32. This is again needed since the base models
800
+ # are in `weight_dtype`. More details:
801
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
802
+ if args.mixed_precision == "fp16":
803
+ models = [transformer_]
804
+ # only upcast trainable parameters (LoRA) into fp32
805
+ cast_training_params(models)
806
+
807
+ accelerator.register_save_state_pre_hook(save_model_hook)
808
+ accelerator.register_load_state_pre_hook(load_model_hook)
809
+
810
+ if args.scale_lr:
811
+ args.learning_rate = (
812
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
813
+ )
814
+
815
+ # Make sure the trainable params are in float32.
816
+ if args.mixed_precision == "fp16":
817
+ models = [transformer]
818
+ # only upcast trainable parameters (LoRA) into fp32
819
+ cast_training_params(models, dtype=torch.float32)
820
+
821
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
822
+
823
+ # Optimization parameters
824
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
825
+ params_to_optimize = [transformer_parameters_with_lr]
826
+
827
+ # Optimizer creation
828
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
829
+ logger.warning(
830
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
831
+ f"set to {args.optimizer.lower()}"
832
+ )
833
+
834
+ if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
835
+ logger.warning(
836
+ f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
837
+ f"set to {args.optimizer.lower()}"
838
+ )
839
+
840
+ if args.optimizer.lower() == "adamw":
841
+ if args.use_8bit_adam:
842
+ try:
843
+ import bitsandbytes as bnb
844
+ except ImportError:
845
+ raise ImportError(
846
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
847
+ )
848
+
849
+ optimizer_class = bnb.optim.AdamW8bit
850
+ else:
851
+ optimizer_class = torch.optim.AdamW
852
+
853
+ optimizer = optimizer_class(
854
+ params_to_optimize,
855
+ betas=(args.adam_beta1, args.adam_beta2),
856
+ weight_decay=args.adam_weight_decay,
857
+ eps=args.adam_epsilon,
858
+ )
859
+
860
+ elif args.optimizer.lower() == "ademamix":
861
+ try:
862
+ import bitsandbytes as bnb
863
+ except ImportError:
864
+ raise ImportError(
865
+ "To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
866
+ )
867
+ if args.use_8bit_ademamix:
868
+ optimizer_class = bnb.optim.AdEMAMix8bit
869
+ else:
870
+ optimizer_class = bnb.optim.AdEMAMix
871
+
872
+ optimizer = optimizer_class(params_to_optimize)
873
+
874
+ if args.optimizer.lower() == "prodigy":
875
+ try:
876
+ import prodigyopt
877
+ except ImportError:
878
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
879
+
880
+ optimizer_class = prodigyopt.Prodigy
881
+
882
+ if args.learning_rate <= 0.1:
883
+ logger.warning(
884
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
885
+ )
886
+
887
+ optimizer = optimizer_class(
888
+ params_to_optimize,
889
+ betas=(args.adam_beta1, args.adam_beta2),
890
+ beta3=args.prodigy_beta3,
891
+ weight_decay=args.adam_weight_decay,
892
+ eps=args.adam_epsilon,
893
+ decouple=args.prodigy_decouple,
894
+ use_bias_correction=args.prodigy_use_bias_correction,
895
+ safeguard_warmup=args.prodigy_safeguard_warmup,
896
+ )
897
+
898
+ # Dataset and DataLoaders creation:
899
+ train_dataset = DreamBoothDataset(
900
+ data_df_path=args.data_df_path,
901
+ dataset_name="Norod78/Yarn-art-style",
902
+ size=args.resolution,
903
+ max_sequence_length=args.max_sequence_length,
904
+ center_crop=args.center_crop,
905
+ )
906
+
907
+ train_dataloader = torch.utils.data.DataLoader(
908
+ train_dataset,
909
+ batch_size=args.train_batch_size,
910
+ shuffle=True,
911
+ collate_fn=collate_fn,
912
+ num_workers=args.dataloader_num_workers,
913
+ )
914
+
915
+ vae_config_shift_factor = vae.config.shift_factor
916
+ vae_config_scaling_factor = vae.config.scaling_factor
917
+ vae_config_block_out_channels = vae.config.block_out_channels
918
+ if args.cache_latents:
919
+ latents_cache = []
920
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
921
+ with torch.no_grad():
922
+ batch["pixel_values"] = batch["pixel_values"].to(
923
+ accelerator.device, non_blocking=True, dtype=weight_dtype
924
+ )
925
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
926
+
927
+ del vae
928
+ free_memory()
929
+
930
+ # Scheduler and math around the number of training steps.
931
+ overrode_max_train_steps = False
932
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
933
+ if args.max_train_steps is None:
934
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
935
+ overrode_max_train_steps = True
936
+
937
+ lr_scheduler = get_scheduler(
938
+ args.lr_scheduler,
939
+ optimizer=optimizer,
940
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
941
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
942
+ num_cycles=args.lr_num_cycles,
943
+ power=args.lr_power,
944
+ )
945
+
946
+ # Prepare everything with our `accelerator`.
947
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
948
+ transformer, optimizer, train_dataloader, lr_scheduler
949
+ )
950
+
951
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
952
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
953
+ if overrode_max_train_steps:
954
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
955
+ # Afterwards we recalculate our number of training epochs
956
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
957
+
958
+ # We need to initialize the trackers we use, and also store our configuration.
959
+ # The trackers initializes automatically on the main process.
960
+ if accelerator.is_main_process:
961
+ tracker_name = "dreambooth-flux-dev-lora-nf4"
962
+ accelerator.init_trackers(tracker_name, config=vars(args))
963
+
964
+ # Train!
965
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
966
+
967
+ logger.info("***** Running training *****")
968
+ logger.info(f" Num examples = {len(train_dataset)}")
969
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
970
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
971
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
972
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
973
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
974
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
975
+ global_step = 0
976
+ first_epoch = 0
977
+
978
+ # Potentially load in the weights and states from a previous save
979
+ if args.resume_from_checkpoint:
980
+ if args.resume_from_checkpoint != "latest":
981
+ path = os.path.basename(args.resume_from_checkpoint)
982
+ else:
983
+ # Get the mos recent checkpoint
984
+ dirs = os.listdir(args.output_dir)
985
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
986
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
987
+ path = dirs[-1] if len(dirs) > 0 else None
988
+
989
+ if path is None:
990
+ accelerator.print(
991
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
992
+ )
993
+ args.resume_from_checkpoint = None
994
+ initial_global_step = 0
995
+ else:
996
+ accelerator.print(f"Resuming from checkpoint {path}")
997
+ accelerator.load_state(os.path.join(args.output_dir, path))
998
+ global_step = int(path.split("-")[1])
999
+
1000
+ initial_global_step = global_step
1001
+ first_epoch = global_step // num_update_steps_per_epoch
1002
+
1003
+ else:
1004
+ initial_global_step = 0
1005
+
1006
+ progress_bar = tqdm(
1007
+ range(0, args.max_train_steps),
1008
+ initial=initial_global_step,
1009
+ desc="Steps",
1010
+ # Only show the progress bar once on each machine.
1011
+ disable=not accelerator.is_local_main_process,
1012
+ )
1013
+
1014
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1015
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1016
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1017
+ timesteps = timesteps.to(accelerator.device)
1018
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1019
+
1020
+ sigma = sigmas[step_indices].flatten()
1021
+ while len(sigma.shape) < n_dim:
1022
+ sigma = sigma.unsqueeze(-1)
1023
+ return sigma
1024
+
1025
+ for epoch in range(first_epoch, args.num_train_epochs):
1026
+ transformer.train()
1027
+
1028
+ for step, batch in enumerate(train_dataloader):
1029
+ models_to_accumulate = [transformer]
1030
+ with accelerator.accumulate(models_to_accumulate):
1031
+ # Convert images to latent space
1032
+ if args.cache_latents:
1033
+ model_input = latents_cache[step].sample()
1034
+ else:
1035
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1036
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1037
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
1038
+ model_input = model_input.to(dtype=weight_dtype)
1039
+
1040
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
1041
+
1042
+ latent_image_ids = FluxPipeline._prepare_latent_image_ids(
1043
+ model_input.shape[0],
1044
+ model_input.shape[2] // 2,
1045
+ model_input.shape[3] // 2,
1046
+ accelerator.device,
1047
+ weight_dtype,
1048
+ )
1049
+ # Sample noise that we'll add to the latents
1050
+ noise = torch.randn_like(model_input)
1051
+ bsz = model_input.shape[0]
1052
+
1053
+ # Sample a random timestep for each image
1054
+ # for weighting schemes where we sample timesteps non-uniformly
1055
+ u = compute_density_for_timestep_sampling(
1056
+ weighting_scheme=args.weighting_scheme,
1057
+ batch_size=bsz,
1058
+ logit_mean=args.logit_mean,
1059
+ logit_std=args.logit_std,
1060
+ mode_scale=args.mode_scale,
1061
+ )
1062
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1063
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1064
+
1065
+ # Add noise according to flow matching.
1066
+ # zt = (1 - texp) * x + texp * z1
1067
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1068
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1069
+
1070
+ packed_noisy_model_input = FluxPipeline._pack_latents(
1071
+ noisy_model_input,
1072
+ batch_size=model_input.shape[0],
1073
+ num_channels_latents=model_input.shape[1],
1074
+ height=model_input.shape[2],
1075
+ width=model_input.shape[3],
1076
+ )
1077
+
1078
+ # handle guidance
1079
+ if unwrap_model(transformer).config.guidance_embeds:
1080
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1081
+ guidance = guidance.expand(model_input.shape[0])
1082
+ else:
1083
+ guidance = None
1084
+
1085
+ # Predict the noise
1086
+ prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
1087
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
1088
+ text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype)
1089
+ model_pred = transformer(
1090
+ hidden_states=packed_noisy_model_input,
1091
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
1092
+ timestep=timesteps / 1000,
1093
+ guidance=guidance,
1094
+ pooled_projections=pooled_prompt_embeds,
1095
+ encoder_hidden_states=prompt_embeds,
1096
+ txt_ids=text_ids,
1097
+ img_ids=latent_image_ids,
1098
+ return_dict=False,
1099
+ )[0]
1100
+ model_pred = FluxPipeline._unpack_latents(
1101
+ model_pred,
1102
+ height=model_input.shape[2] * vae_scale_factor,
1103
+ width=model_input.shape[3] * vae_scale_factor,
1104
+ vae_scale_factor=vae_scale_factor,
1105
+ )
1106
+
1107
+ # these weighting schemes use a uniform timestep sampling
1108
+ # and instead post-weight the loss
1109
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1110
+
1111
+ # flow matching loss
1112
+ target = noise - model_input
1113
+
1114
+ # Compute regular loss.
1115
+ loss = torch.mean(
1116
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1117
+ 1,
1118
+ )
1119
+ loss = loss.mean()
1120
+ accelerator.backward(loss)
1121
+
1122
+ if accelerator.sync_gradients:
1123
+ params_to_clip = transformer.parameters()
1124
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1125
+
1126
+ optimizer.step()
1127
+ lr_scheduler.step()
1128
+ optimizer.zero_grad()
1129
+
1130
+ # Checks if the accelerator has performed an optimization step behind the scenes
1131
+ if accelerator.sync_gradients:
1132
+ progress_bar.update(1)
1133
+ global_step += 1
1134
+
1135
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
1136
+ if global_step % args.checkpointing_steps == 0:
1137
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1138
+ if args.checkpoints_total_limit is not None:
1139
+ checkpoints = os.listdir(args.output_dir)
1140
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1141
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1142
+
1143
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1144
+ if len(checkpoints) >= args.checkpoints_total_limit:
1145
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1146
+ removing_checkpoints = checkpoints[0:num_to_remove]
1147
+
1148
+ logger.info(
1149
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1150
+ )
1151
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1152
+
1153
+ for removing_checkpoint in removing_checkpoints:
1154
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1155
+ shutil.rmtree(removing_checkpoint)
1156
+
1157
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1158
+ accelerator.save_state(save_path)
1159
+ logger.info(f"Saved state to {save_path}")
1160
+
1161
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1162
+ progress_bar.set_postfix(**logs)
1163
+ accelerator.log(logs, step=global_step)
1164
+
1165
+ if global_step >= args.max_train_steps:
1166
+ break
1167
+
1168
+ # Save the lora layers
1169
+ accelerator.wait_for_everyone()
1170
+ if accelerator.is_main_process:
1171
+ transformer = unwrap_model(transformer)
1172
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
1173
+
1174
+ FluxPipeline.save_lora_weights(
1175
+ save_directory=args.output_dir,
1176
+ transformer_lora_layers=transformer_lora_layers,
1177
+ text_encoder_lora_layers=None,
1178
+ )
1179
+
1180
+ if args.push_to_hub:
1181
+ save_model_card(
1182
+ repo_id,
1183
+ base_model=args.pretrained_model_name_or_path,
1184
+ instance_prompt=None,
1185
+ repo_folder=args.output_dir,
1186
+ quantization_config=transformer.config["quantization_config"],
1187
+ )
1188
+ upload_folder(
1189
+ repo_id=repo_id,
1190
+ folder_path=args.output_dir,
1191
+ commit_message="End of training",
1192
+ ignore_patterns=["step_*", "epoch_*"],
1193
+ )
1194
+
1195
+ accelerator.end_training()
1196
+
1197
+
1198
+ if __name__ == "__main__":
1199
+ args = parse_args()
1200
+ main(args)
diffusers/examples/research_projects/intel_opts/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Diffusers examples with Intel optimizations
2
+
3
+ **This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**
4
+
5
+ This aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.
6
+
7
+ ## Accelerating the fine-tuning for textual inversion
8
+
9
+ We accelerate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
10
+
11
+ ## Accelerating the inference for Stable Diffusion using Bfloat16
12
+
13
+ We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
14
+ ```bash
15
+ pip install diffusers transformers accelerate scipy safetensors
16
+
17
+ export KMP_BLOCKTIME=1
18
+ export KMP_SETTINGS=1
19
+ export KMP_AFFINITY=granularity=fine,compact,1,0
20
+
21
+ # Intel OpenMP
22
+ export OMP_NUM_THREADS=< Cores to use >
23
+ export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so
24
+ # Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
25
+ export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so
26
+ export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000"
27
+
28
+ # Launch with default DDIM
29
+ numactl --membind <node N> -C <cpu list> python python inference_bf16.py
30
+ # Launch with DPMSolverMultistepScheduler
31
+ numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm
32
+ ```
33
+
34
+ ## Accelerating the inference for Stable Diffusion using INT8
35
+
36
+ Coming soon ...
diffusers/examples/research_projects/intel_opts/inference_bf16.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import intel_extension_for_pytorch as ipex
4
+ import torch
5
+
6
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
7
+
8
+
9
+ parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
10
+ parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
11
+ parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
12
+ args = parser.parse_args()
13
+
14
+
15
+ device = "cpu"
16
+ prompt = "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brightly buildings"
17
+
18
+ model_id = "path-to-your-trained-model"
19
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
20
+ if args.dpm:
21
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
22
+ pipe = pipe.to(device)
23
+
24
+ # to channels last
25
+ pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
26
+ pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
27
+ pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
28
+ if pipe.requires_safety_checker:
29
+ pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)
30
+
31
+ # optimize with ipex
32
+ sample = torch.randn(2, 4, 64, 64)
33
+ timestep = torch.rand(1) * 999
34
+ encoder_hidden_status = torch.randn(2, 77, 768)
35
+ input_example = (sample, timestep, encoder_hidden_status)
36
+ try:
37
+ pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
38
+ except Exception:
39
+ pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
40
+ pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
41
+ pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
42
+ if pipe.requires_safety_checker:
43
+ pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
44
+
45
+ # compute
46
+ seed = 666
47
+ generator = torch.Generator(device).manual_seed(seed)
48
+ generate_kwargs = {"generator": generator}
49
+ if args.steps is not None:
50
+ generate_kwargs["num_inference_steps"] = args.steps
51
+
52
+ with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
53
+ image = pipe(prompt, **generate_kwargs).images[0]
54
+
55
+ # save image
56
+ image.save("generated.png")
diffusers/examples/research_projects/intel_opts/textual_inversion/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Textual Inversion fine-tuning example
2
+
3
+ [Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
4
+ The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5
+
6
+ ## Training with Intel Extension for PyTorch
7
+
8
+ Intel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example "textual_inversion.py". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.
9
+
10
+ The example supports both single node and multi-node distributed training:
11
+
12
+ ### Single node training
13
+
14
+ ```bash
15
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
16
+ export DATA_DIR="path-to-dir-containing-dicoo-images"
17
+
18
+ python textual_inversion.py \
19
+ --pretrained_model_name_or_path=$MODEL_NAME \
20
+ --train_data_dir=$DATA_DIR \
21
+ --learnable_property="object" \
22
+ --placeholder_token="<dicoo>" --initializer_token="toy" \
23
+ --seed=7 \
24
+ --resolution=512 \
25
+ --train_batch_size=1 \
26
+ --gradient_accumulation_steps=1 \
27
+ --max_train_steps=3000 \
28
+ --learning_rate=2.5e-03 --scale_lr \
29
+ --output_dir="textual_inversion_dicoo"
30
+ ```
31
+
32
+ Note: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.
33
+
34
+ ### Multi-node distributed training
35
+
36
+ Before running the scripts, make sure to install the library's training dependencies successfully:
37
+
38
+ ```bash
39
+ python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu
40
+ ```
41
+
42
+ ```bash
43
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
44
+ export DATA_DIR="path-to-dir-containing-dicoo-images"
45
+
46
+ oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
47
+ source $oneccl_bindings_for_pytorch_path/env/setvars.sh
48
+
49
+ python -m intel_extension_for_pytorch.cpu.launch --distributed \
50
+ --hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \
51
+ --pretrained_model_name_or_path=$MODEL_NAME \
52
+ --train_data_dir=$DATA_DIR \
53
+ --learnable_property="object" \
54
+ --placeholder_token="<dicoo>" --initializer_token="toy" \
55
+ --seed=7 \
56
+ --resolution=512 \
57
+ --train_batch_size=1 \
58
+ --gradient_accumulation_steps=1 \
59
+ --max_train_steps=750 \
60
+ --learning_rate=2.5e-03 --scale_lr \
61
+ --output_dir="textual_inversion_dicoo"
62
+ ```
63
+ The above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the "hostfile" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).
64
+
65
+
66
+ ### Reference
67
+
68
+ We publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.
diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.21.0
4
+ ftfy
5
+ tensorboard
6
+ Jinja2
7
+ intel_extension_for_pytorch>=1.13
diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+
8
+ import intel_extension_for_pytorch as ipex
9
+ import numpy as np
10
+ import PIL
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from accelerate import Accelerator
15
+ from accelerate.logging import get_logger
16
+ from accelerate.utils import ProjectConfiguration, set_seed
17
+ from huggingface_hub import create_repo, upload_folder
18
+
19
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
20
+ from packaging import version
21
+ from PIL import Image
22
+ from torch.utils.data import Dataset
23
+ from torchvision import transforms
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
+
27
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
28
+ from diffusers.optimization import get_scheduler
29
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
30
+ from diffusers.utils import check_min_version
31
+
32
+
33
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
34
+ PIL_INTERPOLATION = {
35
+ "linear": PIL.Image.Resampling.BILINEAR,
36
+ "bilinear": PIL.Image.Resampling.BILINEAR,
37
+ "bicubic": PIL.Image.Resampling.BICUBIC,
38
+ "lanczos": PIL.Image.Resampling.LANCZOS,
39
+ "nearest": PIL.Image.Resampling.NEAREST,
40
+ }
41
+ else:
42
+ PIL_INTERPOLATION = {
43
+ "linear": PIL.Image.LINEAR,
44
+ "bilinear": PIL.Image.BILINEAR,
45
+ "bicubic": PIL.Image.BICUBIC,
46
+ "lanczos": PIL.Image.LANCZOS,
47
+ "nearest": PIL.Image.NEAREST,
48
+ }
49
+ # ------------------------------------------------------------------------------
50
+
51
+
52
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
53
+ check_min_version("0.13.0.dev0")
54
+
55
+
56
+ logger = get_logger(__name__)
57
+
58
+
59
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
60
+ logger.info("Saving embeddings")
61
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
62
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
63
+ torch.save(learned_embeds_dict, save_path)
64
+
65
+
66
+ def parse_args():
67
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
68
+ parser.add_argument(
69
+ "--save_steps",
70
+ type=int,
71
+ default=500,
72
+ help="Save learned_embeds.bin every X updates steps.",
73
+ )
74
+ parser.add_argument(
75
+ "--only_save_embeds",
76
+ action="store_true",
77
+ default=False,
78
+ help="Save only the embeddings for the new concept.",
79
+ )
80
+ parser.add_argument(
81
+ "--pretrained_model_name_or_path",
82
+ type=str,
83
+ default=None,
84
+ required=True,
85
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
86
+ )
87
+ parser.add_argument(
88
+ "--revision",
89
+ type=str,
90
+ default=None,
91
+ required=False,
92
+ help="Revision of pretrained model identifier from huggingface.co/models.",
93
+ )
94
+ parser.add_argument(
95
+ "--tokenizer_name",
96
+ type=str,
97
+ default=None,
98
+ help="Pretrained tokenizer name or path if not the same as model_name",
99
+ )
100
+ parser.add_argument(
101
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
102
+ )
103
+ parser.add_argument(
104
+ "--placeholder_token",
105
+ type=str,
106
+ default=None,
107
+ required=True,
108
+ help="A token to use as a placeholder for the concept.",
109
+ )
110
+ parser.add_argument(
111
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
112
+ )
113
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
114
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
115
+ parser.add_argument(
116
+ "--output_dir",
117
+ type=str,
118
+ default="text-inversion-model",
119
+ help="The output directory where the model predictions and checkpoints will be written.",
120
+ )
121
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
122
+ parser.add_argument(
123
+ "--resolution",
124
+ type=int,
125
+ default=512,
126
+ help=(
127
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
128
+ " resolution"
129
+ ),
130
+ )
131
+ parser.add_argument(
132
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
133
+ )
134
+ parser.add_argument(
135
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
136
+ )
137
+ parser.add_argument("--num_train_epochs", type=int, default=100)
138
+ parser.add_argument(
139
+ "--max_train_steps",
140
+ type=int,
141
+ default=5000,
142
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
143
+ )
144
+ parser.add_argument(
145
+ "--gradient_accumulation_steps",
146
+ type=int,
147
+ default=1,
148
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
149
+ )
150
+ parser.add_argument(
151
+ "--learning_rate",
152
+ type=float,
153
+ default=1e-4,
154
+ help="Initial learning rate (after the potential warmup period) to use.",
155
+ )
156
+ parser.add_argument(
157
+ "--scale_lr",
158
+ action="store_true",
159
+ default=True,
160
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
161
+ )
162
+ parser.add_argument(
163
+ "--lr_scheduler",
164
+ type=str,
165
+ default="constant",
166
+ help=(
167
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
168
+ ' "constant", "constant_with_warmup"]'
169
+ ),
170
+ )
171
+ parser.add_argument(
172
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
173
+ )
174
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
175
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
176
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
177
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
178
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
179
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
180
+ parser.add_argument(
181
+ "--hub_model_id",
182
+ type=str,
183
+ default=None,
184
+ help="The name of the repository to keep in sync with the local `output_dir`.",
185
+ )
186
+ parser.add_argument(
187
+ "--logging_dir",
188
+ type=str,
189
+ default="logs",
190
+ help=(
191
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
192
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
193
+ ),
194
+ )
195
+ parser.add_argument(
196
+ "--mixed_precision",
197
+ type=str,
198
+ default="no",
199
+ choices=["no", "fp16", "bf16"],
200
+ help=(
201
+ "Whether to use mixed precision. Choose"
202
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
203
+ "and an Nvidia Ampere GPU."
204
+ ),
205
+ )
206
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
207
+
208
+ args = parser.parse_args()
209
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
210
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
211
+ args.local_rank = env_local_rank
212
+
213
+ if args.train_data_dir is None:
214
+ raise ValueError("You must specify a train data directory.")
215
+
216
+ return args
217
+
218
+
219
+ imagenet_templates_small = [
220
+ "a photo of a {}",
221
+ "a rendering of a {}",
222
+ "a cropped photo of the {}",
223
+ "the photo of a {}",
224
+ "a photo of a clean {}",
225
+ "a photo of a dirty {}",
226
+ "a dark photo of the {}",
227
+ "a photo of my {}",
228
+ "a photo of the cool {}",
229
+ "a close-up photo of a {}",
230
+ "a bright photo of the {}",
231
+ "a cropped photo of a {}",
232
+ "a photo of the {}",
233
+ "a good photo of the {}",
234
+ "a photo of one {}",
235
+ "a close-up photo of the {}",
236
+ "a rendition of the {}",
237
+ "a photo of the clean {}",
238
+ "a rendition of a {}",
239
+ "a photo of a nice {}",
240
+ "a good photo of a {}",
241
+ "a photo of the nice {}",
242
+ "a photo of the small {}",
243
+ "a photo of the weird {}",
244
+ "a photo of the large {}",
245
+ "a photo of a cool {}",
246
+ "a photo of a small {}",
247
+ ]
248
+
249
+ imagenet_style_templates_small = [
250
+ "a painting in the style of {}",
251
+ "a rendering in the style of {}",
252
+ "a cropped painting in the style of {}",
253
+ "the painting in the style of {}",
254
+ "a clean painting in the style of {}",
255
+ "a dirty painting in the style of {}",
256
+ "a dark painting in the style of {}",
257
+ "a picture in the style of {}",
258
+ "a cool painting in the style of {}",
259
+ "a close-up painting in the style of {}",
260
+ "a bright painting in the style of {}",
261
+ "a cropped painting in the style of {}",
262
+ "a good painting in the style of {}",
263
+ "a close-up painting in the style of {}",
264
+ "a rendition in the style of {}",
265
+ "a nice painting in the style of {}",
266
+ "a small painting in the style of {}",
267
+ "a weird painting in the style of {}",
268
+ "a large painting in the style of {}",
269
+ ]
270
+
271
+
272
+ class TextualInversionDataset(Dataset):
273
+ def __init__(
274
+ self,
275
+ data_root,
276
+ tokenizer,
277
+ learnable_property="object", # [object, style]
278
+ size=512,
279
+ repeats=100,
280
+ interpolation="bicubic",
281
+ flip_p=0.5,
282
+ set="train",
283
+ placeholder_token="*",
284
+ center_crop=False,
285
+ ):
286
+ self.data_root = data_root
287
+ self.tokenizer = tokenizer
288
+ self.learnable_property = learnable_property
289
+ self.size = size
290
+ self.placeholder_token = placeholder_token
291
+ self.center_crop = center_crop
292
+ self.flip_p = flip_p
293
+
294
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
295
+
296
+ self.num_images = len(self.image_paths)
297
+ self._length = self.num_images
298
+
299
+ if set == "train":
300
+ self._length = self.num_images * repeats
301
+
302
+ self.interpolation = {
303
+ "linear": PIL_INTERPOLATION["linear"],
304
+ "bilinear": PIL_INTERPOLATION["bilinear"],
305
+ "bicubic": PIL_INTERPOLATION["bicubic"],
306
+ "lanczos": PIL_INTERPOLATION["lanczos"],
307
+ }[interpolation]
308
+
309
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
310
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
311
+
312
+ def __len__(self):
313
+ return self._length
314
+
315
+ def __getitem__(self, i):
316
+ example = {}
317
+ image = Image.open(self.image_paths[i % self.num_images])
318
+
319
+ if not image.mode == "RGB":
320
+ image = image.convert("RGB")
321
+
322
+ placeholder_string = self.placeholder_token
323
+ text = random.choice(self.templates).format(placeholder_string)
324
+
325
+ example["input_ids"] = self.tokenizer(
326
+ text,
327
+ padding="max_length",
328
+ truncation=True,
329
+ max_length=self.tokenizer.model_max_length,
330
+ return_tensors="pt",
331
+ ).input_ids[0]
332
+
333
+ # default to score-sde preprocessing
334
+ img = np.array(image).astype(np.uint8)
335
+
336
+ if self.center_crop:
337
+ crop = min(img.shape[0], img.shape[1])
338
+ (
339
+ h,
340
+ w,
341
+ ) = (
342
+ img.shape[0],
343
+ img.shape[1],
344
+ )
345
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
346
+
347
+ image = Image.fromarray(img)
348
+ image = image.resize((self.size, self.size), resample=self.interpolation)
349
+
350
+ image = self.flip_transform(image)
351
+ image = np.array(image).astype(np.uint8)
352
+ image = (image / 127.5 - 1.0).astype(np.float32)
353
+
354
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
355
+ return example
356
+
357
+
358
+ def freeze_params(params):
359
+ for param in params:
360
+ param.requires_grad = False
361
+
362
+
363
+ def main():
364
+ args = parse_args()
365
+
366
+ if args.report_to == "wandb" and args.hub_token is not None:
367
+ raise ValueError(
368
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
369
+ " Please use `huggingface-cli login` to authenticate with the Hub."
370
+ )
371
+
372
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
373
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
374
+ accelerator = Accelerator(
375
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
376
+ mixed_precision=args.mixed_precision,
377
+ log_with=args.report_to,
378
+ project_config=accelerator_project_config,
379
+ )
380
+
381
+ # Disable AMP for MPS.
382
+ if torch.backends.mps.is_available():
383
+ accelerator.native_amp = False
384
+
385
+ # If passed along, set the training seed now.
386
+ if args.seed is not None:
387
+ set_seed(args.seed)
388
+
389
+ # Handle the repository creation
390
+ if accelerator.is_main_process:
391
+ if args.output_dir is not None:
392
+ os.makedirs(args.output_dir, exist_ok=True)
393
+
394
+ if args.push_to_hub:
395
+ repo_id = create_repo(
396
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
397
+ ).repo_id
398
+
399
+ # Load the tokenizer and add the placeholder token as a additional special token
400
+ if args.tokenizer_name:
401
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
402
+ elif args.pretrained_model_name_or_path:
403
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
404
+
405
+ # Add the placeholder token in tokenizer
406
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
407
+ if num_added_tokens == 0:
408
+ raise ValueError(
409
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
410
+ " `placeholder_token` that is not already in the tokenizer."
411
+ )
412
+
413
+ # Convert the initializer_token, placeholder_token to ids
414
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
415
+ # Check if initializer_token is a single token or a sequence of tokens
416
+ if len(token_ids) > 1:
417
+ raise ValueError("The initializer token must be a single token.")
418
+
419
+ initializer_token_id = token_ids[0]
420
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
421
+
422
+ # Load models and create wrapper for stable diffusion
423
+ text_encoder = CLIPTextModel.from_pretrained(
424
+ args.pretrained_model_name_or_path,
425
+ subfolder="text_encoder",
426
+ revision=args.revision,
427
+ )
428
+ vae = AutoencoderKL.from_pretrained(
429
+ args.pretrained_model_name_or_path,
430
+ subfolder="vae",
431
+ revision=args.revision,
432
+ )
433
+ unet = UNet2DConditionModel.from_pretrained(
434
+ args.pretrained_model_name_or_path,
435
+ subfolder="unet",
436
+ revision=args.revision,
437
+ )
438
+
439
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
440
+ text_encoder.resize_token_embeddings(len(tokenizer))
441
+
442
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
443
+ token_embeds = text_encoder.get_input_embeddings().weight.data
444
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
445
+
446
+ # Freeze vae and unet
447
+ freeze_params(vae.parameters())
448
+ freeze_params(unet.parameters())
449
+ # Freeze all parameters except for the token embeddings in text encoder
450
+ params_to_freeze = itertools.chain(
451
+ text_encoder.text_model.encoder.parameters(),
452
+ text_encoder.text_model.final_layer_norm.parameters(),
453
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
454
+ )
455
+ freeze_params(params_to_freeze)
456
+
457
+ if args.scale_lr:
458
+ args.learning_rate = (
459
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
460
+ )
461
+
462
+ # Initialize the optimizer
463
+ optimizer = torch.optim.AdamW(
464
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
465
+ lr=args.learning_rate,
466
+ betas=(args.adam_beta1, args.adam_beta2),
467
+ weight_decay=args.adam_weight_decay,
468
+ eps=args.adam_epsilon,
469
+ )
470
+
471
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
472
+
473
+ train_dataset = TextualInversionDataset(
474
+ data_root=args.train_data_dir,
475
+ tokenizer=tokenizer,
476
+ size=args.resolution,
477
+ placeholder_token=args.placeholder_token,
478
+ repeats=args.repeats,
479
+ learnable_property=args.learnable_property,
480
+ center_crop=args.center_crop,
481
+ set="train",
482
+ )
483
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
484
+
485
+ # Scheduler and math around the number of training steps.
486
+ overrode_max_train_steps = False
487
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
488
+ if args.max_train_steps is None:
489
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
490
+ overrode_max_train_steps = True
491
+
492
+ lr_scheduler = get_scheduler(
493
+ args.lr_scheduler,
494
+ optimizer=optimizer,
495
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
496
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
497
+ )
498
+
499
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
500
+ text_encoder, optimizer, train_dataloader, lr_scheduler
501
+ )
502
+
503
+ # Move vae and unet to device
504
+ vae.to(accelerator.device)
505
+ unet.to(accelerator.device)
506
+
507
+ # Keep vae and unet in eval model as we don't train these
508
+ vae.eval()
509
+ unet.eval()
510
+
511
+ unet = ipex.optimize(unet, dtype=torch.bfloat16, inplace=True)
512
+ vae = ipex.optimize(vae, dtype=torch.bfloat16, inplace=True)
513
+
514
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
515
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
516
+ if overrode_max_train_steps:
517
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
518
+ # Afterwards we recalculate our number of training epochs
519
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
520
+
521
+ # We need to initialize the trackers we use, and also store our configuration.
522
+ # The trackers initializes automatically on the main process.
523
+ if accelerator.is_main_process:
524
+ accelerator.init_trackers("textual_inversion", config=vars(args))
525
+
526
+ # Train!
527
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
528
+
529
+ logger.info("***** Running training *****")
530
+ logger.info(f" Num examples = {len(train_dataset)}")
531
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
532
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
533
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
534
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
535
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
536
+ # Only show the progress bar once on each machine.
537
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
538
+ progress_bar.set_description("Steps")
539
+ global_step = 0
540
+
541
+ text_encoder.train()
542
+ text_encoder, optimizer = ipex.optimize(text_encoder, optimizer=optimizer, dtype=torch.bfloat16)
543
+
544
+ for epoch in range(args.num_train_epochs):
545
+ for step, batch in enumerate(train_dataloader):
546
+ with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
547
+ with accelerator.accumulate(text_encoder):
548
+ # Convert images to latent space
549
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
550
+ latents = latents * vae.config.scaling_factor
551
+
552
+ # Sample noise that we'll add to the latents
553
+ noise = torch.randn(latents.shape).to(latents.device)
554
+ bsz = latents.shape[0]
555
+ # Sample a random timestep for each image
556
+ timesteps = torch.randint(
557
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
558
+ ).long()
559
+
560
+ # Add noise to the latents according to the noise magnitude at each timestep
561
+ # (this is the forward diffusion process)
562
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
563
+
564
+ # Get the text embedding for conditioning
565
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
566
+
567
+ # Predict the noise residual
568
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
569
+
570
+ # Get the target for loss depending on the prediction type
571
+ if noise_scheduler.config.prediction_type == "epsilon":
572
+ target = noise
573
+ elif noise_scheduler.config.prediction_type == "v_prediction":
574
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
575
+ else:
576
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
577
+
578
+ loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
579
+ accelerator.backward(loss)
580
+
581
+ # Zero out the gradients for all token embeddings except the newly added
582
+ # embeddings for the concept, as we only want to optimize the concept embeddings
583
+ if accelerator.num_processes > 1:
584
+ grads = text_encoder.module.get_input_embeddings().weight.grad
585
+ else:
586
+ grads = text_encoder.get_input_embeddings().weight.grad
587
+ # Get the index for tokens that we want to zero the grads for
588
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
589
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
590
+
591
+ optimizer.step()
592
+ lr_scheduler.step()
593
+ optimizer.zero_grad()
594
+
595
+ # Checks if the accelerator has performed an optimization step behind the scenes
596
+ if accelerator.sync_gradients:
597
+ progress_bar.update(1)
598
+ global_step += 1
599
+ if global_step % args.save_steps == 0:
600
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
601
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
602
+
603
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
604
+ progress_bar.set_postfix(**logs)
605
+ accelerator.log(logs, step=global_step)
606
+
607
+ if global_step >= args.max_train_steps:
608
+ break
609
+
610
+ accelerator.wait_for_everyone()
611
+
612
+ # Create the pipeline using using the trained modules and save it.
613
+ if accelerator.is_main_process:
614
+ if args.push_to_hub and args.only_save_embeds:
615
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
616
+ save_full_model = True
617
+ else:
618
+ save_full_model = not args.only_save_embeds
619
+ if save_full_model:
620
+ pipeline = StableDiffusionPipeline(
621
+ text_encoder=accelerator.unwrap_model(text_encoder),
622
+ vae=vae,
623
+ unet=unet,
624
+ tokenizer=tokenizer,
625
+ scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
626
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
627
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
628
+ )
629
+ pipeline.save_pretrained(args.output_dir)
630
+ # Save the newly trained embeddings
631
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
632
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
633
+
634
+ if args.push_to_hub:
635
+ upload_folder(
636
+ repo_id=repo_id,
637
+ folder_path=args.output_dir,
638
+ commit_message="End of training",
639
+ ignore_patterns=["step_*", "epoch_*"],
640
+ )
641
+
642
+ accelerator.end_training()
643
+
644
+
645
+ if __name__ == "__main__":
646
+ main()
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distillation for quantization on Textual Inversion models to personalize text2image
2
+
3
+ [Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
4
+ The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5
+ We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
6
+
7
+ ## Installing the dependencies
8
+
9
+ Before running the scripts, make sure to install the library's training dependencies:
10
+
11
+ ```bash
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ ## Prepare Datasets
16
+
17
+ One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
18
+
19
+ <a href="https://huggingface.co/sd-concepts-library/dicoo2/blob/main/concept_images/1.jpeg">
20
+ <img src="https://huggingface.co/sd-concepts-library/dicoo2/resolve/main/concept_images/1.jpeg" width = "300" height="300">
21
+ </a>
22
+
23
+ ## Get a FP32 Textual Inversion model
24
+
25
+ Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
26
+
27
+ ```bash
28
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
29
+ export DATA_DIR="./dicoo"
30
+
31
+ accelerate launch textual_inversion.py \
32
+ --pretrained_model_name_or_path=$MODEL_NAME \
33
+ --train_data_dir=$DATA_DIR \
34
+ --learnable_property="object" \
35
+ --placeholder_token="<dicoo>" --initializer_token="toy" \
36
+ --resolution=512 \
37
+ --train_batch_size=1 \
38
+ --gradient_accumulation_steps=4 \
39
+ --max_train_steps=3000 \
40
+ --learning_rate=5.0e-04 --scale_lr \
41
+ --lr_scheduler="constant" \
42
+ --lr_warmup_steps=0 \
43
+ --output_dir="dicoo_model"
44
+ ```
45
+
46
+ ## Do distillation for quantization
47
+
48
+ Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
49
+
50
+ Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
51
+
52
+ ```bash
53
+ export FP32_MODEL_NAME="./dicoo_model"
54
+ export DATA_DIR="./dicoo"
55
+
56
+ accelerate launch textual_inversion.py \
57
+ --pretrained_model_name_or_path=$FP32_MODEL_NAME \
58
+ --train_data_dir=$DATA_DIR \
59
+ --use_ema --learnable_property="object" \
60
+ --placeholder_token="<dicoo>" --initializer_token="toy" \
61
+ --resolution=512 \
62
+ --train_batch_size=1 \
63
+ --gradient_accumulation_steps=4 \
64
+ --max_train_steps=300 \
65
+ --learning_rate=5.0e-04 --max_grad_norm=3 \
66
+ --lr_scheduler="constant" \
67
+ --lr_warmup_steps=0 \
68
+ --output_dir="int8_model" \
69
+ --do_quantization --do_distillation --verify_loading
70
+ ```
71
+
72
+ After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
73
+
74
+ ## Inference
75
+
76
+ Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
77
+
78
+ ```bash
79
+ export INT8_MODEL_NAME="./int8_model"
80
+
81
+ python text2images.py \
82
+ --pretrained_model_name_or_path=$INT8_MODEL_NAME \
83
+ --caption "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brightly buildings." \
84
+ --images_num 4
85
+ ```
86
+
87
+ Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
88
+
89
+ <p float="left">
90
+ <img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/FP32.png" width = "300" height = "300" alt="FP32" align=center />
91
+ <img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/INT8.png" width = "300" height = "300" alt="INT8" align=center />
92
+ </p>
93
+
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ torchvision
3
+ transformers>=4.25.0
4
+ ftfy
5
+ tensorboard
6
+ modelcards
7
+ neural-compressor
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ from neural_compressor.utils.pytorch import load
7
+ from PIL import Image
8
+ from transformers import CLIPTextModel, CLIPTokenizer
9
+
10
+ from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "-m",
17
+ "--pretrained_model_name_or_path",
18
+ type=str,
19
+ default=None,
20
+ required=True,
21
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
22
+ )
23
+ parser.add_argument(
24
+ "-c",
25
+ "--caption",
26
+ type=str,
27
+ default="robotic cat with wings",
28
+ help="Text used to generate images.",
29
+ )
30
+ parser.add_argument(
31
+ "-n",
32
+ "--images_num",
33
+ type=int,
34
+ default=4,
35
+ help="How much images to generate.",
36
+ )
37
+ parser.add_argument(
38
+ "-s",
39
+ "--seed",
40
+ type=int,
41
+ default=42,
42
+ help="Seed for random process.",
43
+ )
44
+ parser.add_argument(
45
+ "-ci",
46
+ "--cuda_id",
47
+ type=int,
48
+ default=0,
49
+ help="cuda_id.",
50
+ )
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+
55
+ def image_grid(imgs, rows, cols):
56
+ if not len(imgs) == rows * cols:
57
+ raise ValueError("The specified number of rows and columns are not correct.")
58
+
59
+ w, h = imgs[0].size
60
+ grid = Image.new("RGB", size=(cols * w, rows * h))
61
+ grid_w, grid_h = grid.size
62
+
63
+ for i, img in enumerate(imgs):
64
+ grid.paste(img, box=(i % cols * w, i // cols * h))
65
+ return grid
66
+
67
+
68
+ def generate_images(
69
+ pipeline,
70
+ prompt="robotic cat with wings",
71
+ guidance_scale=7.5,
72
+ num_inference_steps=50,
73
+ num_images_per_prompt=1,
74
+ seed=42,
75
+ ):
76
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
77
+ images = pipeline(
78
+ prompt,
79
+ guidance_scale=guidance_scale,
80
+ num_inference_steps=num_inference_steps,
81
+ generator=generator,
82
+ num_images_per_prompt=num_images_per_prompt,
83
+ ).images
84
+ _rows = int(math.sqrt(num_images_per_prompt))
85
+ grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
86
+ return grid, images
87
+
88
+
89
+ args = parse_args()
90
+ # Load models and create wrapper for stable diffusion
91
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
92
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
93
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
94
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
95
+
96
+ pipeline = StableDiffusionPipeline.from_pretrained(
97
+ args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
98
+ )
99
+ pipeline.safety_checker = lambda images, clip_input: (images, False)
100
+ if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
101
+ unet = load(args.pretrained_model_name_or_path, model=unet)
102
+ unet.eval()
103
+ setattr(pipeline, "unet", unet)
104
+ else:
105
+ unet = unet.to(torch.device("cuda", args.cuda_id))
106
+ pipeline = pipeline.to(unet.device)
107
+ grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
108
+ grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
109
+ dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
110
+ os.makedirs(dirname, exist_ok=True)
111
+ for idx, image in enumerate(images):
112
+ image.save(os.path.join(dirname, "{}.png".format(idx + 1)))
diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Iterable
8
+
9
+ import numpy as np
10
+ import PIL
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from accelerate import Accelerator
15
+ from accelerate.utils import ProjectConfiguration, set_seed
16
+ from huggingface_hub import create_repo, upload_folder
17
+ from neural_compressor.utils import logger
18
+ from packaging import version
19
+ from PIL import Image
20
+ from torch.utils.data import Dataset
21
+ from torchvision import transforms
22
+ from tqdm.auto import tqdm
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+
25
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
26
+ from diffusers.optimization import get_scheduler
27
+ from diffusers.utils import make_image_grid
28
+
29
+
30
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
31
+ PIL_INTERPOLATION = {
32
+ "linear": PIL.Image.Resampling.BILINEAR,
33
+ "bilinear": PIL.Image.Resampling.BILINEAR,
34
+ "bicubic": PIL.Image.Resampling.BICUBIC,
35
+ "lanczos": PIL.Image.Resampling.LANCZOS,
36
+ "nearest": PIL.Image.Resampling.NEAREST,
37
+ }
38
+ else:
39
+ PIL_INTERPOLATION = {
40
+ "linear": PIL.Image.LINEAR,
41
+ "bilinear": PIL.Image.BILINEAR,
42
+ "bicubic": PIL.Image.BICUBIC,
43
+ "lanczos": PIL.Image.LANCZOS,
44
+ "nearest": PIL.Image.NEAREST,
45
+ }
46
+ # ------------------------------------------------------------------------------
47
+
48
+
49
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
50
+ logger.info("Saving embeddings")
51
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
52
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
53
+ torch.save(learned_embeds_dict, save_path)
54
+
55
+
56
+ def parse_args():
57
+ parser = argparse.ArgumentParser(description="Example of distillation for quantization on Textual Inversion.")
58
+ parser.add_argument(
59
+ "--save_steps",
60
+ type=int,
61
+ default=500,
62
+ help="Save learned_embeds.bin every X updates steps.",
63
+ )
64
+ parser.add_argument(
65
+ "--pretrained_model_name_or_path",
66
+ type=str,
67
+ default=None,
68
+ required=True,
69
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
70
+ )
71
+ parser.add_argument(
72
+ "--revision",
73
+ type=str,
74
+ default=None,
75
+ required=False,
76
+ help="Revision of pretrained model identifier from huggingface.co/models.",
77
+ )
78
+ parser.add_argument(
79
+ "--tokenizer_name",
80
+ type=str,
81
+ default=None,
82
+ help="Pretrained tokenizer name or path if not the same as model_name",
83
+ )
84
+ parser.add_argument(
85
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
86
+ )
87
+ parser.add_argument(
88
+ "--placeholder_token",
89
+ type=str,
90
+ default=None,
91
+ required=True,
92
+ help="A token to use as a placeholder for the concept.",
93
+ )
94
+ parser.add_argument(
95
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
96
+ )
97
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
98
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
99
+ parser.add_argument(
100
+ "--output_dir",
101
+ type=str,
102
+ default="text-inversion-model",
103
+ help="The output directory where the model predictions and checkpoints will be written.",
104
+ )
105
+ parser.add_argument(
106
+ "--cache_dir",
107
+ type=str,
108
+ default=None,
109
+ help="The directory where the downloaded models and datasets will be stored.",
110
+ )
111
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
112
+ parser.add_argument(
113
+ "--resolution",
114
+ type=int,
115
+ default=512,
116
+ help=(
117
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
118
+ " resolution"
119
+ ),
120
+ )
121
+ parser.add_argument(
122
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
123
+ )
124
+ parser.add_argument(
125
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
126
+ )
127
+ parser.add_argument("--num_train_epochs", type=int, default=100)
128
+ parser.add_argument(
129
+ "--max_train_steps",
130
+ type=int,
131
+ default=5000,
132
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
133
+ )
134
+ parser.add_argument(
135
+ "--gradient_accumulation_steps",
136
+ type=int,
137
+ default=1,
138
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
139
+ )
140
+ parser.add_argument(
141
+ "--learning_rate",
142
+ type=float,
143
+ default=1e-4,
144
+ help="Initial learning rate (after the potential warmup period) to use.",
145
+ )
146
+ parser.add_argument(
147
+ "--scale_lr",
148
+ action="store_true",
149
+ default=False,
150
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
151
+ )
152
+ parser.add_argument(
153
+ "--lr_scheduler",
154
+ type=str,
155
+ default="constant",
156
+ help=(
157
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
158
+ ' "constant", "constant_with_warmup"]'
159
+ ),
160
+ )
161
+ parser.add_argument(
162
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
163
+ )
164
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
165
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
166
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
167
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
168
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
169
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
170
+ parser.add_argument(
171
+ "--hub_model_id",
172
+ type=str,
173
+ default=None,
174
+ help="The name of the repository to keep in sync with the local `output_dir`.",
175
+ )
176
+ parser.add_argument(
177
+ "--logging_dir",
178
+ type=str,
179
+ default="logs",
180
+ help=(
181
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
182
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
183
+ ),
184
+ )
185
+ parser.add_argument(
186
+ "--mixed_precision",
187
+ type=str,
188
+ default="no",
189
+ choices=["no", "fp16", "bf16"],
190
+ help=(
191
+ "Whether to use mixed precision. Choose"
192
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
193
+ "and an Nvidia Ampere GPU."
194
+ ),
195
+ )
196
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
197
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
198
+ parser.add_argument("--do_quantization", action="store_true", help="Whether or not to do quantization.")
199
+ parser.add_argument("--do_distillation", action="store_true", help="Whether or not to do distillation.")
200
+ parser.add_argument(
201
+ "--verify_loading", action="store_true", help="Whether or not to verify the loading of the quantized model."
202
+ )
203
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
204
+
205
+ args = parser.parse_args()
206
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
207
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
208
+ args.local_rank = env_local_rank
209
+
210
+ if args.train_data_dir is None:
211
+ raise ValueError("You must specify a train data directory.")
212
+
213
+ return args
214
+
215
+
216
+ imagenet_templates_small = [
217
+ "a photo of a {}",
218
+ "a rendering of a {}",
219
+ "a cropped photo of the {}",
220
+ "the photo of a {}",
221
+ "a photo of a clean {}",
222
+ "a photo of a dirty {}",
223
+ "a dark photo of the {}",
224
+ "a photo of my {}",
225
+ "a photo of the cool {}",
226
+ "a close-up photo of a {}",
227
+ "a bright photo of the {}",
228
+ "a cropped photo of a {}",
229
+ "a photo of the {}",
230
+ "a good photo of the {}",
231
+ "a photo of one {}",
232
+ "a close-up photo of the {}",
233
+ "a rendition of the {}",
234
+ "a photo of the clean {}",
235
+ "a rendition of a {}",
236
+ "a photo of a nice {}",
237
+ "a good photo of a {}",
238
+ "a photo of the nice {}",
239
+ "a photo of the small {}",
240
+ "a photo of the weird {}",
241
+ "a photo of the large {}",
242
+ "a photo of a cool {}",
243
+ "a photo of a small {}",
244
+ ]
245
+
246
+ imagenet_style_templates_small = [
247
+ "a painting in the style of {}",
248
+ "a rendering in the style of {}",
249
+ "a cropped painting in the style of {}",
250
+ "the painting in the style of {}",
251
+ "a clean painting in the style of {}",
252
+ "a dirty painting in the style of {}",
253
+ "a dark painting in the style of {}",
254
+ "a picture in the style of {}",
255
+ "a cool painting in the style of {}",
256
+ "a close-up painting in the style of {}",
257
+ "a bright painting in the style of {}",
258
+ "a cropped painting in the style of {}",
259
+ "a good painting in the style of {}",
260
+ "a close-up painting in the style of {}",
261
+ "a rendition in the style of {}",
262
+ "a nice painting in the style of {}",
263
+ "a small painting in the style of {}",
264
+ "a weird painting in the style of {}",
265
+ "a large painting in the style of {}",
266
+ ]
267
+
268
+
269
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
270
+ class EMAModel:
271
+ """
272
+ Exponential Moving Average of models weights
273
+ """
274
+
275
+ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
276
+ parameters = list(parameters)
277
+ self.shadow_params = [p.clone().detach() for p in parameters]
278
+
279
+ self.decay = decay
280
+ self.optimization_step = 0
281
+
282
+ def get_decay(self, optimization_step):
283
+ """
284
+ Compute the decay factor for the exponential moving average.
285
+ """
286
+ value = (1 + optimization_step) / (10 + optimization_step)
287
+ return 1 - min(self.decay, value)
288
+
289
+ @torch.no_grad()
290
+ def step(self, parameters):
291
+ parameters = list(parameters)
292
+
293
+ self.optimization_step += 1
294
+ self.decay = self.get_decay(self.optimization_step)
295
+
296
+ for s_param, param in zip(self.shadow_params, parameters):
297
+ if param.requires_grad:
298
+ tmp = self.decay * (s_param - param)
299
+ s_param.sub_(tmp)
300
+ else:
301
+ s_param.copy_(param)
302
+
303
+ torch.cuda.empty_cache()
304
+
305
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
306
+ """
307
+ Copy current averaged parameters into given collection of parameters.
308
+ Args:
309
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
310
+ updated with the stored moving averages. If `None`, the
311
+ parameters with which this `ExponentialMovingAverage` was
312
+ initialized will be used.
313
+ """
314
+ parameters = list(parameters)
315
+ for s_param, param in zip(self.shadow_params, parameters):
316
+ param.data.copy_(s_param.data)
317
+
318
+ def to(self, device=None, dtype=None) -> None:
319
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
320
+ Args:
321
+ device: like `device` argument to `torch.Tensor.to`
322
+ """
323
+ # .to() on the tensors handles None correctly
324
+ self.shadow_params = [
325
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
326
+ for p in self.shadow_params
327
+ ]
328
+
329
+
330
+ class TextualInversionDataset(Dataset):
331
+ def __init__(
332
+ self,
333
+ data_root,
334
+ tokenizer,
335
+ learnable_property="object", # [object, style]
336
+ size=512,
337
+ repeats=100,
338
+ interpolation="bicubic",
339
+ flip_p=0.5,
340
+ set="train",
341
+ placeholder_token="*",
342
+ center_crop=False,
343
+ ):
344
+ self.data_root = data_root
345
+ self.tokenizer = tokenizer
346
+ self.learnable_property = learnable_property
347
+ self.size = size
348
+ self.placeholder_token = placeholder_token
349
+ self.center_crop = center_crop
350
+ self.flip_p = flip_p
351
+
352
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
353
+
354
+ self.num_images = len(self.image_paths)
355
+ self._length = self.num_images
356
+
357
+ if set == "train":
358
+ self._length = self.num_images * repeats
359
+
360
+ self.interpolation = {
361
+ "linear": PIL_INTERPOLATION["linear"],
362
+ "bilinear": PIL_INTERPOLATION["bilinear"],
363
+ "bicubic": PIL_INTERPOLATION["bicubic"],
364
+ "lanczos": PIL_INTERPOLATION["lanczos"],
365
+ }[interpolation]
366
+
367
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
368
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
369
+
370
+ def __len__(self):
371
+ return self._length
372
+
373
+ def __getitem__(self, i):
374
+ example = {}
375
+ image = Image.open(self.image_paths[i % self.num_images])
376
+
377
+ if not image.mode == "RGB":
378
+ image = image.convert("RGB")
379
+
380
+ placeholder_string = self.placeholder_token
381
+ text = random.choice(self.templates).format(placeholder_string)
382
+
383
+ example["input_ids"] = self.tokenizer(
384
+ text,
385
+ padding="max_length",
386
+ truncation=True,
387
+ max_length=self.tokenizer.model_max_length,
388
+ return_tensors="pt",
389
+ ).input_ids[0]
390
+
391
+ # default to score-sde preprocessing
392
+ img = np.array(image).astype(np.uint8)
393
+
394
+ if self.center_crop:
395
+ crop = min(img.shape[0], img.shape[1])
396
+ (
397
+ h,
398
+ w,
399
+ ) = (
400
+ img.shape[0],
401
+ img.shape[1],
402
+ )
403
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
404
+
405
+ image = Image.fromarray(img)
406
+ image = image.resize((self.size, self.size), resample=self.interpolation)
407
+
408
+ image = self.flip_transform(image)
409
+ image = np.array(image).astype(np.uint8)
410
+ image = (image / 127.5 - 1.0).astype(np.float32)
411
+
412
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
413
+ return example
414
+
415
+
416
+ def freeze_params(params):
417
+ for param in params:
418
+ param.requires_grad = False
419
+
420
+
421
+ def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):
422
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
423
+ images = pipeline(
424
+ prompt,
425
+ guidance_scale=guidance_scale,
426
+ num_inference_steps=num_inference_steps,
427
+ generator=generator,
428
+ num_images_per_prompt=num_images_per_prompt,
429
+ ).images
430
+ _rows = int(math.sqrt(num_images_per_prompt))
431
+ grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
432
+ return grid
433
+
434
+
435
+ def main():
436
+ args = parse_args()
437
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
438
+
439
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
440
+
441
+ accelerator = Accelerator(
442
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
443
+ mixed_precision=args.mixed_precision,
444
+ log_with="tensorboard",
445
+ project_config=accelerator_project_config,
446
+ )
447
+
448
+ # If passed along, set the training seed now.
449
+ if args.seed is not None:
450
+ set_seed(args.seed)
451
+
452
+ # Handle the repository creation
453
+ if accelerator.is_main_process:
454
+ if args.output_dir is not None:
455
+ os.makedirs(args.output_dir, exist_ok=True)
456
+
457
+ if args.push_to_hub:
458
+ repo_id = create_repo(
459
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
460
+ ).repo_id
461
+
462
+ # Load the tokenizer and add the placeholder token as a additional special token
463
+ if args.tokenizer_name:
464
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
465
+ elif args.pretrained_model_name_or_path:
466
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
467
+
468
+ # Load models and create wrapper for stable diffusion
469
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
470
+ text_encoder = CLIPTextModel.from_pretrained(
471
+ args.pretrained_model_name_or_path,
472
+ subfolder="text_encoder",
473
+ revision=args.revision,
474
+ )
475
+ vae = AutoencoderKL.from_pretrained(
476
+ args.pretrained_model_name_or_path,
477
+ subfolder="vae",
478
+ revision=args.revision,
479
+ )
480
+ unet = UNet2DConditionModel.from_pretrained(
481
+ args.pretrained_model_name_or_path,
482
+ subfolder="unet",
483
+ revision=args.revision,
484
+ )
485
+
486
+ train_unet = False
487
+ # Freeze vae and unet
488
+ freeze_params(vae.parameters())
489
+ if not args.do_quantization and not args.do_distillation:
490
+ # Add the placeholder token in tokenizer
491
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
492
+ if num_added_tokens == 0:
493
+ raise ValueError(
494
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
495
+ " `placeholder_token` that is not already in the tokenizer."
496
+ )
497
+
498
+ # Convert the initializer_token, placeholder_token to ids
499
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
500
+ # Check if initializer_token is a single token or a sequence of tokens
501
+ if len(token_ids) > 1:
502
+ raise ValueError("The initializer token must be a single token.")
503
+
504
+ initializer_token_id = token_ids[0]
505
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
506
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
507
+ text_encoder.resize_token_embeddings(len(tokenizer))
508
+
509
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
510
+ token_embeds = text_encoder.get_input_embeddings().weight.data
511
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
512
+
513
+ freeze_params(unet.parameters())
514
+ # Freeze all parameters except for the token embeddings in text encoder
515
+ params_to_freeze = itertools.chain(
516
+ text_encoder.text_model.encoder.parameters(),
517
+ text_encoder.text_model.final_layer_norm.parameters(),
518
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
519
+ )
520
+ freeze_params(params_to_freeze)
521
+ else:
522
+ train_unet = True
523
+ freeze_params(text_encoder.parameters())
524
+
525
+ if args.scale_lr:
526
+ args.learning_rate = (
527
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
528
+ )
529
+
530
+ # Initialize the optimizer
531
+ optimizer = torch.optim.AdamW(
532
+ # only optimize the unet or embeddings of text_encoder
533
+ unet.parameters() if train_unet else text_encoder.get_input_embeddings().parameters(),
534
+ lr=args.learning_rate,
535
+ betas=(args.adam_beta1, args.adam_beta2),
536
+ weight_decay=args.adam_weight_decay,
537
+ eps=args.adam_epsilon,
538
+ )
539
+
540
+ train_dataset = TextualInversionDataset(
541
+ data_root=args.train_data_dir,
542
+ tokenizer=tokenizer,
543
+ size=args.resolution,
544
+ placeholder_token=args.placeholder_token,
545
+ repeats=args.repeats,
546
+ learnable_property=args.learnable_property,
547
+ center_crop=args.center_crop,
548
+ set="train",
549
+ )
550
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
551
+
552
+ # Scheduler and math around the number of training steps.
553
+ overrode_max_train_steps = False
554
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
555
+ if args.max_train_steps is None:
556
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
557
+ overrode_max_train_steps = True
558
+
559
+ lr_scheduler = get_scheduler(
560
+ args.lr_scheduler,
561
+ optimizer=optimizer,
562
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
563
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
564
+ )
565
+
566
+ if not train_unet:
567
+ text_encoder = accelerator.prepare(text_encoder)
568
+ unet.to(accelerator.device)
569
+ unet.eval()
570
+ else:
571
+ unet = accelerator.prepare(unet)
572
+ text_encoder.to(accelerator.device)
573
+ text_encoder.eval()
574
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
575
+
576
+ # Move vae to device
577
+ vae.to(accelerator.device)
578
+
579
+ # Keep vae in eval model as we don't train these
580
+ vae.eval()
581
+
582
+ compression_manager = None
583
+
584
+ def train_func(model):
585
+ if train_unet:
586
+ unet_ = model
587
+ text_encoder_ = text_encoder
588
+ else:
589
+ unet_ = unet
590
+ text_encoder_ = model
591
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
592
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
593
+ if overrode_max_train_steps:
594
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
595
+ # Afterwards we recalculate our number of training epochs
596
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
597
+
598
+ # We need to initialize the trackers we use, and also store our configuration.
599
+ # The trackers initializes automatically on the main process.
600
+ if accelerator.is_main_process:
601
+ accelerator.init_trackers("textual_inversion", config=vars(args))
602
+
603
+ # Train!
604
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
605
+
606
+ logger.info("***** Running training *****")
607
+ logger.info(f" Num examples = {len(train_dataset)}")
608
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
609
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
610
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
611
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
612
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
613
+ # Only show the progress bar once on each machine.
614
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
615
+ progress_bar.set_description("Steps")
616
+ global_step = 0
617
+
618
+ if train_unet and args.use_ema:
619
+ ema_unet = EMAModel(unet_.parameters())
620
+
621
+ for epoch in range(args.num_train_epochs):
622
+ model.train()
623
+ train_loss = 0.0
624
+ for step, batch in enumerate(train_dataloader):
625
+ with accelerator.accumulate(model):
626
+ # Convert images to latent space
627
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
628
+ latents = latents * 0.18215
629
+
630
+ # Sample noise that we'll add to the latents
631
+ noise = torch.randn(latents.shape).to(latents.device)
632
+ bsz = latents.shape[0]
633
+ # Sample a random timestep for each image
634
+ timesteps = torch.randint(
635
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
636
+ ).long()
637
+
638
+ # Add noise to the latents according to the noise magnitude at each timestep
639
+ # (this is the forward diffusion process)
640
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
641
+
642
+ # Get the text embedding for conditioning
643
+ encoder_hidden_states = text_encoder_(batch["input_ids"])[0]
644
+
645
+ # Predict the noise residual
646
+ model_pred = unet_(noisy_latents, timesteps, encoder_hidden_states).sample
647
+
648
+ loss = F.mse_loss(model_pred, noise, reduction="none").mean([1, 2, 3]).mean()
649
+ if train_unet and compression_manager:
650
+ unet_inputs = {
651
+ "sample": noisy_latents,
652
+ "timestep": timesteps,
653
+ "encoder_hidden_states": encoder_hidden_states,
654
+ }
655
+ loss = compression_manager.callbacks.on_after_compute_loss(unet_inputs, model_pred, loss)
656
+
657
+ # Gather the losses across all processes for logging (if we use distributed training).
658
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
659
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
660
+
661
+ # Backpropagate
662
+ accelerator.backward(loss)
663
+
664
+ if train_unet:
665
+ if accelerator.sync_gradients:
666
+ accelerator.clip_grad_norm_(unet_.parameters(), args.max_grad_norm)
667
+ else:
668
+ # Zero out the gradients for all token embeddings except the newly added
669
+ # embeddings for the concept, as we only want to optimize the concept embeddings
670
+ if accelerator.num_processes > 1:
671
+ grads = text_encoder_.module.get_input_embeddings().weight.grad
672
+ else:
673
+ grads = text_encoder_.get_input_embeddings().weight.grad
674
+ # Get the index for tokens that we want to zero the grads for
675
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
676
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
677
+
678
+ optimizer.step()
679
+ lr_scheduler.step()
680
+ optimizer.zero_grad()
681
+
682
+ # Checks if the accelerator has performed an optimization step behind the scenes
683
+ if accelerator.sync_gradients:
684
+ if train_unet and args.use_ema:
685
+ ema_unet.step(unet_.parameters())
686
+ progress_bar.update(1)
687
+ global_step += 1
688
+ accelerator.log({"train_loss": train_loss}, step=global_step)
689
+ train_loss = 0.0
690
+ if not train_unet and global_step % args.save_steps == 0:
691
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
692
+ save_progress(text_encoder_, placeholder_token_id, accelerator, args, save_path)
693
+
694
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
695
+ progress_bar.set_postfix(**logs)
696
+ accelerator.log(logs, step=global_step)
697
+
698
+ if global_step >= args.max_train_steps:
699
+ break
700
+ accelerator.wait_for_everyone()
701
+
702
+ if train_unet and args.use_ema:
703
+ ema_unet.copy_to(unet_.parameters())
704
+
705
+ if not train_unet:
706
+ return text_encoder_
707
+
708
+ if not train_unet:
709
+ text_encoder = train_func(text_encoder)
710
+ else:
711
+ import copy
712
+
713
+ model = copy.deepcopy(unet)
714
+ confs = []
715
+ if args.do_quantization:
716
+ from neural_compressor import QuantizationAwareTrainingConfig
717
+
718
+ q_conf = QuantizationAwareTrainingConfig()
719
+ confs.append(q_conf)
720
+
721
+ if args.do_distillation:
722
+ teacher_model = copy.deepcopy(model)
723
+
724
+ def attention_fetcher(x):
725
+ return x.sample
726
+
727
+ layer_mappings = [
728
+ [
729
+ [
730
+ "conv_in",
731
+ ]
732
+ ],
733
+ [
734
+ [
735
+ "time_embedding",
736
+ ]
737
+ ],
738
+ [["down_blocks.0.attentions.0", attention_fetcher]],
739
+ [["down_blocks.0.attentions.1", attention_fetcher]],
740
+ [
741
+ [
742
+ "down_blocks.0.resnets.0",
743
+ ]
744
+ ],
745
+ [
746
+ [
747
+ "down_blocks.0.resnets.1",
748
+ ]
749
+ ],
750
+ [
751
+ [
752
+ "down_blocks.0.downsamplers.0",
753
+ ]
754
+ ],
755
+ [["down_blocks.1.attentions.0", attention_fetcher]],
756
+ [["down_blocks.1.attentions.1", attention_fetcher]],
757
+ [
758
+ [
759
+ "down_blocks.1.resnets.0",
760
+ ]
761
+ ],
762
+ [
763
+ [
764
+ "down_blocks.1.resnets.1",
765
+ ]
766
+ ],
767
+ [
768
+ [
769
+ "down_blocks.1.downsamplers.0",
770
+ ]
771
+ ],
772
+ [["down_blocks.2.attentions.0", attention_fetcher]],
773
+ [["down_blocks.2.attentions.1", attention_fetcher]],
774
+ [
775
+ [
776
+ "down_blocks.2.resnets.0",
777
+ ]
778
+ ],
779
+ [
780
+ [
781
+ "down_blocks.2.resnets.1",
782
+ ]
783
+ ],
784
+ [
785
+ [
786
+ "down_blocks.2.downsamplers.0",
787
+ ]
788
+ ],
789
+ [
790
+ [
791
+ "down_blocks.3.resnets.0",
792
+ ]
793
+ ],
794
+ [
795
+ [
796
+ "down_blocks.3.resnets.1",
797
+ ]
798
+ ],
799
+ [
800
+ [
801
+ "up_blocks.0.resnets.0",
802
+ ]
803
+ ],
804
+ [
805
+ [
806
+ "up_blocks.0.resnets.1",
807
+ ]
808
+ ],
809
+ [
810
+ [
811
+ "up_blocks.0.resnets.2",
812
+ ]
813
+ ],
814
+ [
815
+ [
816
+ "up_blocks.0.upsamplers.0",
817
+ ]
818
+ ],
819
+ [["up_blocks.1.attentions.0", attention_fetcher]],
820
+ [["up_blocks.1.attentions.1", attention_fetcher]],
821
+ [["up_blocks.1.attentions.2", attention_fetcher]],
822
+ [
823
+ [
824
+ "up_blocks.1.resnets.0",
825
+ ]
826
+ ],
827
+ [
828
+ [
829
+ "up_blocks.1.resnets.1",
830
+ ]
831
+ ],
832
+ [
833
+ [
834
+ "up_blocks.1.resnets.2",
835
+ ]
836
+ ],
837
+ [
838
+ [
839
+ "up_blocks.1.upsamplers.0",
840
+ ]
841
+ ],
842
+ [["up_blocks.2.attentions.0", attention_fetcher]],
843
+ [["up_blocks.2.attentions.1", attention_fetcher]],
844
+ [["up_blocks.2.attentions.2", attention_fetcher]],
845
+ [
846
+ [
847
+ "up_blocks.2.resnets.0",
848
+ ]
849
+ ],
850
+ [
851
+ [
852
+ "up_blocks.2.resnets.1",
853
+ ]
854
+ ],
855
+ [
856
+ [
857
+ "up_blocks.2.resnets.2",
858
+ ]
859
+ ],
860
+ [
861
+ [
862
+ "up_blocks.2.upsamplers.0",
863
+ ]
864
+ ],
865
+ [["up_blocks.3.attentions.0", attention_fetcher]],
866
+ [["up_blocks.3.attentions.1", attention_fetcher]],
867
+ [["up_blocks.3.attentions.2", attention_fetcher]],
868
+ [
869
+ [
870
+ "up_blocks.3.resnets.0",
871
+ ]
872
+ ],
873
+ [
874
+ [
875
+ "up_blocks.3.resnets.1",
876
+ ]
877
+ ],
878
+ [
879
+ [
880
+ "up_blocks.3.resnets.2",
881
+ ]
882
+ ],
883
+ [["mid_block.attentions.0", attention_fetcher]],
884
+ [
885
+ [
886
+ "mid_block.resnets.0",
887
+ ]
888
+ ],
889
+ [
890
+ [
891
+ "mid_block.resnets.1",
892
+ ]
893
+ ],
894
+ [
895
+ [
896
+ "conv_out",
897
+ ]
898
+ ],
899
+ ]
900
+ layer_names = [layer_mapping[0][0] for layer_mapping in layer_mappings]
901
+ if not set(layer_names).issubset([n[0] for n in model.named_modules()]):
902
+ raise ValueError(
903
+ "Provided model is not compatible with the default layer_mappings, "
904
+ 'please use the model fine-tuned from "CompVis/stable-diffusion-v1-4", '
905
+ "or modify the layer_mappings variable to fit your model."
906
+ f"\nDefault layer_mappings are as such:\n{layer_mappings}"
907
+ )
908
+ from neural_compressor.config import DistillationConfig, IntermediateLayersKnowledgeDistillationLossConfig
909
+
910
+ distillation_criterion = IntermediateLayersKnowledgeDistillationLossConfig(
911
+ layer_mappings=layer_mappings,
912
+ loss_types=["MSE"] * len(layer_mappings),
913
+ loss_weights=[1.0 / len(layer_mappings)] * len(layer_mappings),
914
+ add_origin_loss=True,
915
+ )
916
+ d_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)
917
+ confs.append(d_conf)
918
+
919
+ from neural_compressor.training import prepare_compression
920
+
921
+ compression_manager = prepare_compression(model, confs)
922
+ compression_manager.callbacks.on_train_begin()
923
+ model = compression_manager.model
924
+ train_func(model)
925
+ compression_manager.callbacks.on_train_end()
926
+
927
+ # Save the resulting model and its corresponding configuration in the given directory
928
+ model.save(args.output_dir)
929
+
930
+ logger.info(f"Optimized model saved to: {args.output_dir}.")
931
+
932
+ # change to framework model for further use
933
+ model = model.model
934
+
935
+ # Create the pipeline using using the trained modules and save it.
936
+ templates = imagenet_style_templates_small if args.learnable_property == "style" else imagenet_templates_small
937
+ prompt = templates[0].format(args.placeholder_token)
938
+ if accelerator.is_main_process:
939
+ pipeline = StableDiffusionPipeline.from_pretrained(
940
+ args.pretrained_model_name_or_path,
941
+ text_encoder=accelerator.unwrap_model(text_encoder),
942
+ vae=vae,
943
+ unet=accelerator.unwrap_model(unet),
944
+ tokenizer=tokenizer,
945
+ )
946
+ pipeline.save_pretrained(args.output_dir)
947
+ pipeline = pipeline.to(unet.device)
948
+ baseline_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
949
+ baseline_model_images.save(
950
+ os.path.join(args.output_dir, "{}_baseline_model.png".format("_".join(prompt.split())))
951
+ )
952
+
953
+ if not train_unet:
954
+ # Also save the newly trained embeddings
955
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
956
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
957
+ else:
958
+ setattr(pipeline, "unet", accelerator.unwrap_model(model))
959
+ if args.do_quantization:
960
+ pipeline = pipeline.to(torch.device("cpu"))
961
+
962
+ optimized_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
963
+ optimized_model_images.save(
964
+ os.path.join(args.output_dir, "{}_optimized_model.png".format("_".join(prompt.split())))
965
+ )
966
+
967
+ if args.push_to_hub:
968
+ upload_folder(
969
+ repo_id=repo_id,
970
+ folder_path=args.output_dir,
971
+ commit_message="End of training",
972
+ ignore_patterns=["step_*", "epoch_*"],
973
+ )
974
+
975
+ accelerator.end_training()
976
+
977
+ if args.do_quantization and args.verify_loading:
978
+ # Load the model obtained after Intel Neural Compressor quantization
979
+ from neural_compressor.utils.pytorch import load
980
+
981
+ loaded_model = load(args.output_dir, model=unet)
982
+ loaded_model.eval()
983
+
984
+ setattr(pipeline, "unet", loaded_model)
985
+ if args.do_quantization:
986
+ pipeline = pipeline.to(torch.device("cpu"))
987
+
988
+ loaded_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
989
+ if loaded_model_images != optimized_model_images:
990
+ logger.info("The quantized model was not successfully loaded.")
991
+ else:
992
+ logger.info("The quantized model was successfully loaded.")
993
+
994
+
995
+ if __name__ == "__main__":
996
+ main()
diffusers/examples/research_projects/ip_adapter/README.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IP Adapter Training Example
2
+
3
+ [IP Adapter](https://huggingface.co/papers/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
4
+
5
+ ## Training locally with PyTorch
6
+
7
+ ### Installing the dependencies
8
+
9
+ Before running the scripts, make sure to install the library's training dependencies:
10
+
11
+ **Important**
12
+
13
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14
+
15
+ ```bash
16
+ git clone https://github.com/huggingface/diffusers
17
+ cd diffusers
18
+ pip install -e .
19
+ ```
20
+
21
+ Then cd in the example folder and run
22
+
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
28
+
29
+ ```bash
30
+ accelerate config
31
+ ```
32
+
33
+ Or for a default accelerate configuration without answering questions about your environment
34
+
35
+ ```bash
36
+ accelerate config default
37
+ ```
38
+
39
+ Or if your environment doesn't support an interactive shell e.g. a notebook
40
+
41
+ ```python
42
+ from accelerate.utils import write_basic_config
43
+ write_basic_config()
44
+ ```
45
+
46
+ Certainly! Below is the documentation in pure Markdown format:
47
+
48
+ ### Accelerate Launch Command Documentation
49
+
50
+ #### Description:
51
+ The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
52
+
53
+ #### Usage Example:
54
+
55
+ ```
56
+ accelerate launch --mixed_precision "fp16" \
57
+ tutorial_train_ip-adapter.py \
58
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
59
+ --image_encoder_path="{image_encoder_path}" \
60
+ --data_json_file="{data.json}" \
61
+ --data_root_path="{image_path}" \
62
+ --mixed_precision="fp16" \
63
+ --resolution=512 \
64
+ --train_batch_size=8 \
65
+ --dataloader_num_workers=4 \
66
+ --learning_rate=1e-04 \
67
+ --weight_decay=0.01 \
68
+ --output_dir="{output_dir}" \
69
+ --save_steps=10000
70
+ ```
71
+
72
+ ### Multi-GPU Script:
73
+ ```
74
+ accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
75
+ tutorial_train_ip-adapter.py \
76
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
77
+ --image_encoder_path="{image_encoder_path}" \
78
+ --data_json_file="{data.json}" \
79
+ --data_root_path="{image_path}" \
80
+ --mixed_precision="fp16" \
81
+ --resolution=512 \
82
+ --train_batch_size=8 \
83
+ --dataloader_num_workers=4 \
84
+ --learning_rate=1e-04 \
85
+ --weight_decay=0.01 \
86
+ --output_dir="{output_dir}" \
87
+ --save_steps=10000
88
+ ```
89
+
90
+ #### Parameters:
91
+ - `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).
92
+ - `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.
93
+ - `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision.
94
+ - `tutorial_train_ip-adapter.py`: Name of the training script to be executed.
95
+ - `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.
96
+ - `--image_encoder_path`: Path to the CLIP image encoder.
97
+ - `--data_json_file`: Path to the training data in JSON format.
98
+ - `--data_root_path`: Root path where training images are located.
99
+ - `--resolution`: Resolution of input images (512x512 in this example).
100
+ - `--train_batch_size`: Batch size for training data (8 in this example).
101
+ - `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).
102
+ - `--learning_rate`: Learning rate for training (1e-04 in this example).
103
+ - `--weight_decay`: Weight decay for regularization (0.01 in this example).
104
+ - `--output_dir`: Directory to save model checkpoints and predictions.
105
+ - `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).
106
+
107
+ ### Inference
108
+
109
+ #### Description:
110
+ The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.
111
+
112
+ #### Usage Example:
113
+ ```python
114
+ from safetensors.torch import load_file, save_file
115
+
116
+ # Load the trained model checkpoint in safetensors format
117
+ ckpt = "checkpoint-50000/pytorch_model.safetensors"
118
+ sd = load_file(ckpt) # Using safetensors load function
119
+
120
+ # Extract image projection and IP adapter components
121
+ image_proj_sd = {}
122
+ ip_sd = {}
123
+
124
+ for k in sd:
125
+ if k.startswith("unet"):
126
+ pass # Skip unet-related keys
127
+ elif k.startswith("image_proj_model"):
128
+ image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
129
+ elif k.startswith("adapter_modules"):
130
+ ip_sd[k.replace("adapter_modules.", "")] = sd[k]
131
+
132
+ # Save the components into separate safetensors files
133
+ save_file(image_proj_sd, "image_proj.safetensors")
134
+ save_file(ip_sd, "ip_adapter.safetensors")
135
+ ```
136
+
137
+ ### Sample Inference Script using the CLIP Model
138
+
139
+ ```python
140
+
141
+ import torch
142
+ from safetensors.torch import load_file
143
+ from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model
144
+
145
+ # Load model components from safetensors
146
+ image_proj_ckpt = "image_proj.safetensors"
147
+ ip_adapter_ckpt = "ip_adapter.safetensors"
148
+
149
+ # Load the saved weights
150
+ image_proj_sd = load_file(image_proj_ckpt)
151
+ ip_adapter_sd = load_file(ip_adapter_ckpt)
152
+
153
+ # Define the model Parameters
154
+ class ImageProjectionModel(torch.nn.Module):
155
+ def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768
156
+ super().__init__()
157
+ self.model = torch.nn.Linear(input_dim, output_dim)
158
+
159
+ def forward(self, x):
160
+ return self.model(x)
161
+
162
+ class IPAdapterModel(torch.nn.Module):
163
+ def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes
164
+ super().__init__()
165
+ self.model = torch.nn.Linear(input_dim, output_dim)
166
+
167
+ def forward(self, x):
168
+ return self.model(x)
169
+
170
+ # Initialize models
171
+ image_proj_model = ImageProjectionModel()
172
+ ip_adapter_model = IPAdapterModel()
173
+
174
+ # Load weights into models
175
+ image_proj_model.load_state_dict(image_proj_sd)
176
+ ip_adapter_model.load_state_dict(ip_adapter_sd)
177
+
178
+ # Set models to evaluation mode
179
+ image_proj_model.eval()
180
+ ip_adapter_model.eval()
181
+
182
+ #Inference pipeline
183
+ def inference(image_tensor):
184
+ """
185
+ Run inference using the loaded models.
186
+
187
+ Args:
188
+ image_tensor: Preprocessed image tensor from CLIPProcessor
189
+
190
+ Returns:
191
+ Final inference results
192
+ """
193
+ with torch.no_grad():
194
+ # Step 1: Project the image features
195
+ image_proj = image_proj_model(image_tensor)
196
+
197
+ # Step 2: Pass the projected features through the IP Adapter
198
+ result = ip_adapter_model(image_proj)
199
+
200
+ return result
201
+
202
+ # Using CLIP for image preprocessing
203
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
204
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
205
+
206
+ #Image file path
207
+ image_path = "path/to/image.jpg"
208
+
209
+ # Preprocess the image
210
+ inputs = processor(images=image_path, return_tensors="pt")
211
+ image_features = clip_model.get_image_features(inputs["pixel_values"])
212
+
213
+ # Normalize the image features as per CLIP's recommendations
214
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
215
+
216
+ # Run inference
217
+ output = inference(image_features)
218
+ print("Inference output:", output)
219
+ ```
220
+
221
+ #### Parameters:
222
+ - `ckpt`: Path to the trained model checkpoint file.
223
+ - `map_location="cpu"`: Specifies that the model should be loaded onto the CPU.
224
+ - `image_proj_sd`: Dictionary to store the components related to image projection.
225
+ - `ip_sd`: Dictionary to store the components related to the IP adapter.
226
+ - `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model.
diffusers/examples/research_projects/ip_adapter/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ torchvision
3
+ transformers>=4.25.1
4
+ ip_adapter
diffusers/examples/research_projects/ip_adapter/tutorial_train_faceid.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
+ from accelerate.utils import ProjectConfiguration
13
+ from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
14
+ from ip_adapter.ip_adapter_faceid import MLPProjModel
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+
19
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
20
+
21
+
22
+ # Dataset
23
+ class MyDataset(torch.utils.data.Dataset):
24
+ def __init__(
25
+ self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
26
+ ):
27
+ super().__init__()
28
+
29
+ self.tokenizer = tokenizer
30
+ self.size = size
31
+ self.i_drop_rate = i_drop_rate
32
+ self.t_drop_rate = t_drop_rate
33
+ self.ti_drop_rate = ti_drop_rate
34
+ self.image_root_path = image_root_path
35
+
36
+ self.data = json.load(
37
+ open(json_file)
38
+ ) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
39
+
40
+ self.transform = transforms.Compose(
41
+ [
42
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
43
+ transforms.CenterCrop(self.size),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize([0.5], [0.5]),
46
+ ]
47
+ )
48
+
49
+ def __getitem__(self, idx):
50
+ item = self.data[idx]
51
+ text = item["text"]
52
+ image_file = item["image_file"]
53
+
54
+ # read image
55
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
56
+ image = self.transform(raw_image.convert("RGB"))
57
+
58
+ face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
59
+ face_id_embed = torch.from_numpy(face_id_embed)
60
+
61
+ # drop
62
+ drop_image_embed = 0
63
+ rand_num = random.random()
64
+ if rand_num < self.i_drop_rate:
65
+ drop_image_embed = 1
66
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
67
+ text = ""
68
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
69
+ text = ""
70
+ drop_image_embed = 1
71
+ if drop_image_embed:
72
+ face_id_embed = torch.zeros_like(face_id_embed)
73
+ # get text and tokenize
74
+ text_input_ids = self.tokenizer(
75
+ text,
76
+ max_length=self.tokenizer.model_max_length,
77
+ padding="max_length",
78
+ truncation=True,
79
+ return_tensors="pt",
80
+ ).input_ids
81
+
82
+ return {
83
+ "image": image,
84
+ "text_input_ids": text_input_ids,
85
+ "face_id_embed": face_id_embed,
86
+ "drop_image_embed": drop_image_embed,
87
+ }
88
+
89
+ def __len__(self):
90
+ return len(self.data)
91
+
92
+
93
+ def collate_fn(data):
94
+ images = torch.stack([example["image"] for example in data])
95
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
96
+ face_id_embed = torch.stack([example["face_id_embed"] for example in data])
97
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
98
+
99
+ return {
100
+ "images": images,
101
+ "text_input_ids": text_input_ids,
102
+ "face_id_embed": face_id_embed,
103
+ "drop_image_embeds": drop_image_embeds,
104
+ }
105
+
106
+
107
+ class IPAdapter(torch.nn.Module):
108
+ """IP-Adapter"""
109
+
110
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
111
+ super().__init__()
112
+ self.unet = unet
113
+ self.image_proj_model = image_proj_model
114
+ self.adapter_modules = adapter_modules
115
+
116
+ if ckpt_path is not None:
117
+ self.load_from_checkpoint(ckpt_path)
118
+
119
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
120
+ ip_tokens = self.image_proj_model(image_embeds)
121
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
122
+ # Predict the noise residual
123
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
124
+ return noise_pred
125
+
126
+ def load_from_checkpoint(self, ckpt_path: str):
127
+ # Calculate original checksums
128
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
129
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
130
+
131
+ state_dict = torch.load(ckpt_path, map_location="cpu")
132
+
133
+ # Load state dict for image_proj_model and adapter_modules
134
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
135
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
136
+
137
+ # Calculate new checksums
138
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
139
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
140
+
141
+ # Verify if the weights have changed
142
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
143
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
144
+
145
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
146
+
147
+
148
+ def parse_args():
149
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
150
+ parser.add_argument(
151
+ "--pretrained_model_name_or_path",
152
+ type=str,
153
+ default=None,
154
+ required=True,
155
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
156
+ )
157
+ parser.add_argument(
158
+ "--pretrained_ip_adapter_path",
159
+ type=str,
160
+ default=None,
161
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
162
+ )
163
+ parser.add_argument(
164
+ "--data_json_file",
165
+ type=str,
166
+ default=None,
167
+ required=True,
168
+ help="Training data",
169
+ )
170
+ parser.add_argument(
171
+ "--data_root_path",
172
+ type=str,
173
+ default="",
174
+ required=True,
175
+ help="Training data root path",
176
+ )
177
+ parser.add_argument(
178
+ "--image_encoder_path",
179
+ type=str,
180
+ default=None,
181
+ required=True,
182
+ help="Path to CLIP image encoder",
183
+ )
184
+ parser.add_argument(
185
+ "--output_dir",
186
+ type=str,
187
+ default="sd-ip_adapter",
188
+ help="The output directory where the model predictions and checkpoints will be written.",
189
+ )
190
+ parser.add_argument(
191
+ "--logging_dir",
192
+ type=str,
193
+ default="logs",
194
+ help=(
195
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
196
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--resolution",
201
+ type=int,
202
+ default=512,
203
+ help=("The resolution for input images"),
204
+ )
205
+ parser.add_argument(
206
+ "--learning_rate",
207
+ type=float,
208
+ default=1e-4,
209
+ help="Learning rate to use.",
210
+ )
211
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
212
+ parser.add_argument("--num_train_epochs", type=int, default=100)
213
+ parser.add_argument(
214
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
215
+ )
216
+ parser.add_argument(
217
+ "--dataloader_num_workers",
218
+ type=int,
219
+ default=0,
220
+ help=(
221
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--save_steps",
226
+ type=int,
227
+ default=2000,
228
+ help=("Save a checkpoint of the training state every X updates"),
229
+ )
230
+ parser.add_argument(
231
+ "--mixed_precision",
232
+ type=str,
233
+ default=None,
234
+ choices=["no", "fp16", "bf16"],
235
+ help=(
236
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
237
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
238
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
239
+ ),
240
+ )
241
+ parser.add_argument(
242
+ "--report_to",
243
+ type=str,
244
+ default="tensorboard",
245
+ help=(
246
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
247
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
248
+ ),
249
+ )
250
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
251
+
252
+ args = parser.parse_args()
253
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
254
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
255
+ args.local_rank = env_local_rank
256
+
257
+ return args
258
+
259
+
260
+ def main():
261
+ args = parse_args()
262
+ logging_dir = Path(args.output_dir, args.logging_dir)
263
+
264
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
265
+
266
+ accelerator = Accelerator(
267
+ mixed_precision=args.mixed_precision,
268
+ log_with=args.report_to,
269
+ project_config=accelerator_project_config,
270
+ )
271
+
272
+ if accelerator.is_main_process:
273
+ if args.output_dir is not None:
274
+ os.makedirs(args.output_dir, exist_ok=True)
275
+
276
+ # Load scheduler, tokenizer and models.
277
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
278
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
279
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
280
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
281
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
282
+ # image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
283
+ # freeze parameters of models to save more memory
284
+ unet.requires_grad_(False)
285
+ vae.requires_grad_(False)
286
+ text_encoder.requires_grad_(False)
287
+ # image_encoder.requires_grad_(False)
288
+
289
+ # ip-adapter
290
+ image_proj_model = MLPProjModel(
291
+ cross_attention_dim=unet.config.cross_attention_dim,
292
+ id_embeddings_dim=512,
293
+ num_tokens=4,
294
+ )
295
+ # init adapter modules
296
+ lora_rank = 128
297
+ attn_procs = {}
298
+ unet_sd = unet.state_dict()
299
+ for name in unet.attn_processors.keys():
300
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
301
+ if name.startswith("mid_block"):
302
+ hidden_size = unet.config.block_out_channels[-1]
303
+ elif name.startswith("up_blocks"):
304
+ block_id = int(name[len("up_blocks.")])
305
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
306
+ elif name.startswith("down_blocks"):
307
+ block_id = int(name[len("down_blocks.")])
308
+ hidden_size = unet.config.block_out_channels[block_id]
309
+ if cross_attention_dim is None:
310
+ attn_procs[name] = LoRAAttnProcessor(
311
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
312
+ )
313
+ else:
314
+ layer_name = name.split(".processor")[0]
315
+ weights = {
316
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
317
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
318
+ }
319
+ attn_procs[name] = LoRAIPAttnProcessor(
320
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
321
+ )
322
+ attn_procs[name].load_state_dict(weights, strict=False)
323
+ unet.set_attn_processor(attn_procs)
324
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
325
+
326
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
327
+
328
+ weight_dtype = torch.float32
329
+ if accelerator.mixed_precision == "fp16":
330
+ weight_dtype = torch.float16
331
+ elif accelerator.mixed_precision == "bf16":
332
+ weight_dtype = torch.bfloat16
333
+ # unet.to(accelerator.device, dtype=weight_dtype)
334
+ vae.to(accelerator.device, dtype=weight_dtype)
335
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
336
+ # image_encoder.to(accelerator.device, dtype=weight_dtype)
337
+
338
+ # optimizer
339
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
340
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
341
+
342
+ # dataloader
343
+ train_dataset = MyDataset(
344
+ args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
345
+ )
346
+ train_dataloader = torch.utils.data.DataLoader(
347
+ train_dataset,
348
+ shuffle=True,
349
+ collate_fn=collate_fn,
350
+ batch_size=args.train_batch_size,
351
+ num_workers=args.dataloader_num_workers,
352
+ )
353
+
354
+ # Prepare everything with our `accelerator`.
355
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
356
+
357
+ global_step = 0
358
+ for epoch in range(0, args.num_train_epochs):
359
+ begin = time.perf_counter()
360
+ for step, batch in enumerate(train_dataloader):
361
+ load_data_time = time.perf_counter() - begin
362
+ with accelerator.accumulate(ip_adapter):
363
+ # Convert images to latent space
364
+ with torch.no_grad():
365
+ latents = vae.encode(
366
+ batch["images"].to(accelerator.device, dtype=weight_dtype)
367
+ ).latent_dist.sample()
368
+ latents = latents * vae.config.scaling_factor
369
+
370
+ # Sample noise that we'll add to the latents
371
+ noise = torch.randn_like(latents)
372
+ bsz = latents.shape[0]
373
+ # Sample a random timestep for each image
374
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
375
+ timesteps = timesteps.long()
376
+
377
+ # Add noise to the latents according to the noise magnitude at each timestep
378
+ # (this is the forward diffusion process)
379
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
380
+
381
+ image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
382
+
383
+ with torch.no_grad():
384
+ encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
385
+
386
+ noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
387
+
388
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
389
+
390
+ # Gather the losses across all processes for logging (if we use distributed training).
391
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
392
+
393
+ # Backpropagate
394
+ accelerator.backward(loss)
395
+ optimizer.step()
396
+ optimizer.zero_grad()
397
+
398
+ if accelerator.is_main_process:
399
+ print(
400
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
401
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
402
+ )
403
+ )
404
+
405
+ global_step += 1
406
+
407
+ if global_step % args.save_steps == 0:
408
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
409
+ accelerator.save_state(save_path)
410
+
411
+ begin = time.perf_counter()
412
+
413
+
414
+ if __name__ == "__main__":
415
+ main()
diffusers/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
+ from accelerate.utils import ProjectConfiguration
13
+ from ip_adapter.ip_adapter import ImageProjModel
14
+ from ip_adapter.utils import is_torch2_available
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
18
+
19
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
20
+
21
+
22
+ if is_torch2_available():
23
+ from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
24
+ from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
25
+ else:
26
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
27
+
28
+
29
+ # Dataset
30
+ class MyDataset(torch.utils.data.Dataset):
31
+ def __init__(
32
+ self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
33
+ ):
34
+ super().__init__()
35
+
36
+ self.tokenizer = tokenizer
37
+ self.size = size
38
+ self.i_drop_rate = i_drop_rate
39
+ self.t_drop_rate = t_drop_rate
40
+ self.ti_drop_rate = ti_drop_rate
41
+ self.image_root_path = image_root_path
42
+
43
+ self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
44
+
45
+ self.transform = transforms.Compose(
46
+ [
47
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
48
+ transforms.CenterCrop(self.size),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.5], [0.5]),
51
+ ]
52
+ )
53
+ self.clip_image_processor = CLIPImageProcessor()
54
+
55
+ def __getitem__(self, idx):
56
+ item = self.data[idx]
57
+ text = item["text"]
58
+ image_file = item["image_file"]
59
+
60
+ # read image
61
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
62
+ image = self.transform(raw_image.convert("RGB"))
63
+ clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
64
+
65
+ # drop
66
+ drop_image_embed = 0
67
+ rand_num = random.random()
68
+ if rand_num < self.i_drop_rate:
69
+ drop_image_embed = 1
70
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
71
+ text = ""
72
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
73
+ text = ""
74
+ drop_image_embed = 1
75
+ # get text and tokenize
76
+ text_input_ids = self.tokenizer(
77
+ text,
78
+ max_length=self.tokenizer.model_max_length,
79
+ padding="max_length",
80
+ truncation=True,
81
+ return_tensors="pt",
82
+ ).input_ids
83
+
84
+ return {
85
+ "image": image,
86
+ "text_input_ids": text_input_ids,
87
+ "clip_image": clip_image,
88
+ "drop_image_embed": drop_image_embed,
89
+ }
90
+
91
+ def __len__(self):
92
+ return len(self.data)
93
+
94
+
95
+ def collate_fn(data):
96
+ images = torch.stack([example["image"] for example in data])
97
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
98
+ clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
99
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
100
+
101
+ return {
102
+ "images": images,
103
+ "text_input_ids": text_input_ids,
104
+ "clip_images": clip_images,
105
+ "drop_image_embeds": drop_image_embeds,
106
+ }
107
+
108
+
109
+ class IPAdapter(torch.nn.Module):
110
+ """IP-Adapter"""
111
+
112
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
113
+ super().__init__()
114
+ self.unet = unet
115
+ self.image_proj_model = image_proj_model
116
+ self.adapter_modules = adapter_modules
117
+
118
+ if ckpt_path is not None:
119
+ self.load_from_checkpoint(ckpt_path)
120
+
121
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
122
+ ip_tokens = self.image_proj_model(image_embeds)
123
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
124
+ # Predict the noise residual
125
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
126
+ return noise_pred
127
+
128
+ def load_from_checkpoint(self, ckpt_path: str):
129
+ # Calculate original checksums
130
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
131
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
132
+
133
+ state_dict = torch.load(ckpt_path, map_location="cpu")
134
+
135
+ # Load state dict for image_proj_model and adapter_modules
136
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
137
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
138
+
139
+ # Calculate new checksums
140
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
141
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
142
+
143
+ # Verify if the weights have changed
144
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
145
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
146
+
147
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
148
+
149
+
150
+ def parse_args():
151
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
152
+ parser.add_argument(
153
+ "--pretrained_model_name_or_path",
154
+ type=str,
155
+ default=None,
156
+ required=True,
157
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
158
+ )
159
+ parser.add_argument(
160
+ "--pretrained_ip_adapter_path",
161
+ type=str,
162
+ default=None,
163
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
164
+ )
165
+ parser.add_argument(
166
+ "--data_json_file",
167
+ type=str,
168
+ default=None,
169
+ required=True,
170
+ help="Training data",
171
+ )
172
+ parser.add_argument(
173
+ "--data_root_path",
174
+ type=str,
175
+ default="",
176
+ required=True,
177
+ help="Training data root path",
178
+ )
179
+ parser.add_argument(
180
+ "--image_encoder_path",
181
+ type=str,
182
+ default=None,
183
+ required=True,
184
+ help="Path to CLIP image encoder",
185
+ )
186
+ parser.add_argument(
187
+ "--output_dir",
188
+ type=str,
189
+ default="sd-ip_adapter",
190
+ help="The output directory where the model predictions and checkpoints will be written.",
191
+ )
192
+ parser.add_argument(
193
+ "--logging_dir",
194
+ type=str,
195
+ default="logs",
196
+ help=(
197
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
198
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
199
+ ),
200
+ )
201
+ parser.add_argument(
202
+ "--resolution",
203
+ type=int,
204
+ default=512,
205
+ help=("The resolution for input images"),
206
+ )
207
+ parser.add_argument(
208
+ "--learning_rate",
209
+ type=float,
210
+ default=1e-4,
211
+ help="Learning rate to use.",
212
+ )
213
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
214
+ parser.add_argument("--num_train_epochs", type=int, default=100)
215
+ parser.add_argument(
216
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
217
+ )
218
+ parser.add_argument(
219
+ "--dataloader_num_workers",
220
+ type=int,
221
+ default=0,
222
+ help=(
223
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
224
+ ),
225
+ )
226
+ parser.add_argument(
227
+ "--save_steps",
228
+ type=int,
229
+ default=2000,
230
+ help=("Save a checkpoint of the training state every X updates"),
231
+ )
232
+ parser.add_argument(
233
+ "--mixed_precision",
234
+ type=str,
235
+ default=None,
236
+ choices=["no", "fp16", "bf16"],
237
+ help=(
238
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
239
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
240
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
241
+ ),
242
+ )
243
+ parser.add_argument(
244
+ "--report_to",
245
+ type=str,
246
+ default="tensorboard",
247
+ help=(
248
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
249
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
250
+ ),
251
+ )
252
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
253
+
254
+ args = parser.parse_args()
255
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
256
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
257
+ args.local_rank = env_local_rank
258
+
259
+ return args
260
+
261
+
262
+ def main():
263
+ args = parse_args()
264
+ logging_dir = Path(args.output_dir, args.logging_dir)
265
+
266
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
267
+
268
+ accelerator = Accelerator(
269
+ mixed_precision=args.mixed_precision,
270
+ log_with=args.report_to,
271
+ project_config=accelerator_project_config,
272
+ )
273
+
274
+ if accelerator.is_main_process:
275
+ if args.output_dir is not None:
276
+ os.makedirs(args.output_dir, exist_ok=True)
277
+
278
+ # Load scheduler, tokenizer and models.
279
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
280
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
281
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
282
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
283
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
284
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
285
+ # freeze parameters of models to save more memory
286
+ unet.requires_grad_(False)
287
+ vae.requires_grad_(False)
288
+ text_encoder.requires_grad_(False)
289
+ image_encoder.requires_grad_(False)
290
+
291
+ # ip-adapter
292
+ image_proj_model = ImageProjModel(
293
+ cross_attention_dim=unet.config.cross_attention_dim,
294
+ clip_embeddings_dim=image_encoder.config.projection_dim,
295
+ clip_extra_context_tokens=4,
296
+ )
297
+ # init adapter modules
298
+ attn_procs = {}
299
+ unet_sd = unet.state_dict()
300
+ for name in unet.attn_processors.keys():
301
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
302
+ if name.startswith("mid_block"):
303
+ hidden_size = unet.config.block_out_channels[-1]
304
+ elif name.startswith("up_blocks"):
305
+ block_id = int(name[len("up_blocks.")])
306
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
307
+ elif name.startswith("down_blocks"):
308
+ block_id = int(name[len("down_blocks.")])
309
+ hidden_size = unet.config.block_out_channels[block_id]
310
+ if cross_attention_dim is None:
311
+ attn_procs[name] = AttnProcessor()
312
+ else:
313
+ layer_name = name.split(".processor")[0]
314
+ weights = {
315
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
316
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
317
+ }
318
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
319
+ attn_procs[name].load_state_dict(weights)
320
+ unet.set_attn_processor(attn_procs)
321
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
322
+
323
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
324
+
325
+ weight_dtype = torch.float32
326
+ if accelerator.mixed_precision == "fp16":
327
+ weight_dtype = torch.float16
328
+ elif accelerator.mixed_precision == "bf16":
329
+ weight_dtype = torch.bfloat16
330
+ # unet.to(accelerator.device, dtype=weight_dtype)
331
+ vae.to(accelerator.device, dtype=weight_dtype)
332
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
333
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
334
+
335
+ # optimizer
336
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
337
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
338
+
339
+ # dataloader
340
+ train_dataset = MyDataset(
341
+ args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
342
+ )
343
+ train_dataloader = torch.utils.data.DataLoader(
344
+ train_dataset,
345
+ shuffle=True,
346
+ collate_fn=collate_fn,
347
+ batch_size=args.train_batch_size,
348
+ num_workers=args.dataloader_num_workers,
349
+ )
350
+
351
+ # Prepare everything with our `accelerator`.
352
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
353
+
354
+ global_step = 0
355
+ for epoch in range(0, args.num_train_epochs):
356
+ begin = time.perf_counter()
357
+ for step, batch in enumerate(train_dataloader):
358
+ load_data_time = time.perf_counter() - begin
359
+ with accelerator.accumulate(ip_adapter):
360
+ # Convert images to latent space
361
+ with torch.no_grad():
362
+ latents = vae.encode(
363
+ batch["images"].to(accelerator.device, dtype=weight_dtype)
364
+ ).latent_dist.sample()
365
+ latents = latents * vae.config.scaling_factor
366
+
367
+ # Sample noise that we'll add to the latents
368
+ noise = torch.randn_like(latents)
369
+ bsz = latents.shape[0]
370
+ # Sample a random timestep for each image
371
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
372
+ timesteps = timesteps.long()
373
+
374
+ # Add noise to the latents according to the noise magnitude at each timestep
375
+ # (this is the forward diffusion process)
376
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
377
+
378
+ with torch.no_grad():
379
+ image_embeds = image_encoder(
380
+ batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
381
+ ).image_embeds
382
+ image_embeds_ = []
383
+ for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
384
+ if drop_image_embed == 1:
385
+ image_embeds_.append(torch.zeros_like(image_embed))
386
+ else:
387
+ image_embeds_.append(image_embed)
388
+ image_embeds = torch.stack(image_embeds_)
389
+
390
+ with torch.no_grad():
391
+ encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
392
+
393
+ noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
394
+
395
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
396
+
397
+ # Gather the losses across all processes for logging (if we use distributed training).
398
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
399
+
400
+ # Backpropagate
401
+ accelerator.backward(loss)
402
+ optimizer.step()
403
+ optimizer.zero_grad()
404
+
405
+ if accelerator.is_main_process:
406
+ print(
407
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
408
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
409
+ )
410
+ )
411
+
412
+ global_step += 1
413
+
414
+ if global_step % args.save_steps == 0:
415
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
416
+ accelerator.save_state(save_path)
417
+
418
+ begin = time.perf_counter()
419
+
420
+
421
+ if __name__ == "__main__":
422
+ main()
diffusers/examples/research_projects/ip_adapter/tutorial_train_plus.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
+ from accelerate.utils import ProjectConfiguration
13
+ from ip_adapter.resampler import Resampler
14
+ from ip_adapter.utils import is_torch2_available
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
18
+
19
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
20
+
21
+
22
+ if is_torch2_available():
23
+ from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
24
+ from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
25
+ else:
26
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
27
+
28
+
29
+ # Dataset
30
+ class MyDataset(torch.utils.data.Dataset):
31
+ def __init__(
32
+ self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
33
+ ):
34
+ super().__init__()
35
+
36
+ self.tokenizer = tokenizer
37
+ self.size = size
38
+ self.i_drop_rate = i_drop_rate
39
+ self.t_drop_rate = t_drop_rate
40
+ self.ti_drop_rate = ti_drop_rate
41
+ self.image_root_path = image_root_path
42
+
43
+ self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
44
+
45
+ self.transform = transforms.Compose(
46
+ [
47
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
48
+ transforms.CenterCrop(self.size),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.5], [0.5]),
51
+ ]
52
+ )
53
+ self.clip_image_processor = CLIPImageProcessor()
54
+
55
+ def __getitem__(self, idx):
56
+ item = self.data[idx]
57
+ text = item["text"]
58
+ image_file = item["image_file"]
59
+
60
+ # read image
61
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
62
+ image = self.transform(raw_image.convert("RGB"))
63
+ clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
64
+
65
+ # drop
66
+ drop_image_embed = 0
67
+ rand_num = random.random()
68
+ if rand_num < self.i_drop_rate:
69
+ drop_image_embed = 1
70
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
71
+ text = ""
72
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
73
+ text = ""
74
+ drop_image_embed = 1
75
+ # get text and tokenize
76
+ text_input_ids = self.tokenizer(
77
+ text,
78
+ max_length=self.tokenizer.model_max_length,
79
+ padding="max_length",
80
+ truncation=True,
81
+ return_tensors="pt",
82
+ ).input_ids
83
+
84
+ return {
85
+ "image": image,
86
+ "text_input_ids": text_input_ids,
87
+ "clip_image": clip_image,
88
+ "drop_image_embed": drop_image_embed,
89
+ }
90
+
91
+ def __len__(self):
92
+ return len(self.data)
93
+
94
+
95
+ def collate_fn(data):
96
+ images = torch.stack([example["image"] for example in data])
97
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
98
+ clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
99
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
100
+
101
+ return {
102
+ "images": images,
103
+ "text_input_ids": text_input_ids,
104
+ "clip_images": clip_images,
105
+ "drop_image_embeds": drop_image_embeds,
106
+ }
107
+
108
+
109
+ class IPAdapter(torch.nn.Module):
110
+ """IP-Adapter"""
111
+
112
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
113
+ super().__init__()
114
+ self.unet = unet
115
+ self.image_proj_model = image_proj_model
116
+ self.adapter_modules = adapter_modules
117
+
118
+ if ckpt_path is not None:
119
+ self.load_from_checkpoint(ckpt_path)
120
+
121
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
122
+ ip_tokens = self.image_proj_model(image_embeds)
123
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
124
+ # Predict the noise residual
125
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
126
+ return noise_pred
127
+
128
+ def load_from_checkpoint(self, ckpt_path: str):
129
+ # Calculate original checksums
130
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
131
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
132
+
133
+ state_dict = torch.load(ckpt_path, map_location="cpu")
134
+
135
+ # Check if 'latents' exists in both the saved state_dict and the current model's state_dict
136
+ strict_load_image_proj_model = True
137
+ if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
138
+ # Check if the shapes are mismatched
139
+ if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
140
+ print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
141
+ print("Removing 'latents' from checkpoint and loading the rest of the weights.")
142
+ del state_dict["image_proj"]["latents"]
143
+ strict_load_image_proj_model = False
144
+
145
+ # Load state dict for image_proj_model and adapter_modules
146
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
147
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
148
+
149
+ # Calculate new checksums
150
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
151
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
152
+
153
+ # Verify if the weights have changed
154
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
155
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
156
+
157
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
158
+
159
+
160
+ def parse_args():
161
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
162
+ parser.add_argument(
163
+ "--pretrained_model_name_or_path",
164
+ type=str,
165
+ default=None,
166
+ required=True,
167
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
168
+ )
169
+ parser.add_argument(
170
+ "--pretrained_ip_adapter_path",
171
+ type=str,
172
+ default=None,
173
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
174
+ )
175
+ parser.add_argument(
176
+ "--num_tokens",
177
+ type=int,
178
+ default=16,
179
+ help="Number of tokens to query from the CLIP image encoding.",
180
+ )
181
+ parser.add_argument(
182
+ "--data_json_file",
183
+ type=str,
184
+ default=None,
185
+ required=True,
186
+ help="Training data",
187
+ )
188
+ parser.add_argument(
189
+ "--data_root_path",
190
+ type=str,
191
+ default="",
192
+ required=True,
193
+ help="Training data root path",
194
+ )
195
+ parser.add_argument(
196
+ "--image_encoder_path",
197
+ type=str,
198
+ default=None,
199
+ required=True,
200
+ help="Path to CLIP image encoder",
201
+ )
202
+ parser.add_argument(
203
+ "--output_dir",
204
+ type=str,
205
+ default="sd-ip_adapter",
206
+ help="The output directory where the model predictions and checkpoints will be written.",
207
+ )
208
+ parser.add_argument(
209
+ "--logging_dir",
210
+ type=str,
211
+ default="logs",
212
+ help=(
213
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
214
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
215
+ ),
216
+ )
217
+ parser.add_argument(
218
+ "--resolution",
219
+ type=int,
220
+ default=512,
221
+ help=("The resolution for input images"),
222
+ )
223
+ parser.add_argument(
224
+ "--learning_rate",
225
+ type=float,
226
+ default=1e-4,
227
+ help="Learning rate to use.",
228
+ )
229
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
230
+ parser.add_argument("--num_train_epochs", type=int, default=100)
231
+ parser.add_argument(
232
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
233
+ )
234
+ parser.add_argument(
235
+ "--dataloader_num_workers",
236
+ type=int,
237
+ default=0,
238
+ help=(
239
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
240
+ ),
241
+ )
242
+ parser.add_argument(
243
+ "--save_steps",
244
+ type=int,
245
+ default=2000,
246
+ help=("Save a checkpoint of the training state every X updates"),
247
+ )
248
+ parser.add_argument(
249
+ "--mixed_precision",
250
+ type=str,
251
+ default=None,
252
+ choices=["no", "fp16", "bf16"],
253
+ help=(
254
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
255
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
256
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
257
+ ),
258
+ )
259
+ parser.add_argument(
260
+ "--report_to",
261
+ type=str,
262
+ default="tensorboard",
263
+ help=(
264
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
265
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
266
+ ),
267
+ )
268
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
269
+
270
+ args = parser.parse_args()
271
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
272
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
273
+ args.local_rank = env_local_rank
274
+
275
+ return args
276
+
277
+
278
+ def main():
279
+ args = parse_args()
280
+ logging_dir = Path(args.output_dir, args.logging_dir)
281
+
282
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
283
+
284
+ accelerator = Accelerator(
285
+ mixed_precision=args.mixed_precision,
286
+ log_with=args.report_to,
287
+ project_config=accelerator_project_config,
288
+ )
289
+
290
+ if accelerator.is_main_process:
291
+ if args.output_dir is not None:
292
+ os.makedirs(args.output_dir, exist_ok=True)
293
+
294
+ # Load scheduler, tokenizer and models.
295
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
296
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
297
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
298
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
299
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
300
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
301
+ # freeze parameters of models to save more memory
302
+ unet.requires_grad_(False)
303
+ vae.requires_grad_(False)
304
+ text_encoder.requires_grad_(False)
305
+ image_encoder.requires_grad_(False)
306
+
307
+ # ip-adapter-plus
308
+ image_proj_model = Resampler(
309
+ dim=unet.config.cross_attention_dim,
310
+ depth=4,
311
+ dim_head=64,
312
+ heads=12,
313
+ num_queries=args.num_tokens,
314
+ embedding_dim=image_encoder.config.hidden_size,
315
+ output_dim=unet.config.cross_attention_dim,
316
+ ff_mult=4,
317
+ )
318
+ # init adapter modules
319
+ attn_procs = {}
320
+ unet_sd = unet.state_dict()
321
+ for name in unet.attn_processors.keys():
322
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
323
+ if name.startswith("mid_block"):
324
+ hidden_size = unet.config.block_out_channels[-1]
325
+ elif name.startswith("up_blocks"):
326
+ block_id = int(name[len("up_blocks.")])
327
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
328
+ elif name.startswith("down_blocks"):
329
+ block_id = int(name[len("down_blocks.")])
330
+ hidden_size = unet.config.block_out_channels[block_id]
331
+ if cross_attention_dim is None:
332
+ attn_procs[name] = AttnProcessor()
333
+ else:
334
+ layer_name = name.split(".processor")[0]
335
+ weights = {
336
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
337
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
338
+ }
339
+ attn_procs[name] = IPAttnProcessor(
340
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens
341
+ )
342
+ attn_procs[name].load_state_dict(weights)
343
+ unet.set_attn_processor(attn_procs)
344
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
345
+
346
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
347
+
348
+ weight_dtype = torch.float32
349
+ if accelerator.mixed_precision == "fp16":
350
+ weight_dtype = torch.float16
351
+ elif accelerator.mixed_precision == "bf16":
352
+ weight_dtype = torch.bfloat16
353
+ # unet.to(accelerator.device, dtype=weight_dtype)
354
+ vae.to(accelerator.device, dtype=weight_dtype)
355
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
356
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
357
+
358
+ # optimizer
359
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
360
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
361
+
362
+ # dataloader
363
+ train_dataset = MyDataset(
364
+ args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
365
+ )
366
+ train_dataloader = torch.utils.data.DataLoader(
367
+ train_dataset,
368
+ shuffle=True,
369
+ collate_fn=collate_fn,
370
+ batch_size=args.train_batch_size,
371
+ num_workers=args.dataloader_num_workers,
372
+ )
373
+
374
+ # Prepare everything with our `accelerator`.
375
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
376
+
377
+ global_step = 0
378
+ for epoch in range(0, args.num_train_epochs):
379
+ begin = time.perf_counter()
380
+ for step, batch in enumerate(train_dataloader):
381
+ load_data_time = time.perf_counter() - begin
382
+ with accelerator.accumulate(ip_adapter):
383
+ # Convert images to latent space
384
+ with torch.no_grad():
385
+ latents = vae.encode(
386
+ batch["images"].to(accelerator.device, dtype=weight_dtype)
387
+ ).latent_dist.sample()
388
+ latents = latents * vae.config.scaling_factor
389
+
390
+ # Sample noise that we'll add to the latents
391
+ noise = torch.randn_like(latents)
392
+ bsz = latents.shape[0]
393
+ # Sample a random timestep for each image
394
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
395
+ timesteps = timesteps.long()
396
+
397
+ # Add noise to the latents according to the noise magnitude at each timestep
398
+ # (this is the forward diffusion process)
399
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
400
+
401
+ clip_images = []
402
+ for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
403
+ if drop_image_embed == 1:
404
+ clip_images.append(torch.zeros_like(clip_image))
405
+ else:
406
+ clip_images.append(clip_image)
407
+ clip_images = torch.stack(clip_images, dim=0)
408
+ with torch.no_grad():
409
+ image_embeds = image_encoder(
410
+ clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True
411
+ ).hidden_states[-2]
412
+
413
+ with torch.no_grad():
414
+ encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
415
+
416
+ noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
417
+
418
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
419
+
420
+ # Gather the losses across all processes for logging (if we use distributed training).
421
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
422
+
423
+ # Backpropagate
424
+ accelerator.backward(loss)
425
+ optimizer.step()
426
+ optimizer.zero_grad()
427
+
428
+ if accelerator.is_main_process:
429
+ print(
430
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
431
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
432
+ )
433
+ )
434
+
435
+ global_step += 1
436
+
437
+ if global_step % args.save_steps == 0:
438
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
439
+ accelerator.save_state(save_path)
440
+
441
+ begin = time.perf_counter()
442
+
443
+
444
+ if __name__ == "__main__":
445
+ main()