Buckets:

hf-doc-build/doc / diffusers /v0.19.2 /en /training /custom_diffusion.html
rtrm's picture
download
raw
48.7 kB
<meta charset="utf-8" /><meta http-equiv="content-security-policy" content=""><meta name="hf:doc:metadata" content="{&quot;local&quot;:&quot;custom-diffusion-training-example&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;running-locally-with-pytorch&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;installing-the-dependencies&quot;,&quot;title&quot;:&quot;Installing the dependencies&quot;},{&quot;local&quot;:&quot;cat-example&quot;,&quot;title&quot;:&quot;Cat example 😺&quot;},{&quot;local&quot;:&quot;training-on-multiple-concepts&quot;,&quot;title&quot;:&quot;Training on multiple concepts 🐱🪵&quot;},{&quot;local&quot;:&quot;training-on-human-faces&quot;,&quot;title&quot;:&quot;Training on human faces&quot;}],&quot;title&quot;:&quot;Running locally with PyTorch&quot;},{&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;inference-from-a-training-checkpoint&quot;,&quot;title&quot;:&quot;Inference from a training checkpoint&quot;}],&quot;title&quot;:&quot;Inference&quot;},{&quot;local&quot;:&quot;set-grads-to-none&quot;,&quot;title&quot;:&quot;Set grads to none&quot;},{&quot;local&quot;:&quot;experimental-results&quot;,&quot;title&quot;:&quot;Experimental results&quot;}],&quot;title&quot;:&quot;Custom Diffusion training example &quot;}" data-svelte="svelte-1phssyn">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/assets/pages/__layout.svelte-hf-doc-builder.css">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/start-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/chunks/vendor-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/chunks/paths-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/pages/__layout.svelte-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/pages/training/custom_diffusion.mdx-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/chunks/IconCopyLink-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/diffusers/v0.19.2/en/_app/chunks/CodeBlock-hf-doc-builder.js">
<h1 class="relative group"><a id="custom-diffusion-training-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#custom-diffusion-training-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Custom Diffusion training example
</span></h1>
<p><a href="https://arxiv.org/abs/2212.04488" rel="nofollow">Custom Diffusion</a> is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
The <code>train_custom_diffusion.py</code> script shows how to implement the training procedure and adapt it for stable diffusion.</p>
<p>This training example was contributed by <a href="https://nupurkmr9.github.io/" rel="nofollow">Nupur Kumari</a> (one of the authors of Custom Diffusion). </p>
<h2 class="relative group"><a id="running-locally-with-pytorch" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#running-locally-with-pytorch"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Running locally with PyTorch
</span></h2>
<h3 class="relative group"><a id="installing-the-dependencies" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#installing-the-dependencies"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Installing the dependencies
</span></h3>
<p>Before running the scripts, make sure to install the library’s training dependencies:</p>
<p><strong>Important</strong></p>
<p>To make sure you can successfully run the latest versions of the example scripts, we highly recommend <strong>installing from source</strong> and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->git <span class="hljs-built_in">clone</span> https://github.com/huggingface/diffusers
<span class="hljs-built_in">cd</span> diffusers
pip install -e .<!-- HTML_TAG_END --></pre></div>
<p>Then cd into the <a href="https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion" rel="nofollow">example folder</a></p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-built_in">cd</span> examples/custom_diffusion<!-- HTML_TAG_END --></pre></div>
<p>Now run</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->pip install -r requirements.txt
pip install clip-retrieval <!-- HTML_TAG_END --></pre></div>
<p>And initialize an <a href="https://github.com/huggingface/accelerate/" rel="nofollow">🤗Accelerate</a> environment with:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div>
<p>Or for a default accelerate configuration without answering questions about your environment</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->accelerate config default<!-- HTML_TAG_END --></pre></div>
<p>Or if your environment doesn’t support an interactive shell e.g. a notebook</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> write_basic_config
write_basic_config()<!-- HTML_TAG_END --></pre></div>
<h3 class="relative group"><a id="cat-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cat-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Cat example 😺
</span></h3>
<p>Now let’s get our dataset. Download dataset from <a href="https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip" rel="nofollow">here</a> and unzip it. To use your own dataset, take a look at the <a href="create_dataset">Create a dataset for training</a> guide.</p>
<p>We also collect 200 real images using <code>clip-retrieval</code> which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization <code>with_prior_preservation</code>, <code>real_prior</code> with <code>prior_loss_weight=1.</code>.
The <code>class_prompt</code> should be the category name same as target image. The collected real images are with text captions similar to the <code>class_prompt</code>. The retrieved image are saved in <code>class_data_dir</code>. You can disable <code>real_prior</code> to use generated images as regularization. To collect the real images use this command first before training. </p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->pip install clip-retrieval
python retrieve.py --class_prompt <span class="hljs-built_in">cat</span> --class_data_dir real_reg/samples_cat --num_class_images 200<!-- HTML_TAG_END --></pre></div>
<p><strong><strong><em>Note: Change the <code>resolution</code> to 768 if you are using the <a href="https://huggingface.co/stabilityai/stable-diffusion-2" rel="nofollow">stable-diffusion-2</a> 768x768 model.</em></strong></strong></p>
<p>The script creates and saves model checkpoints and a <code>pytorch_custom_diffusion_weights.bin</code> file in your repository.</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>
<span class="hljs-built_in">export</span> OUTPUT_DIR=<span class="hljs-string">&quot;path-to-save-model&quot;</span>
<span class="hljs-built_in">export</span> INSTANCE_DIR=<span class="hljs-string">&quot;./data/cat&quot;</span>
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--instance_data_dir=<span class="hljs-variable">$INSTANCE_DIR</span> \
--output_dir=<span class="hljs-variable">$OUTPUT_DIR</span> \
--class_data_dir=./real_reg/samples_cat/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt=<span class="hljs-string">&quot;cat&quot;</span> --num_class_images=200 \
--instance_prompt=<span class="hljs-string">&quot;photo of a &lt;new1&gt; cat&quot;</span> \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--scale_lr --hflip \
--modifier_token <span class="hljs-string">&quot;&lt;new1&gt;&quot;</span> \
--push_to_hub<!-- HTML_TAG_END --></pre></div>
<p><strong>Use <code>--enable_xformers_memory_efficient_attention</code> for faster training with lower VRAM requirement (16GB per GPU). Follow <a href="https://github.com/facebookresearch/xformers" rel="nofollow">this guide</a> for installation instructions.</strong></p>
<p>To track your experiments using Weights and Biases (<code>wandb</code>) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:</p>
<ul><li>Install <code>wandb</code>: <code>pip install wandb</code>.</li>
<li>Authorize: <code>wandb login</code>. </li>
<li>Then specify a <code>validation_prompt</code> and set <code>report_to</code> to <code>wandb</code> while launching training. You can also configure the following related arguments:<ul><li><code>num_validation_images</code></li>
<li><code>validation_steps</code></li></ul></li></ul>
<p>Here is an example command:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--instance_data_dir=<span class="hljs-variable">$INSTANCE_DIR</span> \
--output_dir=<span class="hljs-variable">$OUTPUT_DIR</span> \
--class_data_dir=./real_reg/samples_cat/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt=<span class="hljs-string">&quot;cat&quot;</span> --num_class_images=200 \
--instance_prompt=<span class="hljs-string">&quot;photo of a &lt;new1&gt; cat&quot;</span> \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--scale_lr --hflip \
--modifier_token <span class="hljs-string">&quot;&lt;new1&gt;&quot;</span> \
--validation_prompt=<span class="hljs-string">&quot;&lt;new1&gt; cat sitting in a bucket&quot;</span> \
--report_to=<span class="hljs-string">&quot;wandb&quot;</span> \
--push_to_hub<!-- HTML_TAG_END --></pre></div>
<p>Here is an example <a href="https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau" rel="nofollow">Weights and Biases page</a> where you can check out the intermediate results along with other training details. </p>
<p>If you specify <code>--push_to_hub</code>, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an <a href="https://huggingface.co/sayakpaul/custom-diffusion-cat" rel="nofollow">example repository</a>.</p>
<h3 class="relative group"><a id="training-on-multiple-concepts" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-on-multiple-concepts"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Training on multiple concepts 🐱🪵
</span></h3>
<p>Provide a <a href="https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json" rel="nofollow">json</a> file with the info about each concept, similar to <a href="https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py" rel="nofollow">this</a>.</p>
<p>To collect the real images run this command for each concept in the json file. </p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->pip install clip-retrieval
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200<!-- HTML_TAG_END --></pre></div>
<p>And then we’re ready to start training!</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>
<span class="hljs-built_in">export</span> OUTPUT_DIR=<span class="hljs-string">&quot;path-to-save-model&quot;</span>
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--output_dir=<span class="hljs-variable">$OUTPUT_DIR</span> \
--concepts_list=./concept_list.json \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--num_class_images=200 \
--scale_lr --hflip \
--modifier_token <span class="hljs-string">&quot;&lt;new1&gt;+&lt;new2&gt;&quot;</span> \
--push_to_hub<!-- HTML_TAG_END --></pre></div>
<p>Here is an example <a href="https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg" rel="nofollow">Weights and Biases page</a> where you can check out the intermediate results along with other training details. </p>
<h3 class="relative group"><a id="training-on-human-faces" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-on-human-faces"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Training on human faces
</span></h3>
<p>For fine-tuning on human faces we found the following configuration to work better: <code>learning_rate=5e-6</code>, <code>max_train_steps=1000 to 2000</code>, and <code>freeze_model=crossattn</code> with at least 15-20 images. </p>
<p>To collect the real images use this command first before training. </p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START -->pip install clip-retrieval
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200<!-- HTML_TAG_END --></pre></div>
<p>Then start training!</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> MODEL_NAME=<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>
<span class="hljs-built_in">export</span> OUTPUT_DIR=<span class="hljs-string">&quot;path-to-save-model&quot;</span>
<span class="hljs-built_in">export</span> INSTANCE_DIR=<span class="hljs-string">&quot;path-to-images&quot;</span>
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=<span class="hljs-variable">$MODEL_NAME</span> \
--instance_data_dir=<span class="hljs-variable">$INSTANCE_DIR</span> \
--output_dir=<span class="hljs-variable">$OUTPUT_DIR</span> \
--class_data_dir=./real_reg/samples_person/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt=<span class="hljs-string">&quot;person&quot;</span> --num_class_images=200 \
--instance_prompt=<span class="hljs-string">&quot;photo of a &lt;new1&gt; person&quot;</span> \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=5e-6 \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--scale_lr --hflip --noaug \
--freeze_model crossattn \
--modifier_token <span class="hljs-string">&quot;&lt;new1&gt;&quot;</span> \
--enable_xformers_memory_efficient_attention \
--push_to_hub<!-- HTML_TAG_END --></pre></div>
<h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Inference
</span></h2>
<p>Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the <code>modifier token</code> (e.g. \&lt;new1&gt; in above example) in your prompt.</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipe.unet.load_attn_procs(<span class="hljs-string">&quot;path-to-save-model&quot;</span>, weight_name=<span class="hljs-string">&quot;pytorch_custom_diffusion_weights.bin&quot;</span>)
pipe.load_textual_inversion(<span class="hljs-string">&quot;path-to-save-model&quot;</span>, weight_name=<span class="hljs-string">&quot;&lt;new1&gt;.bin&quot;</span>)
image = pipe(
<span class="hljs-string">&quot;&lt;new1&gt; cat sitting in a bucket&quot;</span>,
num_inference_steps=<span class="hljs-number">100</span>,
guidance_scale=<span class="hljs-number">6.0</span>,
eta=<span class="hljs-number">1.0</span>,
).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;cat.png&quot;</span>)<!-- HTML_TAG_END --></pre></div>
<p>It’s possible to directly load these parameters from a Hub repository:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> huggingface_hub.repocard <span class="hljs-keyword">import</span> RepoCard
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
model_id = <span class="hljs-string">&quot;sayakpaul/custom-diffusion-cat&quot;</span>
card = RepoCard.load(model_id)
base_model_id = card.data.to_dict()[<span class="hljs-string">&quot;base_model&quot;</span>]
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipe.unet.load_attn_procs(model_id, weight_name=<span class="hljs-string">&quot;pytorch_custom_diffusion_weights.bin&quot;</span>)
pipe.load_textual_inversion(model_id, weight_name=<span class="hljs-string">&quot;&lt;new1&gt;.bin&quot;</span>)
image = pipe(
<span class="hljs-string">&quot;&lt;new1&gt; cat sitting in a bucket&quot;</span>,
num_inference_steps=<span class="hljs-number">100</span>,
guidance_scale=<span class="hljs-number">6.0</span>,
eta=<span class="hljs-number">1.0</span>,
).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;cat.png&quot;</span>)<!-- HTML_TAG_END --></pre></div>
<p>Here is an example of performing inference with multiple concepts:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> huggingface_hub.repocard <span class="hljs-keyword">import</span> RepoCard
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
model_id = <span class="hljs-string">&quot;sayakpaul/custom-diffusion-cat-wooden-pot&quot;</span>
card = RepoCard.load(model_id)
base_model_id = card.data.to_dict()[<span class="hljs-string">&quot;base_model&quot;</span>]
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipe.unet.load_attn_procs(model_id, weight_name=<span class="hljs-string">&quot;pytorch_custom_diffusion_weights.bin&quot;</span>)
pipe.load_textual_inversion(model_id, weight_name=<span class="hljs-string">&quot;&lt;new1&gt;.bin&quot;</span>)
pipe.load_textual_inversion(model_id, weight_name=<span class="hljs-string">&quot;&lt;new2&gt;.bin&quot;</span>)
image = pipe(
<span class="hljs-string">&quot;the &lt;new1&gt; cat sculpture in the style of a &lt;new2&gt; wooden pot&quot;</span>,
num_inference_steps=<span class="hljs-number">100</span>,
guidance_scale=<span class="hljs-number">6.0</span>,
eta=<span class="hljs-number">1.0</span>,
).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">&quot;multi-subject.png&quot;</span>)<!-- HTML_TAG_END --></pre></div>
<p>Here, <code>cat</code> and <code>wooden pot</code> refer to the multiple concepts.</p>
<h3 class="relative group"><a id="inference-from-a-training-checkpoint" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference-from-a-training-checkpoint"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Inference from a training checkpoint
</span></h3>
<p>You can also perform inference from one of the complete checkpoint saved during the training process, if you used the <code>--checkpointing_steps</code> argument. </p>
<p>TODO.</p>
<h2 class="relative group"><a id="set-grads-to-none" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#set-grads-to-none"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Set grads to none
</span></h2>
<p>To save even more memory, pass the <code>--set_grads_to_none</code> argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.</p>
<p>More info: <a href="https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" rel="nofollow">https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html</a></p>
<h2 class="relative group"><a id="experimental-results" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#experimental-results"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Experimental results
</span></h2>
<p>You can refer to <a href="https://www.cs.cmu.edu/~custom-diffusion/" rel="nofollow">our webpage</a> that discusses our experiments in detail. </p>
<script type="module" data-hydrate="1vikd0u">
import { start } from "/docs/diffusers/v0.19.2/en/_app/start-hf-doc-builder.js";
start({
target: document.querySelector('[data-hydrate="1vikd0u"]').parentNode,
paths: {"base":"/docs/diffusers/v0.19.2/en","assets":"/docs/diffusers/v0.19.2/en"},
session: {},
route: false,
spa: false,
trailing_slash: "never",
hydrate: {
status: 200,
error: null,
nodes: [
import("/docs/diffusers/v0.19.2/en/_app/pages/__layout.svelte-hf-doc-builder.js"),
import("/docs/diffusers/v0.19.2/en/_app/pages/training/custom_diffusion.mdx-hf-doc-builder.js")
],
params: {}
}
});
</script>

Xet Storage Details

Size:
48.7 kB
·
Xet hash:
da14355bb442df0b3633628bbcce31b92a76b03dc059deb5c993a61a2c06e564

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.