Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"从头开始训练因果语言模型","local":"从头开始训练因果语言模型","sections":[{"title":"收集数据","local":"收集数据","sections":[],"depth":2},{"title":"准备数据集","local":"准备数据集","sections":[],"depth":2},{"title":"初始化一个新模型","local":"初始化一个新模型","sections":[],"depth":2},{"title":"使用 pipeline 进行代码生成","local":"使用管道生成代码","sections":[],"depth":2},{"title":"使用🤗 Accelerate 进行训练","local":"使用🤗 Accelerate 进行训练","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1021/zh-CN/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/entry/start.f3a1a511.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/singletons.9bf55235.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/paths.0ba10750.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/entry/app.c39e37cf.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/index.2bf4358c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/nodes/0.dad18ce3.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/nodes/58.eab5671b.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/Tip.363c041f.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/Youtube.1e50a667.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/CodeBlock.4e987730.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/DocNotebookDropdown.efc1fb7c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/FrameworkSwitchCourse.8d4d4ab6.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/getInferenceSnippets.ebf8be91.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"从头开始训练因果语言模型","local":"从头开始训练因果语言模型","sections":[{"title":"收集数据","local":"收集数据","sections":[],"depth":2},{"title":"准备数据集","local":"准备数据集","sections":[],"depth":2},{"title":"初始化一个新模型","local":"初始化一个新模型","sections":[],"depth":2},{"title":"使用 pipeline 进行代码生成","local":"使用管道生成代码","sections":[],"depth":2},{"title":"使用🤗 Accelerate 进行训练","local":"使用🤗 Accelerate 进行训练","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="从头开始训练因果语言模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#从头开始训练因果语言模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>从头开始训练因果语言模型</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/chapter7/section6_pt.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/chapter7/section6_pt.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-1t2m835">到目前为止,我们主要使用预训练模型,并通过复用预训练的权重,然后使用新的数据对它们进行微调,以适应新的应用场景。正如我们在 <a href="/course/chapter1">第一章</a> 中看到的,这通常称为 <code>迁移学习(transfer learning)</code> ,对于大多数标注数据稀缺的应用场景,它是一种将 Transformer 模型应用到大部分真实的应用场景中的一个非常成功的策略。在本章中,我们将采用不同的方法并从头开始训练一个全新的模型。如果你有大量数据而且这些数据与可用模型的预训练数据差异很大,那么这是一个很好的方法。然而,相比仅微调现有模型,预训练语言模型需要更多的计算资源。训练一个新模型可能是有意义的示例包括由音乐符号、DNA 等分子序列或编程语言组成的数据集。编程语言组成的数据集最近广泛地受到关注,这要归功于 TabNine 和 GitHub 的 Copilot 等工具的流行,它们由 OpenAI 的 Codex 模型提供支持,可以生成长代码序列。这种文本生成任务最适合使用自回归或因果语言模型(例如 GPT-2)。</p> <p data-svelte-h="svelte-z2rcsh">在这一节,我们将构建一个精简版的代码生成模型:使用 Python 代码的一个数据集,来实现一行代码的补全,而不是直接生成完整的函数或类。当你使用 Python 处理数据时,你经常会接触到 Python 数据科学栈,包括 <code>matplotlib</code> , <code>seaborn</code> , <code>pandas</code> ,和 <code>scikit-learn</code> 这些库。当使用这些框架时,经常需要查找特定的命令,如果我们能够用模型来自动给出恰当的推荐命令就太好了!</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/Vpjb1lu0MDk" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-ginij8">在 <a href="/course/chapter6">第六章</a> 中,我们创建了一个高效的 tokenizer 来处理 Python 源代码,但我们还需要一个大规模的数据集来预训练模型。在这里,我们将使用 tokenizer 处理一个来自 GitHub 仓库的 Python 代码语料库。然后,我们将使用 Trainer API 和 🤗 Accelerate 来训练模型。让我们开始吧!</p> <iframe src="https://course-demos-codeparrot-ds.hf.space" frameborder="0" height="300" title="Gradio app" class="block dark:hidden container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe> <p data-svelte-h="svelte-1k1xdgt">这里展示的是一个已经训练并上传到 Hub 的模型,它就是使用本节中的代码训练的。你可以在 <a href="https://huggingface.co/huggingface-course/codeparrot-ds?text=plt.imshow%28" rel="nofollow">这里</a> 找到它。注意,由于文本生成过程中有一些随机性,你可能会得到稍微不同的结果。</p> <h2 class="relative group"><a id="收集数据" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#收集数据"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>收集数据</span></h2> <p data-svelte-h="svelte-13yypc1">我们可以从诸如 GitHub 这样的代码仓库中获取丰富的 Python 代码,通过对每个 Python 仓库进行抓取,我们就可以创建一个数据集。这就是在 <a href="https://learning.oreilly.com/library/view/natural-language-processing/9781098136789/" rel="nofollow">Transformers textbook</a> 中预训练一个大型 GPT-2 模型的方法。开发者整理了名为 <code>codeparrot</code> 的一个大约为 180GB 的 GitHub 数据集, 其中包含大约 2,000 万个的Python 文件。 开发者用这些文件构建了一个数据集,并在 <a href="https://huggingface.co/datasets/transformersbook/codeparrot" rel="nofollow">Hugging Face Hub</a> 上分享了这个数据集。</p> <p data-svelte-h="svelte-1eicdj2">然而,使用完整语料库的训练既耗时又费力,我们只需要找到 Python 数据科学栈相关的数据集子集。所以,让我们从 <code>codeparrot</code> 数据集中筛选出包含这个栈中所有相关库的所有文件。由于数据集的太大,我们希望避免直接把全部的数据集下载下来;因此,我们将使用流式传输的方法来动态过滤它。为了使用上述的库来筛选代码样本,我们将使用以下函数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">any_keyword_in_string</span>(<span class="hljs-params">string, keywords</span>): | |
| <span class="hljs-keyword">for</span> keyword <span class="hljs-keyword">in</span> keywords: | |
| <span class="hljs-keyword">if</span> keyword <span class="hljs-keyword">in</span> string: | |
| <span class="hljs-keyword">return</span> <span class="hljs-literal">True</span> | |
| <span class="hljs-keyword">return</span> <span class="hljs-literal">False</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-188avx">让我们用两个例子来测试一下:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->filters = [<span class="hljs-string">"pandas"</span>, <span class="hljs-string">"sklearn"</span>, <span class="hljs-string">"matplotlib"</span>, <span class="hljs-string">"seaborn"</span>] | |
| example_1 = <span class="hljs-string">"import numpy as np"</span> | |
| example_2 = <span class="hljs-string">"import pandas as pd"</span> | |
| <span class="hljs-built_in">print</span>( | |
| any_keyword_in_string(example_1, filters), any_keyword_in_string(example_2, filters) | |
| )<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-literal">False</span> <span class="hljs-literal">True</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1eff2hf">我们可以使用这个函数来创建一个新的函数,该函数将流式传输数据集并过滤我们想要的元素:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> collections <span class="hljs-keyword">import</span> defaultdict | |
| <span class="hljs-keyword">from</span> tqdm <span class="hljs-keyword">import</span> tqdm | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">filter_streaming_dataset</span>(<span class="hljs-params">dataset, filters</span>): | |
| filtered_dict = defaultdict(<span class="hljs-built_in">list</span>) | |
| total = <span class="hljs-number">0</span> | |
| <span class="hljs-keyword">for</span> sample <span class="hljs-keyword">in</span> tqdm(<span class="hljs-built_in">iter</span>(dataset)): | |
| total += <span class="hljs-number">1</span> | |
| <span class="hljs-keyword">if</span> any_keyword_in_string(sample[<span class="hljs-string">"content"</span>], filters): | |
| <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> sample.items(): | |
| filtered_dict[k].append(v) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"<span class="hljs-subst">{<span class="hljs-built_in">len</span>(filtered_dict[<span class="hljs-string">'content'</span>])/total:<span class="hljs-number">.2</span>%}</span> of data after filtering."</span>) | |
| <span class="hljs-keyword">return</span> Dataset.from_dict(filtered_dict)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1th7mdr">然后我们可以直接使用这里函数流式处理数据集:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 执行这个代码块需要非常长的时间,因此你可以跳过它,继续执行下一个!</span> | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| split = <span class="hljs-string">"train"</span> <span class="hljs-comment"># "valid"</span> | |
| filters = [<span class="hljs-string">"pandas"</span>, <span class="hljs-string">"sklearn"</span>, <span class="hljs-string">"matplotlib"</span>, <span class="hljs-string">"seaborn"</span>] | |
| data = load_dataset(<span class="hljs-string">f"transformersbook/codeparrot-<span class="hljs-subst">{split}</span>"</span>, split=split, streaming=<span class="hljs-literal">True</span>) | |
| filtered_data = filter_streaming_dataset(data, filters)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">3.26</span>% of data after filtering.<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xkdug1">完成这个操作后,我们过滤后的数据集只有原始数据集的大约 3%,但这仍然是相当可观的大小——最终的数据集是 6GB,由 600,000 个 Python 脚本组成!</p> <p data-svelte-h="svelte-13urty5">过滤完整的数据集可能需要 2-3 小时,这取决于你的机器性能和带宽。如果你不想亲自经历这个漫长的过程,我们在 Hub 上提供了过滤后的数据集供你下载:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset, DatasetDict | |
| ds_train = load_dataset(<span class="hljs-string">"huggingface-course/codeparrot-ds-train"</span>, split=<span class="hljs-string">"train"</span>) | |
| ds_valid = load_dataset(<span class="hljs-string">"huggingface-course/codeparrot-ds-valid"</span>, split=<span class="hljs-string">"validation"</span>) | |
| raw_datasets = DatasetDict( | |
| { | |
| <span class="hljs-string">"train"</span>: ds_train, <span class="hljs-comment"># .shuffle().select(range(50000)),</span> | |
| <span class="hljs-string">"valid"</span>: ds_valid, <span class="hljs-comment"># .shuffle().select(range(500))</span> | |
| } | |
| ) | |
| raw_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({ | |
| train: Dataset({ | |
| features: [<span class="hljs-string">'repo_name'</span>, <span class="hljs-string">'path'</span>, <span class="hljs-string">'copies'</span>, <span class="hljs-string">'size'</span>, <span class="hljs-string">'content'</span>, <span class="hljs-string">'license'</span>], | |
| num_rows: <span class="hljs-number">606720</span> | |
| }) | |
| valid: Dataset({ | |
| features: [<span class="hljs-string">'repo_name'</span>, <span class="hljs-string">'path'</span>, <span class="hljs-string">'copies'</span>, <span class="hljs-string">'size'</span>, <span class="hljs-string">'content'</span>, <span class="hljs-string">'license'</span>], | |
| num_rows: <span class="hljs-number">3322</span> | |
| }) | |
| })<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zsplm1">让我们看一个来自数据集的例子。我们将只显示每个字段的前 200 个字符:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> raw_datasets[<span class="hljs-string">"train"</span>][<span class="hljs-number">0</span>]: | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"<span class="hljs-subst">{key.upper()}</span>: <span class="hljs-subst">{raw_datasets[<span class="hljs-string">'train'</span>][<span class="hljs-number">0</span>][key][:<span class="hljs-number">200</span>]}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'REPO_NAME: kmike/scikit-learn'</span> | |
| <span class="hljs-string">'PATH: sklearn/utils/__init__.py'</span> | |
| <span class="hljs-string">'COPIES: 3'</span> | |
| <span class="hljs-string">'SIZE: 10094'</span> | |
| <span class="hljs-string">'''CONTENT: """ | |
| The :mod:`sklearn.utils` module includes various utilites. | |
| """ | |
| from collections import Sequence | |
| import numpy as np | |
| from scipy.sparse import issparse | |
| import warnings | |
| from .murmurhash import murm | |
| LICENSE: bsd-3-clause'''</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-r7e49a">我们可以看到, <code>content</code> 字段包含了我们希望模型训练的代码。有了这个数据集之后,我们需要对文本进行一些处理,以便它们适合于预训练。</p> <h2 class="relative group"><a id="准备数据集" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#准备数据集"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>准备数据集</span></h2> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/ma1TrR7gE7I" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-14jihz1">首先,我们需要将数据进行分词处理,这样才能进行训练。由于我们的主要目标是自动补全短的函数调用,因此我们可以将上下文大小设置得相对较小。这样做的好处是我们可以更快地训练模型,而且需要的内存也大大减少。如果你的应用需要更多的上下文(比如,你希望模型根据包含函数定义的文件编写单元测试),那么应该增大该数字,但是也要记住这会增加 GPU 显存的占用。现在,我们将上下文大小固定为 128 个 tokens 而不是在 GPT-2 或 GPT-3 中使用的 1,024 或 2,048 个 tokens</p> <p data-svelte-h="svelte-13spu2c">大多数文档都包含超过 128 个 tokens 因此简单地将输入截断到最大长度会删除我们数据集的很一大部分。因此,我们将使用 <code>return_overflowing_tokens</code> 选项将整个输入进行分词处理,并将其分割为几个块,正如我们在 <a href="/course/chapter6/4">第六章</a> 中所做的那样。我们还将使用 <code>return_length</code> 选项自动返回创建的每个块的长度。通常,最后一个块的大小会小于上下文大小,我们将去掉最后一块以避免填充问题;因为我们已经有足够的数据,所以不需要它们。</p> <div class="flex justify-center" data-svelte-h="svelte-1mf8pz"><img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/chunking_texts.svg" alt="Chunking a large texts in several pieces."> <img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/chunking_texts-dark.svg" alt="Chunking a large texts in several pieces."></div> <p data-svelte-h="svelte-1jx1efh">让我们通过查看前两个示例来具体了解结果怎么样:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| context_length = <span class="hljs-number">128</span> | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"huggingface-course/code-search-net-tokenizer"</span>) | |
| outputs = tokenizer( | |
| raw_datasets[<span class="hljs-string">"train"</span>][:<span class="hljs-number">2</span>][<span class="hljs-string">"content"</span>], | |
| truncation=<span class="hljs-literal">True</span>, | |
| max_length=context_length, | |
| return_overflowing_tokens=<span class="hljs-literal">True</span>, | |
| return_length=<span class="hljs-literal">True</span>, | |
| ) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Input IDs length: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(outputs[<span class="hljs-string">'input_ids'</span>])}</span>"</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Input chunk lengths: <span class="hljs-subst">{(outputs[<span class="hljs-string">'length'</span>])}</span>"</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Chunk mapping: <span class="hljs-subst">{outputs[<span class="hljs-string">'overflow_to_sample_mapping'</span>]}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Input IDs length: <span class="hljs-number">34</span> | |
| Input chunk lengths: [<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">117</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">41</span>] | |
| Chunk mapping: [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v8sp6o">我们可以看到,这两个例子总共得到了 34 个块。查看块长度,我们可以看到两个文档末端的块少于 128 个 tokens (分别为 117 和 41)。不过这些只占我们所拥有的总块数的一小部分,因此我们可以放心地丢掉它们。通过 <code>overflow_to_sample_mapping</code> 字段,我们还可以分辨出哪些块属于哪个样本。</p> <p data-svelte-h="svelte-1lj93sp">在这个操作中,我们使用了🤗 Datasets 中的 <code>Dataset.map()</code> 函数的一个便捷的特性,即它并不需要一对一地设置分块后和分块前的映射关系;正如我们在 <a href="/course/chapter7/3">第三节</a> 中看到的,我们可以自由地将一个样本拆分或者删除部分样本来创建比输入的 <code>batch_size</code> 更多或更少元素的 batch。 <code>Dataset.map()</code> 函数会自动帮我们关联映射关系,当进行像数据增强或数据过滤这样改变元素数量的操作时非常有用。在我们的情况下,当将每个样本分词并分割成指定上下文大小的块时,我们从每个样本中创建了许多样本。我们需要删除原本的列,因为它们的大小和我们分割后的大小不一样。如果我们想保留它们,我们可以复制它们来填充,并在 <code>Dataset.map()</code> 调用中返回它们。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize</span>(<span class="hljs-params">element</span>): | |
| outputs = tokenizer( | |
| element[<span class="hljs-string">"content"</span>], | |
| truncation=<span class="hljs-literal">True</span>, | |
| max_length=context_length, | |
| return_overflowing_tokens=<span class="hljs-literal">True</span>, | |
| return_length=<span class="hljs-literal">True</span>, | |
| ) | |
| input_batch = [] | |
| <span class="hljs-keyword">for</span> length, input_ids <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(outputs[<span class="hljs-string">"length"</span>], outputs[<span class="hljs-string">"input_ids"</span>]): | |
| <span class="hljs-keyword">if</span> length == context_length: | |
| input_batch.append(input_ids) | |
| <span class="hljs-keyword">return</span> {<span class="hljs-string">"input_ids"</span>: input_batch} | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>( | |
| tokenize, batched=<span class="hljs-literal">True</span>, remove_columns=raw_datasets[<span class="hljs-string">"train"</span>].column_names | |
| ) | |
| tokenized_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({ | |
| train: Dataset({ | |
| features: [<span class="hljs-string">'input_ids'</span>], | |
| num_rows: <span class="hljs-number">16702061</span> | |
| }) | |
| valid: Dataset({ | |
| features: [<span class="hljs-string">'input_ids'</span>], | |
| num_rows: <span class="hljs-number">93164</span> | |
| }) | |
| })<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vpqfku">我们现在有 1670 万个样本,每个样本有 128 个 tokens 总共相当于大约 21 亿个 tokens。作为参考,OpenAI 的 GPT-3 和 Codex 模型分别在 300 和 1000 亿个 tokens 上进行了训练,其中 Codex 模型从 GPT-3 checkpoint 初始化。本节的目标不是与这些能生成长且连贯文本的模型竞争,而是创建一个能为数据科学家提供快速自动代码补全功能的精简版本。</p> <p data-svelte-h="svelte-1bmkjtp">既然我们已经准备好了数据集,那就来设置模型吧!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-t8kag1">✏️ <strong>试一试!</strong>这里我们删除了所有小于设定的上下文大小的块,并不会造成大问题,因为我们使用的是比较小的上下文窗口。随着增大上下文大小(或者语料库中的文档长度都很短),被抛弃的块的比例也会增加。更有效方法是将所有 tokenize 后的样本拼接起来加入一个 batch 中,每个样本之间有一个 <code>eos_token_id</code> token 作为分隔,然后对连接后的序列进行切块处理。作为练习,修改 <code>tokenize()</code> 函数以利用这种方法。请注意,为了获取完整的 token ID 序列你需要设置 <code>truncation=False</code> ,并删除 tokenizer 中的其他参数。</p></div> <h2 class="relative group"><a id="初始化一个新模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#初始化一个新模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>初始化一个新模型</span></h2> <p data-svelte-h="svelte-3zroq5">我们的第一步是初始化一个全新地 GPT-2 模型。我们可以通过加载预训练配置来初始化一个与 GPT-2 small 相同的配置的模型,并确保 tokenizer 大小与模型的词汇表大小匹配,以及设置 <code>bos</code> 和 <code>eos</code> (序列的开始和结束) token IDs:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, GPT2LMHeadModel, AutoConfig | |
| config = AutoConfig.from_pretrained( | |
| <span class="hljs-string">"gpt2"</span>, | |
| vocab_size=<span class="hljs-built_in">len</span>(tokenizer), | |
| n_ctx=context_length, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nsrsjx">有了这个配置对象,我们就可以加载一个全新的 GPT-2 模型。注意,这是我们第一次不使用 <code>from_pretrained()</code> 函数,因为我们实际上是自己初始化一个全新的模型而不是从一个预训练的模型继续训练:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = GPT2LMHeadModel(config) | |
| model_size = <span class="hljs-built_in">sum</span>(t.numel() <span class="hljs-keyword">for</span> t <span class="hljs-keyword">in</span> model.parameters()) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"GPT-2 size: <span class="hljs-subst">{model_size/<span class="hljs-number">1000</span>**<span class="hljs-number">2</span>:<span class="hljs-number">.1</span>f}</span>M parameters"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->GPT-<span class="hljs-number">2</span> size: <span class="hljs-number">124.2</span>M parameters<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15c2pk7">我们的新模型有 124M 个参数需要训练。在开始训练之前,我们需要设置一个数据整理器(DataCollator),它将负责创建 Batch。我们可以使用 <code>DataCollatorForLanguageModeling</code> ,顾名思义,它专门用于语言建模。除了堆叠和填充创建 Batch 之外,它还负责创建语言模型的待预测的标签 —— 在因果语言建模中,输入就是待预测的标签(只是偏移一个元素),而这个数据整理器(DataCollator)在训练过程中实时将输入偏移一个元素来创建它们,因此我们不需要复制 <code>input_ids</code> 。</p> <p data-svelte-h="svelte-sbssvb">注意, <code>DataCollatorForLanguageModeling</code> 同时支持掩码语言建模 (MLM) 和因果语言建模 (CLM)。默认情况下它安装 MLM 需要的格式准备数据,但我们可以通过设置 <code>mlm=False</code> 参数切换到 CLM。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> DataCollatorForLanguageModeling | |
| tokenizer.pad_token = tokenizer.eos_token | |
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ccdb1w">让我们看一个例子:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->out = data_collator([tokenized_datasets[<span class="hljs-string">"train"</span>][i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">5</span>)]) | |
| <span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> out: | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"<span class="hljs-subst">{key}</span> shape: <span class="hljs-subst">{out[key].shape}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->input_ids shape: torch.Size([<span class="hljs-number">5</span>, <span class="hljs-number">128</span>]) | |
| attention_mask shape: torch.Size([<span class="hljs-number">5</span>, <span class="hljs-number">128</span>]) | |
| labels shape: torch.Size([<span class="hljs-number">5</span>, <span class="hljs-number">128</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18al98k">我们可以看到示例的数据已经处理好了,并且所有 tensor 都具有相同的形状。</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1x635zg">⚠️ 输入序列和目标序列对齐将在模型内部自动进行,所以数据整理器只需复制输入序列来创建目标序列。</p></div> <p data-svelte-h="svelte-14vmza0">现在我们已经准备好了所有东西,可以开始训练我们的模型了——好像也不是那么困难!在我们开始训练之前,我们应该登录到 Hugging Face。如果你正在使用 Notebook 运行代码,你可以使用下面的实用函数进行登录:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login | |
| notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ui4fcw">这将显示一个小部件,你可以在其中输入你的 Hugging Face 登录凭据。</p> <p data-svelte-h="svelte-v3d1p4">如果你不是在 Notebook 上工作,只需在终端中输入以下行:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-cli login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6iztyo">剩下要做的就是配置训练参数并启动 <code>Trainer</code> 。本次的训练中我们将使用余弦学习率调度,并进行一些 Warmup。训练的 batch size 是 256 ( <code>per_device_train_batch_size</code> * <code>gradient_accumulation_steps</code> )。当单个 batch 无法放入内存时,可以使用梯度累积,并通过多次向前/向后传递逐步累积梯度。当我们在本节最后使用 🤗 Accelerate 创建训练循环时,我们将看到这一点。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer, TrainingArguments | |
| args = TrainingArguments( | |
| output_dir=<span class="hljs-string">"codeparrot-ds"</span>, | |
| per_device_train_batch_size=<span class="hljs-number">32</span>, | |
| per_device_eval_batch_size=<span class="hljs-number">32</span>, | |
| evaluation_strategy=<span class="hljs-string">"steps"</span>, | |
| eval_steps=<span class="hljs-number">5_000</span>, | |
| logging_steps=<span class="hljs-number">5_000</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">8</span>, | |
| num_train_epochs=<span class="hljs-number">1</span>, | |
| weight_decay=<span class="hljs-number">0.1</span>, | |
| warmup_steps=<span class="hljs-number">1_000</span>, | |
| lr_scheduler_type=<span class="hljs-string">"cosine"</span>, | |
| learning_rate=<span class="hljs-number">5e-4</span>, | |
| save_steps=<span class="hljs-number">5_000</span>, | |
| fp16=<span class="hljs-literal">True</span>, | |
| push_to_hub=<span class="hljs-literal">True</span>, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| args=args, | |
| data_collator=data_collator, | |
| train_dataset=tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=tokenized_datasets[<span class="hljs-string">"valid"</span>], | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1snttie">现在我们只需启动 <code>Trainer</code> 并等待训练完成。根据你是在整个训练集还是在训练集的一个子集上运行它,这将分别需要 20 或 2 个小时,因此请喝杯咖啡或者找一本好书来阅读!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-pbt7ds">训练完成后,我们可以将模型和 tokenizer 推送到 Hub:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.push_to_hub()<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-hphc86">✏️ <strong>试试看!</strong> 除了 <code>TrainingArguments</code> 之外,我们只需要大约 30 行代码就可以从原始文本到训练 GPT-2。用你自己的数据集试试看,看看你能不能得到好的结果!</p></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1783bui">💡 如果你能使用多 GPU 的机器,尝试在那里运行代码。 <code>Trainer</code> 自动管理多台机器,这能极大地加快训练速度。</p></div> <h2 class="relative group"><a id="使用管道生成代码" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#使用管道生成代码"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>使用 pipeline 进行代码生成</span></h2> <p data-svelte-h="svelte-6igul9">现在是见证奇迹的时刻:我们来看看训练好的模型到底表现如何!我们可以在日志中看到损失持续下降,但要测试模型的效果,我们就看看它对一些提示的反应如何。为此,我们将模型包装在一个文本生成的 <code>pipeline</code> 中,并如果有 GPU 可用,我们将把它放在 GPU 上加快生成速度:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| pipe = pipeline( | |
| <span class="hljs-string">"text-generation"</span>, model=<span class="hljs-string">"huggingface-course/codeparrot-ds"</span>, device=device | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-p85nsj">让我们从简单的创建散点图任务开始:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->txt = <span class="hljs-string">"""\ | |
| # 创建一些数据 | |
| x = np.random.randn(100) | |
| y = np.random.randn(100) | |
| # 使用 x,y 创建散点图 | |
| """</span> | |
| <span class="hljs-built_in">print</span>(pipe(txt, num_return_sequences=<span class="hljs-number">1</span>)[<span class="hljs-number">0</span>][<span class="hljs-string">"generated_text"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 创建一些数据</span> | |
| x = np.random.randn(<span class="hljs-number">100</span>) | |
| y = np.random.randn(<span class="hljs-number">100</span>) | |
| <span class="hljs-comment"># 使用 x,y 创建散点图</span> | |
| plt.scatter(x, y) | |
| <span class="hljs-comment"># 创建散点</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4055bd">结果看起来是正确的。那么对于 <code>pandas</code> 操作也可以吗?让我们看看是否能从两个数组创建一个 <code>DataFrame</code> :</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->txt = <span class="hljs-string">"""\ | |
| # 创建一些数据 | |
| x = np.random.randn(100) | |
| y = np.random.randn(100) | |
| # 从 x 和 y 创建 dataframe | |
| """</span> | |
| <span class="hljs-built_in">print</span>(pipe(txt, num_return_sequences=<span class="hljs-number">1</span>)[<span class="hljs-number">0</span>][<span class="hljs-string">"generated_text"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 创建一些数据</span> | |
| x = np.random.randn(<span class="hljs-number">100</span>) | |
| y = np.random.randn(<span class="hljs-number">100</span>) | |
| <span class="hljs-comment"># 从 x 和 y 创建 dataframe</span> | |
| df = pd.DataFrame({<span class="hljs-string">'x'</span>: x, <span class="hljs-string">'y'</span>: y}) | |
| df.insert(<span class="hljs-number">0</span>,<span class="hljs-string">'x'</span>, x) | |
| <span class="hljs-keyword">for</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11bk8i6">很好,这是正确的答案——尽管它又把 <code>x</code> 重复插入了一次。而且由于生成的 token 数量有限,所以下面的 <code>for</code> 循环被切断了。让我们看看我们是否能做些更复杂的事情,让模型帮助我们使用 <code>groupby</code> 操作:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->txt = <span class="hljs-string">"""\ | |
| # 有职业,收入和名字的 dataframe | |
| df = pd.DataFrame({'profession': x, 'income':y, 'name': z}) | |
| # 计算每个职业的平均收入 | |
| """</span> | |
| <span class="hljs-built_in">print</span>(pipe(txt, num_return_sequences=<span class="hljs-number">1</span>)[<span class="hljs-number">0</span>][<span class="hljs-string">"generated_text"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 有职业,收入和名字的 dataframe</span> | |
| df = pd.DataFrame({<span class="hljs-string">'profession'</span>: x, <span class="hljs-string">'income'</span>:y, <span class="hljs-string">'name'</span>: z}) | |
| <span class="hljs-comment"># 计算每个职业的平均收入</span> | |
| profession = df.groupby([<span class="hljs-string">'profession'</span>]).mean() | |
| <span class="hljs-comment"># 计算</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15u88hv">不错;是正确的。最后,让我们看看是否能引导模型使用 <code>scikit-learn</code> 并建立一个随机森林模型:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->txt = <span class="hljs-string">""" | |
| # 从 scikit-learn 导入随机森林回归器 | |
| from sklearn.ensemble import RandomForestRegressor | |
| # 用 X, y 拟合带有 300 个估算器的随机森林模型: | |
| """</span> | |
| <span class="hljs-built_in">print</span>(pipe(txt, num_return_sequences=<span class="hljs-number">1</span>)[<span class="hljs-number">0</span>][<span class="hljs-string">"generated_text"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 从 scikit-learn 导入随机森林回归器</span> | |
| <span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestRegressor | |
| <span class="hljs-comment"># 用 X, y 拟合带有 300 个估算器的随机森林模型:</span> | |
| rf = RandomForestRegressor(n_estimators=<span class="hljs-number">300</span>, random_state=random_state, max_depth=<span class="hljs-number">3</span>) | |
| rf.fit(X, y) | |
| rf<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-rm1ynw">从这几个例子来看,模型似乎已经学习了 Python 数据科学堆栈的一些语法(当然,在将模型部署到现实世界之前,我们需要对其进行更全面的评估)。然而,有时候它需要更多的模型训练定制来达到特定情境所需的性能。例如,如果我们想动态更新 <code>batch_size</code> 或添加一个条件训练循环来跳过坏示例怎么办?一种选择是修改 <code>Trainer</code> 添加新的功能,但有时从头开始编写训练循环会更简单。这就是🤗 Accelerate 的用武之地。</p> <h2 class="relative group"><a id="使用🤗 Accelerate 进行训练" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#使用🤗 Accelerate 进行训练"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>使用🤗 Accelerate 进行训练</span></h2> <p data-svelte-h="svelte-xjh1eu">我们已经看到了如何使用 <code>Trainer</code> 训练模型,在 <code>Trainer</code> 中可以对训练过程可以通过修改一些参数进行一些定制。然而,有时我们想要完全控制训练循环,或者我们想要进行一些更自由的的更改。在这种情况下 🤗 Accelerate 是一个不错的选择,本节我们将介绍如何使用它来训练我们的模型。为了让事情变得更有趣,相比于上面的 <code>Trainer</code> 我们还将在训练循环中添加一些修改。</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/Hm8_PgVTFuc" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-fh2nz1">由于我们主要关注的是为数据科学库提供合理的代码自动补充功能,因此对于更多使用这些库的训练样本赋予更高的权重是有意义的。我们可以通过使用 <code>plt</code> 、 <code>pd</code> 、 <code>sk</code> 、 <code>fit</code> 和 <code>predict</code> 等关键词来轻松地识别出这些例子,这些关键词是 <code>matplotlib.pyplot</code> 、 <code>pandas</code> 和 <code>sklearn</code> 导入后最常用重命名的名称,以及 <code>sklearn</code> 的 <code>fit/predict</code> 方法。如果这些在模型的内部是用单一的一个 <code>token</code> 表示的,我们可以通过 token 的 id 轻松地检查它们是否出现在输入序列中。然而,Tokens 有可能有空格前缀,所以我们也需要在 tokenizer 词汇表中检查这些关键词。为了验证这个策略的有效性,我们会在测试样本中添加一个应该被分割为多个 tokens 的测试 token:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->keytoken_ids = [] | |
| <span class="hljs-keyword">for</span> keyword <span class="hljs-keyword">in</span> [ | |
| <span class="hljs-string">"plt"</span>, | |
| <span class="hljs-string">"pd"</span>, | |
| <span class="hljs-string">"sk"</span>, | |
| <span class="hljs-string">"fit"</span>, | |
| <span class="hljs-string">"predict"</span>, | |
| <span class="hljs-string">" plt"</span>, | |
| <span class="hljs-string">" pd"</span>, | |
| <span class="hljs-string">" sk"</span>, | |
| <span class="hljs-string">" fit"</span>, | |
| <span class="hljs-string">" predict"</span>, | |
| <span class="hljs-string">"testtest"</span>, | |
| ]: | |
| ids = tokenizer([keyword]).input_ids[<span class="hljs-number">0</span>] | |
| <span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(ids) == <span class="hljs-number">1</span>: | |
| keytoken_ids.append(ids[<span class="hljs-number">0</span>]) | |
| <span class="hljs-keyword">else</span>: | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"Keyword has not single token: <span class="hljs-subst">{keyword}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'Keyword has not single token: testtest'</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ptcm2l">太好了,这个方法似乎很有效!我们现在可以编写一个自定义的损失函数,它的输入有输入序列、logits 和我们刚刚选择的关键字。首先需要对齐 <code>logits</code> 和 <code>inputs</code> : 并将输入序列右移一个单位形成目标序列,因为下一个 <code>token</code> 就是当前 <code>token</code> 的预测的目标。我们可以通过从输入序列的第二个 <code>token</code> 开始设置标签,因为模型不会预测第一个 <code>token</code>。然后我们截断最后一个 <code>logit</code>,因为我们没有完整输入序列后面的标签。有了这些,我们就可以计算每个样本的损失,并计算每个样本中所有关键词的出现次数。最后,我们使用出现次数作为权重,计算所有样本的加权平均值。由于我们不想抛弃所有没有关键词的样本,我们将所有的权重都加 1:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.nn <span class="hljs-keyword">import</span> CrossEntropyLoss | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">keytoken_weighted_loss</span>(<span class="hljs-params">inputs, logits, keytoken_ids, alpha=<span class="hljs-number">1.0</span></span>): | |
| <span class="hljs-comment"># 左移 tokens < n 预测 n</span> | |
| shift_labels = inputs[..., <span class="hljs-number">1</span>:].contiguous() | |
| shift_logits = logits[..., :-<span class="hljs-number">1</span>, :].contiguous() | |
| <span class="hljs-comment"># 计算每一个token的loss</span> | |
| loss_fct = CrossEntropyLoss(reduce=<span class="hljs-literal">False</span>) | |
| loss = loss_fct(shift_logits.view(-<span class="hljs-number">1</span>, shift_logits.size(-<span class="hljs-number">1</span>)), shift_labels.view(-<span class="hljs-number">1</span>)) | |
| <span class="hljs-comment"># 对于每个样本重新调整大小并平均</span> | |
| loss_per_sample = loss.view(shift_logits.size(<span class="hljs-number">0</span>), shift_logits.size(<span class="hljs-number">1</span>)).mean(axis=<span class="hljs-number">1</span>) | |
| <span class="hljs-comment"># 计算并缩放权重</span> | |
| weights = torch.stack([(inputs == kt).<span class="hljs-built_in">float</span>() <span class="hljs-keyword">for</span> kt <span class="hljs-keyword">in</span> keytoken_ids]).<span class="hljs-built_in">sum</span>( | |
| axis=[<span class="hljs-number">0</span>, <span class="hljs-number">2</span>] | |
| ) | |
| weights = alpha * (<span class="hljs-number">1.0</span> + weights) | |
| <span class="hljs-comment"># 计算评价权重</span> | |
| weighted_loss = (loss_per_sample * weights).mean() | |
| <span class="hljs-keyword">return</span> weighted_loss<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-30wr61">在我们开始使用这个精妙的新损失函数进行训练之前,我们需要准备一些事情:</p> <ul data-svelte-h="svelte-hscu3g"><li>我们需要数据加载器来批量加载数据。</li> <li>我们需要设置权重衰减参数。</li> <li>有时我们在调试模型的时候可能需要临时评估,所以将评估代码包装在一个函数中。</li></ul> <p data-svelte-h="svelte-8e1ssw">让我们从数据加载器开始。我们只需要将数据集的格式设置为 <code>"torch"</code> ,然后我们就可以将它传递给一个具有适当 <code>batch size</code> 的 PyTorch 的 <code>DataLoader</code> :</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data.dataloader <span class="hljs-keyword">import</span> DataLoader | |
| tokenized_datasets.set_format(<span class="hljs-string">"torch"</span>) | |
| train_dataloader = DataLoader(tokenized_datasets[<span class="hljs-string">"train"</span>], batch_size=<span class="hljs-number">32</span>, shuffle=<span class="hljs-literal">True</span>) | |
| eval_dataloader = DataLoader(tokenized_datasets[<span class="hljs-string">"valid"</span>], batch_size=<span class="hljs-number">32</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13o0n8c">接下来,我们将参数分组,以便优化器知道哪些参数需要进行额外的权重衰减。通常,所有的偏置和 LayerNorm 权重项都不需要进行权重衰减;因此我们可以这样做:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->weight_decay = <span class="hljs-number">0.1</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">get_grouped_params</span>(<span class="hljs-params">model, no_decay=[<span class="hljs-string">"bias"</span>, <span class="hljs-string">"LayerNorm.weight"</span>]</span>): | |
| params_with_wd, params_without_wd = [], [] | |
| <span class="hljs-keyword">for</span> n, p <span class="hljs-keyword">in</span> model.named_parameters(): | |
| <span class="hljs-keyword">if</span> <span class="hljs-built_in">any</span>(nd <span class="hljs-keyword">in</span> n <span class="hljs-keyword">for</span> nd <span class="hljs-keyword">in</span> no_decay): | |
| params_without_wd.append(p) | |
| <span class="hljs-keyword">else</span>: | |
| params_with_wd.append(p) | |
| <span class="hljs-keyword">return</span> [ | |
| {<span class="hljs-string">"params"</span>: params_with_wd, <span class="hljs-string">"weight_decay"</span>: weight_decay}, | |
| {<span class="hljs-string">"params"</span>: params_without_wd, <span class="hljs-string">"weight_decay"</span>: <span class="hljs-number">0.0</span>}, | |
| ]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-19lt4yo">我们希望在训练过程中定期在验证集上评估模型,让我们为此编写一个函数。它只需遍历评估数据加载器,并收集所有进程中的损失值:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">evaluate</span>(): | |
| model.<span class="hljs-built_in">eval</span>() | |
| losses = [] | |
| <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(eval_dataloader): | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = model(batch[<span class="hljs-string">"input_ids"</span>], labels=batch[<span class="hljs-string">"input_ids"</span>]) | |
| losses.append(accelerator.gather(outputs.loss)) | |
| loss = torch.mean(torch.cat(losses)) | |
| <span class="hljs-keyword">try</span>: | |
| perplexity = torch.exp(loss) | |
| <span class="hljs-keyword">except</span> OverflowError: | |
| perplexity = <span class="hljs-built_in">float</span>(<span class="hljs-string">"inf"</span>) | |
| <span class="hljs-keyword">return</span> loss.item(), perplexity.item()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6srwzq">通过 <code>evaluate()</code> 函数我们定期可以获取损失值和 <a href="/course/chapter7/3">困惑度(perplexity)</a> 。接下来,我们重新加载我们的模型以确保我们再次从头开始训练,而不是从上面的 <code>Trainer</code> 继续微调:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = GPT2LMHeadModel(config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-55938r">然后我们可以定义我们的优化器,使用之前的函数来分割权重衰减的参数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| optimizer = AdamW(get_grouped_params(model), lr=<span class="hljs-number">5e-4</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1jozmoj">现在让我们准备模型、优化器和数据加载器,然后我们可以开始训练:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| accelerator = Accelerator(fp16=<span class="hljs-literal">True</span>) | |
| model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | |
| model, optimizer, train_dataloader, eval_dataloader | |
| )<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-vxy3ez">🚨 如果你在 TPU 上训练,你需要将上述单元格开始的所有代码移到一个专门的训练函数中。更多详情请参阅 <a href="/course/chapter3">第三章</a> 。</p></div> <p data-svelte-h="svelte-1eft1a0">现在我们已经将我们的 <code>train_dataloader</code> 传递给了 <code>accelerator.prepare()</code> ,我们可以使用 <code>len()</code> 来计算训练步骤的数量。请记住,我们应该在准备好 <code>dataloader</code> 后再使用 <code>len()</code> ,因为改动 <code>dataloader</code> 会改变其长度。我们使用一个从学习率衰减到 0 的经典线性学习率调度:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->num_train_epochs = <span class="hljs-number">1</span> | |
| num_update_steps_per_epoch = <span class="hljs-built_in">len</span>(train_dataloader) | |
| num_training_steps = num_train_epochs * num_update_steps_per_epoch | |
| lr_scheduler = get_scheduler( | |
| name=<span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">1_000</span>, | |
| num_training_steps=num_training_steps, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-y3aow6">最后,为了将我们的模型推送到 Hub,我们需要在一个工作文件夹中创建一个 <code>Repository</code> 对象。如果你还没有登录的话,首先需要登录到 Hugging Face,我们将根据模型 ID 来确定仓库名称(你可以使用你喜欢的名字替换 <code>repo_name</code> ;它只需要包含你的用户名,可以使用 <code>get_full_repo_name()</code> 函数的查看目前的 <code>repo_name</code>):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> Repository, get_full_repo_name | |
| model_name = <span class="hljs-string">"codeparrot-ds-accelerate"</span> | |
| repo_name = get_full_repo_name(model_name) | |
| repo_name<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'sgugger/codeparrot-ds-accelerate'</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1u17xr7">然后我们可以将该仓库克隆到本地文件夹中。如果本地已经存在一个同名的文件夹,这个本地文件夹应该是我们正在使用的仓库的克隆在本地的版本:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->output_dir = <span class="hljs-string">"codeparrot-ds-accelerate"</span> | |
| repo = Repository(output_dir, clone_from=repo_name)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hiar3m">我们现在可以通过调用 <code>repo.push_to_hub()</code> 方法上传保存在 <code>output_dir</code> 中的所有内容。这将帮助我们在每个训练周期结束时上传中间模型。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->evaluate()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->(<span class="hljs-number">10.934126853942871</span>, <span class="hljs-number">56057.14453125</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1q6x48q">目前的损失和困惑度都是非常高的值,但这并不奇怪,因为我们还没有训练模型。到现在为止,我们已经为编写训练脚本的核心部分:训练循环已经做好了准备。在训练循环中,我们迭代遍历数据加载器并将成批量的数据传递给模型。有了模型输出的 logits,我们就可以使用自定义损失函数计算损伤。我们通过梯度累积步骤的数量来缩放损失,以避免在聚合更多步骤时产生更大的损失。在我们优化之前,我们也会剪裁梯度来更好的收敛。最后,每隔一段步数,我们用新的 <code>evaluate()</code> 函数在评估集上评估模型:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.notebook <span class="hljs-keyword">import</span> tqdm | |
| gradient_accumulation_steps = <span class="hljs-number">8</span> | |
| eval_steps = <span class="hljs-number">5_000</span> | |
| model.train() | |
| completed_steps = <span class="hljs-number">0</span> | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_train_epochs): | |
| <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> tqdm( | |
| <span class="hljs-built_in">enumerate</span>(train_dataloader, start=<span class="hljs-number">1</span>), total=num_training_steps | |
| ): | |
| logits = model(batch[<span class="hljs-string">"input_ids"</span>]).logits | |
| loss = keytoken_weighted_loss(batch[<span class="hljs-string">"input_ids"</span>], logits, keytoken_ids) | |
| <span class="hljs-keyword">if</span> step % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>: | |
| accelerator.<span class="hljs-built_in">print</span>( | |
| { | |
| <span class="hljs-string">"samples"</span>: step * samples_per_step, | |
| <span class="hljs-string">"steps"</span>: completed_steps, | |
| <span class="hljs-string">"loss/train"</span>: loss.item() * gradient_accumulation_steps, | |
| } | |
| ) | |
| loss = loss / gradient_accumulation_steps | |
| accelerator.backward(loss) | |
| <span class="hljs-keyword">if</span> step % gradient_accumulation_steps == <span class="hljs-number">0</span>: | |
| accelerator.clip_grad_norm_(model.parameters(), <span class="hljs-number">1.0</span>) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| completed_steps += <span class="hljs-number">1</span> | |
| <span class="hljs-keyword">if</span> (step % (eval_steps * gradient_accumulation_steps)) == <span class="hljs-number">0</span>: | |
| eval_loss, perplexity = evaluate() | |
| accelerator.<span class="hljs-built_in">print</span>({<span class="hljs-string">"loss/eval"</span>: eval_loss, <span class="hljs-string">"perplexity"</span>: perplexity}) | |
| model.train() | |
| accelerator.wait_for_everyone() | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) | |
| <span class="hljs-keyword">if</span> accelerator.is_main_process: | |
| tokenizer.save_pretrained(output_dir) | |
| repo.push_to_hub( | |
| commit_message=<span class="hljs-string">f"Training in progress step <span class="hljs-subst">{step}</span>"</span>, blocking=<span class="hljs-literal">False</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-al5zgg">就是这样 - 你现在拥有自己的因果语言模型(例如 GPT-2)的自定义训练循环,你可以根据自己的需要进一步定制。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1vr69hw">✏️ <strong>试试看!</strong> 创建适合你的用例的自定义损失函数,或在训练循环中添加另一个自定义步骤。</p></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-ry3l0i">✏️ <strong>试试看!</strong> 当运行长时间的训练实验时,使用 TensorBoard 或 Weights & Biases 等工具记录重要指标是个好主意。向训练循环中添加适当的日志记录,这样你可以随时检查训练进度。</p></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/zh-CN/chapter7/6.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_9aawfw = { | |
| assets: "/docs/course/pr_1021/zh-CN", | |
| base: "/docs/course/pr_1021/zh-CN", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1021/zh-CN/_app/immutable/entry/start.f3a1a511.js"), | |
| import("/docs/course/pr_1021/zh-CN/_app/immutable/entry/app.c39e37cf.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 58], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 132 kB
- Xet hash:
- 68cc53be4e2e502f93697c223f84f00b7e55fa16f09537e555410e8b87e4e751
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.