Spaces:
Runtime error
Runtime error
| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| --> | |
| # Text-to-image | |
| <Tip warning={true}> | |
| text-to-image νμΈνλ μ€ν¬λ¦½νΈλ experimental μνμ λλ€. κ³Όμ ν©νκΈ° μ½κ³ μΉλͺ μ μΈ λ§κ°κ³Ό κ°μ λ¬Έμ μ λΆλͺνκΈ° μ½μ΅λλ€. μ체 λ°μ΄ν°μ μμ μ΅μμ κ²°κ³Όλ₯Ό μ»μΌλ €λ©΄ λ€μν νμ΄νΌνλΌλ―Έν°λ₯Ό νμνλ κ²μ΄ μ’μ΅λλ€. | |
| </Tip> | |
| Stable Diffusionκ³Ό κ°μ text-to-image λͺ¨λΈμ ν μ€νΈ ν둬ννΈμμ μ΄λ―Έμ§λ₯Ό μμ±ν©λλ€. μ΄ κ°μ΄λλ PyTorch λ° Flaxλ₯Ό μ¬μ©νμ¬ μ체 λ°μ΄ν°μ μμ [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) λͺ¨λΈλ‘ νμΈνλνλ λ°©λ²μ 보μ¬μ€λλ€. μ΄ κ°μ΄λμ μ¬μ©λ text-to-image νμΈνλμ μν λͺ¨λ νμ΅ μ€ν¬λ¦½νΈμ κ΄μ¬μ΄ μλ κ²½μ° μ΄ [리ν¬μ§ν 리](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)μμ μμΈν μ°Ύμ μ μμ΅λλ€. | |
| μ€ν¬λ¦½νΈλ₯Ό μ€ννκΈ° μ μ, λΌμ΄λΈλ¬λ¦¬μ νμ΅ dependencyλ€μ μ€μΉν΄μΌ ν©λλ€: | |
| ```bash | |
| pip install git+https://github.com/huggingface/diffusers.git | |
| pip install -U -r requirements.txt | |
| ``` | |
| κ·Έλ¦¬κ³ [π€Accelerate](https://github.com/huggingface/accelerate/) νκ²½μ μ΄κΈ°νν©λλ€: | |
| ```bash | |
| accelerate config | |
| ``` | |
| 리ν¬μ§ν 리λ₯Ό μ΄λ―Έ 볡μ ν κ²½μ°, μ΄ λ¨κ³λ₯Ό μνν νμκ° μμ΅λλ€. λμ , λ‘컬 체ν¬μμ κ²½λ‘λ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μν μ μμΌλ©° κ±°κΈ°μμ λ‘λλ©λλ€. | |
| ### νλμ¨μ΄ μꡬ μ¬ν | |
| `gradient_checkpointing` λ° `mixed_precision`μ μ¬μ©νλ©΄ λ¨μΌ 24GB GPUμμ λͺ¨λΈμ νμΈνλν μ μμ΅λλ€. λ λμ `batch_size`μ λ λΉ λ₯Έ νλ ¨μ μν΄μλ GPU λ©λͺ¨λ¦¬κ° 30GB μ΄μμΈ GPUλ₯Ό μ¬μ©νλ κ²μ΄ μ’μ΅λλ€. TPU λλ GPUμμ νμΈνλμ μν΄ JAXλ Flaxλ₯Ό μ¬μ©ν μλ μμ΅λλ€. μμΈν λ΄μ©μ [μλ](#flax-jax-finetuning)λ₯Ό μ°Έμ‘°νμΈμ. | |
| xFormersλ‘ memory efficient attentionμ νμ±ννμ¬ λ©λͺ¨λ¦¬ μ¬μ©λ ν¨μ¬ λ μ€μΌ μ μμ΅λλ€. [xFormersκ° μ€μΉ](./optimization/xformers)λμ΄ μλμ§ νμΈνκ³ `--enable_xformers_memory_efficient_attention`λ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μν©λλ€. | |
| xFormersλ Flaxμ μ¬μ©ν μ μμ΅λλ€. | |
| ## Hubμ λͺ¨λΈ μ λ‘λνκΈ° | |
| νμ΅ μ€ν¬λ¦½νΈμ λ€μ μΈμλ₯Ό μΆκ°νμ¬ λͺ¨λΈμ νλΈμ μ μ₯ν©λλ€: | |
| ```bash | |
| --push_to_hub | |
| ``` | |
| ## 체ν¬ν¬μΈνΈ μ μ₯ λ° λΆλ¬μ€κΈ° | |
| νμ΅ μ€ λ°μν μ μλ μΌμ λλΉνμ¬ μ κΈ°μ μΌλ‘ 체ν¬ν¬μΈνΈλ₯Ό μ μ₯ν΄ λλ κ²μ΄ μ’μ΅λλ€. 체ν¬ν¬μΈνΈλ₯Ό μ μ₯νλ €λ©΄ νμ΅ μ€ν¬λ¦½νΈμ λ€μ μΈμλ₯Ό λͺ μν©λλ€. | |
| ```bash | |
| --checkpointing_steps=500 | |
| ``` | |
| 500μ€ν λ§λ€ μ 체 νμ΅ stateκ° 'output_dir'μ νμ ν΄λμ μ μ₯λ©λλ€. 체ν¬ν¬μΈνΈλ 'checkpoint-'μ μ§κΈκΉμ§ νμ΅λ step μμ λλ€. μλ₯Ό λ€μ΄ 'checkpoint-1500'μ 1500 νμ΅ step νμ μ μ₯λ 체ν¬ν¬μΈνΈμ λλ€. | |
| νμ΅μ μ¬κ°νκΈ° μν΄ μ²΄ν¬ν¬μΈνΈλ₯Ό λΆλ¬μ€λ €λ©΄ '--resume_from_checkpoint' μΈμλ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μνκ³ μ¬κ°ν 체ν¬ν¬μΈνΈλ₯Ό μ§μ νμμμ€. μλ₯Ό λ€μ΄ λ€μ μΈμλ 1500κ°μ νμ΅ step νμ μ μ₯λ 체ν¬ν¬μΈνΈμμλΆν° νλ ¨μ μ¬κ°ν©λλ€. | |
| ```bash | |
| --resume_from_checkpoint="checkpoint-1500" | |
| ``` | |
| ## νμΈνλ | |
| <frameworkcontent> | |
| <pt> | |
| λ€μκ³Ό κ°μ΄ [Naruto BLIP μΊ‘μ ](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) λ°μ΄ν°μ μμ νμΈνλ μ€νμ μν΄ [PyTorch νμ΅ μ€ν¬λ¦½νΈ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)λ₯Ό μ€νν©λλ€: | |
| ```bash | |
| export MODEL_NAME="CompVis/stable-diffusion-v1-4" | |
| export dataset_name="lambdalabs/naruto-blip-captions" | |
| accelerate launch train_text_to_image.py \ | |
| --pretrained_model_name_or_path=$MODEL_NAME \ | |
| --dataset_name=$dataset_name \ | |
| --use_ema \ | |
| --resolution=512 --center_crop --random_flip \ | |
| --train_batch_size=1 \ | |
| --gradient_accumulation_steps=4 \ | |
| --gradient_checkpointing \ | |
| --mixed_precision="fp16" \ | |
| --max_train_steps=15000 \ | |
| --learning_rate=1e-05 \ | |
| --max_grad_norm=1 \ | |
| --lr_scheduler="constant" --lr_warmup_steps=0 \ | |
| --output_dir="sd-naruto-model" | |
| ``` | |
| μ체 λ°μ΄ν°μ μΌλ‘ νμΈνλνλ €λ©΄ π€ [Datasets](https://huggingface.co/docs/datasets/index)μμ μꡬνλ νμμ λ°λΌ λ°μ΄ν°μ μ μ€λΉνμΈμ. [λ°μ΄ν°μ μ νλΈμ μ λ‘λ](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)νκ±°λ [νμΌλ€μ΄ μλ λ‘컬 ν΄λλ₯Ό μ€λΉ](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν μ μμ΅λλ€. | |
| μ¬μ©μ 컀μ€ν loading logicμ μ¬μ©νλ €λ©΄ μ€ν¬λ¦½νΈλ₯Ό μμ νμμμ€. λμμ΄ λλλ‘ μ½λμ μ μ ν μμΉμ ν¬μΈν°λ₯Ό λ¨κ²Όμ΅λλ€. π€ μλ μμ μ€ν¬λ¦½νΈλ `TRAIN_DIR`μ λ‘컬 λ°μ΄ν°μ μΌλ‘λ₯Ό νμΈνλνλ λ°©λ²κ³Ό `OUTPUT_DIR`μμ λͺ¨λΈμ μ μ₯ν μμΉλ₯Ό 보μ¬μ€λλ€: | |
| ```bash | |
| export MODEL_NAME="CompVis/stable-diffusion-v1-4" | |
| export TRAIN_DIR="path_to_your_dataset" | |
| export OUTPUT_DIR="path_to_save_model" | |
| accelerate launch train_text_to_image.py \ | |
| --pretrained_model_name_or_path=$MODEL_NAME \ | |
| --train_data_dir=$TRAIN_DIR \ | |
| --use_ema \ | |
| --resolution=512 --center_crop --random_flip \ | |
| --train_batch_size=1 \ | |
| --gradient_accumulation_steps=4 \ | |
| --gradient_checkpointing \ | |
| --mixed_precision="fp16" \ | |
| --max_train_steps=15000 \ | |
| --learning_rate=1e-05 \ | |
| --max_grad_norm=1 \ | |
| --lr_scheduler="constant" --lr_warmup_steps=0 \ | |
| --output_dir=${OUTPUT_DIR} | |
| ``` | |
| </pt> | |
| <jax> | |
| [@duongna211](https://github.com/duongna21)μ κΈ°μ¬λ‘, Flaxλ₯Ό μ¬μ©ν΄ TPU λ° GPUμμ Stable Diffusion λͺ¨λΈμ λ λΉ λ₯΄κ² νμ΅ν μ μμ΅λλ€. μ΄λ TPU νλμ¨μ΄μμ λ§€μ° ν¨μ¨μ μ΄μ§λ§ GPUμμλ νλ₯νκ² μλν©λλ€. Flax νμ΅ μ€ν¬λ¦½νΈλ gradient checkpointingλ gradient accumulationκ³Ό κ°μ κΈ°λ₯μ μμ§ μ§μνμ§ μμΌλ―λ‘ λ©λͺ¨λ¦¬κ° 30GB μ΄μμΈ GPU λλ TPU v3κ° νμν©λλ€. | |
| μ€ν¬λ¦½νΈλ₯Ό μ€ννκΈ° μ μ μꡬ μ¬νμ΄ μ€μΉλμ΄ μλμ§ νμΈνμμμ€: | |
| ```bash | |
| pip install -U -r requirements_flax.txt | |
| ``` | |
| κ·Έλ¬λ©΄ λ€μκ³Ό κ°μ΄ [Flax νμ΅ μ€ν¬λ¦½νΈ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)λ₯Ό μ€νν μ μμ΅λλ€. | |
| ```bash | |
| export MODEL_NAME="runwayml/stable-diffusion-v1-5" | |
| export dataset_name="lambdalabs/naruto-blip-captions" | |
| python train_text_to_image_flax.py \ | |
| --pretrained_model_name_or_path=$MODEL_NAME \ | |
| --dataset_name=$dataset_name \ | |
| --resolution=512 --center_crop --random_flip \ | |
| --train_batch_size=1 \ | |
| --max_train_steps=15000 \ | |
| --learning_rate=1e-05 \ | |
| --max_grad_norm=1 \ | |
| --output_dir="sd-naruto-model" | |
| ``` | |
| μ체 λ°μ΄ν°μ μΌλ‘ νμΈνλνλ €λ©΄ π€ [Datasets](https://huggingface.co/docs/datasets/index)μμ μꡬνλ νμμ λ°λΌ λ°μ΄ν°μ μ μ€λΉνμΈμ. [λ°μ΄ν°μ μ νλΈμ μ λ‘λ](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)νκ±°λ [νμΌλ€μ΄ μλ λ‘컬 ν΄λλ₯Ό μ€λΉ](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν μ μμ΅λλ€. | |
| μ¬μ©μ 컀μ€ν loading logicμ μ¬μ©νλ €λ©΄ μ€ν¬λ¦½νΈλ₯Ό μμ νμμμ€. λμμ΄ λλλ‘ μ½λμ μ μ ν μμΉμ ν¬μΈν°λ₯Ό λ¨κ²Όμ΅λλ€. π€ μλ μμ μ€ν¬λ¦½νΈλ `TRAIN_DIR`μ λ‘컬 λ°μ΄ν°μ μΌλ‘λ₯Ό νμΈνλνλ λ°©λ²μ 보μ¬μ€λλ€: | |
| ```bash | |
| export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" | |
| export TRAIN_DIR="path_to_your_dataset" | |
| python train_text_to_image_flax.py \ | |
| --pretrained_model_name_or_path=$MODEL_NAME \ | |
| --train_data_dir=$TRAIN_DIR \ | |
| --resolution=512 --center_crop --random_flip \ | |
| --train_batch_size=1 \ | |
| --mixed_precision="fp16" \ | |
| --max_train_steps=15000 \ | |
| --learning_rate=1e-05 \ | |
| --max_grad_norm=1 \ | |
| --output_dir="sd-naruto-model" | |
| ``` | |
| </jax> | |
| </frameworkcontent> | |
| ## LoRA | |
| Text-to-image λͺ¨λΈ νμΈνλμ μν΄, λκ·λͺ¨ λͺ¨λΈ νμ΅μ κ°μννκΈ° μν νμΈνλ κΈ°μ μΈ LoRA(Low-Rank Adaptation of Large Language Models)λ₯Ό μ¬μ©ν μ μμ΅λλ€. μμΈν λ΄μ©μ [LoRA νμ΅](lora#text-to-image) κ°μ΄λλ₯Ό μ°Έμ‘°νμΈμ. | |
| ## μΆλ‘ | |
| νλΈμ λͺ¨λΈ κ²½λ‘ λλ λͺ¨λΈ μ΄λ¦μ [`StableDiffusionPipeline`]μ μ λ¬νμ¬ μΆλ‘ μ μν΄ νμΈ νλλ λͺ¨λΈμ λΆλ¬μ¬ μ μμ΅λλ€: | |
| <frameworkcontent> | |
| <pt> | |
| ```python | |
| from diffusers import StableDiffusionPipeline | |
| model_path = "path_to_saved_model" | |
| pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) | |
| pipe.to("cuda") | |
| image = pipe(prompt="yoda").images[0] | |
| image.save("yoda-naruto.png") | |
| ``` | |
| </pt> | |
| <jax> | |
| ```python | |
| import jax | |
| import numpy as np | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from diffusers import FlaxStableDiffusionPipeline | |
| model_path = "path_to_saved_model" | |
| pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) | |
| prompt = "yoda naruto" | |
| prng_seed = jax.random.PRNGKey(0) | |
| num_inference_steps = 50 | |
| num_samples = jax.device_count() | |
| prompt = num_samples * [prompt] | |
| prompt_ids = pipeline.prepare_inputs(prompt) | |
| # shard inputs and rng | |
| params = replicate(params) | |
| prng_seed = jax.random.split(prng_seed, jax.device_count()) | |
| prompt_ids = shard(prompt_ids) | |
| images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | |
| images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
| image.save("yoda-naruto.png") | |
| ``` | |
| </jax> | |
| </frameworkcontent> |