Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Denoising Diffusion Policy Optimization","local":"denoising-diffusion-policy-optimization","sections":[{"title":"The why","local":"the-why","sections":[],"depth":2},{"title":"Getting started with Stable Diffusion finetuning with reinforcement learning","local":"getting-started-with-stable-diffusion-finetuning-with-reinforcement-learning","sections":[],"depth":2},{"title":"Getting started with examples/scripts/ddpo.py","local":"getting-started-with-examplesscriptsddpopy","sections":[],"depth":2},{"title":"Setting up the image logging hook function","local":"setting-up-the-image-logging-hook-function","sections":[{"title":"Key terms","local":"key-terms","sections":[],"depth":3},{"title":"Using the finetuned model","local":"using-the-finetuned-model","sections":[],"depth":3}],"depth":2},{"title":"Credits","local":"credits","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/trl/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/scheduler.85c25b89.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/singletons.98fe034d.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/paths.eb9df337.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/index.c142fe32.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/0.5efac18d.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/9.8d95c463.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/CodeBlock.a5e95a57.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/EditOnGithub.a592e7aa.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Denoising Diffusion Policy Optimization","local":"denoising-diffusion-policy-optimization","sections":[{"title":"The why","local":"the-why","sections":[],"depth":2},{"title":"Getting started with Stable Diffusion finetuning with reinforcement learning","local":"getting-started-with-stable-diffusion-finetuning-with-reinforcement-learning","sections":[],"depth":2},{"title":"Getting started with examples/scripts/ddpo.py","local":"getting-started-with-examplesscriptsddpopy","sections":[],"depth":2},{"title":"Setting up the image logging hook function","local":"setting-up-the-image-logging-hook-function","sections":[{"title":"Key terms","local":"key-terms","sections":[],"depth":3},{"title":"Using the finetuned model","local":"using-the-finetuned-model","sections":[],"depth":3}],"depth":2},{"title":"Credits","local":"credits","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="denoising-diffusion-policy-optimization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#denoising-diffusion-policy-optimization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Denoising Diffusion Policy Optimization</span></h1> <h2 class="relative group"><a id="the-why" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-why"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The why</span></h2> <table data-svelte-h="svelte-k42nnq"><thead><tr><th>Before</th> <th>After DDPO finetuning</th></tr></thead> <tbody><tr><td><div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"></div></td> <td><div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"></div></td></tr> <tr><td><div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"></div></td> <td><div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"></div></td></tr> <tr><td><div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"></div></td> <td><div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"></div></td></tr></tbody></table> <h2 class="relative group"><a id="getting-started-with-stable-diffusion-finetuning-with-reinforcement-learning" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#getting-started-with-stable-diffusion-finetuning-with-reinforcement-learning"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Getting started with Stable Diffusion finetuning with reinforcement learning</span></h2> <p data-svelte-h="svelte-1ut0u9">The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace’s <code>diffusers</code> | |
| library. A reason for stating this is that getting started requires a bit of familiarity with the <code>diffusers</code> library concepts, mainly two of them - pipelines and schedulers. | |
| Right out of the box (<code>diffusers</code> library), there isn’t a <code>Pipeline</code> nor a <code>Scheduler</code> instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.</p> <p data-svelte-h="svelte-xp6gwm">There is a pipeline interface that is provided by this library that is required to be implemented to be used with the <code>DDPOTrainer</code>, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. <strong>Note: Only the StableDiffusion architecture is supported at this point.</strong> | |
| There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.</p> <p data-svelte-h="svelte-a4w6o0">The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).</p> <p data-svelte-h="svelte-t0ve8a">For a more detailed look into the interface and the associated default implementation, go <a href="https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py" rel="nofollow">here</a></p> <p data-svelte-h="svelte-1s71aap">Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren’t as finicky as non-LORA based training.</p> <p data-svelte-h="svelte-1e16bv7">Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.</p> <h2 class="relative group"><a id="getting-started-with-examplesscriptsddpopy" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#getting-started-with-examplesscriptsddpopy"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Getting started with examples/scripts/ddpo.py</span></h2> <p data-svelte-h="svelte-1cfw9tr">The <code>ddpo.py</code> script is a working example of using the <code>DDPO</code> trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (<code>DDPOConfig</code>).</p> <p data-svelte-h="svelte-xqnbvi"><strong>Note:</strong> one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.</p> <p data-svelte-h="svelte-1iay88v">Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a <a href="https://huggingface.co/docs/hub/security-tokens" rel="nofollow">huggingface user access token</a> that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">python</span> ddpo.<span class="hljs-keyword">py</span> --hf_user_access_token <span class="hljs-symbol"><token></span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mmabua">To obtain the documentation of <code>stable_diffusion_tuning.py</code>, please run <code>python stable_diffusion_tuning.py --help</code></p> <p data-svelte-h="svelte-gn7d4">The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)</p> <ul data-svelte-h="svelte-1wxt0fz"><li>The configurable sample batch size (<code>--ddpo_config.sample_batch_size=6</code>) should be greater than or equal to the configurable training batch size (<code>--ddpo_config.train_batch_size=3</code>)</li> <li>The configurable sample batch size (<code>--ddpo_config.sample_batch_size=6</code>) must be divisible by the configurable train batch size (<code>--ddpo_config.train_batch_size=3</code>)</li> <li>The configurable sample batch size (<code>--ddpo_config.sample_batch_size=6</code>) must be divisible by both the configurable gradient accumulation steps (<code>--ddpo_config.train_gradient_accumulation_steps=1</code>) and the configurable accelerator processes count</li></ul> <h2 class="relative group"><a id="setting-up-the-image-logging-hook-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#setting-up-the-image-logging-hook-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Setting up the image logging hook function</span></h2> <p data-svelte-h="svelte-171n2tq">Expect the function to be given a list of lists of the form</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[[image, prompt, prompt_metadata, rewards, reward_metadata], ...] | |
| <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qc0duo">and <code>image</code>, <code>prompt</code>, <code>prompt_metadata</code>, <code>rewards</code>, <code>reward_metadata</code> are batched. | |
| The last list in the lists of lists represents the last sample batch. You are likely to want to log this one | |
| While you are free to log however you want the use of <code>wandb</code> or <code>tensorboard</code> is recommended.</p> <h3 class="relative group"><a id="key-terms" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#key-terms"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Key terms</span></h3> <ul data-svelte-h="svelte-hf6txl"><li><code>rewards</code> : The rewards/score is a numerical associated with the generated image and is key to steering the RL process</li> <li><code>reward_metadata</code> : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward</li> <li><code>prompt</code> : The prompt is the text that is used to generate the image</li> <li><code>prompt_metadata</code> : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a <a href="https://huggingface.co/docs/transformers/model_doc/flava" rel="nofollow"><code>FLAVA</code></a> setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: <a href="https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45" rel="nofollow">https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45</a>)</li> <li><code>image</code> : The image generated by the Stable Diffusion model</li></ul> <p data-svelte-h="svelte-1r6a60u">Example code for logging sampled images with <code>wandb</code> is given below.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># for logging these images to wandb</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">image_outputs_hook</span>(<span class="hljs-params">image_data, global_step, accelerate_logger</span>): | |
| <span class="hljs-comment"># For the sake of this example, we only care about the last batch</span> | |
| <span class="hljs-comment"># hence we extract the last element of the list</span> | |
| result = {} | |
| images, prompts, _, rewards, _ = image_data[-<span class="hljs-number">1</span>] | |
| <span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(images): | |
| pil = Image.fromarray( | |
| (image.cpu().numpy().transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>) * <span class="hljs-number">255</span>).astype(np.uint8) | |
| ) | |
| pil = pil.resize((<span class="hljs-number">256</span>, <span class="hljs-number">256</span>)) | |
| result[<span class="hljs-string">f"<span class="hljs-subst">{prompts[i]:<span class="hljs-number">.25</span>}</span> | <span class="hljs-subst">{rewards[i]:<span class="hljs-number">.2</span>f}</span>"</span>] = [pil] | |
| accelerate_logger.log_images( | |
| result, | |
| step=global_step, | |
| ) | |
| <!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="using-the-finetuned-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-the-finetuned-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using the finetuned model</span></h3> <p data-svelte-h="svelte-eqxomu">Assuming you’ve done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> DefaultDDPOStableDiffusionPipeline | |
| pipeline = DefaultDDPOStableDiffusionPipeline(<span class="hljs-string">"metric-space/ddpo-finetuned-sd-model"</span>) | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| <span class="hljs-comment"># memory optimization</span> | |
| pipeline.vae.to(device, torch.float16) | |
| pipeline.text_encoder.to(device, torch.float16) | |
| pipeline.unet.to(device, torch.float16) | |
| prompts = [<span class="hljs-string">"squirrel"</span>, <span class="hljs-string">"crab"</span>, <span class="hljs-string">"starfish"</span>, <span class="hljs-string">"whale"</span>,<span class="hljs-string">"sponge"</span>, <span class="hljs-string">"plankton"</span>] | |
| results = pipeline(prompts) | |
| <span class="hljs-keyword">for</span> prompt, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(prompts,results.images): | |
| image.save(<span class="hljs-string">f"<span class="hljs-subst">{prompt}</span>.png"</span>) | |
| <!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="credits" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#credits"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Credits</span></h2> <p data-svelte-h="svelte-16vahxi">This work is heavily influenced by the repo <a href="https://github.com/kvablack/ddpo-pytorch" rel="nofollow">here</a> and the associated paper <a href="https://huggingface.co/papers/2305.13301" rel="nofollow">Training Diffusion Models | |
| with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine</a>.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/ddpo_trainer.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_5yobsv = { | |
| assets: "/docs/trl/main/en", | |
| base: "/docs/trl/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"), | |
| import("/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 9], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 30 kB
- Xet hash:
- 50f7c5e388a613e38338531f714734c1ec11916865f68a05c6951f4bbe340648
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.