Buckets:

rtrm's picture
download
raw
24.5 kB
<meta charset="utf-8" /><meta http-equiv="content-security-policy" content=""><meta name="hf:doc:metadata" content="{&quot;local&quot;:&quot;stable-diffusion-texttoimage-finetuning&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;running-locally&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;installing-the-dependencies&quot;,&quot;title&quot;:&quot;Installing the dependencies&quot;},{&quot;local&quot;:&quot;hardware-requirements-for-finetuning&quot;,&quot;title&quot;:&quot;Hardware Requirements for Fine-tuning&quot;},{&quot;local&quot;:&quot;finetuning-example&quot;,&quot;title&quot;:&quot;Fine-tuning Example&quot;},{&quot;local&quot;:&quot;flax-jax-finetuning&quot;,&quot;title&quot;:&quot;Flax / JAX fine-tuning&quot;}],&quot;title&quot;:&quot;Running locally &quot;}],&quot;title&quot;:&quot;Stable Diffusion text-to-image fine-tuning&quot;}" data-svelte="svelte-1phssyn">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/assets/pages/__layout.svelte-hf-doc-builder.css">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/start-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/chunks/vendor-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/chunks/paths-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/pages/__layout.svelte-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/pages/training/text2image.mdx-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/chunks/Tip-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/chunks/IconCopyLink-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.7.0/en/_app/chunks/CodeBlock-hf-doc-builder.js">
<h1 class="relative group"><a id="stable-diffusion-texttoimage-finetuning" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#stable-diffusion-texttoimage-finetuning"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Stable Diffusion text-to-image fine-tuning
</span></h1>
<p>The <a href="https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion" rel="nofollow"><code>train_text_to_image.py</code></a> script shows how to fine-tune the stable diffusion model on your own dataset.</p>
<div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p>The text-to-image fine-tuning script is experimental. It’s easy to overfit and run into issues like catastrophic forgetting. We recommend to explore different hyperparameters to get the best results on your dataset.</p></div>
<h2 class="relative group"><a id="running-locally" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#running-locally"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Running locally
</span></h2>
<h3 class="relative group"><a id="installing-the-dependencies" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#installing-the-dependencies"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Installing the dependencies
</span></h3>
<p>Before running the scripts, make sure to install the library’s training dependencies:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt<!-- HTML_TAG_END --></pre></div>
<p>And initialize an <a href="https://github.com/huggingface/accelerate/" rel="nofollow">🤗Accelerate</a> environment with:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div>
<p>You need to accept the model license before downloading or using the weights. In this example we’ll use model version <code>v1-4</code>, so you’ll need to visit <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" rel="nofollow">its card</a>, read the license and tick the checkbox if you agree. </p>
<p>You have to be a registered user in 🤗 Hugging Face Hub, and you’ll also need to use an access token for the code to work. For more information on access tokens, please refer to <a href="https://huggingface.co/docs/hub/security-tokens" rel="nofollow">this section of the documentation</a>.</p>
<p>Run the following command to authenticate your token</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->huggingface-cli login<!-- HTML_TAG_END --></pre></div>
<p>If you have already cloned the repo, then you won’t need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there.</p>
<h3 class="relative group"><a id="hardware-requirements-for-finetuning" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#hardware-requirements-for-finetuning"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Hardware Requirements for Fine-tuning
</span></h3>
<p>Using <code>gradient_checkpointing</code> and <code>mixed_precision</code> it should be possible to fine tune the model on a single 24GB GPU. For higher <code>batch_size</code> and faster training it’s better to use GPUs with more than 30GB of GPU memory. You can also use JAX / Flax for fine-tuning on TPUs or GPUs, see <a href="#flax-jax-finetuning">below</a> for details.</p>
<h3 class="relative group"><a id="finetuning-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#finetuning-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Fine-tuning Example
</span></h3>
<p>The following script will launch a fine-tuning run using <a href="https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions" rel="nofollow">Justin Pinkneys’ captioned Pokemon dataset</a>, available in Hugging Face Hub.</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>
<span class="hljs-built_in">export</span> dataset_name=<span class="hljs-string">&quot;lambdalabs/pokemon-blip-captions&quot;</span>
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--dataset_name=<span class="hljs-variable">$dataset_name</span> \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision=<span class="hljs-string">&quot;fp16&quot;</span> \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler=<span class="hljs-string">&quot;constant&quot;</span> --lr_warmup_steps=0 \
--output_dir=<span class="hljs-string">&quot;sd-pokemon-model&quot;</span> <!-- HTML_TAG_END --></pre></div>
<p>To run on your own training files you need to prepare the dataset according to the format required by <code>datasets</code>. You can upload your dataset to the Hub, or you can prepare a local folder with your files. <a href="https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata" rel="nofollow">This documentation</a> explains how to do it.</p>
<p>You should modify the script if you wish to use custom loading logic. We have left pointers in the code in the appropriate places :)</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>
<span class="hljs-built_in">export</span> TRAIN_DIR=<span class="hljs-string">&quot;path_to_your_dataset&quot;</span>
<span class="hljs-built_in">export</span> OUTPUT_DIR=<span class="hljs-string">&quot;path_to_save_model&quot;</span>
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--train_data_dir=<span class="hljs-variable">$TRAIN_DIR</span> \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision=<span class="hljs-string">&quot;fp16&quot;</span> \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler=<span class="hljs-string">&quot;constant&quot;</span> --lr_warmup_steps=0 \
--output_dir=<span class="hljs-variable">${OUTPUT_DIR}</span><!-- HTML_TAG_END --></pre></div>
<p>Once training is finished the model will be saved to the <code>OUTPUT_DIR</code> specified in the command. To load the fine-tuned model for inference, just pass that path to <code>StableDiffusionPipeline</code>:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
model_path = <span class="hljs-string">&quot;path_to_saved_model&quot;</span>
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to(<span class="hljs-string">&quot;cuda&quot;</span>)
image = pipe(prompt=<span class="hljs-string">&quot;yoda&quot;</span>).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;yoda-pokemon.png&quot;</span>)<!-- HTML_TAG_END --></pre></div>
<h3 class="relative group"><a id="flax-jax-finetuning" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#flax-jax-finetuning"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Flax / JAX fine-tuning
</span></h3>
<p>Thanks to <a href="https://github.com/duongna21" rel="nofollow">@duongna211</a> it’s possible to fine-tune Stable Diffusion using Flax! This is very efficient on TPU hardware but works great on GPUs too. You can use the <a href="https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py" rel="nofollow">Flax training script</a> like this:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->export MODEL_NAME=<span class="hljs-string">&quot;runwayml/stable-diffusion-v1-5&quot;</span>
export dataset_name=<span class="hljs-string">&quot;lambdalabs/pokemon-blip-captions&quot;</span>
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=<span class="hljs-number">512</span> --center_crop --random_flip \
--train_batch_size=<span class="hljs-number">1</span> \
--max_train_steps=<span class="hljs-number">15000</span> \
--learning_rate=<span class="hljs-number">1e-05</span> \
--max_grad_norm=<span class="hljs-number">1</span> \
--output_dir=<span class="hljs-string">&quot;sd-pokemon-model&quot;</span> <!-- HTML_TAG_END --></pre></div>
<script type="module" data-hydrate="16ldxk8">
import { start } from "/docs/diffusers/v0.7.0/en/_app/start-hf-doc-builder.js";
start({
target: document.querySelector('[data-hydrate="16ldxk8"]').parentNode,
paths: {"base":"/docs/diffusers/v0.7.0/en","assets":"/docs/diffusers/v0.7.0/en"},
session: {},
route: false,
spa: false,
trailing_slash: "never",
hydrate: {
status: 200,
error: null,
nodes: [
import("/docs/diffusers/v0.7.0/en/_app/pages/__layout.svelte-hf-doc-builder.js"),
import("/docs/diffusers/v0.7.0/en/_app/pages/training/text2image.mdx-hf-doc-builder.js")
],
params: {}
}
});
</script>

Xet Storage Details

Size:
24.5 kB
·
Xet hash:
e55b4fd68b65e50e0e47885fc48c0dc591c98c2319ed973ef055dc937d167d0c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.