Buckets:

download
raw
71.1 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;训练扩散模型&quot;,&quot;local&quot;:&quot;训练扩散模型&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;训练配置&quot;,&quot;local&quot;:&quot;训练配置&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;加载数据集&quot;,&quot;local&quot;:&quot;加载数据集&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;创建 UNet2DModel&quot;,&quot;local&quot;:&quot;创建-unet2dmodel&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;创建调度器&quot;,&quot;local&quot;:&quot;创建调度器&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练模型&quot;,&quot;local&quot;:&quot;训练模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;下一步&quot;,&quot;local&quot;:&quot;下一步&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_13098/zh/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/entry/start.4dec4f79.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/scheduler.e4ff9b64.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/singletons.07a1ec04.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/index.f9be34a7.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/paths.1c31cc7a.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/entry/app.077c1f41.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/preload-helper.c70cb3ab.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/index.09f1bca0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/nodes/0.caca1171.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/nodes/52.f4c0aecc.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.b06b56b0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/CodeBlock.d66d98da.js">
<link rel="modulepreload" href="/docs/diffusers/pr_13098/zh/_app/immutable/chunks/DocNotebookDropdown.02241b22.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;训练扩散模型&quot;,&quot;local&quot;:&quot;训练扩散模型&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;训练配置&quot;,&quot;local&quot;:&quot;训练配置&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;加载数据集&quot;,&quot;local&quot;:&quot;加载数据集&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;创建 UNet2DModel&quot;,&quot;local&quot;:&quot;创建-unet2dmodel&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;创建调度器&quot;,&quot;local&quot;:&quot;创建调度器&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练模型&quot;,&quot;local&quot;:&quot;训练模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;下一步&quot;,&quot;local&quot;:&quot;下一步&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <div class="flex space-x-1 " style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <h1 class="relative group"><a id="训练扩散模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#训练扩散模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>训练扩散模型</span></h1> <p data-svelte-h="svelte-xqzsao">无条件图像生成是扩散模型最常见的应用之一,它会生成与训练数据集风格相似的图像。通常来说,在某个特定数据集上微调预训练模型能得到最好的结果。你可以在 <a href="https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model" rel="nofollow">Hub</a> 上找到很多现成检查点;如果找不到满意的,也完全可以自己训练一个!</p> <p data-svelte-h="svelte-1bznuf8">这篇教程会教你如何在 <a href="https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset" rel="nofollow">Smithsonian Butterflies</a> 数据集的一个子集上,从零开始训练一个 <code>UNet2DModel</code>,生成属于你自己的 🦋 蝴蝶图像 🦋。</p> <blockquote class="tip" data-svelte-h="svelte-eplgla"><p>💡 这篇训练教程基于 <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb" rel="nofollow">Training with 🧨 Diffusers</a> notebook 编写。如果你想了解更多背景,例如扩散模型的工作原理,也推荐一起看看这个 notebook。</p></blockquote> <p data-svelte-h="svelte-1nahceu">开始之前,请确认已经安装了 🤗 Datasets,用来加载和预处理图像数据集;以及 🤗 Accelerate,用来简化任意数量 GPU 上的训练。下面这条命令也会安装 <a href="https://www.tensorflow.org/tensorboard" rel="nofollow">TensorBoard</a> 来可视化训练指标(你也可以使用 <a href="https://docs.wandb.ai/" rel="nofollow">Weights &amp; Biases</a> 跟踪训练)。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-comment"># 如果你在 Colab 中运行,请取消注释来安装所需依赖</span>
<span class="hljs-comment">#!pip install diffusers[training]</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15ss5vv">我们也很鼓励你把模型分享给社区。为此,你需要登录自己的 Hugging Face 账号(如果还没有,可以在 <a href="https://hf.co/join" rel="nofollow">这里</a> 创建)。你可以在 notebook 中登录,系统会提示你输入 token。请确保这个 token 具有写入权限。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
<span class="hljs-meta">&gt;&gt;&gt; </span>notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m5ch90">或者在终端里登录:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-bash "><!-- HTML_TAG_START -->hf auth login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1d2k38t">由于模型检查点通常比较大,建议安装 <a href="https://git-lfs.com/" rel="nofollow">Git-LFS</a> 来管理这些大文件:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-bash "><!-- HTML_TAG_START -->!sudo apt -qq install git-lfs
!git config --global credential.helper store<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="训练配置" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#训练配置"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>训练配置</span></h2> <p data-svelte-h="svelte-1dir3dp">为了方便起见,我们先创建一个 <code>TrainingConfig</code> 类,把训练超参数放在一起(你可以按需调整):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> dataclasses <span class="hljs-keyword">import</span> dataclass
<span class="hljs-meta">&gt;&gt;&gt; </span>@dataclass
<span class="hljs-meta">... </span><span class="hljs-keyword">class</span> <span class="hljs-title class_">TrainingConfig</span>:
<span class="hljs-meta">... </span> image_size = <span class="hljs-number">128</span> <span class="hljs-comment"># 生成图像的分辨率</span>
<span class="hljs-meta">... </span> train_batch_size = <span class="hljs-number">16</span>
<span class="hljs-meta">... </span> eval_batch_size = <span class="hljs-number">16</span> <span class="hljs-comment"># 评估时每次采样多少张图像</span>
<span class="hljs-meta">... </span> num_epochs = <span class="hljs-number">50</span>
<span class="hljs-meta">... </span> gradient_accumulation_steps = <span class="hljs-number">1</span>
<span class="hljs-meta">... </span> learning_rate = <span class="hljs-number">1e-4</span>
<span class="hljs-meta">... </span> lr_warmup_steps = <span class="hljs-number">500</span>
<span class="hljs-meta">... </span> save_image_epochs = <span class="hljs-number">10</span>
<span class="hljs-meta">... </span> save_model_epochs = <span class="hljs-number">30</span>
<span class="hljs-meta">... </span> mixed_precision = <span class="hljs-string">&quot;fp16&quot;</span> <span class="hljs-comment"># float32 用 `no`,自动混合精度用 `fp16`</span>
<span class="hljs-meta">... </span> output_dir = <span class="hljs-string">&quot;ddpm-butterflies-128&quot;</span> <span class="hljs-comment"># 本地和 HF Hub 上的模型名称</span>
<span class="hljs-meta">... </span> push_to_hub = <span class="hljs-literal">True</span> <span class="hljs-comment"># 是否将保存后的模型上传到 HF Hub</span>
<span class="hljs-meta">... </span> hub_model_id = <span class="hljs-string">&quot;&lt;your-username&gt;/&lt;my-awesome-model&gt;&quot;</span> <span class="hljs-comment"># 在 HF Hub 上创建的仓库名称</span>
<span class="hljs-meta">... </span> hub_private_repo = <span class="hljs-literal">None</span>
<span class="hljs-meta">... </span> overwrite_output_dir = <span class="hljs-literal">True</span> <span class="hljs-comment"># 重新运行 notebook 时是否覆盖旧模型</span>
<span class="hljs-meta">... </span> seed = <span class="hljs-number">0</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>config = TrainingConfig()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="加载数据集" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#加载数据集"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>加载数据集</span></h2> <p data-svelte-h="svelte-1dwdgw3">你可以很轻松地通过 🤗 Datasets 加载 <a href="https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset" rel="nofollow">Smithsonian Butterflies</a> 数据集:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>config.dataset_name = <span class="hljs-string">&quot;huggan/smithsonian_butterflies_subset&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset = load_dataset(config.dataset_name, split=<span class="hljs-string">&quot;train&quot;</span>)<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-7ae2dp"><p>💡 你也可以从 <a href="https://huggingface.co/huggan" rel="nofollow">HugGan Community Event</a> 找到更多数据集,或者通过本地 <a href="https://huggingface.co/docs/datasets/image_dataset#imagefolder" rel="nofollow"><code>ImageFolder</code></a> 使用自己的数据集。如果你使用 HugGan Community Event 里的数据集,把 <code>config.dataset_name</code> 设为对应数据集的 repository id;如果你使用自己的图像,就设为 <code>imagefolder</code></p></blockquote> <p data-svelte-h="svelte-zzj0e2">🤗 Datasets 使用 <code>Image</code> 特性自动解码图像数据,并将其加载为 <a href="https://pillow.readthedocs.io/en/stable/reference/Image.html" rel="nofollow"><code>PIL.Image</code></a>,所以我们可以直接可视化:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
<span class="hljs-meta">&gt;&gt;&gt; </span>fig, axs = plt.subplots(<span class="hljs-number">1</span>, <span class="hljs-number">4</span>, figsize=(<span class="hljs-number">16</span>, <span class="hljs-number">4</span>))
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataset[:<span class="hljs-number">4</span>][<span class="hljs-string">&quot;image&quot;</span>]):
<span class="hljs-meta">... </span> axs[i].imshow(image)
<span class="hljs-meta">... </span> axs[i].set_axis_off()
<span class="hljs-meta">&gt;&gt;&gt; </span>fig.show()<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-1wjbouq"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png"></div> <p data-svelte-h="svelte-1q7nc6k">不过这些图像的尺寸各不相同,所以你需要先做预处理:</p> <ul data-svelte-h="svelte-9t42wx"><li><code>Resize</code> 把图像缩放到 <code>config.image_size</code> 中定义的大小。</li> <li><code>RandomHorizontalFlip</code> 通过随机水平翻转图像来做数据增强。</li> <li><code>Normalize</code> 很重要,它会把像素值缩放到 <code>[-1, 1]</code> 区间,这是模型期望的输入范围。</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torchvision <span class="hljs-keyword">import</span> transforms
<span class="hljs-meta">&gt;&gt;&gt; </span>preprocess = transforms.Compose(
<span class="hljs-meta">... </span> [
<span class="hljs-meta">... </span> transforms.Resize((config.image_size, config.image_size)),
<span class="hljs-meta">... </span> transforms.RandomHorizontalFlip(),
<span class="hljs-meta">... </span> transforms.ToTensor(),
<span class="hljs-meta">... </span> transforms.Normalize([<span class="hljs-number">0.5</span>], [<span class="hljs-number">0.5</span>]),
<span class="hljs-meta">... </span> ]
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1rrvmkf">使用 🤗 Datasets 的 <code>set_transform</code> 方法,在训练过程中按需应用 <code>preprocess</code> 函数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">transform</span>(<span class="hljs-params">examples</span>):
<span class="hljs-meta">... </span> images = [preprocess(image.convert(<span class="hljs-string">&quot;RGB&quot;</span>)) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> examples[<span class="hljs-string">&quot;image&quot;</span>]]
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;images&quot;</span>: images}
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset.set_transform(transform)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lzlik">你也可以再次可视化图像,确认它们已经被调整到目标尺寸。接下来,就可以把数据集封装成一个 <a href="https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader" rel="nofollow">DataLoader</a> 来训练了!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span>train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="创建-unet2dmodel" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#创建-unet2dmodel"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>创建 UNet2DModel</span></h2> <p data-svelte-h="svelte-1uyi8q8">在 🧨 Diffusers 中,可以很方便地通过模型类和参数创建预训练模型。例如,下面创建一个 <code>UNet2DModel</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> UNet2DModel
<span class="hljs-meta">&gt;&gt;&gt; </span>model = UNet2DModel(
<span class="hljs-meta">... </span> sample_size=config.image_size, <span class="hljs-comment"># 目标图像分辨率</span>
<span class="hljs-meta">... </span> in_channels=<span class="hljs-number">3</span>, <span class="hljs-comment"># 输入通道数,RGB 图像为 3</span>
<span class="hljs-meta">... </span> out_channels=<span class="hljs-number">3</span>, <span class="hljs-comment"># 输出通道数</span>
<span class="hljs-meta">... </span> layers_per_block=<span class="hljs-number">2</span>, <span class="hljs-comment"># 每个 UNet block 中使用多少个 ResNet 层</span>
<span class="hljs-meta">... </span> block_out_channels=(<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">512</span>, <span class="hljs-number">512</span>), <span class="hljs-comment"># 每个 UNet block 的输出通道数</span>
<span class="hljs-meta">... </span> down_block_types=(
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>, <span class="hljs-comment"># 标准的 ResNet 下采样块</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;AttnDownBlock2D&quot;</span>, <span class="hljs-comment"># 带空间自注意力的 ResNet 下采样块</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span> up_block_types=(
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>, <span class="hljs-comment"># 标准的 ResNet 上采样块</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;AttnUpBlock2D&quot;</span>, <span class="hljs-comment"># 带空间自注意力的 ResNet 上采样块</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-y8x571">通常最好先快速检查一下,样本图像的形状和模型输出形状是否一致:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>sample_image = dataset[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;images&quot;</span>].unsqueeze(<span class="hljs-number">0</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Input shape:&quot;</span>, sample_image.shape)
Input shape: torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Output shape:&quot;</span>, model(sample_image, timestep=<span class="hljs-number">0</span>).sample.shape)
Output shape: torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1yxiknp">很好!接下来,你还需要一个调度器为图像添加噪声。</p> <h2 class="relative group"><a id="创建调度器" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#创建调度器"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>创建调度器</span></h2> <p data-svelte-h="svelte-1lqt886">调度器在训练和推理中的行为不同。推理时,调度器会从噪声中生成图像;训练时,调度器会取扩散过程某一步的模型输出或样本,并根据<em>噪声日程</em><em>更新规则</em>对图像加噪。</p> <p data-svelte-h="svelte-7uzu0g">我们先看看 <code>DDPMScheduler</code>,并使用 <code>add_noise</code> 方法给前面的 <code>sample_image</code> 添加一些随机噪声:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DDPMScheduler
<span class="hljs-meta">&gt;&gt;&gt; </span>noise_scheduler = DDPMScheduler(num_train_timesteps=<span class="hljs-number">1000</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>noise = torch.randn(sample_image.shape)
<span class="hljs-meta">&gt;&gt;&gt; </span>timesteps = torch.LongTensor([<span class="hljs-number">50</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
<span class="hljs-meta">&gt;&gt;&gt; </span>Image.fromarray(((noisy_image.permute(<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>) + <span class="hljs-number">1.0</span>) * <span class="hljs-number">127.5</span>).<span class="hljs-built_in">type</span>(torch.uint8).numpy()[<span class="hljs-number">0</span>])<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-qcxk5"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png"></div> <p data-svelte-h="svelte-6ma4h6">模型训练的目标,就是预测添加到图像中的噪声。当前步骤的损失可以这样计算:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch.nn.functional <span class="hljs-keyword">as</span> F
<span class="hljs-meta">&gt;&gt;&gt; </span>noise_pred = model(noisy_image, timesteps).sample
<span class="hljs-meta">&gt;&gt;&gt; </span>loss = F.mse_loss(noise_pred, noise)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="训练模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#训练模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>训练模型</span></h2> <p data-svelte-h="svelte-123jz3t">到这里,启动训练所需的大部分组件都准备好了,剩下的就是把它们拼起来。</p> <p data-svelte-h="svelte-1rt5439">首先,你需要一个优化器和一个学习率调度器:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers.optimization <span class="hljs-keyword">import</span> get_cosine_schedule_with_warmup
<span class="hljs-meta">&gt;&gt;&gt; </span>optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
<span class="hljs-meta">&gt;&gt;&gt; </span>lr_scheduler = get_cosine_schedule_with_warmup(
<span class="hljs-meta">... </span> optimizer=optimizer,
<span class="hljs-meta">... </span> num_warmup_steps=config.lr_warmup_steps,
<span class="hljs-meta">... </span> num_training_steps=(<span class="hljs-built_in">len</span>(train_dataloader) * config.num_epochs),
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1osxiju">接着,你还需要一种评估模型的方法。评估时,我们可以使用 <code>DDPMPipeline</code> 生成一批示例图像,并把它们保存成一个网格图:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DDPMPipeline
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> make_image_grid
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> os
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">evaluate</span>(<span class="hljs-params">config, epoch, pipeline</span>):
<span class="hljs-meta">... </span> <span class="hljs-comment"># 从随机噪声采样图像(这就是反向扩散过程)</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 管道默认输出类型是 `List[PIL.Image]`</span>
<span class="hljs-meta">... </span> images = pipeline(
<span class="hljs-meta">... </span> batch_size=config.eval_batch_size,
<span class="hljs-meta">... </span> generator=torch.Generator(device=<span class="hljs-string">&#x27;cpu&#x27;</span>).manual_seed(config.seed), <span class="hljs-comment"># 单独使用一个 torch generator,避免回退主训练循环的随机状态</span>
<span class="hljs-meta">... </span> ).images
<span class="hljs-meta">... </span> <span class="hljs-comment"># 把图像拼成网格</span>
<span class="hljs-meta">... </span> image_grid = make_image_grid(images, rows=<span class="hljs-number">4</span>, cols=<span class="hljs-number">4</span>)
<span class="hljs-meta">... </span> <span class="hljs-comment"># 保存图像</span>
<span class="hljs-meta">... </span> test_dir = os.path.join(config.output_dir, <span class="hljs-string">&quot;samples&quot;</span>)
<span class="hljs-meta">... </span> os.makedirs(test_dir, exist_ok=<span class="hljs-literal">True</span>)
<span class="hljs-meta">... </span> image_grid.save(<span class="hljs-string">f&quot;<span class="hljs-subst">{test_dir}</span>/<span class="hljs-subst">{epoch:04d}</span>.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1t8i912">现在,你可以用 🤗 Accelerate 把这些组件包装进一个训练循环中,轻松实现 TensorBoard 日志记录、梯度累积和混合精度训练。为了把模型上传到 Hub,还需要写一个函数来创建仓库并将训练结果推送到 Hub。</p> <blockquote class="tip" data-svelte-h="svelte-fcxys8"><p>💡 下面的训练循环看起来可能有点长,也有点吓人,但等你真正只用一行代码启动训练时,就会觉得很值得!如果你现在只想快点开始生成图像,也可以先直接复制运行下面的代码,之后再回头仔细研究训练循环,比如等模型训练完成的时候。🤗</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> create_repo, upload_folder
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> pathlib <span class="hljs-keyword">import</span> Path
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> os
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">train_loop</span>(<span class="hljs-params">config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler</span>):
<span class="hljs-meta">... </span> <span class="hljs-comment"># 初始化 accelerator 和 tensorboard 日志</span>
<span class="hljs-meta">... </span> accelerator = Accelerator(
<span class="hljs-meta">... </span> mixed_precision=config.mixed_precision,
<span class="hljs-meta">... </span> gradient_accumulation_steps=config.gradient_accumulation_steps,
<span class="hljs-meta">... </span> log_with=<span class="hljs-string">&quot;tensorboard&quot;</span>,
<span class="hljs-meta">... </span> project_dir=os.path.join(config.output_dir, <span class="hljs-string">&quot;logs&quot;</span>),
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.is_main_process:
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.output_dir <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
<span class="hljs-meta">... </span> os.makedirs(config.output_dir, exist_ok=<span class="hljs-literal">True</span>)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.push_to_hub:
<span class="hljs-meta">... </span> repo_id = create_repo(
<span class="hljs-meta">... </span> repo_id=config.hub_model_id <span class="hljs-keyword">or</span> Path(config.output_dir).name, exist_ok=<span class="hljs-literal">True</span>
<span class="hljs-meta">... </span> ).repo_id
<span class="hljs-meta">... </span> accelerator.init_trackers(<span class="hljs-string">&quot;train_example&quot;</span>)
<span class="hljs-meta">... </span> <span class="hljs-comment"># 准备所有对象</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 不需要记住固定顺序,只要解包时和传给 prepare 的顺序一致即可。</span>
<span class="hljs-meta">... </span> model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
<span class="hljs-meta">... </span> model, optimizer, train_dataloader, lr_scheduler
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> global_step = <span class="hljs-number">0</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 开始训练模型</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(config.num_epochs):
<span class="hljs-meta">... </span> progress_bar = tqdm(total=<span class="hljs-built_in">len</span>(train_dataloader), disable=<span class="hljs-keyword">not</span> accelerator.is_local_main_process)
<span class="hljs-meta">... </span> progress_bar.set_description(<span class="hljs-string">f&quot;Epoch <span class="hljs-subst">{epoch}</span>&quot;</span>)
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_dataloader):
<span class="hljs-meta">... </span> clean_images = batch[<span class="hljs-string">&quot;images&quot;</span>]
<span class="hljs-meta">... </span> <span class="hljs-comment"># 为图像采样噪声</span>
<span class="hljs-meta">... </span> noise = torch.randn(clean_images.shape, device=clean_images.device)
<span class="hljs-meta">... </span> bs = clean_images.shape[<span class="hljs-number">0</span>]
<span class="hljs-meta">... </span> <span class="hljs-comment"># 为每张图像随机采样一个时间步</span>
<span class="hljs-meta">... </span> timesteps = torch.randint(
<span class="hljs-meta">... </span> <span class="hljs-number">0</span>, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
<span class="hljs-meta">... </span> dtype=torch.int64
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> <span class="hljs-comment"># 按照每个时间步对应的噪声强度给干净图像加噪</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># (这就是前向扩散过程)</span>
<span class="hljs-meta">... </span> noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
<span class="hljs-meta">... </span> <span class="hljs-keyword">with</span> accelerator.accumulate(model):
<span class="hljs-meta">... </span> <span class="hljs-comment"># 预测噪声残差</span>
<span class="hljs-meta">... </span> noise_pred = model(noisy_images, timesteps, return_dict=<span class="hljs-literal">False</span>)[<span class="hljs-number">0</span>]
<span class="hljs-meta">... </span> loss = F.mse_loss(noise_pred, noise)
<span class="hljs-meta">... </span> accelerator.backward(loss)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.sync_gradients:
<span class="hljs-meta">... </span> accelerator.clip_grad_norm_(model.parameters(), <span class="hljs-number">1.0</span>)
<span class="hljs-meta">... </span> optimizer.step()
<span class="hljs-meta">... </span> lr_scheduler.step()
<span class="hljs-meta">... </span> optimizer.zero_grad()
<span class="hljs-meta">... </span> progress_bar.update(<span class="hljs-number">1</span>)
<span class="hljs-meta">... </span> logs = {<span class="hljs-string">&quot;loss&quot;</span>: loss.detach().item(), <span class="hljs-string">&quot;lr&quot;</span>: lr_scheduler.get_last_lr()[<span class="hljs-number">0</span>], <span class="hljs-string">&quot;step&quot;</span>: global_step}
<span class="hljs-meta">... </span> progress_bar.set_postfix(**logs)
<span class="hljs-meta">... </span> accelerator.log(logs, step=global_step)
<span class="hljs-meta">... </span> global_step += <span class="hljs-number">1</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 每个 epoch 后可以选择用 evaluate() 采样一些演示图像,并保存模型</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.is_main_process:
<span class="hljs-meta">... </span> pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> (epoch + <span class="hljs-number">1</span>) % config.save_image_epochs == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> epoch == config.num_epochs - <span class="hljs-number">1</span>:
<span class="hljs-meta">... </span> evaluate(config, epoch, pipeline)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> (epoch + <span class="hljs-number">1</span>) % config.save_model_epochs == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> epoch == config.num_epochs - <span class="hljs-number">1</span>:
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.push_to_hub:
<span class="hljs-meta">... </span> upload_folder(
<span class="hljs-meta">... </span> repo_id=repo_id,
<span class="hljs-meta">... </span> folder_path=config.output_dir,
<span class="hljs-meta">... </span> commit_message=<span class="hljs-string">f&quot;Epoch <span class="hljs-subst">{epoch}</span>&quot;</span>,
<span class="hljs-meta">... </span> ignore_patterns=[<span class="hljs-string">&quot;step_*&quot;</span>, <span class="hljs-string">&quot;epoch_*&quot;</span>],
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> <span class="hljs-keyword">else</span>:
<span class="hljs-meta">... </span> pipeline.save_pretrained(config.output_dir)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18jog22">呼,这段代码确实不少!不过现在你终于可以用 🤗 Accelerate 的 <code>notebook_launcher</code> 函数启动训练了。把训练循环函数、所有训练参数以及进程数(你可以改成自己可用 GPU 的数量)传进去即可:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher
<span class="hljs-meta">&gt;&gt;&gt; </span>args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
<span class="hljs-meta">&gt;&gt;&gt; </span>notebook_launcher(train_loop, args, num_processes=<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6zpbxv">训练完成后,来看看你的扩散模型最终生成的 🦋 蝴蝶图像 🦋 吧!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> glob
<span class="hljs-meta">&gt;&gt;&gt; </span>sample_images = <span class="hljs-built_in">sorted</span>(glob.glob(<span class="hljs-string">f&quot;<span class="hljs-subst">{config.output_dir}</span>/samples/*.png&quot;</span>))
<span class="hljs-meta">&gt;&gt;&gt; </span>Image.<span class="hljs-built_in">open</span>(sample_images[-<span class="hljs-number">1</span>])<!-- HTML_TAG_END --></pre></div> <div class="flex justify-center" data-svelte-h="svelte-tzqtub"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png"></div> <h2 class="relative group"><a id="下一步" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#下一步"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>下一步</span></h2> <p data-svelte-h="svelte-exe5vc">无条件图像生成只是可训练任务中的一个例子。你可以继续访问 <a href="../training/overview">🧨 Diffusers 训练示例</a> 页面,探索更多任务和训练技术。比如:</p> <ul data-svelte-h="svelte-vw4xsz"><li><a href="../training/text_inversion">Textual Inversion</a>:教会模型一个特定的视觉概念,并把它融入生成结果中。</li> <li><a href="../training/dreambooth">DreamBooth</a>:给定某个主体的若干输入图像,生成该主体的个性化图像。</li> <li><a href="../training/text2image">引导</a>:在你自己的数据集上微调 Stable Diffusion 模型。</li> <li><a href="../training/lora">引导</a>:使用 LoRA 这种更省内存的方法,更快地微调超大模型。</li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/zh/tutorials/basic_training.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1l6n7ys = {
assets: "/docs/diffusers/pr_13098/zh",
base: "/docs/diffusers/pr_13098/zh",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_13098/zh/_app/immutable/entry/start.4dec4f79.js"),
import("/docs/diffusers/pr_13098/zh/_app/immutable/entry/app.077c1f41.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 52],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
71.1 kB
·
Xet hash:
f20f9b78f0c8ce751d6520011a13efa5b40ed73f8fe1930ca5018336bb2842f6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.