Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Aligning Text-to-Image Diffusion Models with Reward Backpropagation","local":"aligning-text-to-image-diffusion-models-with-reward-backpropagation","sections":[{"title":"The why","local":"the-why","sections":[],"depth":2},{"title":"Getting started with examples/scripts/alignprop.py","local":"getting-started-with-examplesscriptsalignproppy","sections":[],"depth":2},{"title":"Setting up the image logging hook function","local":"setting-up-the-image-logging-hook-function","sections":[{"title":"Key terms","local":"key-terms","sections":[],"depth":3},{"title":"Using the finetuned model","local":"using-the-finetuned-model","sections":[],"depth":3}],"depth":2},{"title":"Credits","local":"credits","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/trl/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/scheduler.85c25b89.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/singletons.98fe034d.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/paths.eb9df337.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/index.c142fe32.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/0.5efac18d.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/2.be5f3526.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/CodeBlock.a5e95a57.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/EditOnGithub.a592e7aa.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Aligning Text-to-Image Diffusion Models with Reward Backpropagation","local":"aligning-text-to-image-diffusion-models-with-reward-backpropagation","sections":[{"title":"The why","local":"the-why","sections":[],"depth":2},{"title":"Getting started with examples/scripts/alignprop.py","local":"getting-started-with-examplesscriptsalignproppy","sections":[],"depth":2},{"title":"Setting up the image logging hook function","local":"setting-up-the-image-logging-hook-function","sections":[{"title":"Key terms","local":"key-terms","sections":[],"depth":3},{"title":"Using the finetuned model","local":"using-the-finetuned-model","sections":[],"depth":3}],"depth":2},{"title":"Credits","local":"credits","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="aligning-text-to-image-diffusion-models-with-reward-backpropagation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#aligning-text-to-image-diffusion-models-with-reward-backpropagation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Aligning Text-to-Image Diffusion Models with Reward Backpropagation</span></h1> <h2 class="relative group"><a id="the-why" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-why"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The why</span></h2> <p data-svelte-h="svelte-1ekl94b">If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO. | |
| AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.</p> <div style="text-align: center" data-svelte-h="svelte-x64r6o"><img src="https://align-prop.github.io/reward_tuning.png"></div> <h2 class="relative group"><a id="getting-started-with-examplesscriptsalignproppy" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#getting-started-with-examplesscriptsalignproppy"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Getting started with examples/scripts/alignprop.py</span></h2> <p data-svelte-h="svelte-5pt88m">The <code>alignprop.py</code> script is a working example of using the <code>AlignProp</code> trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (<code>AlignPropConfig</code>).</p> <p data-svelte-h="svelte-13c8m5g"><strong>Note:</strong> one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.</p> <p data-svelte-h="svelte-1iay88v">Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a <a href="https://huggingface.co/docs/hub/security-tokens" rel="nofollow">huggingface user access token</a> that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">python</span> alignprop.<span class="hljs-keyword">py</span> --hf_user_access_token <span class="hljs-symbol"><token></span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mmabua">To obtain the documentation of <code>stable_diffusion_tuning.py</code>, please run <code>python stable_diffusion_tuning.py --help</code></p> <p data-svelte-h="svelte-gn7d4">The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)</p> <ul data-svelte-h="svelte-641i8w"><li>The configurable randomized truncation range (<code>--alignprop_config.truncated_rand_backprop_minmax=(0,50)</code>) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)</li> <li>The configurable truncation backprop absolute step (<code>--alignprop_config.truncated_backprop_timestep=49</code>) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False</li></ul> <h2 class="relative group"><a id="setting-up-the-image-logging-hook-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#setting-up-the-image-logging-hook-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Setting up the image logging hook function</span></h2> <p data-svelte-h="svelte-irm1pa">Expect the function to be given a dictionary with keys</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-string">'image'</span>, <span class="hljs-string">'prompt'</span>, <span class="hljs-string">'prompt_metadata'</span>, <span class="hljs-string">'rewards'</span>] | |
| <!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1txp1uc">and <code>image</code>, <code>prompt</code>, <code>prompt_metadata</code>, <code>rewards</code>are batched. | |
| You are free to log however you want the use of <code>wandb</code> or <code>tensorboard</code> is recommended.</p> <h3 class="relative group"><a id="key-terms" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#key-terms"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Key terms</span></h3> <ul data-svelte-h="svelte-1rki34a"><li><code>rewards</code> : The rewards/score is a numerical associated with the generated image and is key to steering the RL process</li> <li><code>prompt</code> : The prompt is the text that is used to generate the image</li> <li><code>prompt_metadata</code> : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a <a href="https://huggingface.co/docs/transformers/model_doc/flava" rel="nofollow"><code>FLAVA</code></a> setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: <a href="https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45" rel="nofollow">https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45</a>)</li> <li><code>image</code> : The image generated by the Stable Diffusion model</li></ul> <p data-svelte-h="svelte-1r6a60u">Example code for logging sampled images with <code>wandb</code> is given below.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># for logging these images to wandb</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">image_outputs_hook</span>(<span class="hljs-params">image_data, global_step, accelerate_logger</span>): | |
| <span class="hljs-comment"># For the sake of this example, we only care about the last batch</span> | |
| <span class="hljs-comment"># hence we extract the last element of the list</span> | |
| result = {} | |
| images, prompts, rewards = [image_data[<span class="hljs-string">'images'</span>],image_data[<span class="hljs-string">'prompts'</span>],image_data[<span class="hljs-string">'rewards'</span>]] | |
| <span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(images): | |
| pil = Image.fromarray( | |
| (image.cpu().numpy().transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>) * <span class="hljs-number">255</span>).astype(np.uint8) | |
| ) | |
| pil = pil.resize((<span class="hljs-number">256</span>, <span class="hljs-number">256</span>)) | |
| result[<span class="hljs-string">f"<span class="hljs-subst">{prompts[i]:<span class="hljs-number">.25</span>}</span> | <span class="hljs-subst">{rewards[i]:<span class="hljs-number">.2</span>f}</span>"</span>] = [pil] | |
| accelerate_logger.log_images( | |
| result, | |
| step=global_step, | |
| ) | |
| <!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="using-the-finetuned-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-the-finetuned-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using the finetuned model</span></h3> <p data-svelte-h="svelte-eqxomu">Assuming you’ve done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline | |
| pipeline = StableDiffusionPipeline.from_pretrained(<span class="hljs-string">"runwayml/stable-diffusion-v1-5"</span>) | |
| pipeline.to(<span class="hljs-string">"cuda"</span>) | |
| pipeline.load_lora_weights(<span class="hljs-string">'mihirpd/alignprop-trl-aesthetics'</span>) | |
| prompts = [<span class="hljs-string">"squirrel"</span>, <span class="hljs-string">"crab"</span>, <span class="hljs-string">"starfish"</span>, <span class="hljs-string">"whale"</span>,<span class="hljs-string">"sponge"</span>, <span class="hljs-string">"plankton"</span>] | |
| results = pipeline(prompts) | |
| <span class="hljs-keyword">for</span> prompt, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(prompts,results.images): | |
| image.save(<span class="hljs-string">f"dump/<span class="hljs-subst">{prompt}</span>.png"</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="credits" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#credits"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Credits</span></h2> <p data-svelte-h="svelte-8i7v5m">This work is heavily influenced by the repo <a href="https://github.com/mihirp1998/AlignProp/" rel="nofollow">here</a> and the associated paper <a href="https://huggingface.co/papers/2310.03739" rel="nofollow">Aligning Text-to-Image Diffusion Models with Reward Backpropagation | |
| by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki</a>.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/alignprop_trainer.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_5yobsv = { | |
| assets: "/docs/trl/main/en", | |
| base: "/docs/trl/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"), | |
| import("/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 2], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 24.7 kB
- Xet hash:
- b46a8d812c13dc1dec3d0777fe7af857997d5e9ae979f17e08c3666436994007
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.