Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Translation","local":"translation","sections":[{"title":"Preparing the data","local":"preparing-the-data","sections":[{"title":"The KDE4 dataset","local":"the-kde4-dataset","sections":[],"depth":3},{"title":"Processing the data","local":"processing-the-data","sections":[],"depth":3}],"depth":2},{"title":"Fine-tuning the model with the Trainer API","local":"fine-tuning-the-model-with-the-trainer-api","sections":[],"depth":2},{"title":"Fine-tuning the model with Keras","local":"fine-tuning-the-model-with-keras","sections":[{"title":"Data collation","local":"data-collation","sections":[],"depth":3},{"title":"Metrics","local":"metrics","sections":[],"depth":3},{"title":"Fine-tuning the model","local":"fine-tuning-the-model","sections":[],"depth":3}],"depth":2},{"title":"A custom training loop","local":"a-custom-training-loop","sections":[{"title":"Preparing everything for training","local":"preparing-everything-for-training","sections":[],"depth":3},{"title":"Training loop","local":"training-loop","sections":[],"depth":3}],"depth":2},{"title":"Using the fine-tuned model","local":"using-the-fine-tuned-model","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/entry/start.3d2e2978.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/singletons.7a9af56d.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/paths.692f56cf.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/entry/app.d60bb0c9.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/index.7cb9c9b8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/nodes/0.bbc29778.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/nodes/80.f6b20f59.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/Tip.d10b3fc9.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/Youtube.8666c400.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/CodeBlock.abae2786.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/CourseFloatingBanner.df82c153.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/FrameworkSwitchCourse.97630871.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/getInferenceSnippets.a2135f3c.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Translation","local":"translation","sections":[{"title":"Preparing the data","local":"preparing-the-data","sections":[{"title":"The KDE4 dataset","local":"the-kde4-dataset","sections":[],"depth":3},{"title":"Processing the data","local":"processing-the-data","sections":[],"depth":3}],"depth":2},{"title":"Fine-tuning the model with the Trainer API","local":"fine-tuning-the-model-with-the-trainer-api","sections":[],"depth":2},{"title":"Fine-tuning the model with Keras","local":"fine-tuning-the-model-with-keras","sections":[{"title":"Data collation","local":"data-collation","sections":[],"depth":3},{"title":"Metrics","local":"metrics","sections":[],"depth":3},{"title":"Fine-tuning the model","local":"fine-tuning-the-model","sections":[],"depth":3}],"depth":2},{"title":"A custom training loop","local":"a-custom-training-loop","sections":[{"title":"Preparing everything for training","local":"preparing-everything-for-training","sections":[],"depth":3},{"title":"Training loop","local":"training-loop","sections":[],"depth":3}],"depth":2},{"title":"Using the fine-tuned model","local":"using-the-fine-tuned-model","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="translation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#translation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Translation</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-7-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter7/section4_pt.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter7/section4_pt.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-eyotq9">Let’s now dive into translation. This is another <a href="/course/chapter1/7">sequence-to-sequence task</a>, which means it’s a problem that can be formulated as going from one sequence to another. In that sense the problem is pretty close to <a href="/course/chapter7/6">summarization</a>, and you could adapt what we will see here to other sequence-to-sequence problems such as:</p> <ul data-svelte-h="svelte-mdl7i"><li><strong>Style transfer</strong>: Creating a model that <em>translates</em> texts written in a certain style to another (e.g., formal to casual or Shakespearean English to modern English)</li> <li><strong>Generative question answering</strong>: Creating a model that generates answers to questions, given a context</li></ul> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/1JvfrvZgi6c" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-qi4xhg">If you have a big enough corpus of texts in two (or more) languages, you can train a new translation model from scratch like we will in the section on <a href="/course/chapter7/6">causal language modeling</a>. It will be faster, however, to fine-tune an existing translation model, be it a multilingual one like mT5 or mBART that you want to fine-tune to a specific language pair, or even a model specialized for translation from one language to another that you want to fine-tune to your specific corpus.</p> <p data-svelte-h="svelte-19rlunj">In this section, we will fine-tune a Marian model pretrained to translate from English to French (since a lot of Hugging Face employees speak both those languages) on the <a href="https://huggingface.co/datasets/kde4" rel="nofollow">KDE4 dataset</a>, which is a dataset of localized files for the <a href="https://apps.kde.org/" rel="nofollow">KDE apps</a>. The model we will use has been pretrained on a large corpus of French and English texts taken from the <a href="https://opus.nlpl.eu/" rel="nofollow">Opus dataset</a>, which actually contains the KDE4 dataset. But even if the pretrained model we use has seen that data during its pretraining, we will see that we can get a better version of it after fine-tuning.</p> <p data-svelte-h="svelte-104mbqb">Once we’re finished, we will have a model able to make predictions like this one:</p> <iframe src="https://course-demos-marian-finetuned-kde4-en-to-fr.hf.space" frameborder="0" height="350" title="Gradio app" class="block dark:hidden container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe> <a class="flex justify-center" href="/huggingface-course/marian-finetuned-kde4-en-to-fr" data-svelte-h="svelte-1ifl6xg"><img class="block dark:hidden lg:w-3/5" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/modeleval-marian-finetuned-kde4-en-to-fr.png" alt="One-hot encoded labels for question answering."> <img class="hidden dark:block lg:w-3/5" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/modeleval-marian-finetuned-kde4-en-to-fr-dark.png" alt="One-hot encoded labels for question answering."></a> <p data-svelte-h="svelte-8feqj6">As in the previous sections, you can find the actual model that we’ll train and upload to the Hub using the code below and double-check its predictions <a href="https://huggingface.co/huggingface-course/marian-finetuned-kde4-en-to-fr?text=This+plugin+allows+you+to+automatically+translate+web+pages+between+several+languages." rel="nofollow">here</a>.</p> <h2 class="relative group"><a id="preparing-the-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preparing-the-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preparing the data</span></h2> <p data-svelte-h="svelte-1w48ec5">To fine-tune or train a translation model from scratch, we will need a dataset suitable for the task. As mentioned previously, we’ll use the <a href="https://huggingface.co/datasets/kde4" rel="nofollow">KDE4 dataset</a> in this section, but you can adapt the code to use your own data quite easily, as long as you have pairs of sentences in the two languages you want to translate from and into. Refer back to <a href="/course/chapter5">Chapter 5</a> if you need a reminder of how to load your custom data in a <code>Dataset</code>.</p> <h3 class="relative group"><a id="the-kde4-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-kde4-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The KDE4 dataset</span></h3> <p data-svelte-h="svelte-psemqo">As usual, we download our dataset using the <code>load_dataset()</code> function:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| raw_datasets = load_dataset(<span class="hljs-string">"kde4"</span>, lang1=<span class="hljs-string">"en"</span>, lang2=<span class="hljs-string">"fr"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-waadns">If you want to work with a different pair of languages, you can specify them by their codes. A total of 92 languages are available for this dataset; you can see them all by expanding the language tags on its <a href="https://huggingface.co/datasets/kde4" rel="nofollow">dataset card</a>.</p> <img src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/language_tags.png" alt="Language available for the KDE4 dataset." width="100%"> <p data-svelte-h="svelte-11kki55">Let’s have a look at the dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->raw_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({ | |
| train: Dataset({ | |
| features: [<span class="hljs-string">'id'</span>, <span class="hljs-string">'translation'</span>], | |
| num_rows: <span class="hljs-number">210173</span> | |
| }) | |
| })<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-am89ab">We have 210,173 pairs of sentences, but in one single split, so we will need to create our own validation set. As we saw in <a href="/course/chapter5">Chapter 5</a>, a <code>Dataset</code> has a <code>train_test_split()</code> method that can help us. We’ll provide a seed for reproducibility:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->split_datasets = raw_datasets[<span class="hljs-string">"train"</span>].train_test_split(train_size=<span class="hljs-number">0.9</span>, seed=<span class="hljs-number">20</span>) | |
| split_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({ | |
| train: Dataset({ | |
| features: [<span class="hljs-string">'id'</span>, <span class="hljs-string">'translation'</span>], | |
| num_rows: <span class="hljs-number">189155</span> | |
| }) | |
| test: Dataset({ | |
| features: [<span class="hljs-string">'id'</span>, <span class="hljs-string">'translation'</span>], | |
| num_rows: <span class="hljs-number">21018</span> | |
| }) | |
| })<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1b0nmgl">We can rename the <code>"test"</code> key to <code>"validation"</code> like this:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->split_datasets[<span class="hljs-string">"validation"</span>] = split_datasets.pop(<span class="hljs-string">"test"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yn8zd9">Now let’s take a look at one element of the dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->split_datasets[<span class="hljs-string">"train"</span>][<span class="hljs-number">1</span>][<span class="hljs-string">"translation"</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'en'</span>: <span class="hljs-string">'Default to expanded threads'</span>, | |
| <span class="hljs-string">'fr'</span>: <span class="hljs-string">'Par défaut, développer les fils de discussion'</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1neulry">We get a dictionary with two sentences in the pair of languages we requested. One particularity of this dataset full of technical computer science terms is that they are all fully translated in French. However, French engineers leave most computer science-specific words in English when they talk. Here, for instance, the word “threads” might well appear in a French sentence, especially in a technical conversation; but in this dataset it has been translated into the more correct “fils de discussion.” The pretrained model we use, which has been pretrained on a larger corpus of French and English sentences, takes the easier option of leaving the word as is:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline | |
| model_checkpoint = <span class="hljs-string">"Helsinki-NLP/opus-mt-en-fr"</span> | |
| translator = pipeline(<span class="hljs-string">"translation"</span>, model=model_checkpoint) | |
| translator(<span class="hljs-string">"Default to expanded threads"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[{<span class="hljs-string">'translation_text'</span>: <span class="hljs-string">'Par défaut pour les threads élargis'</span>}]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1x2u8k0">Another example of this behavior can be seen with the word “plugin,” which isn’t officially a French word but which most native speakers will understand and not bother to translate. | |
| In the KDE4 dataset this word has been translated in French into the more official “module d’extension”:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->split_datasets[<span class="hljs-string">"train"</span>][<span class="hljs-number">172</span>][<span class="hljs-string">"translation"</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'en'</span>: <span class="hljs-string">'Unable to import %1 using the OFX importer plugin. This file is not the correct format.'</span>, | |
| <span class="hljs-string">'fr'</span>: <span class="hljs-string">"Impossible d'importer %1 en utilisant le module d'extension d'importation OFX. Ce fichier n'a pas un format correct."</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pzq3b6">Our pretrained model, however, sticks with the compact and familiar English word:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->translator( | |
| <span class="hljs-string">"Unable to import %1 using the OFX importer plugin. This file is not the correct format."</span> | |
| )<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[{<span class="hljs-string">'translation_text'</span>: <span class="hljs-string">"Impossible d'importer %1 en utilisant le plugin d'importateur OFX. Ce fichier n'est pas le bon format."</span>}]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xxhzkh">It will be interesting to see if our fine-tuned model picks up on those particularities of the dataset (spoiler alert: it will).</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/0Oxphw4Q9fo" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-uwcde8">✏️ <strong>Your turn!</strong> Another English word that is often used in French is “email.” Find the first sample in the training dataset that uses this word. How is it translated? How does the pretrained model translate the same English sentence?</p></div> <h3 class="relative group"><a id="processing-the-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#processing-the-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Processing the data</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/XAR8jnZZuUs" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-7kxu8d">You should know the drill by now: the texts all need to be converted into sets of token IDs so the model can make sense of them. For this task, we’ll need to tokenize both the inputs and the targets. Our first task is to create our <code>tokenizer</code> object. As noted earlier, we’ll be using a Marian English to French pretrained model. If you are trying this code with another pair of languages, make sure to adapt the model checkpoint. The <a href="https://huggingface.co/Helsinki-NLP" rel="nofollow">Helsinki-NLP</a> organization provides more than a thousand models in multiple languages.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| model_checkpoint = <span class="hljs-string">"Helsinki-NLP/opus-mt-en-fr"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors=<span class="hljs-string">"pt"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1akd2n2">You can also replace the <code>model_checkpoint</code> with any other model you prefer from the <a href="https://huggingface.co/models" rel="nofollow">Hub</a>, or a local folder where you’ve saved a pretrained model and a tokenizer.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-w81hxj">💡 If you are using a multilingual tokenizer such as mBART, mBART-50, or M2M100, you will need to set the language codes of your inputs and targets in the tokenizer by setting <code>tokenizer.src_lang</code> and <code>tokenizer.tgt_lang</code> to the right values.</p></div> <p data-svelte-h="svelte-1pc1kc0">The preparation of our data is pretty straightforward. There’s just one thing to remember; you need to ensure that the tokenizer processes the targets in the output language (here, French). You can do this by passing the targets to the <code>text_targets</code> argument of the tokenizer’s <code>__call__</code> method.</p> <p data-svelte-h="svelte-16j9n8c">To see how this works, let’s process one sample of each language in the training set:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->en_sentence = split_datasets[<span class="hljs-string">"train"</span>][<span class="hljs-number">1</span>][<span class="hljs-string">"translation"</span>][<span class="hljs-string">"en"</span>] | |
| fr_sentence = split_datasets[<span class="hljs-string">"train"</span>][<span class="hljs-number">1</span>][<span class="hljs-string">"translation"</span>][<span class="hljs-string">"fr"</span>] | |
| inputs = tokenizer(en_sentence, text_target=fr_sentence) | |
| inputs<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'input_ids'</span>: [<span class="hljs-number">47591</span>, <span class="hljs-number">12</span>, <span class="hljs-number">9842</span>, <span class="hljs-number">19634</span>, <span class="hljs-number">9</span>, <span class="hljs-number">0</span>], <span class="hljs-string">'attention_mask'</span>: [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>], <span class="hljs-string">'labels'</span>: [<span class="hljs-number">577</span>, <span class="hljs-number">5891</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3184</span>, <span class="hljs-number">16</span>, <span class="hljs-number">2542</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1710</span>, <span class="hljs-number">0</span>]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-o6ed2q">As we can see, the output contains the input IDs associated with the English sentence, while the IDs associated with the French one are stored in the <code>labels</code> field. If you forget to indicate that you are tokenizing labels, they will be tokenized by the input tokenizer, which in the case of a Marian model is not going to go well at all:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->wrong_targets = tokenizer(fr_sentence) | |
| <span class="hljs-built_in">print</span>(tokenizer.convert_ids_to_tokens(wrong_targets[<span class="hljs-string">"input_ids"</span>])) | |
| <span class="hljs-built_in">print</span>(tokenizer.convert_ids_to_tokens(inputs[<span class="hljs-string">"labels"</span>]))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-string">'▁Par'</span>, <span class="hljs-string">'▁dé'</span>, <span class="hljs-string">'f'</span>, <span class="hljs-string">'aut'</span>, <span class="hljs-string">','</span>, <span class="hljs-string">'▁dé'</span>, <span class="hljs-string">'ve'</span>, <span class="hljs-string">'lop'</span>, <span class="hljs-string">'per'</span>, <span class="hljs-string">'▁les'</span>, <span class="hljs-string">'▁fil'</span>, <span class="hljs-string">'s'</span>, <span class="hljs-string">'▁de'</span>, <span class="hljs-string">'▁discussion'</span>, <span class="hljs-string">'</s>'</span>] | |
| [<span class="hljs-string">'▁Par'</span>, <span class="hljs-string">'▁défaut'</span>, <span class="hljs-string">','</span>, <span class="hljs-string">'▁développer'</span>, <span class="hljs-string">'▁les'</span>, <span class="hljs-string">'▁fils'</span>, <span class="hljs-string">'▁de'</span>, <span class="hljs-string">'▁discussion'</span>, <span class="hljs-string">'</s>'</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10w8fo4">As we can see, using the English tokenizer to preprocess a French sentence results in a lot more tokens, since the tokenizer doesn’t know any French words (except those that also appear in the English language, like “discussion”).</p> <p data-svelte-h="svelte-q0ybzj">Since <code>inputs</code> is a dictionary with our usual keys (input IDs, attention mask, etc.), the last step is to define the preprocessing function we will apply on the datasets:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->max_length = <span class="hljs-number">128</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| inputs = [ex[<span class="hljs-string">"en"</span>] <span class="hljs-keyword">for</span> ex <span class="hljs-keyword">in</span> examples[<span class="hljs-string">"translation"</span>]] | |
| targets = [ex[<span class="hljs-string">"fr"</span>] <span class="hljs-keyword">for</span> ex <span class="hljs-keyword">in</span> examples[<span class="hljs-string">"translation"</span>]] | |
| model_inputs = tokenizer( | |
| inputs, text_target=targets, max_length=max_length, truncation=<span class="hljs-literal">True</span> | |
| ) | |
| <span class="hljs-keyword">return</span> model_inputs<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-wqps0w">Note that we set the same maximum length for our inputs and outputs. Since the texts we’re dealing with seem pretty short, we use 128.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-9rujvy">💡 If you are using a T5 model (more specifically, one of the <code>t5-xxx</code> checkpoints), the model will expect the text inputs to have a prefix indicating the task at hand, such as <code>translate: English to French:</code>.</p></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1i4t0xl">⚠️ We don’t pay attention to the attention mask of the targets, as the model won’t expect it. Instead, the labels corresponding to a padding token should be set to <code>-100</code> so they are ignored in the loss computation. This will be done by our data collator later on since we are applying dynamic padding, but if you use padding here, you should adapt the preprocessing function to set all labels that correspond to the padding token to <code>-100</code>.</p></div> <p data-svelte-h="svelte-u9dyhf">We can now apply that preprocessing in one go on all the splits of our dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenized_datasets = split_datasets.<span class="hljs-built_in">map</span>( | |
| preprocess_function, | |
| batched=<span class="hljs-literal">True</span>, | |
| remove_columns=split_datasets[<span class="hljs-string">"train"</span>].column_names, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1a79fos">Now that the data has been preprocessed, we are ready to fine-tune our pretrained model!</p> <h2 class="relative group"><a id="fine-tuning-the-model-with-the-trainer-api" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-the-model-with-the-trainer-api"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning the model with the Trainer API</span></h2> <p data-svelte-h="svelte-1u27cnc">The actual code using the <code>Trainer</code> will be the same as before, with just one little change: we use a <a href="https://huggingface.co/transformers/main_classes/trainer.html#seq2seqtrainer" rel="nofollow"><code>Seq2SeqTrainer</code></a> here, which is a subclass of <code>Trainer</code> that will allow us to properly deal with the evaluation, using the <code>generate()</code> method to predict outputs from the inputs. We’ll dive into that in more detail when we talk about the metric computation.</p> <p data-svelte-h="svelte-c4jxu5">First things first, we need an actual model to fine-tune. We’ll use the usual <code>AutoModel</code> API:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSeq2SeqLM | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1merkap">Note that this time we are using a model that was trained on a translation task and can actually be used already, so there is no warning about missing weights or newly initialized ones.</p> <h3 class="relative group"><a id="data-collation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-collation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Data collation</span></h3> <p data-svelte-h="svelte-1uzskyq">We’ll need a data collator to deal with the padding for dynamic batching. We can’t just use a <code>DataCollatorWithPadding</code> like in <a href="/course/chapter3">Chapter 3</a> in this case, because that only pads the inputs (input IDs, attention mask, and token type IDs). Our labels should also be padded to the maximum length encountered in the labels. And, as mentioned previously, the padding value used to pad the labels should be <code>-100</code> and not the padding token of the tokenizer, to make sure those padded values are ignored in the loss computation.</p> <p data-svelte-h="svelte-18ubg06">This is all done by a <a href="https://huggingface.co/transformers/main_classes/data_collator.html#datacollatorforseq2seq" rel="nofollow"><code>DataCollatorForSeq2Seq</code></a>. Like the <code>DataCollatorWithPadding</code>, it takes the <code>tokenizer</code> used to preprocess the inputs, but it also takes the <code>model</code>. This is because this data collator will also be responsible for preparing the decoder input IDs, which are shifted versions of the labels with a special token at the beginning. Since this shift is done slightly differently for different architectures, the <code>DataCollatorForSeq2Seq</code> needs to know the <code>model</code> object:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> DataCollatorForSeq2Seq | |
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-bijhma">To test this on a few samples, we just call it on a list of examples from our tokenized training set:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->batch = data_collator([tokenized_datasets[<span class="hljs-string">"train"</span>][i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">1</span>, <span class="hljs-number">3</span>)]) | |
| batch.keys()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->dict_keys([<span class="hljs-string">'attention_mask'</span>, <span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'labels'</span>, <span class="hljs-string">'decoder_input_ids'</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-990iko">We can check our labels have been padded to the maximum length of the batch, using <code>-100</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->batch[<span class="hljs-string">"labels"</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tensor([[ <span class="hljs-number">577</span>, <span class="hljs-number">5891</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3184</span>, <span class="hljs-number">16</span>, <span class="hljs-number">2542</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1710</span>, <span class="hljs-number">0</span>, -<span class="hljs-number">100</span>, | |
| -<span class="hljs-number">100</span>, -<span class="hljs-number">100</span>, -<span class="hljs-number">100</span>, -<span class="hljs-number">100</span>, -<span class="hljs-number">100</span>, -<span class="hljs-number">100</span>], | |
| [ <span class="hljs-number">1211</span>, <span class="hljs-number">3</span>, <span class="hljs-number">49</span>, <span class="hljs-number">9409</span>, <span class="hljs-number">1211</span>, <span class="hljs-number">3</span>, <span class="hljs-number">29140</span>, <span class="hljs-number">817</span>, <span class="hljs-number">3124</span>, <span class="hljs-number">817</span>, | |
| <span class="hljs-number">550</span>, <span class="hljs-number">7032</span>, <span class="hljs-number">5821</span>, <span class="hljs-number">7907</span>, <span class="hljs-number">12649</span>, <span class="hljs-number">0</span>]])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hivgoi">And we can also have a look at the decoder input IDs, to see that they are shifted versions of the labels:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->batch[<span class="hljs-string">"decoder_input_ids"</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tensor([[<span class="hljs-number">59513</span>, <span class="hljs-number">577</span>, <span class="hljs-number">5891</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3184</span>, <span class="hljs-number">16</span>, <span class="hljs-number">2542</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1710</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">59513</span>, <span class="hljs-number">59513</span>, <span class="hljs-number">59513</span>, <span class="hljs-number">59513</span>, <span class="hljs-number">59513</span>, <span class="hljs-number">59513</span>], | |
| [<span class="hljs-number">59513</span>, <span class="hljs-number">1211</span>, <span class="hljs-number">3</span>, <span class="hljs-number">49</span>, <span class="hljs-number">9409</span>, <span class="hljs-number">1211</span>, <span class="hljs-number">3</span>, <span class="hljs-number">29140</span>, <span class="hljs-number">817</span>, <span class="hljs-number">3124</span>, | |
| <span class="hljs-number">817</span>, <span class="hljs-number">550</span>, <span class="hljs-number">7032</span>, <span class="hljs-number">5821</span>, <span class="hljs-number">7907</span>, <span class="hljs-number">12649</span>]])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-unhint">Here are the labels for the first and second elements in our dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">1</span>, <span class="hljs-number">3</span>): | |
| <span class="hljs-built_in">print</span>(tokenized_datasets[<span class="hljs-string">"train"</span>][i][<span class="hljs-string">"labels"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-number">577</span>, <span class="hljs-number">5891</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3184</span>, <span class="hljs-number">16</span>, <span class="hljs-number">2542</span>, <span class="hljs-number">5</span>, <span class="hljs-number">1710</span>, <span class="hljs-number">0</span>] | |
| [<span class="hljs-number">1211</span>, <span class="hljs-number">3</span>, <span class="hljs-number">49</span>, <span class="hljs-number">9409</span>, <span class="hljs-number">1211</span>, <span class="hljs-number">3</span>, <span class="hljs-number">29140</span>, <span class="hljs-number">817</span>, <span class="hljs-number">3124</span>, <span class="hljs-number">817</span>, <span class="hljs-number">550</span>, <span class="hljs-number">7032</span>, <span class="hljs-number">5821</span>, <span class="hljs-number">7907</span>, <span class="hljs-number">12649</span>, <span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-gyo8yo">We will pass this <code>data_collator</code> along to the <code>Seq2SeqTrainer</code>. Next, let’s have a look at the metric.</p> <h3 class="relative group"><a id="metrics" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#metrics"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Metrics</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/M05L1DhFqcw" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1dssf7l">The feature that <code>Seq2SeqTrainer</code> adds to its superclass <code>Trainer</code> is the ability to use the <code>generate()</code> method during evaluation or prediction. During training, the model will use the <code>decoder_input_ids</code> with an attention mask ensuring it does not use the tokens after the token it’s trying to predict, to speed up training. During inference we won’t be able to use those since we won’t have labels, so it’s a good idea to evaluate our model with the same setup.</p> <p data-svelte-h="svelte-thx2aq">As we saw in <a href="/course/chapter1/6">Chapter 1</a>, the decoder performs inference by predicting tokens one by one — something that’s implemented behind the scenes in 🤗 Transformers by the <code>generate()</code> method. The <code>Seq2SeqTrainer</code> will let us use that method for evaluation if we set <code>predict_with_generate=True</code>.</p> <p data-svelte-h="svelte-1kiw7od">The traditional metric used for translation is the <a href="https://en.wikipedia.org/wiki/BLEU" rel="nofollow">BLEU score</a>, introduced in <a href="https://aclanthology.org/P02-1040.pdf" rel="nofollow">a 2002 article</a> by Kishore Papineni et al. The BLEU score evaluates how close the translations are to their labels. It does not measure the intelligibility or grammatical correctness of the model’s generated outputs, but uses statistical rules to ensure that all the words in the generated outputs also appear in the targets. In addition, there are rules that penalize repetitions of the same words if they are not also repeated in the targets (to avoid the model outputting sentences like <code>"the the the the the"</code>) and output sentences that are shorter than those in the targets (to avoid the model outputting sentences like <code>"the"</code>).</p> <p data-svelte-h="svelte-1bnfkqb">One weakness with BLEU is that it expects the text to already be tokenized, which makes it difficult to compare scores between models that use different tokenizers. So instead, the most commonly used metric for benchmarking translation models today is <a href="https://github.com/mjpost/sacrebleu" rel="nofollow">SacreBLEU</a>, which addresses this weakness (and others) by standardizing the tokenization step. To use this metric, we first need to install the SacreBLEU library:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip install sacrebleu<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jh6rtm">We can then load it via <code>evaluate.load()</code> like we did in <a href="/course/chapter3">Chapter 3</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate | |
| metric = evaluate.load(<span class="hljs-string">"sacrebleu"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1oah3bx">This metric will take texts as inputs and targets. It is designed to accept several acceptable targets, as there are often multiple acceptable translations of the same sentence — the dataset we’re using only provides one, but it’s not uncommon in NLP to find datasets that give several sentences as labels. So, the predictions should be a list of sentences, but the references should be a list of lists of sentences.</p> <p data-svelte-h="svelte-1h2fqkp">Let’s try an example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions = [ | |
| <span class="hljs-string">"This plugin lets you translate web pages between several languages automatically."</span> | |
| ] | |
| references = [ | |
| [ | |
| <span class="hljs-string">"This plugin allows you to automatically translate web pages between several languages."</span> | |
| ] | |
| ] | |
| metric.compute(predictions=predictions, references=references)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'score'</span>: <span class="hljs-number">46.750469682990165</span>, | |
| <span class="hljs-string">'counts'</span>: [<span class="hljs-number">11</span>, <span class="hljs-number">6</span>, <span class="hljs-number">4</span>, <span class="hljs-number">3</span>], | |
| <span class="hljs-string">'totals'</span>: [<span class="hljs-number">12</span>, <span class="hljs-number">11</span>, <span class="hljs-number">10</span>, <span class="hljs-number">9</span>], | |
| <span class="hljs-string">'precisions'</span>: [<span class="hljs-number">91.67</span>, <span class="hljs-number">54.54</span>, <span class="hljs-number">40.0</span>, <span class="hljs-number">33.33</span>], | |
| <span class="hljs-string">'bp'</span>: <span class="hljs-number">0.9200444146293233</span>, | |
| <span class="hljs-string">'sys_len'</span>: <span class="hljs-number">12</span>, | |
| <span class="hljs-string">'ref_len'</span>: <span class="hljs-number">13</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1iu1ast">This gets a BLEU score of 46.75, which is rather good — for reference, the original Transformer model in the <a href="https://arxiv.org/pdf/1706.03762.pdf" rel="nofollow">“Attention Is All You Need” paper</a> achieved a BLEU score of 41.8 on a similar translation task between English and French! (For more information about the individual metrics, like <code>counts</code> and <code>bp</code>, see the <a href="https://github.com/mjpost/sacrebleu/blob/078c440168c6adc89ba75fe6d63f0d922d42bcfe/sacrebleu/metrics/bleu.py#L74" rel="nofollow">SacreBLEU repository</a>.) On the other hand, if we try with the two bad types of predictions (lots of repetitions or too short) that often come out of translation models, we will get rather bad BLEU scores:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions = [<span class="hljs-string">"This This This This"</span>] | |
| references = [ | |
| [ | |
| <span class="hljs-string">"This plugin allows you to automatically translate web pages between several languages."</span> | |
| ] | |
| ] | |
| metric.compute(predictions=predictions, references=references)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'score'</span>: <span class="hljs-number">1.683602693167689</span>, | |
| <span class="hljs-string">'counts'</span>: [<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| <span class="hljs-string">'totals'</span>: [<span class="hljs-number">4</span>, <span class="hljs-number">3</span>, <span class="hljs-number">2</span>, <span class="hljs-number">1</span>], | |
| <span class="hljs-string">'precisions'</span>: [<span class="hljs-number">25.0</span>, <span class="hljs-number">16.67</span>, <span class="hljs-number">12.5</span>, <span class="hljs-number">12.5</span>], | |
| <span class="hljs-string">'bp'</span>: <span class="hljs-number">0.10539922456186433</span>, | |
| <span class="hljs-string">'sys_len'</span>: <span class="hljs-number">4</span>, | |
| <span class="hljs-string">'ref_len'</span>: <span class="hljs-number">13</span>}<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions = [<span class="hljs-string">"This plugin"</span>] | |
| references = [ | |
| [ | |
| <span class="hljs-string">"This plugin allows you to automatically translate web pages between several languages."</span> | |
| ] | |
| ] | |
| metric.compute(predictions=predictions, references=references)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'score'</span>: <span class="hljs-number">0.0</span>, | |
| <span class="hljs-string">'counts'</span>: [<span class="hljs-number">2</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| <span class="hljs-string">'totals'</span>: [<span class="hljs-number">2</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| <span class="hljs-string">'precisions'</span>: [<span class="hljs-number">100.0</span>, <span class="hljs-number">100.0</span>, <span class="hljs-number">0.0</span>, <span class="hljs-number">0.0</span>], | |
| <span class="hljs-string">'bp'</span>: <span class="hljs-number">0.004086771438464067</span>, | |
| <span class="hljs-string">'sys_len'</span>: <span class="hljs-number">2</span>, | |
| <span class="hljs-string">'ref_len'</span>: <span class="hljs-number">13</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-tsitwn">The score can go from 0 to 100, and higher is better.</p> <p data-svelte-h="svelte-12prd39">To get from the model outputs to texts the metric can use, we will use the <code>tokenizer.batch_decode()</code> method. We just have to clean up all the <code>-100</code>s in the labels (the tokenizer will automatically do the same for the padding token):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_preds</span>): | |
| preds, labels = eval_preds | |
| <span class="hljs-comment"># In case the model returns more than the prediction logits</span> | |
| <span class="hljs-keyword">if</span> <span class="hljs-built_in">isinstance</span>(preds, <span class="hljs-built_in">tuple</span>): | |
| preds = preds[<span class="hljs-number">0</span>] | |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-comment"># Replace -100s in the labels as we can't decode them</span> | |
| labels = np.where(labels != -<span class="hljs-number">100</span>, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-comment"># Some simple post-processing</span> | |
| decoded_preds = [pred.strip() <span class="hljs-keyword">for</span> pred <span class="hljs-keyword">in</span> decoded_preds] | |
| decoded_labels = [[label.strip()] <span class="hljs-keyword">for</span> label <span class="hljs-keyword">in</span> decoded_labels] | |
| result = metric.compute(predictions=decoded_preds, references=decoded_labels) | |
| <span class="hljs-keyword">return</span> {<span class="hljs-string">"bleu"</span>: result[<span class="hljs-string">"score"</span>]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gk0eqs">Now that this is done, we are ready to fine-tune our model!</p> <h3 class="relative group"><a id="fine-tuning-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning the model</span></h3> <p data-svelte-h="svelte-gd91fw">The first step is to log in to Hugging Face, so you’re able to upload your results to the Model Hub. There’s a convenience function to help you with this in a notebook:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login | |
| notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ied0vh">This will display a widget where you can enter your Hugging Face login credentials.</p> <p data-svelte-h="svelte-648vlf">If you aren’t working in a notebook, just type the following line in your terminal:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-cli login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hue7r3">Once this is done, we can define our <code>Seq2SeqTrainingArguments</code>. Like for the <code>Trainer</code>, we use a subclass of <code>TrainingArguments</code> that contains a few more fields:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Seq2SeqTrainingArguments | |
| args = Seq2SeqTrainingArguments( | |
| <span class="hljs-string">f"marian-finetuned-kde4-en-to-fr"</span>, | |
| evaluation_strategy=<span class="hljs-string">"no"</span>, | |
| save_strategy=<span class="hljs-string">"epoch"</span>, | |
| learning_rate=<span class="hljs-number">2e-5</span>, | |
| per_device_train_batch_size=<span class="hljs-number">32</span>, | |
| per_device_eval_batch_size=<span class="hljs-number">64</span>, | |
| weight_decay=<span class="hljs-number">0.01</span>, | |
| save_total_limit=<span class="hljs-number">3</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| predict_with_generate=<span class="hljs-literal">True</span>, | |
| fp16=<span class="hljs-literal">True</span>, | |
| push_to_hub=<span class="hljs-literal">True</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vr7kv5">Apart from the usual hyperparameters (like learning rate, number of epochs, batch size, and some weight decay), here are a few changes compared to what we saw in the previous sections:</p> <ul data-svelte-h="svelte-ckg3z2"><li>We don’t set any regular evaluation, as evaluation takes a while; we will just evaluate our model once before training and after.</li> <li>We set <code>fp16=True</code>, which speeds up training on modern GPUs.</li> <li>We set <code>predict_with_generate=True</code>, as discussed above.</li> <li>We use <code>push_to_hub=True</code> to upload the model to the Hub at the end of each epoch.</li></ul> <p data-svelte-h="svelte-14ymb4g">Note that you can specify the full name of the repository you want to push to with the <code>hub_model_id</code> argument (in particular, you will have to use this argument to push to an organization). For instance, when we pushed the model to the <a href="https://huggingface.co/huggingface-course" rel="nofollow"><code>huggingface-course</code> organization</a>, we added <code>hub_model_id="huggingface-course/marian-finetuned-kde4-en-to-fr"</code> to <code>Seq2SeqTrainingArguments</code>. By default, the repository used will be in your namespace and named after the output directory you set, so in our case it will be <code>"sgugger/marian-finetuned-kde4-en-to-fr"</code> (which is the model we linked to at the beginning of this section).</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-449gha">💡 If the output directory you are using already exists, it needs to be a local clone of the repository you want to push to. If it isn’t, you’ll get an error when defining your <code>Seq2SeqTrainer</code> and will need to set a new name.</p></div> <p data-svelte-h="svelte-vnqwl3">Finally, we just pass everything to the <code>Seq2SeqTrainer</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Seq2SeqTrainer | |
| trainer = Seq2SeqTrainer( | |
| model, | |
| args, | |
| train_dataset=tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=tokenized_datasets[<span class="hljs-string">"validation"</span>], | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ixolzg">Before training, we’ll first look at the score our model gets, to double-check that we’re not making things worse with our fine-tuning. This command will take a bit of time, so you can grab a coffee while it executes:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.evaluate(max_length=max_length)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'eval_loss'</span>: <span class="hljs-number">1.6964408159255981</span>, | |
| <span class="hljs-string">'eval_bleu'</span>: <span class="hljs-number">39.26865061007616</span>, | |
| <span class="hljs-string">'eval_runtime'</span>: <span class="hljs-number">965.8884</span>, | |
| <span class="hljs-string">'eval_samples_per_second'</span>: <span class="hljs-number">21.76</span>, | |
| <span class="hljs-string">'eval_steps_per_second'</span>: <span class="hljs-number">0.341</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16qcd8b">A BLEU score of 39 is not too bad, which reflects the fact that our model is already good at translating English sentences to French ones.</p> <p data-svelte-h="svelte-1shq479">Next is the training, which will also take a bit of time:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-2l4fow">Note that while the training happens, each time the model is saved (here, every epoch) it is uploaded to the Hub in the background. This way, you will be able to to resume your training on another machine if necessary.</p> <p data-svelte-h="svelte-1tyayn3">Once training is done, we evaluate our model again — hopefully we will see some amelioration in the BLEU score!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.evaluate(max_length=max_length)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'eval_loss'</span>: <span class="hljs-number">0.8558505773544312</span>, | |
| <span class="hljs-string">'eval_bleu'</span>: <span class="hljs-number">52.94161337775576</span>, | |
| <span class="hljs-string">'eval_runtime'</span>: <span class="hljs-number">714.2576</span>, | |
| <span class="hljs-string">'eval_samples_per_second'</span>: <span class="hljs-number">29.426</span>, | |
| <span class="hljs-string">'eval_steps_per_second'</span>: <span class="hljs-number">0.461</span>, | |
| <span class="hljs-string">'epoch'</span>: <span class="hljs-number">3.0</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-2wjeu9">That’s a nearly 14-point improvement, which is great.</p> <p data-svelte-h="svelte-emcbnj">Finally, we use the <code>push_to_hub()</code> method to make sure we upload the latest version of the model. The <code>Trainer</code> also drafts a model card with all the evaluation results and uploads it. This model card contains metadata that helps the Model Hub pick the widget for the inference demo. Usually, there is no need to say anything as it can infer the right widget from the model class, but in this case, the same model class can be used for all kinds of sequence-to-sequence problems, so we specify it’s a translation model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.push_to_hub(tags=<span class="hljs-string">"translation"</span>, commit_message=<span class="hljs-string">"Training complete"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mme7gz">This command returns the URL of the commit it just did, if you want to inspect it:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'https://huggingface.co/sgugger/marian-finetuned-kde4-en-to-fr/commit/3601d621e3baae2bc63d3311452535f8f58f6ef3'</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1is5qce">At this stage, you can use the inference widget on the Model Hub to test your model and share it with your friends. You have successfully fine-tuned a model on a translation task — congratulations!</p> <p data-svelte-h="svelte-1q5ysr9">If you want to dive a bit more deeply into the training loop, we will now show you how to do the same thing using 🤗 Accelerate.</p> <h2 class="relative group"><a id="a-custom-training-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#a-custom-training-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>A custom training loop</span></h2> <p data-svelte-h="svelte-qavg11">Let’s now take a look at the full training loop, so you can easily customize the parts you need. It will look a lot like what we did in <a href="/course/chapter7/2">section 2</a> and <a href="/course/chapter3/4">Chapter 3</a>.</p> <h3 class="relative group"><a id="preparing-everything-for-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preparing-everything-for-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preparing everything for training</span></h3> <p data-svelte-h="svelte-1xyz8xy">You’ve seen all of this a few times now, so we’ll go through the code quite quickly. First we’ll build the <code>DataLoader</code>s from our datasets, after setting the datasets to the <code>"torch"</code> format so we get PyTorch tensors:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader | |
| tokenized_datasets.set_format(<span class="hljs-string">"torch"</span>) | |
| train_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| shuffle=<span class="hljs-literal">True</span>, | |
| collate_fn=data_collator, | |
| batch_size=<span class="hljs-number">8</span>, | |
| ) | |
| eval_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"validation"</span>], collate_fn=data_collator, batch_size=<span class="hljs-number">8</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-e8509p">Next we reinstantiate our model, to make sure we’re not continuing the fine-tuning from before but starting from the pretrained model again:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-47lq6z">Then we will need an optimizer:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">2e-5</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18hq2jn">Once we have all those objects, we can send them to the <code>accelerator.prepare()</code> method. Remember that if you want to train on TPUs in a Colab notebook, you will need to move all of this code into a training function, and that shouldn’t execute any cell that instantiates an <code>Accelerator</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| accelerator = Accelerator() | |
| model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | |
| model, optimizer, train_dataloader, eval_dataloader | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-udl66p">Now that we have sent our <code>train_dataloader</code> to <code>accelerator.prepare()</code>, we can use its length to compute the number of training steps. Remember we should always do this after preparing the dataloader, as that method will change the length of the <code>DataLoader</code>. We use a classic linear schedule from the learning rate to 0:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler | |
| num_train_epochs = <span class="hljs-number">3</span> | |
| num_update_steps_per_epoch = <span class="hljs-built_in">len</span>(train_dataloader) | |
| num_training_steps = num_train_epochs * num_update_steps_per_epoch | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-v6q3yq">Lastly, to push our model to the Hub, we will need to create a <code>Repository</code> object in a working folder. First log in to the Hugging Face Hub, if you’re not logged in already. We’ll determine the repository name from the model ID we want to give our model (feel free to replace the <code>repo_name</code> with your own choice; it just needs to contain your username, which is what the function <code>get_full_repo_name()</code> does):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> Repository, get_full_repo_name | |
| model_name = <span class="hljs-string">"marian-finetuned-kde4-en-to-fr-accelerate"</span> | |
| repo_name = get_full_repo_name(model_name) | |
| repo_name<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'sgugger/marian-finetuned-kde4-en-to-fr-accelerate'</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-e9w2am">Then we can clone that repository in a local folder. If it already exists, this local folder should be a clone of the repository we are working with:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->output_dir = <span class="hljs-string">"marian-finetuned-kde4-en-to-fr-accelerate"</span> | |
| repo = Repository(output_dir, clone_from=repo_name)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1mcbsiy">We can now upload anything we save in <code>output_dir</code> by calling the <code>repo.push_to_hub()</code> method. This will help us upload the intermediate models at the end of each epoch.</p> <h3 class="relative group"><a id="training-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training loop</span></h3> <p data-svelte-h="svelte-1x90r0t">We are now ready to write the full training loop. To simplify its evaluation part, we define this <code>postprocess()</code> function that takes predictions and labels and converts them to the lists of strings our <code>metric</code> object will expect:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">postprocess</span>(<span class="hljs-params">predictions, labels</span>): | |
| predictions = predictions.cpu().numpy() | |
| labels = labels.cpu().numpy() | |
| decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-comment"># Replace -100 in the labels as we can't decode them.</span> | |
| labels = np.where(labels != -<span class="hljs-number">100</span>, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=<span class="hljs-literal">True</span>) | |
| <span class="hljs-comment"># Some simple post-processing</span> | |
| decoded_preds = [pred.strip() <span class="hljs-keyword">for</span> pred <span class="hljs-keyword">in</span> decoded_preds] | |
| decoded_labels = [[label.strip()] <span class="hljs-keyword">for</span> label <span class="hljs-keyword">in</span> decoded_labels] | |
| <span class="hljs-keyword">return</span> decoded_preds, decoded_labels<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11l1dbc">The training loop looks a lot like the ones in <a href="/course/chapter7/2">section 2</a> and <a href="/course/chapter3">Chapter 3</a>, with a few differences in the evaluation part — so let’s focus on that!</p> <p data-svelte-h="svelte-1sfscsd">The first thing to note is that we use the <code>generate()</code> method to compute predictions, but this is a method on our base model, not the wrapped model 🤗 Accelerate created in the <code>prepare()</code> method. That’s why we unwrap the model first, then call this method.</p> <p data-svelte-h="svelte-1cf30vw">The second thing is that, like with <a href="/course/chapter7/2">token classification</a>, two processes may have padded the inputs and labels to different shapes, so we use <code>accelerator.pad_across_processes()</code> to make the predictions and labels the same shape before calling the <code>gather()</code> method. If we don’t do this, the evaluation will either error out or hang forever.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm | |
| <span class="hljs-keyword">import</span> torch | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_train_epochs): | |
| <span class="hljs-comment"># Training</span> | |
| model.train() | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>) | |
| <span class="hljs-comment"># Evaluation</span> | |
| model.<span class="hljs-built_in">eval</span>() | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> tqdm(eval_dataloader): | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| generated_tokens = accelerator.unwrap_model(model).generate( | |
| batch[<span class="hljs-string">"input_ids"</span>], | |
| attention_mask=batch[<span class="hljs-string">"attention_mask"</span>], | |
| max_length=<span class="hljs-number">128</span>, | |
| ) | |
| labels = batch[<span class="hljs-string">"labels"</span>] | |
| <span class="hljs-comment"># Necessary to pad predictions and labels for being gathered</span> | |
| generated_tokens = accelerator.pad_across_processes( | |
| generated_tokens, dim=<span class="hljs-number">1</span>, pad_index=tokenizer.pad_token_id | |
| ) | |
| labels = accelerator.pad_across_processes(labels, dim=<span class="hljs-number">1</span>, pad_index=-<span class="hljs-number">100</span>) | |
| predictions_gathered = accelerator.gather(generated_tokens) | |
| labels_gathered = accelerator.gather(labels) | |
| decoded_preds, decoded_labels = postprocess(predictions_gathered, labels_gathered) | |
| metric.add_batch(predictions=decoded_preds, references=decoded_labels) | |
| results = metric.compute() | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"epoch <span class="hljs-subst">{epoch}</span>, BLEU score: <span class="hljs-subst">{results[<span class="hljs-string">'score'</span>]:<span class="hljs-number">.2</span>f}</span>"</span>) | |
| <span class="hljs-comment"># Save and upload</span> | |
| accelerator.wait_for_everyone() | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) | |
| <span class="hljs-keyword">if</span> accelerator.is_main_process: | |
| tokenizer.save_pretrained(output_dir) | |
| repo.push_to_hub( | |
| commit_message=<span class="hljs-string">f"Training in progress epoch <span class="hljs-subst">{epoch}</span>"</span>, blocking=<span class="hljs-literal">False</span> | |
| )<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->epoch <span class="hljs-number">0</span>, BLEU score: <span class="hljs-number">53.47</span> | |
| epoch <span class="hljs-number">1</span>, BLEU score: <span class="hljs-number">54.24</span> | |
| epoch <span class="hljs-number">2</span>, BLEU score: <span class="hljs-number">54.44</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ezkapr">Once this is done, you should have a model that has results pretty similar to the one trained with the <code>Seq2SeqTrainer</code>. You can check the one we trained using this code at <a href="https://huggingface.co/huggingface-course/marian-finetuned-kde4-en-to-fr-accelerate" rel="nofollow"><em>huggingface-course/marian-finetuned-kde4-en-to-fr-accelerate</em></a>. And if you want to test out any tweaks to the training loop, you can directly implement them by editing the code shown above!</p> <h2 class="relative group"><a id="using-the-fine-tuned-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-the-fine-tuned-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using the fine-tuned model</span></h2> <p data-svelte-h="svelte-d11l7w">We’ve already shown you how you can use the model we fine-tuned on the Model Hub with the inference widget. To use it locally in a <code>pipeline</code>, we just have to specify the proper model identifier:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline | |
| <span class="hljs-comment"># Replace this with your own checkpoint</span> | |
| model_checkpoint = <span class="hljs-string">"huggingface-course/marian-finetuned-kde4-en-to-fr"</span> | |
| translator = pipeline(<span class="hljs-string">"translation"</span>, model=model_checkpoint) | |
| translator(<span class="hljs-string">"Default to expanded threads"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[{<span class="hljs-string">'translation_text'</span>: <span class="hljs-string">'Par défaut, développer les fils de discussion'</span>}]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gj0rwy">As expected, our pretrained model adapted its knowledge to the corpus we fine-tuned it on, and instead of leaving the English word “threads” alone, it now translates it to the French official version. It’s the same for “plugin”:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->translator( | |
| <span class="hljs-string">"Unable to import %1 using the OFX importer plugin. This file is not the correct format."</span> | |
| )<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[{<span class="hljs-string">'translation_text'</span>: <span class="hljs-string">"Impossible d'importer %1 en utilisant le module externe d'importation OFX. Ce fichier n'est pas le bon format."</span>}]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1o5slqz">Another great example of domain adaptation!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-9o8gsu">✏️ <strong>Your turn!</strong> What does the model return on the sample with the word “email” you identified earlier?</p></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/en/chapter7/4.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_engvle = { | |
| assets: "/docs/course/pr_1021/en", | |
| base: "/docs/course/pr_1021/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1021/en/_app/immutable/entry/start.3d2e2978.js"), | |
| import("/docs/course/pr_1021/en/_app/immutable/entry/app.d60bb0c9.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 80], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 165 kB
- Xet hash:
- 03dd56218d29b2eb50249effe261a42252420cf857e17d20242e918c1db41ca9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.