Buckets:

rtrm's picture
download
raw
195 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;抽取式问答问答&quot;,&quot;local&quot;:&quot;抽取式问答&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;准备数据&quot;,&quot;local&quot;:&quot;准备数据&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;SQuAD 数据集&quot;,&quot;local&quot;:&quot;SQuAD 数据集&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;处理训练数据&quot;,&quot;local&quot;:&quot;处理训练数据&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;处理验证数据&quot;,&quot;local&quot;:&quot;处理验证数据&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;使用 Trainer API 微调模型&quot;,&quot;local&quot;:&quot;使用 Trainer API 微调模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;使用 Keras 微调模型&quot;,&quot;local&quot;:&quot;使用 Keras 微调模型&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;后处理&quot;,&quot;local&quot;:&quot;后处理&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;微调模型&quot;,&quot;local&quot;:&quot;微调模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;自定义训练循环&quot;,&quot;local&quot;:&quot;自定义训练循环&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;为训练做准备&quot;,&quot;local&quot;:&quot;为训练做准备&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练循环&quot;,&quot;local&quot;:&quot;训练循环&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;使用微调模型&quot;,&quot;local&quot;:&quot;使用微调模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/course/pr_1021/zh-CN/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/entry/start.f3a1a511.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/scheduler.37c15a92.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/singletons.9bf55235.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/index.18351ede.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/paths.0ba10750.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/entry/app.c39e37cf.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/index.2bf4358c.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/nodes/0.dad18ce3.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/nodes/59.032738c7.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/Tip.363c041f.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/Youtube.1e50a667.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/CodeBlock.4e987730.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/DocNotebookDropdown.efc1fb7c.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/FrameworkSwitchCourse.8d4d4ab6.js">
<link rel="modulepreload" href="/docs/course/pr_1021/zh-CN/_app/immutable/chunks/getInferenceSnippets.ebf8be91.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;抽取式问答问答&quot;,&quot;local&quot;:&quot;抽取式问答&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;准备数据&quot;,&quot;local&quot;:&quot;准备数据&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;SQuAD 数据集&quot;,&quot;local&quot;:&quot;SQuAD 数据集&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;处理训练数据&quot;,&quot;local&quot;:&quot;处理训练数据&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;处理验证数据&quot;,&quot;local&quot;:&quot;处理验证数据&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;使用 Trainer API 微调模型&quot;,&quot;local&quot;:&quot;使用 Trainer API 微调模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;使用 Keras 微调模型&quot;,&quot;local&quot;:&quot;使用 Keras 微调模型&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;后处理&quot;,&quot;local&quot;:&quot;后处理&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;微调模型&quot;,&quot;local&quot;:&quot;微调模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;自定义训练循环&quot;,&quot;local&quot;:&quot;自定义训练循环&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;为训练做准备&quot;,&quot;local&quot;:&quot;为训练做准备&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;训练循环&quot;,&quot;local&quot;:&quot;训练循环&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;使用微调模型&quot;,&quot;local&quot;:&quot;使用微调模型&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="抽取式问答" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#抽取式问答"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>抽取式问答问答</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/chapter7/section7_pt.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/chapter7/section7_pt.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-17nre2f">现在我们来看看问答这个任务!这个任务有很多种类型,但我们在本节将要关注的是称为 <code>抽取式(extractive)</code> 问题回答的形式。会有一些问题和文档,其中答案就在文档段落之内。</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/ajPx5LwJD-I" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-fj1yq1">我们将使用 <a href="https://rajpurkar.github.io/SQuAD-explorer/" rel="nofollow">SQuAD 数据集</a> 微调一个 BERT 模型,其中包括群众工作者对一组维基百科文章提出的问题。以下是一个小的测试样例:</p> <iframe src="https://course-demos-bert-finetuned-squad.hf.space" frameborder="0" height="450" title="Gradio app" class="block dark:hidden container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe> <p data-svelte-h="svelte-1lpqctr">本节使用的代码已经上传到了 Hub。你可以在 <a href="https://huggingface.co/huggingface-course/bert-finetuned-squad?context=%F0%9F%A4%97+Transformers+is+backed+by+the+three+most+popular+deep+learning+libraries+%E2%80%94+Jax%2C+PyTorch+and+TensorFlow+%E2%80%94+with+a+seamless+integration+between+them.+It%27s+straightforward+to+train+your+models+with+one+before+loading+them+for+inference+with+the+other.&question=Which+deep+learning+libraries+back+%F0%9F%A4%97+Transformers%3F" rel="nofollow">这里</a> 找到它并尝试用它进行预测。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1s3rar2">💡 像 BERT 这样的纯编码器模型往往很擅长提取诸如 “谁发明了 Transformer 架构?”之类的事实性问题的答案。但在给出诸如 “为什么天空是蓝色的?” 之类的开放式问题时表现不佳。在这些更具挑战性的情况下,通常使用编码器-解码器模型如 T5 和 BART 来以类似于 <a href="https://chat.openai.com/course/chapter7/5" rel="nofollow">文本摘要</a> 的方式整合信息。如果你对这种 <code>生成式(generative)</code> 问答感兴趣,我们推荐你查看我们做的基于 <a href="https://huggingface.co/datasets/eli5" rel="nofollow">ELI5 数据集</a><a href="https://yjernite.github.io/lfqa.html" rel="nofollow">演示demo</a></p></div> <h2 class="relative group"><a id="准备数据" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#准备数据"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>准备数据</span></h2> <p data-svelte-h="svelte-3pmurx">作为抽取式问题回答的学术基准最常用的数据集是 <a href="https://rajpurkar.github.io/SQuAD-explorer/" rel="nofollow">SQuAD</a> ,所以我们在这里将使用它。还有一个更难的 <a href="https://huggingface.co/datasets/squad_v2" rel="nofollow">SQuAD v2</a> 基准,其中包含一些没有答案的问题。你也可以使用自己的数据集,只要你自己的数据集包含了 Context 列、问题列和答案列,应该也能够适用下面的步骤。</p> <h3 class="relative group"><a id="SQuAD 数据集" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#SQuAD 数据集"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>SQuAD 数据集</span></h3> <p data-svelte-h="svelte-1fzlnr5">像往常一样,我们可以使用 <code>load_dataset()</code> 在一行中下载和缓存数据集:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
raw_datasets = load_dataset(<span class="hljs-string">&quot;squad&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-w5a09o">我们可以查看这个 <code>raw_datasets</code> 对象来了解关于 SQuAD 数据集的更多信息:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->raw_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;id&#x27;</span>, <span class="hljs-string">&#x27;title&#x27;</span>, <span class="hljs-string">&#x27;context&#x27;</span>, <span class="hljs-string">&#x27;question&#x27;</span>, <span class="hljs-string">&#x27;answers&#x27;</span>],
num_rows: <span class="hljs-number">87599</span>
})
validation: Dataset({
features: [<span class="hljs-string">&#x27;id&#x27;</span>, <span class="hljs-string">&#x27;title&#x27;</span>, <span class="hljs-string">&#x27;context&#x27;</span>, <span class="hljs-string">&#x27;question&#x27;</span>, <span class="hljs-string">&#x27;answers&#x27;</span>],
num_rows: <span class="hljs-number">10570</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-147znmv">看起来我们的数据集拥有所需的 <code>context</code><code>question</code><code>answers</code> 字段,所以让我们打印训练集的第一个元素:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Context: &quot;</span>, raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">0</span>][<span class="hljs-string">&quot;context&quot;</span>])
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Question: &quot;</span>, raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">0</span>][<span class="hljs-string">&quot;question&quot;</span>])
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Answer: &quot;</span>, raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">0</span>][<span class="hljs-string">&quot;answers&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Context: <span class="hljs-string">&#x27;Architecturally, the school has a Catholic character. Atop the Main Building\&#x27;s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend &quot;Venite Ad Me Omnes&quot;. Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.&#x27;</span>
Question: <span class="hljs-string">&#x27;To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?&#x27;</span>
Answer: {<span class="hljs-string">&#x27;text&#x27;</span>: [<span class="hljs-string">&#x27;Saint Bernadette Soubirous&#x27;</span>], <span class="hljs-string">&#x27;answer_start&#x27;</span>: [<span class="hljs-number">515</span>]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10b1ues"><code>context</code><code>question</code> 字段的使用非常直观。 <code>answers</code> 字段相对复杂一些,因为它是一个字典,包含两个列表的字段。这是在评估时 <code>squad</code> 指标需要的格式;如果你使用的是你自己的数据,不需要将答案处理成完全相同的格式。 <code>text</code> 字段是非常明显的答案文本,而 <code>answer_start</code> 字段包含了 Context 中每个答案开始的索引。</p> <p data-svelte-h="svelte-f2gx7w">在训练过程中,只有一个可能的答案。我们也可以使用 <code>Dataset.filter()</code> 方法来进行检查:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->raw_datasets[<span class="hljs-string">&quot;train&quot;</span>].<span class="hljs-built_in">filter</span>(<span class="hljs-keyword">lambda</span> x: <span class="hljs-built_in">len</span>(x[<span class="hljs-string">&quot;answers&quot;</span>][<span class="hljs-string">&quot;text&quot;</span>]) != <span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Dataset({
features: [<span class="hljs-string">&#x27;id&#x27;</span>, <span class="hljs-string">&#x27;title&#x27;</span>, <span class="hljs-string">&#x27;context&#x27;</span>, <span class="hljs-string">&#x27;question&#x27;</span>, <span class="hljs-string">&#x27;answers&#x27;</span>],
num_rows: <span class="hljs-number">0</span>
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-csutur">然而,在评估过程中,每个样本可能有多个答案,这些答案可能相同或不同:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>][<span class="hljs-number">0</span>][<span class="hljs-string">&quot;answers&quot;</span>])
<span class="hljs-built_in">print</span>(raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>][<span class="hljs-number">2</span>][<span class="hljs-string">&quot;answers&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;text&#x27;</span>: [<span class="hljs-string">&#x27;Denver Broncos&#x27;</span>, <span class="hljs-string">&#x27;Denver Broncos&#x27;</span>, <span class="hljs-string">&#x27;Denver Broncos&#x27;</span>], <span class="hljs-string">&#x27;answer_start&#x27;</span>: [<span class="hljs-number">177</span>, <span class="hljs-number">177</span>, <span class="hljs-number">177</span>]}
{<span class="hljs-string">&#x27;text&#x27;</span>: [<span class="hljs-string">&#x27;Santa Clara, California&#x27;</span>, <span class="hljs-string">&quot;Levi&#x27;s Stadium&quot;</span>, <span class="hljs-string">&quot;Levi&#x27;s Stadium in the San Francisco Bay Area at Santa Clara, California.&quot;</span>], <span class="hljs-string">&#x27;answer_start&#x27;</span>: [<span class="hljs-number">403</span>, <span class="hljs-number">355</span>, <span class="hljs-number">355</span>]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sw0li3">我们不会深入探究评估的代码,因为所有的东西都将由🤗 Datasets metric 帮我们完成,但简单来说,一些问题可能有多个可能的答案,而该评估代码将把预测的答案与所有可接受的答案进行比较,并选择最佳分数。例如,让我们看一下索引为 2 的样本:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>][<span class="hljs-number">2</span>][<span class="hljs-string">&quot;context&quot;</span>])
<span class="hljs-built_in">print</span>(raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>][<span class="hljs-number">2</span>][<span class="hljs-string">&quot;question&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\&#x27;s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the &quot;golden anniversary&quot; with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as &quot;Super Bowl L&quot;), so that the logo could prominently feature the Arabic numerals 50.&#x27;</span>
<span class="hljs-string">&#x27;Where did Super Bowl 50 take place?&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4l8p3u">我们可以看到,答案的确可能是我们之前看到的三个可能选择 <code>[&#39;Denver Broncos&#39;, &#39;Denver Broncos&#39;, &#39;Denver Broncos&#39;]</code> 的之一。</p> <h3 class="relative group"><a id="处理训练数据" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#处理训练数据"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>处理训练数据</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/qgaM0weJHpA" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-15fd3xb">我们从预处理训练数据开始。最困难的部分将是生成问题答案的位置,即找到 Context 中对应答案 token 的起始和结束位置。</p> <p data-svelte-h="svelte-whz6e0">但我们不要急于求成。首先,我们需要使用 tokenizer 将输入中的文本转换为模型可以理解的 ID:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer
model_checkpoint = <span class="hljs-string">&quot;bert-base-cased&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10bozhx">如前所述,我们将对 BERT 模型进行微调,但你可以使用任何其他模型类型,只要它实现了快速 tokenizer 即可。你可以在 <a href="https://huggingface.co/transformers/#supported-frameworks" rel="nofollow">支持快速 tokenizer 的框架</a> 表中看到所有带有快速版本的架构,要检查你正在使用的 <code>tokenizer</code> 对象是否真的是由🤗 Tokenizers 支持的,你可以查看它的 <code>is_fast</code> 属性:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer.is_fast<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-literal">True</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-isih9e">我们可以将 question 和 context 一起传递给我们的 tokenizer 它会正确插入特殊 tokens 形成如下句子:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-selector-attr">[CLS]</span> question <span class="hljs-selector-attr">[SEP]</span> context <span class="hljs-selector-attr">[SEP]</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ok0ucp">让我们检查一下处理后的样本:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->context = raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">0</span>][<span class="hljs-string">&quot;context&quot;</span>]
question = raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">0</span>][<span class="hljs-string">&quot;question&quot;</span>]
inputs = tokenizer(question, context)
tokenizer.decode(inputs[<span class="hljs-string">&quot;input_ids&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, &#x27;</span>
<span class="hljs-string">&#x27;the school has a Catholic character. Atop the Main Building\&#x27;s gold dome is a golden statue of the Virgin &#x27;</span>
<span class="hljs-string">&#x27;Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms &#x27;</span>
<span class="hljs-string">&#x27;upraised with the legend &quot; Venite Ad Me Omnes &quot;. Next to the Main Building is the Basilica of the Sacred &#x27;</span>
<span class="hljs-string">&#x27;Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a &#x27;</span>
<span class="hljs-string">&#x27;replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette &#x27;</span>
<span class="hljs-string">&#x27;Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues &#x27;</span>
<span class="hljs-string">&#x27;and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dfn802">需要预测的是答案起始和结束 token 的索引,模型的任务是为输入中的每个标记预测一个起始和结束的 logit 值,理论上的预测的结果如下所示:</p> <div class="flex justify-center" data-svelte-h="svelte-cxbodn"><img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/qa_labels.svg" alt="One-hot encoded labels for question answering."> <img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/qa_labels-dark.svg" alt="One-hot encoded labels for question answering."></div> <p data-svelte-h="svelte-1n72rdl">在做个例子中,Context 没有很长,但是数据集中的一些示例的 Context 会很长,会超过我们设置的最大长度(本例中为 384)。正如我们在 <a href="/course/chapter6/4">第六章</a> 中所看到的,当我们探索 <code>question-answering</code> 管道的内部结构时,我们会通过将一个样本的较长的 Context 划分成多个片段,并在这些片段之间使用滑动窗口,来处理较长的 Context。</p> <p data-svelte-h="svelte-19o372i">要了解在这个过程中对当前的训练样本进行了哪些处理,我们可以将长度限制为 100,并使用长度为 50 的 token 窗口。我们将设置以下的参数:</p> <ul data-svelte-h="svelte-h3zbmt"><li><code>max_length</code> 来设置最大长度 (这里为 100)</li> <li><code>truncation=&quot;only_second&quot;</code> 在问题和 Context 过长时截断 Context(Context 位于第二个位置,第一个是 Question)</li> <li><code>stride</code> 设置两个连续块之间的重叠 tokens 数 (这里为 50)</li> <li><code>return_overflowing_tokens=True</code> 告诉 tokenizer 我们想要保留超过长度的 tokens</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->inputs = tokenizer(
question,
context,
max_length=<span class="hljs-number">100</span>,
truncation=<span class="hljs-string">&quot;only_second&quot;</span>,
stride=<span class="hljs-number">50</span>,
return_overflowing_tokens=<span class="hljs-literal">True</span>,
)
<span class="hljs-keyword">for</span> ids <span class="hljs-keyword">in</span> inputs[<span class="hljs-string">&quot;input_ids&quot;</span>]:
<span class="hljs-built_in">print</span>(tokenizer.decode(ids))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\&#x27;s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend &quot; Venite Ad Me Omnes &quot;. Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]&#x27;</span>
<span class="hljs-string">&#x27;[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend &quot; Venite Ad Me Omnes &quot;. Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]&#x27;</span>
<span class="hljs-string">&#x27;[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 [SEP]&#x27;</span>
<span class="hljs-string">&#x27;[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pbcg9l">如我们所见,示例文本被拆分成四个输入,每个输入都包含问题和 Context 的一部分。请注意,问题的答案 (“Bernadette Soubirous”) 仅出现在第三个和最后一个片段中,因此通过以这种方式处理较长的 Context 时,我们可能创建一些 Context 中不包含答案的训练样本。我们把这些样本的标签设置为 <code>start_position = end_position = 0</code> (这样的话,实际上我们的答案指向了 <code>[CLS]</code> tokens)。如果答案被截断,那么只在这一部分预测答案的起始(或结束)的token 的索引。对于答案完全在 Context 中的示例,标签将是答案起始的 token 的索引和答案结束的 token 的索引。</p> <p data-svelte-h="svelte-1ae3boe">数据集为我们提供了 Context 中答案的起始的位置索引,加上答案的长度,我们可以找到 Context 中的结束索引。要将它们映射到 tokens 索引,我们将需要使用我们在 <a href="/course/chapter6/4">第六章</a> 中学到的偏移映射。我们可以通过使用 <code>return_offsets_mapping=True</code>,让我们的 tokenizer 返回偏移后的映射:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->inputs = tokenizer(
question,
context,
max_length=<span class="hljs-number">100</span>,
truncation=<span class="hljs-string">&quot;only_second&quot;</span>,
stride=<span class="hljs-number">50</span>,
return_overflowing_tokens=<span class="hljs-literal">True</span>,
return_offsets_mapping=<span class="hljs-literal">True</span>,
)
inputs.keys()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->dict_keys([<span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;token_type_ids&#x27;</span>, <span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;offset_mapping&#x27;</span>, <span class="hljs-string">&#x27;overflow_to_sample_mapping&#x27;</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s5v9op">如我们所见,我们得到了 inputs ID、tokens 类型 ID 和注意力掩码,以及我们所需的偏移映射和一个额外的 <code>overflow_to_sample_mapping</code> 。当我们同时对多个文本并行 tokenize 时,为了从支持 Rust 中受益,这个键的值对我们很有用。由于一个长的样本可以切分为多个短的样本,它保存了这些短的样本是来自于哪个长的样本。因为这里我们只对一个样本进行了 tokenize,所以我们得到一个由 <code>0</code> 组成的列表:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->inputs[<span class="hljs-string">&quot;overflow_to_sample_mapping&quot;</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-py56t3">但是,如果我们对更多的示例进行 tokenize ,它会变得更加有用:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->inputs = tokenizer(
raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">2</span>:<span class="hljs-number">6</span>][<span class="hljs-string">&quot;question&quot;</span>],
raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">2</span>:<span class="hljs-number">6</span>][<span class="hljs-string">&quot;context&quot;</span>],
max_length=<span class="hljs-number">100</span>,
truncation=<span class="hljs-string">&quot;only_second&quot;</span>,
stride=<span class="hljs-number">50</span>,
return_overflowing_tokens=<span class="hljs-literal">True</span>,
return_offsets_mapping=<span class="hljs-literal">True</span>,
)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;The 4 examples gave <span class="hljs-subst">{<span class="hljs-built_in">len</span>(inputs[<span class="hljs-string">&#x27;input_ids&#x27;</span>])}</span> features.&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Here is where each comes from: <span class="hljs-subst">{inputs[<span class="hljs-string">&#x27;overflow_to_sample_mapping&#x27;</span>]}</span>.&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;The 4 examples gave 19 features.&#x27;</span>
<span class="hljs-string">&#x27;Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3].&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15cxksl">在我们的这个例子中,前三条数据 (在训练集中的索引 2、3 和 4 处) 每条数据被拆分为4个样本,最后一条数据(在训练集中的索引 5 处) 拆分为了5个样本。</p> <p data-svelte-h="svelte-pv7tg2">这些信息将有助于将我们拆分后的文本块映射到其相应的标签。如前所述,这些标签的规则是:</p> <ul data-svelte-h="svelte-17h1ibo"><li>如果答案不在相应上下文的范围内,则为 <code>(0, 0)</code></li> <li>如果答案在相应上下文的范围内,则为 <code>(start_position, end_position)</code> ,其中 <code>start_position</code> 是答案起始处的 token 索引(在 inputs ID 中), <code>end_position</code> 是答案结束处的 token 索引(在 inputs ID 中)</li></ul> <p data-svelte-h="svelte-nzwgtk">为了确定这两种情况中的哪一种,并且如果是第二种,则需要确定 token 的位置,我们首先找到在输入 ID 中起始和结束上下文的索引。我们首先找到拆分后的每一个部分在 <code>Context</code> 起始和结束的索引,可以使用 token 类型 ID 来完成此操作,但由于并非所有模型都支持这样的操作(如DistilBERT),因此可以使用 <code>tokenizer</code><code>sequence_ids()</code> 函数返回的 BatchEncoding 对象。</p> <p data-svelte-h="svelte-1o3rbgc">有了这些 tokens 的索引之后,我们就可以计算相应的偏移量了,它们是两个整数的元组,表示原始 Context 中的字符范围。因此,我们可以检测每个分块中的 Context 块是在答案之后起始还是在答案起始之前结束(在这种情况下,标签是 <code>(0, 0)</code> )。如果答案就在 Context 里,我们就循环查找答案的第一个和最后一个 token:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->answers = raw_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">2</span>:<span class="hljs-number">6</span>][<span class="hljs-string">&quot;answers&quot;</span>]
start_positions = []
end_positions = []
<span class="hljs-keyword">for</span> i, offset <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(inputs[<span class="hljs-string">&quot;offset_mapping&quot;</span>]):
sample_idx = inputs[<span class="hljs-string">&quot;overflow_to_sample_mapping&quot;</span>][i]
answer = answers[sample_idx]
start_char = answer[<span class="hljs-string">&quot;answer_start&quot;</span>][<span class="hljs-number">0</span>]
end_char = answer[<span class="hljs-string">&quot;answer_start&quot;</span>][<span class="hljs-number">0</span>] + <span class="hljs-built_in">len</span>(answer[<span class="hljs-string">&quot;text&quot;</span>][<span class="hljs-number">0</span>])
sequence_ids = inputs.sequence_ids(i)
<span class="hljs-comment"># 找到上下文的起始和结束</span>
idx = <span class="hljs-number">0</span>
<span class="hljs-keyword">while</span> sequence_ids[idx] != <span class="hljs-number">1</span>:
idx += <span class="hljs-number">1</span>
context_start = idx
<span class="hljs-keyword">while</span> sequence_ids[idx] == <span class="hljs-number">1</span>:
idx += <span class="hljs-number">1</span>
context_end = idx - <span class="hljs-number">1</span>
<span class="hljs-comment"># 如果答案不完全在上下文内,标签为(0, 0)</span>
<span class="hljs-keyword">if</span> offset[context_start][<span class="hljs-number">0</span>] &gt; start_char <span class="hljs-keyword">or</span> offset[context_end][<span class="hljs-number">1</span>] &lt; end_char:
start_positions.append(<span class="hljs-number">0</span>)
end_positions.append(<span class="hljs-number">0</span>)
<span class="hljs-keyword">else</span>:
<span class="hljs-comment"># 否则,它就是起始和结束 token 的位置</span>
idx = context_start
<span class="hljs-keyword">while</span> idx &lt;= context_end <span class="hljs-keyword">and</span> offset[idx][<span class="hljs-number">0</span>] &lt;= start_char:
idx += <span class="hljs-number">1</span>
start_positions.append(idx - <span class="hljs-number">1</span>)
idx = context_end
<span class="hljs-keyword">while</span> idx &gt;= context_start <span class="hljs-keyword">and</span> offset[idx][<span class="hljs-number">1</span>] &gt;= end_char:
idx -= <span class="hljs-number">1</span>
end_positions.append(idx + <span class="hljs-number">1</span>)
start_positions, end_positions<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->([<span class="hljs-number">83</span>, <span class="hljs-number">51</span>, <span class="hljs-number">19</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">64</span>, <span class="hljs-number">27</span>, <span class="hljs-number">0</span>, <span class="hljs-number">34</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">67</span>, <span class="hljs-number">34</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>],
[<span class="hljs-number">85</span>, <span class="hljs-number">53</span>, <span class="hljs-number">21</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">70</span>, <span class="hljs-number">33</span>, <span class="hljs-number">0</span>, <span class="hljs-number">40</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">68</span>, <span class="hljs-number">35</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-oybube">让我们查看一些结果来验证一下我们的方法是否正确。在拆分后的第一个部分的文本中,我们看到了 <code>(83, 85)</code> 是待预测的标签值,因此让我们将理论答案与从 83 到 85(包括 85)的 tokens 解码的结果进行比较:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->idx = <span class="hljs-number">0</span>
sample_idx = inputs[<span class="hljs-string">&quot;overflow_to_sample_mapping&quot;</span>][idx]
answer = answers[sample_idx][<span class="hljs-string">&quot;text&quot;</span>][<span class="hljs-number">0</span>]
start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs[<span class="hljs-string">&quot;input_ids&quot;</span>][idx][start : end + <span class="hljs-number">1</span>])
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Theoretical answer: <span class="hljs-subst">{answer}</span>, labels give: <span class="hljs-subst">{labeled_answer}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;Theoretical answer: the Main Building, labels give: the Main Building&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mib9oo">很好!寻找的答案是正确的!现在让我们来看一下拆分后的第4个文本块,我们我们得到的标签是 <code>(0, 0)</code> ,这意味着答案不在这个文本块中:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->idx = <span class="hljs-number">4</span>
sample_idx = inputs[<span class="hljs-string">&quot;overflow_to_sample_mapping&quot;</span>][idx]
answer = answers[sample_idx][<span class="hljs-string">&quot;text&quot;</span>][<span class="hljs-number">0</span>]
decoded_example = tokenizer.decode(inputs[<span class="hljs-string">&quot;input_ids&quot;</span>][idx])
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Theoretical answer: <span class="hljs-subst">{answer}</span>, decoded example: <span class="hljs-subst">{decoded_example}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;Theoretical answer: a Marian place of prayer and reflection, decoded example: [CLS] What is the Grotto at Notre Dame? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\&#x27;s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend &quot; Venite Ad Me Omnes &quot;. Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grot [SEP]&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1huhakg">确实,我们在 Context 中没有看到答案。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-o6ip24">✏️ <strong>轮你来了!</strong> 在使用 XLNet 架构时,如果截取后的文本长度没有达到设定的最大长度,需要在左侧进行填充,并且需要交互问题和 Context 的顺序。尝试将我们刚刚看到的所有代码调整为 XLNet 架构(并添加 <code>padding=True</code> )。请注意,因为是在左侧填充的,所以填充后的 <code>[CLS]</code> tokens 可能不在索引为 0 的位置。</p></div> <p data-svelte-h="svelte-147kkd5">现在,我们已经逐步了解了如何预处理我们的训练数据,接下来可以将其组合到一个函数中,并使用该函数处理整个训练数据集。我们将每个拆分后的样本都填充到我们设置的最大长度,因为大多数上下文都很长(相应的样本会被分割成几小块),所以在这里进行动态填充的所带来的增益不是很大。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->max_length = <span class="hljs-number">384</span>
stride = <span class="hljs-number">128</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_training_examples</span>(<span class="hljs-params">examples</span>):
questions = [q.strip() <span class="hljs-keyword">for</span> q <span class="hljs-keyword">in</span> examples[<span class="hljs-string">&quot;question&quot;</span>]]
inputs = tokenizer(
questions,
examples[<span class="hljs-string">&quot;context&quot;</span>],
max_length=max_length,
truncation=<span class="hljs-string">&quot;only_second&quot;</span>,
stride=stride,
return_overflowing_tokens=<span class="hljs-literal">True</span>,
return_offsets_mapping=<span class="hljs-literal">True</span>,
padding=<span class="hljs-string">&quot;max_length&quot;</span>,
)
offset_mapping = inputs.pop(<span class="hljs-string">&quot;offset_mapping&quot;</span>)
sample_map = inputs.pop(<span class="hljs-string">&quot;overflow_to_sample_mapping&quot;</span>)
answers = examples[<span class="hljs-string">&quot;answers&quot;</span>]
start_positions = []
end_positions = []
<span class="hljs-keyword">for</span> i, offset <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(offset_mapping):
sample_idx = sample_map[i]
answer = answers[sample_idx]
start_char = answer[<span class="hljs-string">&quot;answer_start&quot;</span>][<span class="hljs-number">0</span>]
end_char = answer[<span class="hljs-string">&quot;answer_start&quot;</span>][<span class="hljs-number">0</span>] + <span class="hljs-built_in">len</span>(answer[<span class="hljs-string">&quot;text&quot;</span>][<span class="hljs-number">0</span>])
sequence_ids = inputs.sequence_ids(i)
<span class="hljs-comment"># 找到上下文的起始和结束</span>
idx = <span class="hljs-number">0</span>
<span class="hljs-keyword">while</span> sequence_ids[idx] != <span class="hljs-number">1</span>:
idx += <span class="hljs-number">1</span>
context_start = idx
<span class="hljs-keyword">while</span> sequence_ids[idx] == <span class="hljs-number">1</span>:
idx += <span class="hljs-number">1</span>
context_end = idx - <span class="hljs-number">1</span>
<span class="hljs-comment"># 如果答案不完全在上下文内,标签为(0, 0)</span>
<span class="hljs-keyword">if</span> offset[context_start][<span class="hljs-number">0</span>] &gt; start_char <span class="hljs-keyword">or</span> offset[context_end][<span class="hljs-number">1</span>] &lt; end_char:
start_positions.append(<span class="hljs-number">0</span>)
end_positions.append(<span class="hljs-number">0</span>)
<span class="hljs-keyword">else</span>:
<span class="hljs-comment"># 否则,它就是起始和结束 tokens 的位置</span>
idx = context_start
<span class="hljs-keyword">while</span> idx &lt;= context_end <span class="hljs-keyword">and</span> offset[idx][<span class="hljs-number">0</span>] &lt;= start_char:
idx += <span class="hljs-number">1</span>
start_positions.append(idx - <span class="hljs-number">1</span>)
idx = context_end
<span class="hljs-keyword">while</span> idx &gt;= context_start <span class="hljs-keyword">and</span> offset[idx][<span class="hljs-number">1</span>] &gt;= end_char:
idx -= <span class="hljs-number">1</span>
end_positions.append(idx + <span class="hljs-number">1</span>)
inputs[<span class="hljs-string">&quot;start_positions&quot;</span>] = start_positions
inputs[<span class="hljs-string">&quot;end_positions&quot;</span>] = end_positions
<span class="hljs-keyword">return</span> inputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-sm3cnw">请注意,我们定义了两个常量来确定所使用的最大长度以及滑动窗口的长度,并且在之前 tokenize 之前对数据进行了一些清洗:SQuAD 数据集中的一些问题在开头和结尾有额外的空格,这些空格没有任何意义(如果你使用像 RoBERTa 这样的模型,它们会占用 tokenize 的长度),因此我们去掉了这些额外的空格。</p> <p data-svelte-h="svelte-jwligw">要使用该函数处理整个训练集,我们可以使用 <code>Dataset.map()</code> 方法并设置 <code>batched=True</code> 参数。这是必要的,因为我们正在更改数据集的长度(因为一个样本可能会产生多个子样本):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->train_dataset = raw_datasets[<span class="hljs-string">&quot;train&quot;</span>].<span class="hljs-built_in">map</span>(
preprocess_training_examples,
batched=<span class="hljs-literal">True</span>,
remove_columns=raw_datasets[<span class="hljs-string">&quot;train&quot;</span>].column_names,
)
<span class="hljs-built_in">len</span>(raw_datasets[<span class="hljs-string">&quot;train&quot;</span>]), <span class="hljs-built_in">len</span>(train_dataset)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->(<span class="hljs-number">87599</span>, <span class="hljs-number">88729</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1wsc9tv">如我们所见,预处理增加了大约 1000 个样本。我们的训练集现在已经准备好使用了——让我们深入研究一下验证集的预处理!</p> <h3 class="relative group"><a id="处理验证数据" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#处理验证数据"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>处理验证数据</span></h3> <p data-svelte-h="svelte-vgiq74">验证集的预处理会更加容易,因为我们不需要生成标签(除非我们想计算验证损失,但那个数字并不能真正帮助我们了解模型的好坏,如果要评估模型更好的方式使用我们之前提到的<code>squad</code> 指标)。真正的挑战在于将模型的预测转化为为原始 Context 的片段。为此,我们只需要存储偏移映射并且找到一种方法来将每个分割后的样本与分割前的原始片段匹配起来。由于原始数据集中有一个 ID 列,我们可以使用ID来代表原始的片段。</p> <p data-svelte-h="svelte-1xg5guh">我们唯一需要做的是对偏移映射进行一些微小修改。偏移映射包含问题和 Context 的偏移量(问题的偏移量是0,Context 是1),但当我们进入后处理阶段,我们将无法知道 inputs ID 的哪个部分对应于 Context,哪个部分是问题(我们使用的 <code>sequence_ids()</code> 方法仅可用于 tokenizer 的输出)。因此,我们将将与问题对应的偏移设置为 <code>None</code> Context 对应的偏移量保持不变:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_validation_examples</span>(<span class="hljs-params">examples</span>):
questions = [q.strip() <span class="hljs-keyword">for</span> q <span class="hljs-keyword">in</span> examples[<span class="hljs-string">&quot;question&quot;</span>]]
inputs = tokenizer(
questions,
examples[<span class="hljs-string">&quot;context&quot;</span>],
max_length=max_length,
truncation=<span class="hljs-string">&quot;only_second&quot;</span>,
stride=stride,
return_overflowing_tokens=<span class="hljs-literal">True</span>,
return_offsets_mapping=<span class="hljs-literal">True</span>,
padding=<span class="hljs-string">&quot;max_length&quot;</span>,
)
sample_map = inputs.pop(<span class="hljs-string">&quot;overflow_to_sample_mapping&quot;</span>)
example_ids = []
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(inputs[<span class="hljs-string">&quot;input_ids&quot;</span>])):
sample_idx = sample_map[i]
example_ids.append(examples[<span class="hljs-string">&quot;id&quot;</span>][sample_idx])
sequence_ids = inputs.sequence_ids(i)
offset = inputs[<span class="hljs-string">&quot;offset_mapping&quot;</span>][i]
inputs[<span class="hljs-string">&quot;offset_mapping&quot;</span>][i] = [
o <span class="hljs-keyword">if</span> sequence_ids[k] == <span class="hljs-number">1</span> <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">for</span> k, o <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(offset)
]
inputs[<span class="hljs-string">&quot;example_id&quot;</span>] = example_ids
<span class="hljs-keyword">return</span> inputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-unawa0">我们可以像处理训练集一样使用此函数处理整个验证数据集:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->validation_dataset = raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>].<span class="hljs-built_in">map</span>(
preprocess_validation_examples,
batched=<span class="hljs-literal">True</span>,
remove_columns=raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>].column_names,
)
<span class="hljs-built_in">len</span>(raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>]), <span class="hljs-built_in">len</span>(validation_dataset)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->(<span class="hljs-number">10570</span>, <span class="hljs-number">10822</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xk9f54">从最终的结果来看,我们只添加了几百个样本,因此验证数据集中的 Context 似乎要短一些。</p> <p data-svelte-h="svelte-1i50qrr">现在我们已经对所有数据进行了预处理,我们可以开始训练了。</p> <h2 class="relative group"><a id="使用 Trainer API 微调模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#使用 Trainer API 微调模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>使用 Trainer API 微调模型</span></h2> <p data-svelte-h="svelte-1lfzb1t">这个例子的训练代码与前面的部分非常相似,最困难的部分是编写 <code>compute_metrics()</code> 评估指标函数。由于我们将所有样本填充到了我们设置的最大长度,所以没有需要定义的数据整理器,因此我们唯一需要担心的事情是如何计算评估指标。比较困难的部分将是将模型预测的结果还原到原始示例中的文本片段;一旦我们完成了这一步骤,🤗 Datasets 库中的 metric 就可以帮助我们做大部分工作。</p> <h3 class="relative group"><a id="后处理" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#后处理"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>后处理</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/BNy08iIWVJM" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1md09k8">模型将输出答案在 inputs ID 中起始和结束位置的 logit,正如我们在探索 <a href="/course/chapter6/3b"><code>question-answering</code> pipeline</a> 时看到的那样。后处理步骤与我们在那里所做的很相似,所以这里简单回顾一下我们所采取的操作:</p> <ul data-svelte-h="svelte-s2nj44"><li>我们屏蔽了除了 Context 之外的 tokens 对应的起始和结束 logit。</li> <li>然后,我们使用 softmax 将起始和结束 logits 转换为概率。</li> <li>我们通过将两个概率对应的乘积来为每个 <code>(start_token, end_token)</code> 对计算一个分数。</li> <li>我们寻找具有最大分数且产生有效答案(例如, <code>start_token</code> 小于 <code>end_token</code> )的对。</li></ul> <p data-svelte-h="svelte-11icqdr">这次我们将稍微改变这个流程,因为我们不需要计算实际分数(只需要预测的答案的文本)。这意味着我们可以跳过 softmax 步骤(因为 softmax 并不会改变分数大小的排序)。为了加快计算速度,我们也不会为所有可能的 <code>(start_token, end_token)</code> 对计算分数,而只会计算与最高的 <code>n_best</code> 对应的 logit 分数(其中 <code>n_best=20</code> )。由于我们将跳过 softmax,这些分数将是 logit 分数,而且是起始和结束对数概率的和(而不是乘积,因为对数运算规则: ($\log(ab) = \log(a) + \log(b))$。</p> <p data-svelte-h="svelte-3r3ncy">为了验证猜想,我们需要一些预测。由于我们还没有训练我们的模型,我们将使用 QA 管道的默认模型对一小部分验证集生成一些预测。我们可以使用和之前一样的处理函数;因为它依赖于全局常量 <code>tokenizer</code> ,我们只需将该对象更改为我们要临时使用的模型的 tokenizer</p> <p data-svelte-h="svelte-1su9sro">为了测试这些代码,我们需要一些预测结果。由于我们还没有训练模型,我们将使用 QA pipeline 的默认模型在验证集的一小部分上生成一些预测结果。我们可以使用与之前相同的处理函数;因为它依赖于全局常量 <code>tokenizer</code> ,所以只需将其更改为这次临时使用的模型对应的 <code>tokenizer</code> 即可。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->small_eval_set = raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>].select(<span class="hljs-built_in">range</span>(<span class="hljs-number">100</span>))
trained_checkpoint = <span class="hljs-string">&quot;distilbert-base-cased-distilled-squad&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.<span class="hljs-built_in">map</span>(
preprocess_validation_examples,
batched=<span class="hljs-literal">True</span>,
remove_columns=raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>].column_names,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1n19di9">现在预处理已经完成,我们将 <code>tokenizer</code> 改回我们最初选择的 <code>tokenizer</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yg0nuj">然后我们移除 <code>eval_set</code> 中模型不需要的列,构建一个包含所有小型验证集数据的 batch,并将其传递给模型。如果有可用的 GPU,我们将使用 GPU 以加快计算:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForQuestionAnswering
eval_set_for_model = eval_set.remove_columns([<span class="hljs-string">&quot;example_id&quot;</span>, <span class="hljs-string">&quot;offset_mapping&quot;</span>])
eval_set_for_model.set_format(<span class="hljs-string">&quot;torch&quot;</span>)
device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
batch = {k: eval_set_for_model[k].to(device) <span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
device
)
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = trained_model(**batch)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qvg8hb">为了便于实验,让我们将这些输出转换为 NumPy 数组:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-wv9v8x">现在,我们需要找到 <code>small_eval_set</code> 中每个样本的预测答案。一个样本可能会被拆分成 <code>eval_set</code> 中的多个子样本,所以第一步是将 <code>small_eval_set</code> 中的每个样本映射到 <code>eval_set</code> 中对应的子样本:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> collections
example_to_features = collections.defaultdict(<span class="hljs-built_in">list</span>)
<span class="hljs-keyword">for</span> idx, feature <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(eval_set):
example_to_features[feature[<span class="hljs-string">&quot;example_id&quot;</span>]].append(idx)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1osctm8">有了这个映射,我们可以通过循环遍历所有样本,并遍历每个样本的所有子样本。正如之前所说,我们将查看 <code>n_best</code> 个起始 logit 和结束 logit 的得分,排除以下情况:</p> <ul data-svelte-h="svelte-4gflxf"><li>答案不在上下文中</li> <li>答案长度为负数</li> <li>答案过长(我们将长度限制为 <code>max_answer_length=30</code></li></ul> <p data-svelte-h="svelte-qna47y">当我们得到一个样本的所有得分可能答案,我们只需选择具有最佳 logit 得分的答案:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
n_best = <span class="hljs-number">20</span>
max_answer_length = <span class="hljs-number">30</span>
predicted_answers = []
<span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> small_eval_set:
example_id = example[<span class="hljs-string">&quot;id&quot;</span>]
context = example[<span class="hljs-string">&quot;context&quot;</span>]
answers = []
<span class="hljs-keyword">for</span> feature_index <span class="hljs-keyword">in</span> example_to_features[example_id]:
start_logit = start_logits[feature_index]
end_logit = end_logits[feature_index]
offsets = eval_set[<span class="hljs-string">&quot;offset_mapping&quot;</span>][feature_index]
start_indexes = np.argsort(start_logit)[-<span class="hljs-number">1</span> : -n_best - <span class="hljs-number">1</span> : -<span class="hljs-number">1</span>].tolist()
end_indexes = np.argsort(end_logit)[-<span class="hljs-number">1</span> : -n_best - <span class="hljs-number">1</span> : -<span class="hljs-number">1</span>].tolist()
<span class="hljs-keyword">for</span> start_index <span class="hljs-keyword">in</span> start_indexes:
<span class="hljs-keyword">for</span> end_index <span class="hljs-keyword">in</span> end_indexes:
<span class="hljs-comment"># 跳过不完全在上下文中的答案</span>
<span class="hljs-keyword">if</span> offsets[start_index] <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">or</span> offsets[end_index] <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:
<span class="hljs-keyword">continue</span>
<span class="hljs-comment"># 跳过长度为负数或大于 max_answer_length 的答案。</span>
<span class="hljs-keyword">if</span> (
end_index &lt; start_index
<span class="hljs-keyword">or</span> end_index - start_index + <span class="hljs-number">1</span> &gt; max_answer_length
):
<span class="hljs-keyword">continue</span>
answers.append(
{
<span class="hljs-string">&quot;text&quot;</span>: context[offsets[start_index][<span class="hljs-number">0</span>] : offsets[end_index][<span class="hljs-number">1</span>]],
<span class="hljs-string">&quot;logit_score&quot;</span>: start_logit[start_index] + end_logit[end_index],
}
)
best_answer = <span class="hljs-built_in">max</span>(answers, key=<span class="hljs-keyword">lambda</span> x: x[<span class="hljs-string">&quot;logit_score&quot;</span>])
predicted_answers.append({<span class="hljs-string">&quot;id&quot;</span>: example_id, <span class="hljs-string">&quot;prediction_text&quot;</span>: best_answer[<span class="hljs-string">&quot;text&quot;</span>]})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15fxtrp">完成上述处理后,预测答案就变成了我们将使用的评估指标所要求的输入的格式,在这种情况下可以借助🤗 Evaluate 库来加载它。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate
metric = evaluate.load(<span class="hljs-string">&quot;squad&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-q37ipu">这个评估指标一个如上所示格式(一个包含示例 ID 和预测文本的字典列表)的预测答案,同时也需要一个如下格式(一个包含示例 ID 和可能答案的字典列表)的参考答案:</p> <p data-svelte-h="svelte-lurpes">该评估指标需要一个由样本 ID 和预测文本字典的列表组成预测答案,同时也需要一个由参考ID 和可能答案字典的列表组成参考答案。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->theoretical_answers = [
{<span class="hljs-string">&quot;id&quot;</span>: ex[<span class="hljs-string">&quot;id&quot;</span>], <span class="hljs-string">&quot;answers&quot;</span>: ex[<span class="hljs-string">&quot;answers&quot;</span>]} <span class="hljs-keyword">for</span> ex <span class="hljs-keyword">in</span> small_eval_set
]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-iyw4yz">现在,我们可以通过查看两个列表中的第一个元素来检查是否符合评估指标的要求:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(predicted_answers[<span class="hljs-number">0</span>])
<span class="hljs-built_in">print</span>(theoretical_answers[<span class="hljs-number">0</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;id&#x27;</span>: <span class="hljs-string">&#x27;56be4db0acb8001400a502ec&#x27;</span>, <span class="hljs-string">&#x27;prediction_text&#x27;</span>: <span class="hljs-string">&#x27;Denver Broncos&#x27;</span>}
{<span class="hljs-string">&#x27;id&#x27;</span>: <span class="hljs-string">&#x27;56be4db0acb8001400a502ec&#x27;</span>, <span class="hljs-string">&#x27;answers&#x27;</span>: {<span class="hljs-string">&#x27;text&#x27;</span>: [<span class="hljs-string">&#x27;Denver Broncos&#x27;</span>, <span class="hljs-string">&#x27;Denver Broncos&#x27;</span>, <span class="hljs-string">&#x27;Denver Broncos&#x27;</span>], <span class="hljs-string">&#x27;answer_start&#x27;</span>: [<span class="hljs-number">177</span>, <span class="hljs-number">177</span>, <span class="hljs-number">177</span>]}}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fi3r0u">还不错!现在让我们看一下评估指标给出的分数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->metric.compute(predictions=predicted_answers, references=theoretical_answers)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;exact_match&#x27;</span>: <span class="hljs-number">83.0</span>, <span class="hljs-string">&#x27;f1&#x27;</span>: <span class="hljs-number">88.25</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9anhfa">根据 <a href="https://arxiv.org/abs/1910.01108v2" rel="nofollow">DistilBERT 的论文</a> 所述,DistilBERT 在 SQuAD 上微调后整体数据集的得分为 79.1 和 86.9,相比之下我们取得的结果相当不错。</p> <p data-svelte-h="svelte-l7z453">现在,让我们将刚才所做的放入 <code>compute_metrics()</code> 函数中,就可以在 <code>Trainer</code> 中使用它了。通常, <code>compute_metrics()</code> 函数只接收一个包含 logits 和带预测标签组成的 <code>eval_preds</code> 元组。但是在这里,我们需要更多的信息才能评估结果,因为我们需要在分割后的数据集中查找偏移量,并在原始数据集中查找原始 Context,因此我们无法在训练过程中使用此函数来获取常规的评估结果。我们只会在训练结束时使用它来检查训练的结果。
<code>compute_metrics()</code> 函数与之前的步骤相同;我们只是添加了一个小的检查,以防我们找不到任何有效的答案(在这种情况下,我们的预测会输出一个空字符串)。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">start_logits, end_logits, features, examples</span>):
example_to_features = collections.defaultdict(<span class="hljs-built_in">list</span>)
<span class="hljs-keyword">for</span> idx, feature <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(features):
example_to_features[feature[<span class="hljs-string">&quot;example_id&quot;</span>]].append(idx)
predicted_answers = []
<span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> tqdm(examples):
example_id = example[<span class="hljs-string">&quot;id&quot;</span>]
context = example[<span class="hljs-string">&quot;context&quot;</span>]
answers = []
<span class="hljs-comment"># 循环遍历与该示例相关联的所有特征</span>
<span class="hljs-keyword">for</span> feature_index <span class="hljs-keyword">in</span> example_to_features[example_id]:
start_logit = start_logits[feature_index]
end_logit = end_logits[feature_index]
offsets = features[feature_index][<span class="hljs-string">&quot;offset_mapping&quot;</span>]
start_indexes = np.argsort(start_logit)[-<span class="hljs-number">1</span> : -n_best - <span class="hljs-number">1</span> : -<span class="hljs-number">1</span>].tolist()
end_indexes = np.argsort(end_logit)[-<span class="hljs-number">1</span> : -n_best - <span class="hljs-number">1</span> : -<span class="hljs-number">1</span>].tolist()
<span class="hljs-keyword">for</span> start_index <span class="hljs-keyword">in</span> start_indexes:
<span class="hljs-keyword">for</span> end_index <span class="hljs-keyword">in</span> end_indexes:
<span class="hljs-comment"># 跳过不完全位于上下文中的答案</span>
<span class="hljs-keyword">if</span> offsets[start_index] <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">or</span> offsets[end_index] <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:
<span class="hljs-keyword">continue</span>
<span class="hljs-comment"># 跳过长度小于 0 或大于 max_answer_length 的答案</span>
<span class="hljs-keyword">if</span> (
end_index &lt; start_index
<span class="hljs-keyword">or</span> end_index - start_index + <span class="hljs-number">1</span> &gt; max_answer_length
):
<span class="hljs-keyword">continue</span>
answer = {
<span class="hljs-string">&quot;text&quot;</span>: context[offsets[start_index][<span class="hljs-number">0</span>] : offsets[end_index][<span class="hljs-number">1</span>]],
<span class="hljs-string">&quot;logit_score&quot;</span>: start_logit[start_index] + end_logit[end_index],
}
answers.append(answer)
<span class="hljs-comment"># 选择得分最高的答案</span>
<span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(answers) &gt; <span class="hljs-number">0</span>:
best_answer = <span class="hljs-built_in">max</span>(answers, key=<span class="hljs-keyword">lambda</span> x: x[<span class="hljs-string">&quot;logit_score&quot;</span>])
predicted_answers.append(
{<span class="hljs-string">&quot;id&quot;</span>: example_id, <span class="hljs-string">&quot;prediction_text&quot;</span>: best_answer[<span class="hljs-string">&quot;text&quot;</span>]}
)
<span class="hljs-keyword">else</span>:
predicted_answers.append({<span class="hljs-string">&quot;id&quot;</span>: example_id, <span class="hljs-string">&quot;prediction_text&quot;</span>: <span class="hljs-string">&quot;&quot;</span>})
theoretical_answers = [{<span class="hljs-string">&quot;id&quot;</span>: ex[<span class="hljs-string">&quot;id&quot;</span>], <span class="hljs-string">&quot;answers&quot;</span>: ex[<span class="hljs-string">&quot;answers&quot;</span>]} <span class="hljs-keyword">for</span> ex <span class="hljs-keyword">in</span> examples]
<span class="hljs-keyword">return</span> metric.compute(predictions=predicted_answers, references=theoretical_answers)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1cugmqy">我们可以评估我们模型在评估数据集输出的结果:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->compute_metrics(start_logits, end_logits, eval_set, small_eval_set)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;exact_match&#x27;</span>: <span class="hljs-number">83.0</span>, <span class="hljs-string">&#x27;f1&#x27;</span>: <span class="hljs-number">88.25</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gjqe5r">看起来不错!现在让我们使用它来微调我们的模型。</p> <h3 class="relative group"><a id="微调模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#微调模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>微调模型</span></h3> <p data-svelte-h="svelte-1u66doy">现在我们已经准备好训练我们的模型了。首先,让我们像之前一样使用 <code>AutoModelForQuestionAnswering</code> 类创建模型:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-wtml0r">和之前一样,我们会收到一个警告,提示有些权重没有被使用(来自预训练头部的权重),而其他一些权重是随机初始化的(用于问答头部的权重)。你现在应该已经习惯了这种情况,但这意味着这个模型还没有准备好使用,需要进行微调——好在这正是我们接下来要做的事情!</p> <p data-svelte-h="svelte-evk12l">为了能够将我们的模型推送到 Hub,我们需要登录 Hugging Face。如果你在 Notebook 中运行此代码,则可以使用以下的函数执行此操作,该函数会显示一个小部件,你可以在其中输入登录凭据进行登陆:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xgqbsi">如果你不在 Notebook 中工作,只需在终端中输入以下行:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-cli login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ej8xuu">完成后,我们就可以定义我们的 <code>TrainingArguments</code> 。正如我们在定义计算评估函数时所说的,由于 <code>compute_metrics()</code> 函数的输入参数限制,我们无法使用常规的方法来编写评估循环。不过,我们可以编写自己的 <code>Trainer</code> 子类来实现这一点(你可以在 <a href="https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/trainer_qa.py" rel="nofollow">问答示例代码</a> 中找到该方法),但放在本节中会有些冗长。因此,我们在这里将仅在训练结束时评估模型,并在下面的“自定义训练循环”中向你展示如何使用常规的方法进行评估。</p> <p data-svelte-h="svelte-1r5rzrn">这确实是 <code>Trainer</code> API 局限性的地方,而🤗 Accelerate 库则非常适合处理这种情况:定制化特定用例的类可能会很繁琐,但定制化调整训练循环却很简单。</p> <p data-svelte-h="svelte-1f1cwm7">让我们来看看我们的 <code>TrainingArguments</code></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
args = TrainingArguments(
<span class="hljs-string">&quot;bert-finetuned-squad&quot;</span>,
evaluation_strategy=<span class="hljs-string">&quot;no&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
num_train_epochs=<span class="hljs-number">3</span>,
weight_decay=<span class="hljs-number">0.01</span>,
fp16=<span class="hljs-literal">True</span>,
push_to_hub=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ktnidl">我们之前已经见过其中大部分内容:我们设置了一些超参数(如学习率、训练的周期数和一些权重衰减),并设定我们想在每个周期结束时保存模型、跳过评估,并将结果上传到模型中心。我们还启用了 <code>fp16=True</code> 的混合精度训练,因为它可以在最新的 GPU 上加快训练速度。</p> <p data-svelte-h="svelte-66ivgf">默认情况下,使用的仓库将保存在你的账户中,并以你设置的输出目录命名,所以在我们的例子中它将位于 <code>&quot;sgugger/bert-finetuned-squad&quot;</code> 中。我们可以通过传递一个 <code>hub_model_id</code> 参数来覆盖这个设置;例如,要将模型推送到我们使用的 <code>huggingface_course</code> 组织中,我们使用了 <code>hub_model_id=&quot;huggingface_course/bert-finetuned-squad&quot;</code> (这是我们在本节开始时演示的模型)。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1rlvdw5">💡 如果你正在使用的输出目录已经存在一个同名的文件,则它需要是你要推送到的存储库克隆在本地的版本(因此,如果在定义你的 <code>Trainer</code> 时出现错误,请设置一个新的名称)。</p></div> <p data-svelte-h="svelte-1jvw8yp">最后,我们只需将所有内容传递给 <code>Trainer</code> 类并启动训练:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
tokenizer=tokenizer,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qrpkyl">请注意,在训练过程中,每次模型保存(例如,每个 epoch 结束时),模型都会在后台上传到 Hub。这样,如果需要的话,你就可以在另一台机器上恢复训练。整个训练过程需要一些时间(在 Titan RTX 上略超过一个小时),所以你可以喝杯咖啡或者重新阅读一些你觉得更具挑战性的课程部分来消磨时间。还要注意,在第一个 epoch 完成后,你可以看到一些权重上传到 Hub,并且你可以在其页面上开始使用你的模型进行测试。</p> <p data-svelte-h="svelte-mf8pxg">训练完成后,我们就可以评估我们最终的模型了(并祈祷我们可以一次成功)。 <code>Trainer</code><code>predict()</code> 方法将返回一个元组,其中第一个元素将是模型的预测结果(在这里是一个包含起始和结束 logits 的数值对)。我们将这个结果传递给我们的 <code>compute_metrics()</code> 函数:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;exact_match&#x27;</span>: <span class="hljs-number">81.18259224219489</span>, <span class="hljs-string">&#x27;f1&#x27;</span>: <span class="hljs-number">88.67381321905516</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1miqnxd">很棒!作为对比,BERT 文章中报告的该模型的基准分数分别为 80.8 和 88.5,所以我们的结果正好达到了预期分数。</p> <p data-svelte-h="svelte-1o94ozx">最后,我们使用 <code>push_to_hub()</code> 方法确保上传模型的最新版本:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.push_to_hub(commit_message=<span class="hljs-string">&quot;Training complete&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-tvkzhk">如果你想检查它,上面的代码返回它刚刚执行的提交的 URL:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;https://huggingface.co/sgugger/bert-finetuned-squad/commit/9dcee1fbc25946a6ed4bb32efb1bd71d5fa90b68&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-eijuu7"><code>Trainer</code> 还会创建一个包含所有评估结果的模型卡片,并将其上传。</p> <p data-svelte-h="svelte-l4fjyh">在这个阶段,你可以使用模型库中的推理小部件来测试模型,并与你的朋友、家人和同伴分享。恭喜你成功地在问答任务上对模型进行了微调!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1agsdtz">✏️ <strong>轮到你了!</strong> 尝试使用另一个模型架构,看看它在这个任务上表现得是否更好!</p></div> <p data-svelte-h="svelte-1o8bvx4">如果你想更深入地了解训练循环,我们现在将向你展示如何使用 🤗 Accelerate 来做同样的事情。</p> <h2 class="relative group"><a id="自定义训练循环" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#自定义训练循环"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>自定义训练循环</span></h2> <p data-svelte-h="svelte-gpqxca">现在,让我们来看一下完整的训练循环,这样你就可以轻松地自定义所需的部分。它看起来很像 <a href="https://www.hubchat.top/course/chapter3/4" rel="nofollow">第三章</a> 中的训练循环,只是评估的过程有所不同。由于我们不再受 <code>Trainer</code> 类的限制,因此我们可以在模型训练的过程中定期评估模型。</p> <h3 class="relative group"><a id="为训练做准备" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#为训练做准备"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>为训练做准备</span></h3> <p data-svelte-h="svelte-goanit">首先,我们需要使用数据集构建 <code>DataLoader</code> 。我们将这些数据集的格式设置为 <code>&quot;torch&quot;</code> ,并删除模型不使用的验证集的列。然后,我们可以使用 Transformers 提供的 <code>default_data_collator</code> 作为 <code>collate_fn</code> ,并打乱训练集,但不打乱验证集:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> default_data_collator
train_dataset.set_format(<span class="hljs-string">&quot;torch&quot;</span>)
validation_set = validation_dataset.remove_columns([<span class="hljs-string">&quot;example_id&quot;</span>, <span class="hljs-string">&quot;offset_mapping&quot;</span>])
validation_set.set_format(<span class="hljs-string">&quot;torch&quot;</span>)
train_dataloader = DataLoader(
train_dataset,
shuffle=<span class="hljs-literal">True</span>,
collate_fn=default_data_collator,
batch_size=<span class="hljs-number">8</span>,
)
eval_dataloader = DataLoader(
validation_set, collate_fn=default_data_collator, batch_size=<span class="hljs-number">8</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jnp5ix">接下来,我们重新实例化我们的模型,以确保我们不是从上面的微调继续训练,而是从原始的 BERT 预训练模型重新开始训练:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ds4q0b">然后,我们需要一个优化器。通常我们使用经典的 <code>AdamW</code> 优化器,它与 Adam 类似,不过在权重衰减的方式上有些不同:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">2e-5</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-h5ww4q">当我们拥有了所有这些对象,我们可以将它们发送到 <code>accelerator.prepare()</code> 方法。请记住,如果你想在 Colab Notebook 上使用 TPU 进行训练,你需要将所有这些代码移到一个训练函数中,不要在 Colab Notebook 的单元格中直接实例化 <code>Accelerator</code> 对象。这是因为在 TPU 环境下,直接在单元格中实例化可能会导致资源分配和初始化的问题。此外我们还可以通过向 <code>Accelerator</code> 传递 <code>fp16=True</code> 来强制使用混合精度训练(或者,如果你想要将代码作为脚本执行,只需确保填写正确的🤗 Accelerate <code>config</code> )。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
accelerator = Accelerator(fp16=<span class="hljs-literal">True</span>)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zwgz4i">从前面几节中你应该知道,我们只有在 <code>train_dataloader</code> 通过 <code>accelerator.prepare()</code> 方法后才能使用其长度来计算训练步骤的数量。我们使用与之前章节相同的线性学习率调度:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler
num_train_epochs = <span class="hljs-number">3</span>
num_update_steps_per_epoch = <span class="hljs-built_in">len</span>(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qn5vsb">要将模型推送到 Hub,我们需要在工作文件夹中创建一个 <code>Repository</code> 对象。如果你尚未登录 Hugging Face Hub,请先登录。我们将根据我们给模型指定的模型 ID 确定仓库名称(可以根据自己的选择替换 <code>repo_name</code> ;只需要包含你的用户名即可,用户名可以使用 <code>get_full_repo_name()</code> 函数可以获取):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> Repository, get_full_repo_name
model_name = <span class="hljs-string">&quot;bert-finetuned-squad-accelerate&quot;</span>
repo_name = get_full_repo_name(model_name)
repo_name<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;sgugger/bert-finetuned-squad-accelerate&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hccqr1">然后,我们可以将该存储库克隆到本地文件夹中。如果在设定的目录中已经存在一个同名的文件夹,那么这个本地文件夹应该是我们正在使用的仓库克隆在本地的版本,否则它会报错:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->output_dir = <span class="hljs-string">&quot;bert-finetuned-squad-accelerate&quot;</span>
repo = Repository(output_dir, clone_from=repo_name)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pq6saw">现在,我们可以通过调用 <code>repo.push_to_hub()</code> 方法上传保存在 <code>output_dir</code> 中的所有内容。这将帮助我们在每个时期结束时上传中间模型。</p> <h2 class="relative group"><a id="训练循环" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#训练循环"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>训练循环</span></h2> <p data-svelte-h="svelte-13swx0f">现在,我们准备编写完整的训练循环。在定义一个进度条以跟踪训练进度之后,循环分为三个部分:</p> <ul data-svelte-h="svelte-7cwya3"><li><p>训练本身,即对 <code>train_dataloader</code> 进行迭代,模型前向传播、反向传播和优化器更新。</p></li> <li><p>评估,我们将遍历整个评估数据集,同时收集 <code>start_logits</code><code>end_logits</code> 中的所有值。完成评估循环后,我们会将所有结果汇总到一起。需要注意的是,由于 <code>Accelerator</code> 可能会在最后添加一些额外的样本,以确保每个进程中的样本数量相同,因此我们需要对这些数据进行截断,以防止多余样本影响最终结果。</p></li> <li><p>保存和上传,首先保存模型和 Tokenizer,然后调用 <code>repo.push_to_hub()</code> 。与之前一样,我们使用 <code>blocking=False</code> 参数告诉🤗 Hub 库在异步进程中推送。这样,训练将继续进行,而这个(需要很长时间的)上传指令将在后台异步执行。</p></li></ul> <p data-svelte-h="svelte-1gcnnbn">以下训练循环的完整代码:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-keyword">import</span> torch
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_train_epochs):
<span class="hljs-comment"># 训练</span>
model.train()
<span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_dataloader):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)
<span class="hljs-comment"># 评估</span>
model.<span class="hljs-built_in">eval</span>()
start_logits = []
end_logits = []
accelerator.<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Evaluation!&quot;</span>)
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> tqdm(eval_dataloader):
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(**batch)
start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())
start_logits = np.concatenate(start_logits)
end_logits = np.concatenate(end_logits)
start_logits = start_logits[: <span class="hljs-built_in">len</span>(validation_dataset)]
end_logits = end_logits[: <span class="hljs-built_in">len</span>(validation_dataset)]
metrics = compute_metrics(
start_logits, end_logits, validation_dataset, raw_datasets[<span class="hljs-string">&quot;validation&quot;</span>]
)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;epoch <span class="hljs-subst">{epoch}</span>:&quot;</span>, metrics)
<span class="hljs-comment"># 保存和上传</span>
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
<span class="hljs-keyword">if</span> accelerator.is_main_process:
tokenizer.save_pretrained(output_dir)
repo.push_to_hub(
commit_message=<span class="hljs-string">f&quot;Training in progress epoch <span class="hljs-subst">{epoch}</span>&quot;</span>, blocking=<span class="hljs-literal">False</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s3m652">如果这是你第一次看到使用🤗 Accelerate 保存的模型,请花点时间了解一下与之相关的三行代码</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-83zik6">第一行很好理解:它告诉所有进程在继续之前等待所有进程都到达该阶段。这是为了确保我们在保存之前,在每个进程中都有相同的模型。然后,我们获取 <code>unwrapped_model</code> ,它是我们定义的基本模型。 <code>accelerator.prepare()</code> 方法会更改模型来适应分布式训练,因此它不再具有 <code>save_pretrained()</code> 方法;使用 <code>accelerator.unwrap_model()</code> 方法可以撤消这个更改。最后,我们调用 <code>save_pretrained()</code> ,告诉该方法应该使用 <code>accelerator.save()</code> 保存模型 而不是 <code>torch.save()</code></p> <p data-svelte-h="svelte-pe9cmv">完成后,你应该拥有一个产生与使用 <code>Trainer</code> 训练的模型非常相似的结果的模型。你可以在 <a href="https://huggingface.co/huggingface-course/bert-finetuned-squad-accelerate" rel="nofollow"><code>huggingface-course/bert-finetuned-squad-accelerate</code></a> 查看我们使用此代码训练的模型。如果你想测试对训练循环进行的任何调整,可以直接通过编辑上面显示的代码来实现!</p> <h2 class="relative group"><a id="使用微调模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#使用微调模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>使用微调模型</span></h2> <p data-svelte-h="svelte-rxroqt">我们已经向你展示了如何使用在模型中心上进行微调的模型,并使用推理小部件进行测试。要在本地使用 <code>pipeline</code> 来使用微调的模型,你只需指定模型标识符:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline
<span class="hljs-comment"># 将其替换为你自己的 checkpoint</span>
model_checkpoint = <span class="hljs-string">&quot;huggingface-course/bert-finetuned-squad&quot;</span>
question_answerer = pipeline(<span class="hljs-string">&quot;question-answering&quot;</span>, model=model_checkpoint)
context = <span class="hljs-string">&quot;&quot;&quot;
🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration
between them. It&#x27;s straightforward to train your models with one before loading them for inference with the other.
&quot;&quot;&quot;</span>
question = <span class="hljs-string">&quot;Which deep learning libraries back 🤗 Transformers?&quot;</span>
question_answerer(question=question, context=context)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;score&#x27;</span>: <span class="hljs-number">0.9979003071784973</span>,
<span class="hljs-string">&#x27;start&#x27;</span>: <span class="hljs-number">78</span>,
<span class="hljs-string">&#x27;end&#x27;</span>: <span class="hljs-number">105</span>,
<span class="hljs-string">&#x27;answer&#x27;</span>: <span class="hljs-string">&#x27;Jax, PyTorch and TensorFlow&#x27;</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-enmepv">很棒!我们的模型与 pipeline 的默认模型一样有效!</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/zh-CN/chapter7/7.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_9aawfw = {
assets: "/docs/course/pr_1021/zh-CN",
base: "/docs/course/pr_1021/zh-CN",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/course/pr_1021/zh-CN/_app/immutable/entry/start.f3a1a511.js"),
import("/docs/course/pr_1021/zh-CN/_app/immutable/entry/app.c39e37cf.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 59],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
195 kB
·
Xet hash:
83e3025c8f3d1683deeb4b53ba1af8ffa6015e0e2d931208bcb53f96c8c9074c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.