| # LoRA training example for Stable Diffusion XL (SDXL) | |
| Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. | |
| In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: | |
| - Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). | |
| - Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. | |
| - LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. | |
| [cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. | |
| With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset | |
| on consumer GPUs like Tesla T4, Tesla V100. | |
| ## Running locally with PyTorch | |
| ### Installing the dependencies | |
| Before running the scripts, make sure to install the library's training dependencies: | |
| **Important** | |
| To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | |
| ```bash | |
| git clone https://github.com/huggingface/diffusers | |
| cd diffusers | |
| pip install -e . | |
| ``` | |
| Then cd in the `examples/text_to_image` folder and run | |
| ```bash | |
| pip install -r requirements_sdxl.txt | |
| ``` | |
| And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | |
| ```bash | |
| accelerate config | |
| ``` | |
| Or for a default accelerate configuration without answering questions about your environment | |
| ```bash | |
| accelerate config default | |
| ``` | |
| Or if your environment doesn't support an interactive shell (e.g., a notebook) | |
| ```python | |
| from accelerate.utils import write_basic_config | |
| write_basic_config() | |
| ``` | |
| When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. | |
| ### Training | |
| First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). | |
| **___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** | |
| ```bash | |
| export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" | |
| export DATASET_NAME="lambdalabs/pokemon-blip-captions" | |
| ``` | |
| For this example we want to directly store the trained LoRA embeddings on the Hub, so | |
| we need to be logged in and add the `--push_to_hub` flag. | |
| ```bash | |
| huggingface-cli login | |
| ``` | |
| Now we can start training! | |
| ```bash | |
| accelerate launch train_text_to_image_lora_sdxl.py \ | |
| --pretrained_model_name_or_path=$MODEL_NAME \ | |
| --dataset_name=$DATASET_NAME --caption_column="text" \ | |
| --resolution=1024 --random_flip \ | |
| --train_batch_size=1 \ | |
| --num_train_epochs=2 --checkpointing_steps=500 \ | |
| --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ | |
| --seed=42 \ | |
| --output_dir="sd-pokemon-model-lora-sdxl" \ | |
| --validation_prompt="cute dragon creature" --report_to="wandb" \ | |
| --push_to_hub | |
| ``` | |
| The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. | |
| ### Finetuning the text encoder and UNet | |
| The script also allows you to finetune the `text_encoder` along with the `unet`. | |
| 🚨 Training the text encoder requires additional memory. | |
| Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`: | |
| ```bash | |
| accelerate launch train_text_to_image_lora_sdxl.py \ | |
| --pretrained_model_name_or_path=$MODEL_NAME \ | |
| --dataset_name=$DATASET_NAME --caption_column="text" \ | |
| --resolution=1024 --random_flip \ | |
| --train_batch_size=1 \ | |
| --num_train_epochs=2 --checkpointing_steps=500 \ | |
| --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ | |
| --seed=42 \ | |
| --output_dir="sd-pokemon-model-lora-sdxl-txt" \ | |
| --train_text_encoder \ | |
| --validation_prompt="cute dragon creature" --report_to="wandb" \ | |
| --push_to_hub | |
| ``` | |
| ### Inference | |
| Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You | |
| need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`. | |
| ```python | |
| from diffusers import DiffusionPipeline | |
| import torch | |
| model_path = "takuoko/sd-pokemon-model-lora-sdxl" | |
| pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) | |
| pipe.to("cuda") | |
| pipe.load_lora_weights(model_path) | |
| prompt = "A pokemon with green eyes and red legs." | |
| image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] | |
| image.save("pokemon.png") | |
| ``` | |