Buckets:

rtrm's picture
download
raw
201 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Summarization&quot;,&quot;local&quot;:&quot;summarization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Preparing a multilingual corpus&quot;,&quot;local&quot;:&quot;preparing-a-multilingual-corpus&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Models for text summarization&quot;,&quot;local&quot;:&quot;models-for-text-summarization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preprocessing the data&quot;,&quot;local&quot;:&quot;preprocessing-the-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Metrics for text summarization&quot;,&quot;local&quot;:&quot;metrics-for-text-summarization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Creating a strong baseline&quot;,&quot;local&quot;:&quot;creating-a-strong-baseline&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning mT5 with the Trainer API&quot;,&quot;local&quot;:&quot;fine-tuning-mt5-with-the-trainer-api&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning mT5 with Keras&quot;,&quot;local&quot;:&quot;fine-tuning-mt5-with-keras&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning mT5 with 🤗 Accelerate&quot;,&quot;local&quot;:&quot;fine-tuning-mt5-with-accelerate&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Preparing everything for training&quot;,&quot;local&quot;:&quot;preparing-everything-for-training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Training loop&quot;,&quot;local&quot;:&quot;training-loop&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using your fine-tuned model&quot;,&quot;local&quot;:&quot;using-your-fine-tuned-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/course/pr_1069/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/scheduler.37c15a92.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/singletons.bc78d867.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.18351ede.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/paths.76894643.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.7cb9c9b8.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/0.f5347c47.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/81.0bf9209d.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Tip.d10b3fc9.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Youtube.8666c400.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CodeBlock.abae2786.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CourseFloatingBanner.df82c153.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/FrameworkSwitchCourse.97630871.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/getInferenceSnippets.f9350a3f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Summarization&quot;,&quot;local&quot;:&quot;summarization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Preparing a multilingual corpus&quot;,&quot;local&quot;:&quot;preparing-a-multilingual-corpus&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Models for text summarization&quot;,&quot;local&quot;:&quot;models-for-text-summarization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preprocessing the data&quot;,&quot;local&quot;:&quot;preprocessing-the-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Metrics for text summarization&quot;,&quot;local&quot;:&quot;metrics-for-text-summarization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Creating a strong baseline&quot;,&quot;local&quot;:&quot;creating-a-strong-baseline&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning mT5 with the Trainer API&quot;,&quot;local&quot;:&quot;fine-tuning-mt5-with-the-trainer-api&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning mT5 with Keras&quot;,&quot;local&quot;:&quot;fine-tuning-mt5-with-keras&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning mT5 with 🤗 Accelerate&quot;,&quot;local&quot;:&quot;fine-tuning-mt5-with-accelerate&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Preparing everything for training&quot;,&quot;local&quot;:&quot;preparing-everything-for-training&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Training loop&quot;,&quot;local&quot;:&quot;training-loop&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using your fine-tuned model&quot;,&quot;local&quot;:&quot;using-your-fine-tuned-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="summarization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#summarization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Summarization</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-7-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter7/section5_pt.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter7/section5_pt.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-vih9wa">In this section we’ll take a look at how Transformer models can be used to condense long documents into summaries, a task known as <em>text summarization</em>. This is one of the most challenging NLP tasks as it requires a range of abilities, such as understanding long passages and generating coherent text that captures the main topics in a document. However, when done well, text summarization is a powerful tool that can speed up various business processes by relieving the burden of domain experts to read long documents in detail.</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/yHnr5Dk2zCI" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1gh0rb6">Although there already exist various fine-tuned models for summarization on the <a href="https://huggingface.co/models?pipeline_tag=summarization&sort=downloads" rel="nofollow">Hugging Face Hub</a>, almost all of these are only suitable for English documents. So, to add a twist in this section, we’ll train a bilingual model for English and Spanish. By the end of this section, you’ll have a <a href="https://huggingface.co/huggingface-course/mt5-small-finetuned-amazon-en-es" rel="nofollow">model</a> that can summarize customer reviews like the one shown here:</p> <iframe src="https://course-demos-mt5-small-finetuned-amazon-en-es.hf.space" frameborder="0" height="400" title="Gradio app" class="block dark:hidden container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe> <p data-svelte-h="svelte-15nukw4">As we’ll see, these summaries are concise because they’re learned from the titles that customers provide in their product reviews. Let’s start by putting together a suitable bilingual corpus for this task.</p> <h2 class="relative group"><a id="preparing-a-multilingual-corpus" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preparing-a-multilingual-corpus"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preparing a multilingual corpus</span></h2> <p data-svelte-h="svelte-182urio">We’ll use the <a href="https://huggingface.co/datasets/amazon_reviews_multi" rel="nofollow">Multilingual Amazon Reviews Corpus</a> to create our bilingual summarizer. This corpus consists of Amazon product reviews in six languages and is typically used to benchmark multilingual classifiers. However, since each review is accompanied by a short title, we can use the titles as the target summaries for our model to learn from! To get started, let’s download the English and Spanish subsets from the Hugging Face Hub:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
spanish_dataset = load_dataset(<span class="hljs-string">&quot;amazon_reviews_multi&quot;</span>, <span class="hljs-string">&quot;es&quot;</span>)
english_dataset = load_dataset(<span class="hljs-string">&quot;amazon_reviews_multi&quot;</span>, <span class="hljs-string">&quot;en&quot;</span>)
english_dataset<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;review_id&#x27;</span>, <span class="hljs-string">&#x27;product_id&#x27;</span>, <span class="hljs-string">&#x27;reviewer_id&#x27;</span>, <span class="hljs-string">&#x27;stars&#x27;</span>, <span class="hljs-string">&#x27;review_body&#x27;</span>, <span class="hljs-string">&#x27;review_title&#x27;</span>, <span class="hljs-string">&#x27;language&#x27;</span>, <span class="hljs-string">&#x27;product_category&#x27;</span>],
num_rows: <span class="hljs-number">200000</span>
})
validation: Dataset({
features: [<span class="hljs-string">&#x27;review_id&#x27;</span>, <span class="hljs-string">&#x27;product_id&#x27;</span>, <span class="hljs-string">&#x27;reviewer_id&#x27;</span>, <span class="hljs-string">&#x27;stars&#x27;</span>, <span class="hljs-string">&#x27;review_body&#x27;</span>, <span class="hljs-string">&#x27;review_title&#x27;</span>, <span class="hljs-string">&#x27;language&#x27;</span>, <span class="hljs-string">&#x27;product_category&#x27;</span>],
num_rows: <span class="hljs-number">5000</span>
})
test: Dataset({
features: [<span class="hljs-string">&#x27;review_id&#x27;</span>, <span class="hljs-string">&#x27;product_id&#x27;</span>, <span class="hljs-string">&#x27;reviewer_id&#x27;</span>, <span class="hljs-string">&#x27;stars&#x27;</span>, <span class="hljs-string">&#x27;review_body&#x27;</span>, <span class="hljs-string">&#x27;review_title&#x27;</span>, <span class="hljs-string">&#x27;language&#x27;</span>, <span class="hljs-string">&#x27;product_category&#x27;</span>],
num_rows: <span class="hljs-number">5000</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fcazu2">As you can see, for each language there are 200,000 reviews for the <code>train</code> split, and 5,000 reviews for each of the <code>validation</code> and <code>test</code> splits. The review information we are interested in is contained in the <code>review_body</code> and <code>review_title</code> columns. Let’s take a look at a few examples by creating a simple function that takes a random sample from the training set with the techniques we learned in <a href="/course/chapter5">Chapter 5</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">show_samples</span>(<span class="hljs-params">dataset, num_samples=<span class="hljs-number">3</span>, seed=<span class="hljs-number">42</span></span>):
sample = dataset[<span class="hljs-string">&quot;train&quot;</span>].shuffle(seed=seed).select(<span class="hljs-built_in">range</span>(num_samples))
<span class="hljs-keyword">for</span> example <span class="hljs-keyword">in</span> sample:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;\n&#x27;&gt;&gt; Title: <span class="hljs-subst">{example[<span class="hljs-string">&#x27;review_title&#x27;</span>]}</span>&#x27;&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt; Review: <span class="hljs-subst">{example[<span class="hljs-string">&#x27;review_body&#x27;</span>]}</span>&#x27;&quot;</span>)
show_samples(english_dataset)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt; Title: Worked in front position, not rear&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: 3 stars because these are not rear brakes as stated in the item description. At least the mount adapter only worked on the front fork of the bike that I got it for.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Title: meh&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: Does it’s job and it’s gorgeous but mine is falling apart, I had to basically put it together again with hot glue&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Title: Can\&#x27;t beat these for the money&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: Bought this for handling miscellaneous aircraft parts and hanger &quot;stuff&quot; that I needed to organize; it really fit the bill. The unit arrived quickly, was well packaged and arrived intact (always a good sign). There are five wall mounts-- three on the top and two on the bottom. I wanted to mount it on the wall, so all I had to do was to remove the top two layers of plastic drawers, as well as the bottom corner drawers, place it when I wanted and mark it; I then used some of the new plastic screw in wall anchors (the 50 pound variety) and it easily mounted to the wall. Some have remarked that they wanted dividers for the drawers, and that they made those. Good idea. My application was that I needed something that I can see the contents at about eye level, so I wanted the fuller-sized drawers. I also like that these are the new plastic that doesn\&#x27;t get brittle and split like my older plastic drawers did. I like the all-plastic construction. It\&#x27;s heavy duty enough to hold metal parts, but being made of plastic it\&#x27;s not as heavy as a metal frame, so you can easily mount it to the wall and still load it up with heavy stuff, or light stuff. No problem there. For the money, you can\&#x27;t beat it. Best one of these I\&#x27;ve bought to date-- and I\&#x27;ve been using some version of these for over forty years.&#x27;</span><!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1oj9qcf">✏️ <strong>Try it out!</strong> Change the random seed in the <code>Dataset.shuffle()</code> command to explore other reviews in the corpus. If you’re a Spanish speaker, take a look at some of the reviews in <code>spanish_dataset</code> to see if the titles also seem like reasonable summaries.</p></div> <p data-svelte-h="svelte-hpyi6w">This sample shows the diversity of reviews one typically finds online, ranging from positive to negative (and everything in between!). Although the example with the “meh” title is not very informative, the other titles look like decent summaries of the reviews themselves. Training a summarization model on all 400,000 reviews would take far too long on a single GPU, so instead we’ll focus on generating summaries for a single domain of products. To get a feel for what domains we can choose from, let’s convert <code>english_dataset</code> to a <code>pandas.DataFrame</code> and compute the number of reviews per product category:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->english_dataset.set_format(<span class="hljs-string">&quot;pandas&quot;</span>)
english_df = english_dataset[<span class="hljs-string">&quot;train&quot;</span>][:]
<span class="hljs-comment"># Show counts for top 20 products</span>
english_df[<span class="hljs-string">&quot;product_category&quot;</span>].value_counts()[:<span class="hljs-number">20</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->home <span class="hljs-number">17679</span>
apparel <span class="hljs-number">15951</span>
wireless <span class="hljs-number">15717</span>
other <span class="hljs-number">13418</span>
beauty <span class="hljs-number">12091</span>
drugstore <span class="hljs-number">11730</span>
kitchen <span class="hljs-number">10382</span>
toy <span class="hljs-number">8745</span>
sports <span class="hljs-number">8277</span>
automotive <span class="hljs-number">7506</span>
lawn_and_garden <span class="hljs-number">7327</span>
home_improvement <span class="hljs-number">7136</span>
pet_products <span class="hljs-number">7082</span>
digital_ebook_purchase <span class="hljs-number">6749</span>
pc <span class="hljs-number">6401</span>
electronics <span class="hljs-number">6186</span>
office_product <span class="hljs-number">5521</span>
shoes <span class="hljs-number">5197</span>
grocery <span class="hljs-number">4730</span>
book <span class="hljs-number">3756</span>
Name: product_category, dtype: int64<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-bxczij">The most popular products in the English dataset are about household items, clothing, and wireless electronics. To stick with the Amazon theme, though, let’s focus on summarizing book reviews — after all, this is what the company was founded on! We can see two product categories that fit the bill (<code>book</code> and <code>digital_ebook_purchase</code>), so let’s filter the datasets in both languages for just these products. As we saw in <a href="/course/chapter5">Chapter 5</a>, the <code>Dataset.filter()</code> function allows us to slice a dataset very efficiently, so we can define a simple function to do this:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">filter_books</span>(<span class="hljs-params">example</span>):
<span class="hljs-keyword">return</span> (
example[<span class="hljs-string">&quot;product_category&quot;</span>] == <span class="hljs-string">&quot;book&quot;</span>
<span class="hljs-keyword">or</span> example[<span class="hljs-string">&quot;product_category&quot;</span>] == <span class="hljs-string">&quot;digital_ebook_purchase&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-724upw">Now when we apply this function to <code>english_dataset</code> and <code>spanish_dataset</code>, the result will contain just those rows involving the book categories. Before applying the filter, let’s switch the format of <code>english_dataset</code> from <code>&quot;pandas&quot;</code> back to <code>&quot;arrow&quot;</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->english_dataset.reset_format()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-bvpqh">We can then apply the filter function, and as a sanity check let’s inspect a sample of reviews to see if they are indeed about books:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->spanish_books = spanish_dataset.<span class="hljs-built_in">filter</span>(filter_books)
english_books = english_dataset.<span class="hljs-built_in">filter</span>(filter_books)
show_samples(english_books)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt; Title: I\&#x27;m dissapointed.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: I guess I had higher expectations for this book from the reviews. I really thought I\&#x27;d at least like it. The plot idea was great. I loved Ash but, it just didnt go anywhere. Most of the book was about their radio show and talking to callers. I wanted the author to dig deeper so we could really get to know the characters. All we know about Grace is that she is attractive looking, Latino and is kind of a brat. I\&#x27;m dissapointed.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Title: Good art, good price, poor design&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: I had gotten the DC Vintage calendar the past two years, but it was on backorder forever this year and I saw they had shrunk the dimensions for no good reason. This one has good art choices but the design has the fold going through the picture, so it\&#x27;s less aesthetically pleasing, especially if you want to keep a picture to hang. For the price, a good calendar&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Title: Helpful&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: Nearly all the tips useful and. I consider myself an intermediate to advanced user of OneNote. I would highly recommend.&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-19bhr1q">Okay, we can see that the reviews are not strictly about books and might refer to things like calendars and electronic applications such as OneNote. Nevertheless, the domain seems about right to train a summarization model on. Before we look at various models that are suitable for this task, we have one last bit of data preparation to do: combining the English and Spanish reviews as a single <code>DatasetDict</code> object. 🤗 Datasets provides a handy <code>concatenate_datasets()</code> function that (as the name suggests) will stack two <code>Dataset</code> objects on top of each other. So, to create our bilingual dataset, we’ll loop over each split, concatenate the datasets for that split, and shuffle the result to ensure our model doesn’t overfit to a single language:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> concatenate_datasets, DatasetDict
books_dataset = DatasetDict()
<span class="hljs-keyword">for</span> split <span class="hljs-keyword">in</span> english_books.keys():
books_dataset[split] = concatenate_datasets(
[english_books[split], spanish_books[split]]
)
books_dataset[split] = books_dataset[split].shuffle(seed=<span class="hljs-number">42</span>)
<span class="hljs-comment"># Peek at a few examples</span>
show_samples(books_dataset)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt; Title: Easy to follow!!!!&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: I loved The dash diet weight loss Solution. Never hungry. I would recommend this diet. Also the menus are well rounded. Try it. Has lots of the information need thanks.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Title: PARCIALMENTE DAÑADO&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: Me llegó el día que tocaba, junto a otros libros que pedí, pero la caja llegó en mal estado lo cual dañó las esquinas de los libros porque venían sin protección (forro).&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Title: no lo he podido descargar&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt; Review: igual que el anterior&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1rxhw56">This certainly looks like a mix of English and Spanish reviews! Now that we have a training corpus, one final thing to check is the distribution of words in the reviews and their titles. This is especially important for summarization tasks, where short reference summaries in the data can bias the model to only output one or two words in the generated summaries. The plots below show the word distributions, and we can see that the titles are heavily skewed toward just 1-2 words:</p> <div class="flex justify-center" data-svelte-h="svelte-1pwnbfd"><img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/review-lengths.svg" alt="Word count distributions for the review titles and texts."> <img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/review-lengths-dark.svg" alt="Word count distributions for the review titles and texts."></div> <p data-svelte-h="svelte-y3cfja">To deal with this, we’ll filter out the examples with very short titles so that our model can produce more interesting summaries. Since we’re dealing with English and Spanish texts, we can use a rough heuristic to split the titles on whitespace and then use our trusty <code>Dataset.filter()</code> method as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->books_dataset = books_dataset.<span class="hljs-built_in">filter</span>(<span class="hljs-keyword">lambda</span> x: <span class="hljs-built_in">len</span>(x[<span class="hljs-string">&quot;review_title&quot;</span>].split()) &gt; <span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-802ffu">Now that we’ve prepared our corpus, let’s take a look at a few possible Transformer models that one might fine-tune on it!</p> <h2 class="relative group"><a id="models-for-text-summarization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#models-for-text-summarization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Models for text summarization</span></h2> <p data-svelte-h="svelte-1unhc81">If you think about it, text summarization is a similar sort of task to machine translation: we have a body of text like a review that we’d like to “translate” into a shorter version that captures the salient features of the input. Accordingly, most Transformer models for summarization adopt the encoder-decoder architecture that we first encountered in <a href="/course/chapter1">Chapter 1</a>, although there are some exceptions like the GPT family of models which can also be used for summarization in few-shot settings. The following table lists some popular pretrained models that can be fine-tuned for summarization.</p> <table data-svelte-h="svelte-1oone28"><thead><tr><th align="center">Transformer model</th> <th>Description</th> <th align="center">Multilingual?</th></tr></thead> <tbody><tr><td align="center"><a href="https://huggingface.co/gpt2-xl" rel="nofollow">GPT-2</a></td> <td>Although trained as an auto-regressive language model, you can make GPT-2 generate summaries by appending “TL;DR” at the end of the input text.</td> <td align="center"></td></tr> <tr><td align="center"><a href="https://huggingface.co/google/pegasus-large" rel="nofollow">PEGASUS</a></td> <td>Uses a pretraining objective to predict masked sentences in multi-sentence texts. This pretraining objective is closer to summarization than vanilla language modeling and scores highly on popular benchmarks.</td> <td align="center"></td></tr> <tr><td align="center"><a href="https://huggingface.co/t5-base" rel="nofollow">T5</a></td> <td>A universal Transformer architecture that formulates all tasks in a text-to-text framework; e.g., the input format for the model to summarize a document is <code>summarize: ARTICLE</code>.</td> <td align="center"></td></tr> <tr><td align="center"><a href="https://huggingface.co/google/mt5-base" rel="nofollow">mT5</a></td> <td>A multilingual version of T5, pretrained on the multilingual Common Crawl corpus (mC4), covering 101 languages.</td> <td align="center"></td></tr> <tr><td align="center"><a href="https://huggingface.co/facebook/bart-base" rel="nofollow">BART</a></td> <td>A novel Transformer architecture with both an encoder and a decoder stack trained to reconstruct corrupted input that combines the pretraining schemes of BERT and GPT-2.</td> <td align="center"></td></tr> <tr><td align="center"><a href="https://huggingface.co/facebook/mbart-large-50" rel="nofollow">mBART-50</a></td> <td>A multilingual version of BART, pretrained on 50 languages.</td> <td align="center"></td></tr></tbody></table> <p data-svelte-h="svelte-qa3zel">As you can see from this table, the majority of Transformer models for summarization (and indeed most NLP tasks) are monolingual. This is great if your task is in a “high-resource” language like English or German, but less so for the thousands of other languages in use across the world. Fortunately, there is a class of multilingual Transformer models, like mT5 and mBART, that come to the rescue. These models are pretrained using language modeling, but with a twist: instead of training on a corpus of one language, they are trained jointly on texts in over 50 languages at once!</p> <p data-svelte-h="svelte-4rtqmb">We’ll focus on mT5, an interesting architecture based on T5 that was pretrained in a text-to-text framework. In T5, every NLP task is formulated in terms of a prompt prefix like <code>summarize:</code> which conditions the model to adapt the generated text to the prompt. As shown in the figure below, this makes T5 extremely versatile, as you can solve many tasks with a single model!</p> <div class="flex justify-center" data-svelte-h="svelte-k7lnur"><img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/t5.svg" alt="Different tasks performed by the T5 architecture."> <img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/t5-dark.svg" alt="Different tasks performed by the T5 architecture."></div> <p data-svelte-h="svelte-1iimti4">mT5 doesn’t use prefixes, but shares much of the versatility of T5 and has the advantage of being multilingual. Now that we’ve picked a model, let’s take a look at preparing our data for training.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-rgei27">✏️ <strong>Try it out!</strong> Once you’ve worked through this section, see how well mT5 compares to mBART by fine-tuning the latter with the same techniques. For bonus points, you can also try fine-tuning T5 on just the English reviews. Since T5 has a special prefix prompt, you’ll need to prepend <code>summarize:</code> to the input examples in the preprocessing steps below.</p></div> <h2 class="relative group"><a id="preprocessing-the-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preprocessing-the-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preprocessing the data</span></h2> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/1m7BerpSq8A" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-165o15w">Our next task is to tokenize and encode our reviews and their titles. As usual, we begin by loading the tokenizer associated with the pretrained model checkpoint. We’ll use <code>mt5-small</code> as our checkpoint so we can fine-tune the model in a reasonable amount of time:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer
model_checkpoint = <span class="hljs-string">&quot;google/mt5-small&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-gxjsqb">💡 In the early stages of your NLP projects, a good practice is to train a class of “small” models on a small sample of data. This allows you to debug and iterate faster toward an end-to-end workflow. Once you are confident in the results, you can always scale up the model by simply changing the model checkpoint!</p></div> <p data-svelte-h="svelte-1olylm2">Let’s test out the mT5 tokenizer on a small example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->inputs = tokenizer(<span class="hljs-string">&quot;I loved reading the Hunger Games!&quot;</span>)
inputs<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;input_ids&#x27;</span>: [<span class="hljs-number">336</span>, <span class="hljs-number">259</span>, <span class="hljs-number">28387</span>, <span class="hljs-number">11807</span>, <span class="hljs-number">287</span>, <span class="hljs-number">62893</span>, <span class="hljs-number">295</span>, <span class="hljs-number">12507</span>, <span class="hljs-number">1</span>], <span class="hljs-string">&#x27;attention_mask&#x27;</span>: [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10aknob">Here we can see the familiar <code>input_ids</code> and <code>attention_mask</code> that we encountered in our first fine-tuning experiments back in <a href="/course/chapter3">Chapter 3</a>. Let’s decode these input IDs with the tokenizer’s <code>convert_ids_to_tokens()</code> function to see what kind of tokenizer we’re dealing with:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer.convert_ids_to_tokens(inputs.input_ids)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-string">&#x27;▁I&#x27;</span>, <span class="hljs-string">&#x27;&#x27;</span>, <span class="hljs-string">&#x27;loved&#x27;</span>, <span class="hljs-string">&#x27;▁reading&#x27;</span>, <span class="hljs-string">&#x27;▁the&#x27;</span>, <span class="hljs-string">&#x27;▁Hung&#x27;</span>, <span class="hljs-string">&#x27;er&#x27;</span>, <span class="hljs-string">&#x27;▁Games&#x27;</span>, <span class="hljs-string">&#x27;&lt;/s&gt;&#x27;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k3p6bc">The special Unicode character <code></code> and end-of-sequence token <code>&lt;/s&gt;</code> indicate that we’re dealing with the SentencePiece tokenizer, which is based on the Unigram segmentation algorithm discussed in <a href="/course/chapter6">Chapter 6</a>. Unigram is especially useful for multilingual corpora since it allows SentencePiece to be agnostic about accents, punctuation, and the fact that many languages, like Japanese, do not have whitespace characters.</p> <p data-svelte-h="svelte-1xp22s9">To tokenize our corpus, we have to deal with a subtlety associated with summarization: because our labels are also text, it is possible that they exceed the model’s maximum context size. This means we need to apply truncation to both the reviews and their titles to ensure we don’t pass excessively long inputs to our model. The tokenizers in 🤗 Transformers provide a nifty <code>text_target</code> argument that allows you to tokenize the labels in parallel to the inputs. Here is an example of how the inputs and targets are processed for mT5:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->max_input_length = <span class="hljs-number">512</span>
max_target_length = <span class="hljs-number">30</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>):
model_inputs = tokenizer(
examples[<span class="hljs-string">&quot;review_body&quot;</span>],
max_length=max_input_length,
truncation=<span class="hljs-literal">True</span>,
)
labels = tokenizer(
examples[<span class="hljs-string">&quot;review_title&quot;</span>], max_length=max_target_length, truncation=<span class="hljs-literal">True</span>
)
model_inputs[<span class="hljs-string">&quot;labels&quot;</span>] = labels[<span class="hljs-string">&quot;input_ids&quot;</span>]
<span class="hljs-keyword">return</span> model_inputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cdpb6n">Let’s walk through this code to understand what’s happening. The first thing we’ve done is define values for <code>max_input_length</code> and <code>max_target_length</code>, which set the upper limits for how long our reviews and titles can be. Since the review body is typically much larger than the title, we’ve scaled these values accordingly.</p> <p data-svelte-h="svelte-rm19fl">With <code>preprocess_function()</code>, it is then a simple matter to tokenize the whole corpus using the handy <code>Dataset.map()</code> function we’ve used extensively throughout this course:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenized_datasets = books_dataset.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1bchku6">Now that the corpus has been preprocessed, let’s take a look at some metrics that are commonly used for summarization. As we’ll see, there is no silver bullet when it comes to measuring the quality of machine-generated text.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1m43pbb">💡 You may have noticed that we used <code>batched=True</code> in our <code>Dataset.map()</code> function above. This encodes the examples in batches of 1,000 (the default) and allows you to make use of the multithreading capabilities of the fast tokenizers in 🤗 Transformers. Where possible, try using <code>batched=True</code> to get the most out of your preprocessing!</p></div> <h2 class="relative group"><a id="metrics-for-text-summarization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#metrics-for-text-summarization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Metrics for text summarization</span></h2> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/TMshhnrEXlg" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-ft2w74">In comparison to most of the other tasks we’ve covered in this course, measuring the performance of text generation tasks like summarization or translation is not as straightforward. For example, given a review like “I loved reading the Hunger Games”, there are multiple valid summaries, like “I loved the Hunger Games” or “Hunger Games is a great read”. Clearly, applying some sort of exact match between the generated summary and the label is not a good solution — even humans would fare poorly under such a metric, because we all have our own writing style.</p> <p data-svelte-h="svelte-ffia08">For summarization, one of the most commonly used metrics is the <a href="https://en.wikipedia.org/wiki/ROUGE_(metric)" rel="nofollow">ROUGE score</a> (short for Recall-Oriented Understudy for Gisting Evaluation). The basic idea behind this metric is to compare a generated summary against a set of reference summaries that are typically created by humans. To make this more precise, suppose we want to compare the following two summaries:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->generated_summary = <span class="hljs-string">&quot;I absolutely loved reading the Hunger Games&quot;</span>
reference_summary = <span class="hljs-string">&quot;I loved reading the Hunger Games&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pifmxc">One way to compare them could be to count the number of overlapping words, which in this case would be 6. However, this is a bit crude, so instead ROUGE is based on computing the <em>precision</em> and <em>recall</em> scores for the overlap.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1ggbwee">🙋 Don’t worry if this is the first time you’ve heard of precision and recall — we’ll go through some explicit examples together to make it all clear. These metrics are usually encountered in classification tasks, so if you want to understand how precision and recall are defined in that context, we recommend checking out the <code>scikit-learn</code> <a href="https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html" rel="nofollow">guides</a>.</p></div> <p>For ROUGE, recall measures how much of the reference summary is captured by the generated one. If we are just comparing words, recall can be calculated according to the following formula:
<!-- HTML_TAG_START --><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">c</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">l</mi><mi mathvariant="normal">l</mi></mrow><mo>=</mo><mfrac><mrow><mi mathvariant="normal">N</mi><mi mathvariant="normal">u</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">b</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mtext></mtext><mi mathvariant="normal">o</mi><mi mathvariant="normal">f</mi><mtext></mtext><mi mathvariant="normal">o</mi><mi mathvariant="normal">v</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">l</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">p</mi><mi mathvariant="normal">p</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">n</mi><mi mathvariant="normal">g</mi><mtext></mtext><mi mathvariant="normal">w</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">d</mi><mi mathvariant="normal">s</mi></mrow><mrow><mi mathvariant="normal">T</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">t</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">l</mi><mtext></mtext><mi mathvariant="normal">n</mi><mi mathvariant="normal">u</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">b</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mtext></mtext><mi mathvariant="normal">o</mi><mi mathvariant="normal">f</mi><mtext></mtext><mi mathvariant="normal">w</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">d</mi><mi mathvariant="normal">s</mi><mtext></mtext><mi mathvariant="normal">i</mi><mi mathvariant="normal">n</mi><mtext></mtext><mi mathvariant="normal">r</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">f</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">n</mi><mi mathvariant="normal">c</mi><mi mathvariant="normal">e</mi><mtext></mtext><mi mathvariant="normal">s</mi><mi mathvariant="normal">u</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">y</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex"> \mathrm{Recall} = \frac{\mathrm{Number\,of\,overlapping\, words}}{\mathrm{Total\, number\, of\, words\, in\, reference\, summary}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord"><span class="mord mathrm">Recall</span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:2.2519em;vertical-align:-0.8804em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3714em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">Total</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">number</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.07778em;">of</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">words</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">in</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">reference</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.01389em;">summary</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">Number</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.07778em;">of</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.01389em;">overlapping</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">words</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8804em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span><!-- HTML_TAG_END --></p> <p>For our simple example above, this formula gives a perfect recall of 6/6 = 1; i.e., all the words in the reference summary have been produced by the model. This may sound great, but imagine if our generated summary had been “I really really loved reading the Hunger Games all night”. This would also have perfect recall, but is arguably a worse summary since it is verbose. To deal with these scenarios we also compute the precision, which in the ROUGE context measures how much of the generated summary was relevant:
<!-- HTML_TAG_START --><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mrow><mi mathvariant="normal">P</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">c</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">s</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">n</mi></mrow><mo>=</mo><mfrac><mrow><mi mathvariant="normal">N</mi><mi mathvariant="normal">u</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">b</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mtext></mtext><mi mathvariant="normal">o</mi><mi mathvariant="normal">f</mi><mtext></mtext><mi mathvariant="normal">o</mi><mi mathvariant="normal">v</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">l</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">p</mi><mi mathvariant="normal">p</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">n</mi><mi mathvariant="normal">g</mi><mtext></mtext><mi mathvariant="normal">w</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">d</mi><mi mathvariant="normal">s</mi></mrow><mrow><mi mathvariant="normal">T</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">t</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">l</mi><mtext></mtext><mi mathvariant="normal">n</mi><mi mathvariant="normal">u</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">b</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mtext></mtext><mi mathvariant="normal">o</mi><mi mathvariant="normal">f</mi><mtext></mtext><mi mathvariant="normal">w</mi><mi mathvariant="normal">o</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">d</mi><mi mathvariant="normal">s</mi><mtext></mtext><mi mathvariant="normal">i</mi><mi mathvariant="normal">n</mi><mtext></mtext><mi mathvariant="normal">g</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">n</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">t</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">d</mi><mtext></mtext><mi mathvariant="normal">s</mi><mi mathvariant="normal">u</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">m</mi><mi mathvariant="normal">a</mi><mi mathvariant="normal">r</mi><mi mathvariant="normal">y</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex"> \mathrm{Precision} = \frac{\mathrm{Number\,of\,overlapping\, words}}{\mathrm{Total\, number\, of\, words\, in\, generated\, summary}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord"><span class="mord mathrm">Precision</span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:2.2519em;vertical-align:-0.8804em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.3714em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">Total</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">number</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.07778em;">of</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">words</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">in</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">generated</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.01389em;">summary</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">Number</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.07778em;">of</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm" style="margin-right:0.01389em;">overlapping</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathrm">words</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8804em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span><!-- HTML_TAG_END --></p> <p data-svelte-h="svelte-128ef20">Applying this to our verbose summary gives a precision of 6/10 = 0.6, which is considerably worse than the precision of 6/7 = 0.86 obtained by our shorter one. In practice, both precision and recall are usually computed, and then the F1-score (the harmonic mean of precision and recall) is reported. We can do this easily in 🤗 Datasets by first installing the <code>rouge_score</code> package:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip install rouge_score<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1moj4p7">and then loading the ROUGE metric as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate
rouge_score = evaluate.load(<span class="hljs-string">&quot;rouge&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pm5m1">Then we can use the <code>rouge_score.compute()</code> function to calculate all the metrics at once:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->scores = rouge_score.compute(
predictions=[generated_summary], references=[reference_summary]
)
scores<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;rouge1&#x27;</span>: AggregateScore(low=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>), mid=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>), high=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>)),
<span class="hljs-string">&#x27;rouge2&#x27;</span>: AggregateScore(low=Score(precision=<span class="hljs-number">0.67</span>, recall=<span class="hljs-number">0.8</span>, fmeasure=<span class="hljs-number">0.73</span>), mid=Score(precision=<span class="hljs-number">0.67</span>, recall=<span class="hljs-number">0.8</span>, fmeasure=<span class="hljs-number">0.73</span>), high=Score(precision=<span class="hljs-number">0.67</span>, recall=<span class="hljs-number">0.8</span>, fmeasure=<span class="hljs-number">0.73</span>)),
<span class="hljs-string">&#x27;rougeL&#x27;</span>: AggregateScore(low=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>), mid=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>), high=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>)),
<span class="hljs-string">&#x27;rougeLsum&#x27;</span>: AggregateScore(low=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>), mid=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>), high=Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>))}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vxjqyc">Whoa, there’s a lot of information in that output — what does it all mean? First, 🤗 Datasets actually computes confidence intervals for precision, recall, and F1-score; these are the <code>low</code>, <code>mid</code>, and <code>high</code> attributes you can see here. Moreover, 🤗 Datasets computes a variety of ROUGE scores which are based on different types of text granularity when comparing the generated and reference summaries. The <code>rouge1</code> variant is the overlap of unigrams — this is just a fancy way of saying the overlap of words and is exactly the metric we’ve discussed above. To verify this, let’s pull out the <code>mid</code> value of our scores:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->scores[<span class="hljs-string">&quot;rouge1&quot;</span>].mid<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Score(precision=<span class="hljs-number">0.86</span>, recall=<span class="hljs-number">1.0</span>, fmeasure=<span class="hljs-number">0.92</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-tfvibj">Great, the precision and recall numbers match up! Now what about those other ROUGE scores? <code>rouge2</code> measures the overlap between bigrams (think the overlap of pairs of words), while <code>rougeL</code> and <code>rougeLsum</code> measure the longest matching sequences of words by looking for the longest common substrings in the generated and reference summaries. The “sum” in <code>rougeLsum</code> refers to the fact that this metric is computed over a whole summary, while <code>rougeL</code> is computed as the average over individual sentences.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1e616sx">✏️ <strong>Try it out!</strong> Create your own example of a generated and reference summary and see if the resulting ROUGE scores agree with a manual calculation based on the formulas for precision and recall. For bonus points, split the text into bigrams and compare the precision and recall for the <code>rouge2</code> metric.</p></div> <p data-svelte-h="svelte-i8acz1">We’ll use these ROUGE scores to track the performance of our model, but before doing that let’s do something every good NLP practitioner should do: create a strong, yet simple baseline!</p> <h3 class="relative group"><a id="creating-a-strong-baseline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#creating-a-strong-baseline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Creating a strong baseline</span></h3> <p data-svelte-h="svelte-1cuk3qa">A common baseline for text summarization is to simply take the first three sentences of an article, often called the <em>lead-3</em> baseline. We could use full stops to track the sentence boundaries, but this will fail on acronyms like “U.S.” or “U.N.” — so instead we’ll use the <code>nltk</code> library, which includes a better algorithm to handle these cases. You can install the package using <code>pip</code> as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip install nltk<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1phjm6q">and then download the punctuation rules:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> nltk
nltk.download(<span class="hljs-string">&quot;punkt&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18wyy6x">Next, we import the sentence tokenizer from <code>nltk</code> and create a simple function to extract the first three sentences in a review. The convention in text summarization is to separate each summary with a newline, so let’s also include this and test it on a training example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> nltk.tokenize <span class="hljs-keyword">import</span> sent_tokenize
<span class="hljs-keyword">def</span> <span class="hljs-title function_">three_sentence_summary</span>(<span class="hljs-params">text</span>):
<span class="hljs-keyword">return</span> <span class="hljs-string">&quot;\n&quot;</span>.join(sent_tokenize(text)[:<span class="hljs-number">3</span>])
<span class="hljs-built_in">print</span>(three_sentence_summary(books_dataset[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">1</span>][<span class="hljs-string">&quot;review_body&quot;</span>]))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;I grew up reading Koontz, and years ago, I stopped,convinced i had &quot;outgrown&quot; him.&#x27;</span>
<span class="hljs-string">&#x27;Still,when a friend was looking for something suspenseful too read, I suggested Koontz.&#x27;</span>
<span class="hljs-string">&#x27;She found Strangers.&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1isqbeo">This seems to work, so let’s now implement a function that extracts these “summaries” from a dataset and computes the ROUGE scores for the baseline:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">evaluate_baseline</span>(<span class="hljs-params">dataset, metric</span>):
summaries = [three_sentence_summary(text) <span class="hljs-keyword">for</span> text <span class="hljs-keyword">in</span> dataset[<span class="hljs-string">&quot;review_body&quot;</span>]]
<span class="hljs-keyword">return</span> metric.compute(predictions=summaries, references=dataset[<span class="hljs-string">&quot;review_title&quot;</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-65wxwj">We can then use this function to compute the ROUGE scores over the validation set and prettify them a bit using Pandas:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> pandas <span class="hljs-keyword">as</span> pd
score = evaluate_baseline(books_dataset[<span class="hljs-string">&quot;validation&quot;</span>], rouge_score)
rouge_names = [<span class="hljs-string">&quot;rouge1&quot;</span>, <span class="hljs-string">&quot;rouge2&quot;</span>, <span class="hljs-string">&quot;rougeL&quot;</span>, <span class="hljs-string">&quot;rougeLsum&quot;</span>]
rouge_dict = <span class="hljs-built_in">dict</span>((rn, <span class="hljs-built_in">round</span>(score[rn].mid.fmeasure * <span class="hljs-number">100</span>, <span class="hljs-number">2</span>)) <span class="hljs-keyword">for</span> rn <span class="hljs-keyword">in</span> rouge_names)
rouge_dict<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">16.74</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">8.83</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">15.6</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">15.96</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-154pk5t">We can see that the <code>rouge2</code> score is significantly lower than the rest; this likely reflects the fact that review titles are typically concise and so the lead-3 baseline is too verbose. Now that we have a good baseline to work from, let’s turn our attention toward fine-tuning mT5!</p> <h2 class="relative group"><a id="fine-tuning-mt5-with-the-trainer-api" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-mt5-with-the-trainer-api"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning mT5 with the Trainer API</span></h2> <p data-svelte-h="svelte-13fz6t3">Fine-tuning a model for summarization is very similar to the other tasks we’ve covered in this chapter. The first thing we need to do is load the pretrained model from the <code>mt5-small</code> checkpoint. Since summarization is a sequence-to-sequence task, we can load the model with the <code>AutoModelForSeq2SeqLM</code> class, which will automatically download and cache the weights:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1dyy1at">💡 If you’re wondering why you don’t see any warnings about fine-tuning the model on a downstream task, that’s because for sequence-to-sequence tasks we keep all the weights of the network. Compare this to our text classification model in <a href="/course/chapter3">Chapter 3</a>, where the head of the pretrained model was replaced with a randomly initialized network.</p></div> <p data-svelte-h="svelte-11h147n">The next thing we need to do is log in to the Hugging Face Hub. If you’re running this code in a notebook, you can do so with the following utility function:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8p0cke">which will display a widget where you can enter your credentials. Alternatively, you can run this command in your terminal and log in there:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-<span class="hljs-keyword">cli</span> login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lra7fb">We’ll need to generate summaries in order to compute ROUGE scores during training. Fortunately, 🤗 Transformers provides dedicated <code>Seq2SeqTrainingArguments</code> and <code>Seq2SeqTrainer</code> classes that can do this for us automatically! To see how this works, let’s first define the hyperparameters and other arguments for our experiments:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Seq2SeqTrainingArguments
batch_size = <span class="hljs-number">8</span>
num_train_epochs = <span class="hljs-number">8</span>
<span class="hljs-comment"># Show the training loss with every epoch</span>
logging_steps = <span class="hljs-built_in">len</span>(tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>]) // batch_size
model_name = model_checkpoint.split(<span class="hljs-string">&quot;/&quot;</span>)[-<span class="hljs-number">1</span>]
args = Seq2SeqTrainingArguments(
output_dir=<span class="hljs-string">f&quot;<span class="hljs-subst">{model_name}</span>-finetuned-amazon-en-es&quot;</span>,
evaluation_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
learning_rate=<span class="hljs-number">5.6e-5</span>,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=<span class="hljs-number">0.01</span>,
save_total_limit=<span class="hljs-number">3</span>,
num_train_epochs=num_train_epochs,
predict_with_generate=<span class="hljs-literal">True</span>,
logging_steps=logging_steps,
push_to_hub=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1kfg486">Here, the <code>predict_with_generate</code> argument has been set to indicate that we should generate summaries during evaluation so that we can compute ROUGE scores for each epoch. As discussed in <a href="/course/chapter1">Chapter 1</a>, the decoder performs inference by predicting tokens one by one, and this is implemented by the model’s <code>generate()</code> method. Setting <code>predict_with_generate=True</code> tells the <code>Seq2SeqTrainer</code> to use that method for evaluation. We’ve also adjusted some of the default hyperparameters, like the learning rate, number of epochs, and weight decay, and we’ve set the <code>save_total_limit</code> option to only save up to 3 checkpoints during training — this is because even the “small” version of mT5 uses around a GB of hard drive space, and we can save a bit of room by limiting the number of copies we save.</p> <p data-svelte-h="svelte-9f4al4">The <code>push_to_hub=True</code> argument will allow us to push the model to the Hub after training; you’ll find the repository under your user profile in the location defined by <code>output_dir</code>. Note that you can specify the name of the repository you want to push to with the <code>hub_model_id</code> argument (in particular, you will have to use this argument to push to an organization). For instance, when we pushed the model to the <a href="https://huggingface.co/huggingface-course" rel="nofollow"><code>huggingface-course</code> organization</a>, we added <code>hub_model_id=&quot;huggingface-course/mt5-finetuned-amazon-en-es&quot;</code> to <code>Seq2SeqTrainingArguments</code>.</p> <p data-svelte-h="svelte-1w4eq0d">The next thing we need to do is provide the trainer with a <code>compute_metrics()</code> function so that we can evaluate our model during training. For summarization this is a bit more involved than simply calling <code>rouge_score.compute()</code> on the model’s predictions, since we need to <em>decode</em> the outputs and labels into text before we can compute the ROUGE scores. The following function does exactly that, and also makes use of the <code>sent_tokenize()</code> function from <code>nltk</code> to separate the summary sentences with newlines:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>):
predictions, labels = eval_pred
<span class="hljs-comment"># Decode generated summaries into text</span>
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=<span class="hljs-literal">True</span>)
<span class="hljs-comment"># Replace -100 in the labels as we can&#x27;t decode them</span>
labels = np.where(labels != -<span class="hljs-number">100</span>, labels, tokenizer.pad_token_id)
<span class="hljs-comment"># Decode reference summaries into text</span>
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=<span class="hljs-literal">True</span>)
<span class="hljs-comment"># ROUGE expects a newline after each sentence</span>
decoded_preds = [<span class="hljs-string">&quot;\n&quot;</span>.join(sent_tokenize(pred.strip())) <span class="hljs-keyword">for</span> pred <span class="hljs-keyword">in</span> decoded_preds]
decoded_labels = [<span class="hljs-string">&quot;\n&quot;</span>.join(sent_tokenize(label.strip())) <span class="hljs-keyword">for</span> label <span class="hljs-keyword">in</span> decoded_labels]
<span class="hljs-comment"># Compute ROUGE scores</span>
result = rouge_score.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=<span class="hljs-literal">True</span>
)
<span class="hljs-comment"># Extract the median scores</span>
result = {key: value.mid.fmeasure * <span class="hljs-number">100</span> <span class="hljs-keyword">for</span> key, value <span class="hljs-keyword">in</span> result.items()}
<span class="hljs-keyword">return</span> {k: <span class="hljs-built_in">round</span>(v, <span class="hljs-number">4</span>) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> result.items()}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nhn5pq">Next, we need to define a data collator for our sequence-to-sequence task. Since mT5 is an encoder-decoder Transformer model, one subtlety with preparing our batches is that during decoding we need to shift the labels to the right by one. This is required to ensure that the decoder only sees the previous ground truth labels and not the current or future ones, which would be easy for the model to memorize. This is similar to how masked self-attention is applied to the inputs in a task like <a href="/course/chapter7/6">causal language modeling</a>.</p> <p data-svelte-h="svelte-1yx0z91">Luckily, 🤗 Transformers provides a <code>DataCollatorForSeq2Seq</code> collator that will dynamically pad the inputs and the labels for us. To instantiate this collator, we simply need to provide the <code>tokenizer</code> and <code>model</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dta7o8">Let’s see what this collator produces when fed a small batch of examples. First, we need to remove the columns with strings because the collator won’t know how to pad these elements:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenized_datasets = tokenized_datasets.remove_columns(
books_dataset[<span class="hljs-string">&quot;train&quot;</span>].column_names
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xaeqdc">Since the collator expects a list of <code>dict</code>s, where each <code>dict</code> represents a single example in the dataset, we also need to wrangle the data into the expected format before passing it to the data collator:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->features = [tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>][i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">2</span>)]
data_collator(features)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;attention_mask&#x27;</span>: tensor([[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>,
<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>],
[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>,
<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>]]), <span class="hljs-string">&#x27;input_ids&#x27;</span>: tensor([[ <span class="hljs-number">1494</span>, <span class="hljs-number">259</span>, <span class="hljs-number">8622</span>, <span class="hljs-number">390</span>, <span class="hljs-number">259</span>, <span class="hljs-number">262</span>, <span class="hljs-number">2316</span>, <span class="hljs-number">3435</span>, <span class="hljs-number">955</span>,
<span class="hljs-number">772</span>, <span class="hljs-number">281</span>, <span class="hljs-number">772</span>, <span class="hljs-number">1617</span>, <span class="hljs-number">263</span>, <span class="hljs-number">305</span>, <span class="hljs-number">14701</span>, <span class="hljs-number">260</span>, <span class="hljs-number">1385</span>,
<span class="hljs-number">3031</span>, <span class="hljs-number">259</span>, <span class="hljs-number">24146</span>, <span class="hljs-number">332</span>, <span class="hljs-number">1037</span>, <span class="hljs-number">259</span>, <span class="hljs-number">43906</span>, <span class="hljs-number">305</span>, <span class="hljs-number">336</span>,
<span class="hljs-number">260</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>],
[ <span class="hljs-number">259</span>, <span class="hljs-number">27531</span>, <span class="hljs-number">13483</span>, <span class="hljs-number">259</span>, <span class="hljs-number">7505</span>, <span class="hljs-number">260</span>, <span class="hljs-number">112240</span>, <span class="hljs-number">15192</span>, <span class="hljs-number">305</span>,
<span class="hljs-number">53198</span>, <span class="hljs-number">276</span>, <span class="hljs-number">259</span>, <span class="hljs-number">74060</span>, <span class="hljs-number">263</span>, <span class="hljs-number">260</span>, <span class="hljs-number">459</span>, <span class="hljs-number">25640</span>, <span class="hljs-number">776</span>,
<span class="hljs-number">2119</span>, <span class="hljs-number">336</span>, <span class="hljs-number">259</span>, <span class="hljs-number">2220</span>, <span class="hljs-number">259</span>, <span class="hljs-number">18896</span>, <span class="hljs-number">288</span>, <span class="hljs-number">4906</span>, <span class="hljs-number">288</span>,
<span class="hljs-number">1037</span>, <span class="hljs-number">3931</span>, <span class="hljs-number">260</span>, <span class="hljs-number">7083</span>, <span class="hljs-number">101476</span>, <span class="hljs-number">1143</span>, <span class="hljs-number">260</span>, <span class="hljs-number">1</span>]]), <span class="hljs-string">&#x27;labels&#x27;</span>: tensor([[ <span class="hljs-number">7483</span>, <span class="hljs-number">259</span>, <span class="hljs-number">2364</span>, <span class="hljs-number">15695</span>, <span class="hljs-number">1</span>, -<span class="hljs-number">100</span>],
[ <span class="hljs-number">259</span>, <span class="hljs-number">27531</span>, <span class="hljs-number">13483</span>, <span class="hljs-number">259</span>, <span class="hljs-number">7505</span>, <span class="hljs-number">1</span>]]), <span class="hljs-string">&#x27;decoder_input_ids&#x27;</span>: tensor([[ <span class="hljs-number">0</span>, <span class="hljs-number">7483</span>, <span class="hljs-number">259</span>, <span class="hljs-number">2364</span>, <span class="hljs-number">15695</span>, <span class="hljs-number">1</span>],
[ <span class="hljs-number">0</span>, <span class="hljs-number">259</span>, <span class="hljs-number">27531</span>, <span class="hljs-number">13483</span>, <span class="hljs-number">259</span>, <span class="hljs-number">7505</span>]])}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lmn2nq">The main thing to notice here is that the first example is longer than the second one, so the <code>input_ids</code> and <code>attention_mask</code> of the second example have been padded on the right with a <code>[PAD]</code> token (whose ID is <code>0</code>). Similarly, we can see that the <code>labels</code> have been padded with <code>-100</code>s, to make sure the padding tokens are ignored by the loss function. And finally, we can see a new <code>decoder_input_ids</code> which has shifted the labels to the right by inserting a <code>[PAD]</code> token in the first entry.</p> <p data-svelte-h="svelte-14wouu">We finally have all the ingredients we need to train with! We now simply need to instantiate the trainer with the standard arguments:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gr4lzn">and launch our training run:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8gu1qg">During training, you should see the training loss decrease and the ROUGE scores increase with each epoch. Once the training is complete, you can see the final ROUGE scores by running <code>Trainer.evaluate()</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.evaluate()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;eval_loss&#x27;</span>: <span class="hljs-number">3.028524398803711</span>,
<span class="hljs-string">&#x27;eval_rouge1&#x27;</span>: <span class="hljs-number">16.9728</span>,
<span class="hljs-string">&#x27;eval_rouge2&#x27;</span>: <span class="hljs-number">8.2969</span>,
<span class="hljs-string">&#x27;eval_rougeL&#x27;</span>: <span class="hljs-number">16.8366</span>,
<span class="hljs-string">&#x27;eval_rougeLsum&#x27;</span>: <span class="hljs-number">16.851</span>,
<span class="hljs-string">&#x27;eval_gen_len&#x27;</span>: <span class="hljs-number">10.1597</span>,
<span class="hljs-string">&#x27;eval_runtime&#x27;</span>: <span class="hljs-number">6.1054</span>,
<span class="hljs-string">&#x27;eval_samples_per_second&#x27;</span>: <span class="hljs-number">38.982</span>,
<span class="hljs-string">&#x27;eval_steps_per_second&#x27;</span>: <span class="hljs-number">4.914</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-c5gs0a">From the scores we can see that our model has handily outperformed our lead-3 baseline — nice! The final thing to do is push the model weights to the Hub, as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.push_to_hub(<span class="hljs-attribute">commit_message</span>=<span class="hljs-string">&quot;Training complete&quot;</span>, <span class="hljs-attribute">tags</span>=<span class="hljs-string">&quot;summarization&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;https://huggingface.co/huggingface-course/mt5-finetuned-amazon-en-es/commit/aa0536b829b28e73e1e4b94b8a5aacec420d40e0&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-euvgvz">This will save the checkpoint and configuration files to <code>output_dir</code>, before uploading all the files to the Hub. By specifying the <code>tags</code> argument, we also ensure that the widget on the Hub will be one for a summarization pipeline instead of the default text generation one associated with the mT5 architecture (for more information about model tags, see the <a href="https://huggingface.co/docs/hub/main#how-is-a-models-type-of-inference-api-and-widget-determined" rel="nofollow">🤗 Hub documentation</a>). The output from <code>trainer.push_to_hub()</code> is a URL to the Git commit hash, so you can easily see the changes that were made to the model repository!</p> <p data-svelte-h="svelte-thah2z">To wrap up this section, let’s take a look at how we can also fine-tune mT5 using the low-level features provided by 🤗 Accelerate.</p> <h2 class="relative group"><a id="fine-tuning-mt5-with-accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-mt5-with-accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning mT5 with 🤗 Accelerate</span></h2> <p data-svelte-h="svelte-uwq75m">Fine-tuning our model with 🤗 Accelerate is very similar to the text classification example we encountered in <a href="/course/chapter3">Chapter 3</a>. The main differences will be the need to explicitly generate our summaries during training and define how we compute the ROUGE scores (recall that the <code>Seq2SeqTrainer</code> took care of the generation for us). Let’s take a look how we can implement these two requirements within 🤗 Accelerate!</p> <h3 class="relative group"><a id="preparing-everything-for-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preparing-everything-for-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preparing everything for training</span></h3> <p data-svelte-h="svelte-1raz9tt">The first thing we need to do is create a <code>DataLoader</code> for each of our splits. Since the PyTorch dataloaders expect batches of tensors, we need to set the format to <code>&quot;torch&quot;</code> in our datasets:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenized_datasets.set_format(<span class="hljs-string">&quot;torch&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ymev3i">Now that we’ve got datasets consisting of just tensors, the next thing to do is instantiate the <code>DataCollatorForSeq2Seq</code> again. For this we need to provide a fresh version of the model, so let’s load it again from our cache:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vgcsz6">We can then instantiate the data collator and use this to define our dataloaders:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
batch_size = <span class="hljs-number">8</span>
train_dataloader = DataLoader(
tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>],
shuffle=<span class="hljs-literal">True</span>,
collate_fn=data_collator,
batch_size=batch_size,
)
eval_dataloader = DataLoader(
tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>], collate_fn=data_collator, batch_size=batch_size
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-41m32h">The next thing to do is define the optimizer we want to use. As in our other examples, we’ll use <code>AdamW</code>, which works well for most problems:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">2e-5</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fp4o9v">Finally, we feed our model, optimizer, and dataloaders to the <code>accelerator.prepare()</code> method:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1h0m1lz">🚨 If you’re training on a TPU, you’ll need to move all the code above into a dedicated training function. See <a href="/course/chapter3">Chapter 3</a> for more details.</p></div> <p data-svelte-h="svelte-h8aj3b">Now that we’ve prepared our objects, there are three remaining things to do:</p> <ul data-svelte-h="svelte-1odn6yw"><li>Define the learning rate schedule.</li> <li>Implement a function to post-process the summaries for evaluation.</li> <li>Create a repository on the Hub that we can push our model to.</li></ul> <p data-svelte-h="svelte-12m7glt">For the learning rate schedule, we’ll use the standard linear one from previous sections:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler
num_train_epochs = <span class="hljs-number">10</span>
num_update_steps_per_epoch = <span class="hljs-built_in">len</span>(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1l4rcdk">For post-processing, we need a function that splits the generated summaries into sentences that are separated by newlines. This is the format the ROUGE metric expects, and we can achieve this with the following snippet of code:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">postprocess_text</span>(<span class="hljs-params">preds, labels</span>):
preds = [pred.strip() <span class="hljs-keyword">for</span> pred <span class="hljs-keyword">in</span> preds]
labels = [label.strip() <span class="hljs-keyword">for</span> label <span class="hljs-keyword">in</span> labels]
<span class="hljs-comment"># ROUGE expects a newline after each sentence</span>
preds = [<span class="hljs-string">&quot;\n&quot;</span>.join(nltk.sent_tokenize(pred)) <span class="hljs-keyword">for</span> pred <span class="hljs-keyword">in</span> preds]
labels = [<span class="hljs-string">&quot;\n&quot;</span>.join(nltk.sent_tokenize(label)) <span class="hljs-keyword">for</span> label <span class="hljs-keyword">in</span> labels]
<span class="hljs-keyword">return</span> preds, labels<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17mpp9">This should look familiar to you if you recall how we defined the <code>compute_metrics()</code> function of the <code>Seq2SeqTrainer</code>.</p> <p data-svelte-h="svelte-1wkjrqf">Finally, we need to create a model repository on the Hugging Face Hub. For this, we can use the appropriately titled 🤗 Hub library. We just need to define a name for our repository, and the library has a utility function to combine the repository ID with the user profile:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> get_full_repo_name
model_name = <span class="hljs-string">&quot;test-bert-finetuned-squad-accelerate&quot;</span>
repo_name = get_full_repo_name(model_name)
repo_name<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;lewtun/mt5-finetuned-amazon-en-es-accelerate&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1yghzy">Now we can use this repository name to clone a local version to our results directory that will store the training artifacts:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> Repository
output_dir = <span class="hljs-string">&quot;results-mt5-finetuned-squad-accelerate&quot;</span>
repo = Repository(output_dir, clone_from=repo_name)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vbolln">This will allow us to push the artifacts back to the Hub by calling the <code>repo.push_to_hub()</code> method during training! Let’s now wrap up our analysis by writing out the training loop.</p> <h3 class="relative group"><a id="training-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training loop</span></h3> <p data-svelte-h="svelte-1cqv7vl">The training loop for summarization is quite similar to the other 🤗 Accelerate examples that we’ve encountered and is roughly split into four main steps:</p> <ol data-svelte-h="svelte-yr5mj5"><li>Train the model by iterating over all the examples in <code>train_dataloader</code> for each epoch.</li> <li>Generate model summaries at the end of each epoch, by first generating the tokens and then decoding them (and the reference summaries) into text.</li> <li>Compute the ROUGE scores using the same techniques we saw earlier.</li> <li>Save the checkpoints and push everything to the Hub. Here we rely on the nifty <code>blocking=False</code> argument of the <code>Repository</code> object so that we can push the checkpoints per epoch <em>asynchronously</em>. This allows us to continue training without having to wait for the somewhat slow upload associated with a GB-sized model!</li></ol> <p data-svelte-h="svelte-14dai24">These steps can be seen in the following block of code:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_train_epochs):
<span class="hljs-comment"># Training</span>
model.train()
<span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_dataloader):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)
<span class="hljs-comment"># Evaluation</span>
model.<span class="hljs-built_in">eval</span>()
<span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(eval_dataloader):
<span class="hljs-keyword">with</span> torch.no_grad():
generated_tokens = accelerator.unwrap_model(model).generate(
batch[<span class="hljs-string">&quot;input_ids&quot;</span>],
attention_mask=batch[<span class="hljs-string">&quot;attention_mask&quot;</span>],
)
generated_tokens = accelerator.pad_across_processes(
generated_tokens, dim=<span class="hljs-number">1</span>, pad_index=tokenizer.pad_token_id
)
labels = batch[<span class="hljs-string">&quot;labels&quot;</span>]
<span class="hljs-comment"># If we did not pad to max length, we need to pad the labels too</span>
labels = accelerator.pad_across_processes(
batch[<span class="hljs-string">&quot;labels&quot;</span>], dim=<span class="hljs-number">1</span>, pad_index=tokenizer.pad_token_id
)
generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
labels = accelerator.gather(labels).cpu().numpy()
<span class="hljs-comment"># Replace -100 in the labels as we can&#x27;t decode them</span>
labels = np.where(labels != -<span class="hljs-number">100</span>, labels, tokenizer.pad_token_id)
<span class="hljs-keyword">if</span> <span class="hljs-built_in">isinstance</span>(generated_tokens, <span class="hljs-built_in">tuple</span>):
generated_tokens = generated_tokens[<span class="hljs-number">0</span>]
decoded_preds = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=<span class="hljs-literal">True</span>
)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=<span class="hljs-literal">True</span>)
decoded_preds, decoded_labels = postprocess_text(
decoded_preds, decoded_labels
)
rouge_score.add_batch(predictions=decoded_preds, references=decoded_labels)
<span class="hljs-comment"># Compute metrics</span>
result = rouge_score.compute()
<span class="hljs-comment"># Extract the median ROUGE scores</span>
result = {key: value.mid.fmeasure * <span class="hljs-number">100</span> <span class="hljs-keyword">for</span> key, value <span class="hljs-keyword">in</span> result.items()}
result = {k: <span class="hljs-built_in">round</span>(v, <span class="hljs-number">4</span>) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> result.items()}
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Epoch <span class="hljs-subst">{epoch}</span>:&quot;</span>, result)
<span class="hljs-comment"># Save and upload</span>
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
<span class="hljs-keyword">if</span> accelerator.is_main_process:
tokenizer.save_pretrained(output_dir)
repo.push_to_hub(
commit_message=<span class="hljs-string">f&quot;Training in progress epoch <span class="hljs-subst">{epoch}</span>&quot;</span>, blocking=<span class="hljs-literal">False</span>
)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Epoch <span class="hljs-number">0</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">5.6351</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">1.1625</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">5.4866</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">5.5005</span>}
Epoch <span class="hljs-number">1</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">9.8646</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">3.4106</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">9.9439</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">9.9306</span>}
Epoch <span class="hljs-number">2</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">11.0872</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">3.3273</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">11.0508</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">10.9468</span>}
Epoch <span class="hljs-number">3</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">11.8587</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">4.8167</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">11.7986</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">11.7518</span>}
Epoch <span class="hljs-number">4</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">12.9842</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">5.5887</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">12.7546</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">12.7029</span>}
Epoch <span class="hljs-number">5</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">13.4628</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">6.4598</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">13.312</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">13.2913</span>}
Epoch <span class="hljs-number">6</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">12.9131</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">5.8914</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">12.6896</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">12.5701</span>}
Epoch <span class="hljs-number">7</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">13.3079</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">6.2994</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">13.1536</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">13.1194</span>}
Epoch <span class="hljs-number">8</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">13.96</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">6.5998</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">13.9123</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">13.7744</span>}
Epoch <span class="hljs-number">9</span>: {<span class="hljs-string">&#x27;rouge1&#x27;</span>: <span class="hljs-number">14.1192</span>, <span class="hljs-string">&#x27;rouge2&#x27;</span>: <span class="hljs-number">7.0059</span>, <span class="hljs-string">&#x27;rougeL&#x27;</span>: <span class="hljs-number">14.1172</span>, <span class="hljs-string">&#x27;rougeLsum&#x27;</span>: <span class="hljs-number">13.9509</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-172kpng">And that’s it! Once you run this, you’ll have a model and results that are pretty similar to the ones we obtained with the <code>Trainer</code>.</p> <h2 class="relative group"><a id="using-your-fine-tuned-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-your-fine-tuned-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using your fine-tuned model</span></h2> <p data-svelte-h="svelte-kygi7d">Once you’ve pushed the model to the Hub, you can play with it either via the inference widget or with a <code>pipeline</code> object, as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline
hub_model_id = <span class="hljs-string">&quot;huggingface-course/mt5-small-finetuned-amazon-en-es&quot;</span>
summarizer = pipeline(<span class="hljs-string">&quot;summarization&quot;</span>, model=hub_model_id)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-nht6a4">We can feed some examples from the test set (which the model has not seen) to our pipeline to get a feel for the quality of the summaries. First let’s implement a simple function to show the review, title, and generated summary together:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">print_summary</span>(<span class="hljs-params">idx</span>):
review = books_dataset[<span class="hljs-string">&quot;test&quot;</span>][idx][<span class="hljs-string">&quot;review_body&quot;</span>]
title = books_dataset[<span class="hljs-string">&quot;test&quot;</span>][idx][<span class="hljs-string">&quot;review_title&quot;</span>]
summary = summarizer(books_dataset[<span class="hljs-string">&quot;test&quot;</span>][idx][<span class="hljs-string">&quot;review_body&quot;</span>])[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;summary_text&quot;</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; Review: <span class="hljs-subst">{review}</span>&#x27;&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;\n&#x27;&gt;&gt;&gt; Title: <span class="hljs-subst">{title}</span>&#x27;&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;\n&#x27;&gt;&gt;&gt; Summary: <span class="hljs-subst">{summary}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1wlcxao">Let’s take a look at one of the English examples we get:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->print_summary(<span class="hljs-number">100</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; Review: Nothing special at all about this product... the book is too small and stiff and hard to write in. The huge sticker on the back doesn’t come off and looks super tacky. I would not purchase this again. I could have just bought a journal from the dollar store and it would be basically the same thing. It’s also really expensive for what it is.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Title: Not impressed at all... buy something else&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Summary: Nothing special at all about this product&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4jjf7d">This is not too bad! We can see that our model has actually been able to perform <em>abstractive</em> summarization by augmenting parts of the review with new words. And perhaps the coolest aspect of our model is that it is bilingual, so we can also generate summaries of Spanish reviews:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->print_summary(<span class="hljs-number">0</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; Review: Es una trilogia que se hace muy facil de leer. Me ha gustado, no me esperaba el final para nada&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Title: Buena literatura para adolescentes&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Summary: Muy facil de leer&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1p3b7rw">The summary translates into “Very easy to read” in English, which we can see in this case was extracted directly from the review. Nevertheless, this shows the versatility of the mT5 model and has given you a taste of what it’s like to deal with a multilingual corpus!</p> <p data-svelte-h="svelte-1gun2uf">Next, we’ll turn our attention to a slightly more complex task: training a language model from scratch.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/en/chapter7/5.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1y0degu = {
assets: "/docs/course/pr_1069/en",
base: "/docs/course/pr_1069/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js"),
import("/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 81],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
201 kB
·
Xet hash:
d761d21d1a83c4fa927f846ab31f54bfbe8a6a882fe0e0fbea375401df98c480

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.