Spaces:
Runtime error
Runtime error
| # Würstchen text-to-image fine-tuning | |
| ## Running locally with PyTorch | |
| Before running the scripts, make sure to install the library's training dependencies: | |
| **Important** | |
| To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment: | |
| ```bash | |
| git clone https://github.com/huggingface/diffusers | |
| cd diffusers | |
| pip install . | |
| ``` | |
| Then cd into the example folder and run | |
| ```bash | |
| cd examples/wuerstchen/text_to_image | |
| pip install -r requirements.txt | |
| ``` | |
| And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | |
| ```bash | |
| accelerate config | |
| ``` | |
| For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run: | |
| ```bash | |
| huggingface-cli login | |
| ``` | |
| ## Prior training | |
| You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups. | |
| <br> | |
| <!-- accelerate_snippet_start --> | |
| ```bash | |
| export DATASET_NAME="lambdalabs/naruto-blip-captions" | |
| accelerate launch train_text_to_image_prior.py \ | |
| --mixed_precision="fp16" \ | |
| --dataset_name=$DATASET_NAME \ | |
| --resolution=768 \ | |
| --train_batch_size=4 \ | |
| --gradient_accumulation_steps=4 \ | |
| --gradient_checkpointing \ | |
| --dataloader_num_workers=4 \ | |
| --max_train_steps=15000 \ | |
| --learning_rate=1e-05 \ | |
| --max_grad_norm=1 \ | |
| --checkpoints_total_limit=3 \ | |
| --lr_scheduler="constant" --lr_warmup_steps=0 \ | |
| --validation_prompts="A robot naruto, 4k photo" \ | |
| --report_to="wandb" \ | |
| --push_to_hub \ | |
| --output_dir="wuerstchen-prior-naruto-model" | |
| ``` | |
| <!-- accelerate_snippet_end --> | |
| ## Training with LoRA | |
| Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. | |
| In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: | |
| - Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). | |
| - Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. | |
| - LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. | |
| ### Prior Training | |
| First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Naruto captions dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions). | |
| ```bash | |
| export DATASET_NAME="lambdalabs/naruto-blip-captions" | |
| accelerate launch train_text_to_image_lora_prior.py \ | |
| --mixed_precision="fp16" \ | |
| --dataset_name=$DATASET_NAME --caption_column="text" \ | |
| --resolution=768 \ | |
| --train_batch_size=8 \ | |
| --num_train_epochs=100 --checkpointing_steps=5000 \ | |
| --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ | |
| --seed=42 \ | |
| --rank=4 \ | |
| --validation_prompt="cute dragon creature" \ | |
| --report_to="wandb" \ | |
| --push_to_hub \ | |
| --output_dir="wuerstchen-prior-naruto-lora" | |
| ``` | |