Buckets:

rtrm's picture
download
raw
161 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Fine-tuning a masked language model&quot;,&quot;local&quot;:&quot;fine-tuning-a-masked-language-model&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Picking a pretrained model for masked language modeling&quot;,&quot;local&quot;:&quot;picking-a-pretrained-model-for-masked-language-modeling&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;The dataset&quot;,&quot;local&quot;:&quot;the-dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preprocessing the data&quot;,&quot;local&quot;:&quot;preprocessing-the-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning DistilBERT with the Trainer API&quot;,&quot;local&quot;:&quot;fine-tuning-distilbert-with-the-trainer-api&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Perplexity for language models&quot;,&quot;local&quot;:&quot;perplexity-for-language-models&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning DistilBERT with 🤗 Accelerate&quot;,&quot;local&quot;:&quot;fine-tuning-distilbert-with-accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using our fine-tuned model&quot;,&quot;local&quot;:&quot;using-our-fine-tuned-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/course/pr_1069/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/scheduler.37c15a92.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/singletons.bc78d867.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.18351ede.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/paths.76894643.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.7cb9c9b8.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/0.f5347c47.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/79.d62ff13d.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Tip.d10b3fc9.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Youtube.8666c400.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CodeBlock.abae2786.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CourseFloatingBanner.df82c153.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/FrameworkSwitchCourse.97630871.js">
<link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/getInferenceSnippets.f9350a3f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Fine-tuning a masked language model&quot;,&quot;local&quot;:&quot;fine-tuning-a-masked-language-model&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Picking a pretrained model for masked language modeling&quot;,&quot;local&quot;:&quot;picking-a-pretrained-model-for-masked-language-modeling&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;The dataset&quot;,&quot;local&quot;:&quot;the-dataset&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preprocessing the data&quot;,&quot;local&quot;:&quot;preprocessing-the-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning DistilBERT with the Trainer API&quot;,&quot;local&quot;:&quot;fine-tuning-distilbert-with-the-trainer-api&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Perplexity for language models&quot;,&quot;local&quot;:&quot;perplexity-for-language-models&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning DistilBERT with 🤗 Accelerate&quot;,&quot;local&quot;:&quot;fine-tuning-distilbert-with-accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using our fine-tuned model&quot;,&quot;local&quot;:&quot;using-our-fine-tuned-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="fine-tuning-a-masked-language-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-a-masked-language-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning a masked language model</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-7-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter7/section3_pt.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter7/section3_pt.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-11rqc5m">For many NLP applications involving Transformer models, you can simply take a pretrained model from the Hugging Face Hub and fine-tune it directly on your data for the task at hand. Provided that the corpus used for pretraining is not too different from the corpus used for fine-tuning, transfer learning will usually produce good results.</p> <p data-svelte-h="svelte-17pvd7s">However, there are a few cases where you’ll want to first fine-tune the language models on your data, before training a task-specific head. For example, if your dataset contains legal contracts or scientific articles, a vanilla Transformer model like BERT will typically treat the domain-specific words in your corpus as rare tokens, and the resulting performance may be less than satisfactory. By fine-tuning the language model on in-domain data you can boost the performance of many downstream tasks, which means you usually only have to do this step once!</p> <p data-svelte-h="svelte-9mmj84">This process of fine-tuning a pretrained language model on in-domain data is usually called <em>domain adaptation</em>. It was popularized in 2018 by <a href="https://arxiv.org/abs/1801.06146" rel="nofollow">ULMFiT</a>, which was one of the first neural architectures (based on LSTMs) to make transfer learning really work for NLP. An example of domain adaptation with ULMFiT is shown in the image below; in this section we’ll do something similar, but with a Transformer instead of an LSTM!</p> <div class="flex justify-center" data-svelte-h="svelte-1lvgpuv"><img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/ulmfit.svg" alt="ULMFiT."> <img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/ulmfit-dark.svg" alt="ULMFiT."></div> <p data-svelte-h="svelte-19eeiah">By the end of this section you’ll have a <a href="https://huggingface.co/huggingface-course/distilbert-base-uncased-finetuned-imdb?text=This+is+a+great+%5BMASK%5D." rel="nofollow">masked language model</a> on the Hub that can autocomplete sentences as shown below:</p> <iframe src="https://course-demos-distilbert-base-uncased-finetuned-imdb.hf.space" frameborder="0" height="300" title="Gradio app" class="block dark:hidden container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe> <p data-svelte-h="svelte-1tyjc61">Let’s dive in!</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/mqElG5QJWUg" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-kl6xhw">🙋 If the terms “masked language modeling” and “pretrained model” sound unfamiliar to you, go check out <a href="/course/chapter1">Chapter 1</a>, where we explain all these core concepts, complete with videos!</p></div> <h2 class="relative group"><a id="picking-a-pretrained-model-for-masked-language-modeling" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#picking-a-pretrained-model-for-masked-language-modeling"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Picking a pretrained model for masked language modeling</span></h2> <p data-svelte-h="svelte-194w5r2">To get started, let’s pick a suitable pretrained model for masked language modeling. As shown in the following screenshot, you can find a list of candidates by applying the “Fill-Mask” filter on the <a href="https://huggingface.co/models?pipeline_tag=fill-mask&sort=downloads" rel="nofollow">Hugging Face Hub</a>:</p> <div class="flex justify-center" data-svelte-h="svelte-zc90xi"><img src="https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/mlm-models.png" alt="Hub models." width="80%"></div> <p data-svelte-h="svelte-1uiv2x4">Although the BERT and RoBERTa family of models are the most downloaded, we’ll use a model called <a href="https://huggingface.co/distilbert-base-uncased" rel="nofollow">DistilBERT</a>
that can be trained much faster with little to no loss in downstream performance. This model was trained using a special technique called <a href="https://en.wikipedia.org/wiki/Knowledge_distillation" rel="nofollow"><em>knowledge distillation</em></a>, where a large “teacher model” like BERT is used to guide the training of a “student model” that has far fewer parameters. An explanation of the details of knowledge distillation would take us too far afield in this section, but if you’re interested you can read all about it in <a href="https://www.oreilly.com/library/view/natural-language-processing/9781098136789/" rel="nofollow"><em>Natural Language Processing with Transformers</em></a> (colloquially known as the Transformers textbook).</p> <p data-svelte-h="svelte-1cnwv98">Let’s go ahead and download DistilBERT using the <code>AutoModelForMaskedLM</code> class:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForMaskedLM
model_checkpoint = <span class="hljs-string">&quot;distilbert-base-uncased&quot;</span>
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ezrcnz">We can see how many parameters this model has by calling the <code>num_parameters()</code> method:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->distilbert_num_parameters = model.num_parameters() / <span class="hljs-number">1_000_000</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; DistilBERT number of parameters: <span class="hljs-subst">{<span class="hljs-built_in">round</span>(distilbert_num_parameters)}</span>M&#x27;&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; BERT number of parameters: 110M&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; DistilBERT number of parameters: 67M&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; BERT number of parameters: 110M&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g3lyuf">With around 67 million parameters, DistilBERT is approximately two times smaller than the BERT base model, which roughly translates into a two-fold speedup in training — nice! Let’s now see what kinds of tokens this model predicts are the most likely completions of a small sample of text:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->text = <span class="hljs-string">&quot;This is a great [MASK].&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hat418">As humans, we can imagine many possibilities for the <code>[MASK]</code> token, such as “day”, “ride”, or “painting”. For pretrained models, the predictions depend on the corpus the model was trained on, since it learns to pick up the statistical patterns present in the data. Like BERT, DistilBERT was pretrained on the <a href="https://huggingface.co/datasets/wikipedia" rel="nofollow">English Wikipedia</a> and <a href="https://huggingface.co/datasets/bookcorpus" rel="nofollow">BookCorpus</a> datasets, so we expect the predictions for <code>[MASK]</code> to reflect these domains. To predict the mask we need DistilBERT’s tokenizer to produce the inputs for the model, so let’s download that from the Hub as well:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qgwydn">With a tokenizer and a model, we can now pass our text example to the model, extract the logits, and print out the top 5 candidates:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
inputs = tokenizer(text, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>)
token_logits = model(**inputs).logits
<span class="hljs-comment"># Find the location of [MASK] and extract its logits</span>
mask_token_index = torch.where(inputs[<span class="hljs-string">&quot;input_ids&quot;</span>] == tokenizer.mask_token_id)[<span class="hljs-number">1</span>]
mask_token_logits = token_logits[<span class="hljs-number">0</span>, mask_token_index, :]
<span class="hljs-comment"># Pick the [MASK] candidates with the highest logits</span>
top_5_tokens = torch.topk(mask_token_logits, <span class="hljs-number">5</span>, dim=<span class="hljs-number">1</span>).indices[<span class="hljs-number">0</span>].tolist()
<span class="hljs-keyword">for</span> token <span class="hljs-keyword">in</span> top_5_tokens:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; <span class="hljs-subst">{text.replace(tokenizer.mask_token, tokenizer.decode([token]))}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; This is a great deal.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; This is a great success.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; This is a great adventure.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; This is a great idea.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; This is a great feat.&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3jwbuo">We can see from the outputs that the model’s predictions refer to everyday terms, which is perhaps not surprising given the foundation of English Wikipedia. Let’s see how we can change this domain to something a bit more niche — highly polarized movie reviews!</p> <h2 class="relative group"><a id="the-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The dataset</span></h2> <p data-svelte-h="svelte-e3ztt2">To showcase domain adaptation, we’ll use the famous <a href="https://huggingface.co/datasets/imdb" rel="nofollow">Large Movie Review Dataset</a> (or IMDb for short), which is a corpus of movie reviews that is often used to benchmark sentiment analysis models. By fine-tuning DistilBERT on this corpus, we expect the language model will adapt its vocabulary from the factual data of Wikipedia that it was pretrained on to the more subjective elements of movie reviews. We can get the data from the Hugging Face Hub with the <code>load_dataset()</code> function from 🤗 Datasets:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
imdb_dataset = load_dataset(<span class="hljs-string">&quot;imdb&quot;</span>)
imdb_dataset<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>],
num_rows: <span class="hljs-number">25000</span>
})
test: Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>],
num_rows: <span class="hljs-number">25000</span>
})
unsupervised: Dataset({
features: [<span class="hljs-string">&#x27;text&#x27;</span>, <span class="hljs-string">&#x27;label&#x27;</span>],
num_rows: <span class="hljs-number">50000</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1uxz3j4">We can see that the <code>train</code> and <code>test</code> splits each consist of 25,000 reviews, while there is an unlabeled split called <code>unsupervised</code> that contains 50,000 reviews. Let’s take a look at a few samples to get an idea of what kind of text we’re dealing with. As we’ve done in previous chapters of the course, we’ll chain the <code>Dataset.shuffle()</code> and <code>Dataset.select()</code> functions to create a random sample:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->sample = imdb_dataset[<span class="hljs-string">&quot;train&quot;</span>].shuffle(seed=<span class="hljs-number">42</span>).select(<span class="hljs-built_in">range</span>(<span class="hljs-number">3</span>))
<span class="hljs-keyword">for</span> row <span class="hljs-keyword">in</span> sample:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;\n&#x27;&gt;&gt;&gt; Review: <span class="hljs-subst">{row[<span class="hljs-string">&#x27;text&#x27;</span>]}</span>&#x27;&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; Label: <span class="hljs-subst">{row[<span class="hljs-string">&#x27;label&#x27;</span>]}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->
<span class="hljs-string">&#x27;&gt;&gt;&gt; Review: This is your typical Priyadarshan movie--a bunch of loony characters out on some silly mission. His signature climax has the entire cast of the film coming together and fighting each other in some crazy moshpit over hidden money. Whether it is a winning lottery ticket in Malamaal Weekly, black money in Hera Pheri, &quot;kodokoo&quot; in Phir Hera Pheri, etc., etc., the director is becoming ridiculously predictable. Don\&#x27;t get me wrong; as clichéd and preposterous his movies may be, I usually end up enjoying the comedy. However, in most his previous movies there has actually been some good humor, (Hungama and Hera Pheri being noteworthy ones). Now, the hilarity of his films is fading as he is using the same formula over and over again.&lt;br /&gt;&lt;br /&gt;Songs are good. Tanushree Datta looks awesome. Rajpal Yadav is irritating, and Tusshar is not a whole lot better. Kunal Khemu is OK, and Sharman Joshi is the best.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Label: 0&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Review: Okay, the story makes no sense, the characters lack any dimensionally, the best dialogue is ad-libs about the low quality of movie, the cinematography is dismal, and only editing saves a bit of the muddle, but Sam&quot; Peckinpah directed the film. Somehow, his direction is not enough. For those who appreciate Peckinpah and his great work, this movie is a disappointment. Even a great cast cannot redeem the time the viewer wastes with this minimal effort.&lt;br /&gt;&lt;br /&gt;The proper response to the movie is the contempt that the director San Peckinpah, James Caan, Robert Duvall, Burt Young, Bo Hopkins, Arthur Hill, and even Gig Young bring to their work. Watch the great Peckinpah films. Skip this mess.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Label: 0&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Review: I saw this movie at the theaters when I was about 6 or 7 years old. I loved it then, and have recently come to own a VHS version. &lt;br /&gt;&lt;br /&gt;My 4 and 6 year old children love this movie and have been asking again and again to watch it. &lt;br /&gt;&lt;br /&gt;I have enjoyed watching it again too. Though I have to admit it is not as good on a little TV.&lt;br /&gt;&lt;br /&gt;I do not have older children so I do not know what they would think of it. &lt;br /&gt;&lt;br /&gt;The songs are very cute. My daughter keeps singing them over and over.&lt;br /&gt;&lt;br /&gt;Hope this helps.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Label: 1&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13i7i9s">Yep, these are certainly movie reviews, and if you’re old enough you may even understand the comment in the last review about owning a VHS version 😜! Although we won’t need the labels for language modeling, we can already see that a <code>0</code> denotes a negative review, while a <code>1</code> corresponds to a positive one.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-fd14wo">✏️ <strong>Try it out!</strong> Create a random sample of the <code>unsupervised</code> split and verify that the labels are neither <code>0</code> nor <code>1</code>. While you’re at it, you could also check that the labels in the <code>train</code> and <code>test</code> splits are indeed <code>0</code> or <code>1</code> — this is a useful sanity check that every NLP practitioner should perform at the start of a new project!</p></div> <p data-svelte-h="svelte-1a1khye">Now that we’ve had a quick look at the data, let’s dive into preparing it for masked language modeling. As we’ll see, there are some additional steps that one needs to take compared to the sequence classification tasks we saw in <a href="/course/chapter3">Chapter 3</a>. Let’s go!</p> <h2 class="relative group"><a id="preprocessing-the-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preprocessing-the-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preprocessing the data</span></h2> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/8PmhEIXhBvI" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1whaa76">For both auto-regressive and masked language modeling, a common preprocessing step is to concatenate all the examples and then split the whole corpus into chunks of equal size. This is quite different from our usual approach, where we simply tokenize individual examples. Why concatenate everything together? The reason is that individual examples might get truncated if they’re too long, and that would result in losing information that might be useful for the language modeling task!</p> <p data-svelte-h="svelte-18o7drt">So to get started, we’ll first tokenize our corpus as usual, but <em>without</em> setting the <code>truncation=True</code> option in our tokenizer. We’ll also grab the word IDs if they are available ((which they will be if we’re using a fast tokenizer, as described in <a href="/course/chapter6/3">Chapter 6</a>), as we will need them later on to do whole word masking. We’ll wrap this in a simple function, and while we’re at it we’ll remove the <code>text</code> and <code>label</code> columns since we don’t need them any longer:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">examples</span>):
result = tokenizer(examples[<span class="hljs-string">&quot;text&quot;</span>])
<span class="hljs-keyword">if</span> tokenizer.is_fast:
result[<span class="hljs-string">&quot;word_ids&quot;</span>] = [result.word_ids(i) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(result[<span class="hljs-string">&quot;input_ids&quot;</span>]))]
<span class="hljs-keyword">return</span> result
<span class="hljs-comment"># Use batched=True to activate fast multithreading!</span>
tokenized_datasets = imdb_dataset.<span class="hljs-built_in">map</span>(
tokenize_function, batched=<span class="hljs-literal">True</span>, remove_columns=[<span class="hljs-string">&quot;text&quot;</span>, <span class="hljs-string">&quot;label&quot;</span>]
)
tokenized_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">25000</span>
})
test: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">25000</span>
})
unsupervised: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">50000</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1btid30">Since DistilBERT is a BERT-like model, we can see that the encoded texts consist of the <code>input_ids</code> and <code>attention_mask</code> that we’ve seen in other chapters, as well as the <code>word_ids</code> we added.</p> <p data-svelte-h="svelte-188lv63">Now that we’ve tokenized our movie reviews, the next step is to group them all together and split the result into chunks. But how big should these chunks be? This will ultimately be determined by the amount of GPU memory that you have available, but a good starting point is to see what the model’s maximum context size is. This can be inferred by inspecting the <code>model_max_length</code> attribute of the tokenizer:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer.model_max_length<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">512</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-x8g16c">This value is derived from the <em>tokenizer_config.json</em> file associated with a checkpoint; in this case we can see that the context size is 512 tokens, just like with BERT.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-snuco1">✏️ <strong>Try it out!</strong> Some Transformer models, like <a href="https://huggingface.co/google/bigbird-roberta-base" rel="nofollow">BigBird</a> and <a href="hf.co/allenai/longformer-base-4096">Longformer</a>, have a much longer context length than BERT and other early Transformer models. Instantiate the tokenizer for one of these checkpoints and verify that the <code>model_max_length</code> agrees with what’s quoted on its model card.</p></div> <p data-svelte-h="svelte-xnt0eg">So, in order to run our experiments on GPUs like those found on Google Colab, we’ll pick something a bit smaller that can fit in memory:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->chunk_size = <span class="hljs-number">128</span><!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1w9d6if">Note that using a small chunk size can be detrimental in real-world scenarios, so you should use a size that corresponds to the use case you will apply your model to.</p></div> <p data-svelte-h="svelte-b36s9w">Now comes the fun part. To show how the concatenation works, let’s take a few reviews from our tokenized training set and print out the number of tokens per review:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Slicing produces a list of lists for each feature</span>
tokenized_samples = tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>][:<span class="hljs-number">3</span>]
<span class="hljs-keyword">for</span> idx, sample <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(tokenized_samples[<span class="hljs-string">&quot;input_ids&quot;</span>]):
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; Review <span class="hljs-subst">{idx}</span> length: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(sample)}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; Review 0 length: 200&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Review 1 length: 559&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Review 2 length: 192&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mssxs4">We can then concatenate all these examples with a simple dictionary comprehension, as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->concatenated_examples = {
k: <span class="hljs-built_in">sum</span>(tokenized_samples[k], []) <span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> tokenized_samples.keys()
}
total_length = <span class="hljs-built_in">len</span>(concatenated_examples[<span class="hljs-string">&quot;input_ids&quot;</span>])
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; Concatenated reviews length: <span class="hljs-subst">{total_length}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; Concatenated reviews length: 951&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qtwsnu">Great, the total length checks out — so now let’s split the concatenated reviews into chunks of the size given by <code>chunk_size</code>. To do so, we iterate over the features in <code>concatenated_examples</code> and use a list comprehension to create slices of each feature. The result is a dictionary of chunks for each feature:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->chunks = {
k: [t[i : i + chunk_size] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, total_length, chunk_size)]
<span class="hljs-keyword">for</span> k, t <span class="hljs-keyword">in</span> concatenated_examples.items()
}
<span class="hljs-keyword">for</span> chunk <span class="hljs-keyword">in</span> chunks[<span class="hljs-string">&quot;input_ids&quot;</span>]:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&#x27;&gt;&gt;&gt; Chunk length: <span class="hljs-subst">{<span class="hljs-built_in">len</span>(chunk)}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 128&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; Chunk length: 55&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1y1gzji">As you can see in this example, the last chunk will generally be smaller than the maximum chunk size. There are two main strategies for dealing with this:</p> <ul data-svelte-h="svelte-187zn68"><li>Drop the last chunk if it’s smaller than <code>chunk_size</code>.</li> <li>Pad the last chunk until its length equals <code>chunk_size</code>.</li></ul> <p data-svelte-h="svelte-jf6vcc">We’ll take the first approach here, so let’s wrap all of the above logic in a single function that we can apply to our tokenized datasets:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">group_texts</span>(<span class="hljs-params">examples</span>):
<span class="hljs-comment"># Concatenate all texts</span>
concatenated_examples = {k: <span class="hljs-built_in">sum</span>(examples[k], []) <span class="hljs-keyword">for</span> k <span class="hljs-keyword">in</span> examples.keys()}
<span class="hljs-comment"># Compute length of concatenated texts</span>
total_length = <span class="hljs-built_in">len</span>(concatenated_examples[<span class="hljs-built_in">list</span>(examples.keys())[<span class="hljs-number">0</span>]])
<span class="hljs-comment"># We drop the last chunk if it&#x27;s smaller than chunk_size</span>
total_length = (total_length // chunk_size) * chunk_size
<span class="hljs-comment"># Split by chunks of max_len</span>
result = {
k: [t[i : i + chunk_size] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, total_length, chunk_size)]
<span class="hljs-keyword">for</span> k, t <span class="hljs-keyword">in</span> concatenated_examples.items()
}
<span class="hljs-comment"># Create a new labels column</span>
result[<span class="hljs-string">&quot;labels&quot;</span>] = result[<span class="hljs-string">&quot;input_ids&quot;</span>].copy()
<span class="hljs-keyword">return</span> result<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-5629lr">Note that in the last step of <code>group_texts()</code> we create a new <code>labels</code> column which is a copy of the <code>input_ids</code> one. As we’ll see shortly, that’s because in masked language modeling the objective is to predict randomly masked tokens in the input batch, and by creating a <code>labels</code> column we provide the ground truth for our language model to learn from.</p> <p data-svelte-h="svelte-18n5xxn">Let’s now apply <code>group_texts()</code> to our tokenized datasets using our trusty <code>Dataset.map()</code> function:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->lm_datasets = tokenized_datasets.<span class="hljs-built_in">map</span>(group_texts, batched=<span class="hljs-literal">True</span>)
lm_datasets<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;labels&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">61289</span>
})
test: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;labels&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">59905</span>
})
unsupervised: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;labels&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">122963</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-sr887y">You can see that grouping and then chunking the texts has produced many more examples than our original 25,000 for the <code>train</code> and <code>test</code> splits. That’s because we now have examples involving <em>contiguous tokens</em> that span across multiple examples from the original corpus. You can see this explicitly by looking for the special <code>[SEP]</code> and <code>[CLS]</code> tokens in one of the chunks:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer.decode(lm_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">1</span>][<span class="hljs-string">&quot;input_ids&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&quot;.... at.......... high. a classic line : inspector : i&#x27;m here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn&#x27;t! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-kah7qo">In this example you can see two overlapping movie reviews, one about a high school movie and the other about homelessness. Let’s also check out what the labels look like for masked language modeling:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer.decode(lm_datasets[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">1</span>][<span class="hljs-string">&quot;labels&quot;</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&quot;.... at.......... high. a classic line : inspector : i&#x27;m here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn&#x27;t! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1iew6is">As expected from our <code>group_texts()</code> function above, this looks identical to the decoded <code>input_ids</code> — but then how can our model possibly learn anything? We’re missing a key step: inserting <code>[MASK]</code> tokens at random positions in the inputs! Let’s see how we can do this on the fly during fine-tuning using a special data collator.</p> <h2 class="relative group"><a id="fine-tuning-distilbert-with-the-trainer-api" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-distilbert-with-the-trainer-api"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning DistilBERT with the Trainer API</span></h2> <p data-svelte-h="svelte-jaj1rc">Fine-tuning a masked language model is almost identical to fine-tuning a sequence classification model, like we did in <a href="/course/chapter3">Chapter 3</a>. The only difference is that we need a special data collator that can randomly mask some of the tokens in each batch of texts. Fortunately, 🤗 Transformers comes prepared with a dedicated <code>DataCollatorForLanguageModeling</code> for just this task. We just have to pass it the tokenizer and an <code>mlm_probability</code> argument that specifies what fraction of the tokens to mask. We’ll pick 15%, which is the amount used for BERT and a common choice in the literature:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=<span class="hljs-number">0.15</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dxdd6j">To see how the random masking works, let’s feed a few examples to the data collator. Since it expects a list of <code>dict</code>s, where each <code>dict</code> represents a single chunk of contiguous text, we first iterate over the dataset before feeding the batch to the collator. We remove the <code>&quot;word_ids&quot;</code> key for this data collator as it does not expect it:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->samples = [lm_datasets[<span class="hljs-string">&quot;train&quot;</span>][i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">2</span>)]
<span class="hljs-keyword">for</span> sample <span class="hljs-keyword">in</span> samples:
_ = sample.pop(<span class="hljs-string">&quot;word_ids&quot;</span>)
<span class="hljs-keyword">for</span> chunk <span class="hljs-keyword">in</span> data_collator(samples)[<span class="hljs-string">&quot;input_ids&quot;</span>]:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;\n&#x27;&gt;&gt;&gt; <span class="hljs-subst">{tokenizer.decode(chunk)}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; [CLS] bromwell [MASK] is a cartoon comedy. it ran at the same [MASK] as some other [MASK] about school life, [MASK] as &quot; teachers &quot;. [MASK] [MASK] [MASK] in the teaching [MASK] lead [MASK] to believe that bromwell high\&#x27;[MASK] satire is much closer to reality than is &quot; teachers &quot;. the scramble [MASK] [MASK] financially, the [MASK]ful students whogn [MASK] right through [MASK] pathetic teachers\&#x27;pomp, the pettiness of the whole situation, distinction remind me of the schools i knew and their students. when i saw [MASK] episode in [MASK] a student repeatedly tried to burn down the school, [MASK] immediately recalled. [MASK]...&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; .... at.. [MASK]... [MASK]... high. a classic line plucked inspector : i\&#x27;[MASK] here to [MASK] one of your [MASK]. student : welcome to bromwell [MASK]. i expect that many adults of my age think that [MASK]mwell [MASK] is [MASK] fetched. what a pity that it isn\&#x27;t! [SEP] [CLS] [MASK]ness ( or [MASK]lessness as george 宇in stated )公 been an issue for years but never [MASK] plan to help those on the street that were once considered human [MASK] did everything from going to school, [MASK], [MASK] vote for the matter. most people think [MASK] the homeless&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-tf324a">Nice, it worked! We can see that the <code>[MASK]</code> token has been randomly inserted at various locations in our text. These will be the tokens which our model will have to predict during training — and the beauty of the data collator is that it will randomize the <code>[MASK]</code> insertion with every batch!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1f3fue2">✏️ <strong>Try it out!</strong> Run the code snippet above several times to see the random masking happen in front of your very eyes! Also replace the <code>tokenizer.decode()</code> method with <code>tokenizer.convert_ids_to_tokens()</code> to see that sometimes a single token from a given word is masked, and not the others.</p></div> <p data-svelte-h="svelte-w1yods">One side effect of random masking is that our evaluation metrics will not be deterministic when using the <code>Trainer</code>, since we use the same data collator for the training and test sets. We’ll see later, when we look at fine-tuning with 🤗 Accelerate, how we can use the flexibility of a custom evaluation loop to freeze the randomness.</p> <p data-svelte-h="svelte-171x2x4">When training models for masked language modeling, one technique that can be used is to mask whole words together, not just individual tokens. This approach is called <em>whole word masking</em>. If we want to use whole word masking, we will need to build a data collator ourselves. A data collator is just a function that takes a list of samples and converts them into a batch, so let’s do this now! We’ll use the word IDs computed earlier to make a map between word indices and the corresponding tokens, then randomly decide which words to mask and apply that mask on the inputs. Note that the labels are all <code>-100</code> except for the ones corresponding to mask words.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> collections
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> default_data_collator
wwm_probability = <span class="hljs-number">0.2</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">whole_word_masking_data_collator</span>(<span class="hljs-params">features</span>):
<span class="hljs-keyword">for</span> feature <span class="hljs-keyword">in</span> features:
word_ids = feature.pop(<span class="hljs-string">&quot;word_ids&quot;</span>)
<span class="hljs-comment"># Create a map between words and corresponding token indices</span>
mapping = collections.defaultdict(<span class="hljs-built_in">list</span>)
current_word_index = -<span class="hljs-number">1</span>
current_word = <span class="hljs-literal">None</span>
<span class="hljs-keyword">for</span> idx, word_id <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(word_ids):
<span class="hljs-keyword">if</span> word_id <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
<span class="hljs-keyword">if</span> word_id != current_word:
current_word = word_id
current_word_index += <span class="hljs-number">1</span>
mapping[current_word_index].append(idx)
<span class="hljs-comment"># Randomly mask words</span>
mask = np.random.binomial(<span class="hljs-number">1</span>, wwm_probability, (<span class="hljs-built_in">len</span>(mapping),))
input_ids = feature[<span class="hljs-string">&quot;input_ids&quot;</span>]
labels = feature[<span class="hljs-string">&quot;labels&quot;</span>]
new_labels = [-<span class="hljs-number">100</span>] * <span class="hljs-built_in">len</span>(labels)
<span class="hljs-keyword">for</span> word_id <span class="hljs-keyword">in</span> np.where(mask)[<span class="hljs-number">0</span>]:
word_id = word_id.item()
<span class="hljs-keyword">for</span> idx <span class="hljs-keyword">in</span> mapping[word_id]:
new_labels[idx] = labels[idx]
input_ids[idx] = tokenizer.mask_token_id
feature[<span class="hljs-string">&quot;labels&quot;</span>] = new_labels
<span class="hljs-keyword">return</span> default_data_collator(features)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xi9xjp">Next, we can try it on the same samples as before:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->samples = [lm_datasets[<span class="hljs-string">&quot;train&quot;</span>][i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">2</span>)]
batch = whole_word_masking_data_collator(samples)
<span class="hljs-keyword">for</span> chunk <span class="hljs-keyword">in</span> batch[<span class="hljs-string">&quot;input_ids&quot;</span>]:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;\n&#x27;&gt;&gt;&gt; <span class="hljs-subst">{tokenizer.decode(chunk)}</span>&#x27;&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; [CLS] bromwell high is a cartoon comedy [MASK] it ran at the same time as some other programs about school life, such as &quot; teachers &quot;. my 35 years in the teaching profession lead me to believe that bromwell high\&#x27;s satire is much closer to reality than is &quot; teachers &quot;. the scramble to survive financially, the insightful students who can see right through their pathetic teachers\&#x27;pomp, the pettiness of the whole situation, all remind me of the schools i knew and their students. when i saw the episode in which a student repeatedly tried to burn down the school, i immediately recalled.....&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; .... [MASK] [MASK] [MASK] [MASK]....... high. a classic line : inspector : i\&#x27;m here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn\&#x27;t! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless&#x27;</span><!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-wmwcem">✏️ <strong>Try it out!</strong> Run the code snippet above several times to see the random masking happen in front of your very eyes! Also replace the <code>tokenizer.decode()</code> method with <code>tokenizer.convert_ids_to_tokens()</code> to see that the tokens from a given word are always masked together.</p></div> <p data-svelte-h="svelte-1hbi27y">Now that we have two data collators, the rest of the fine-tuning steps are standard. Training can take a while on Google Colab if you’re not lucky enough to score a mythical P100 GPU 😭, so we’ll first downsample the size of the training set to a few thousand examples. Don’t worry, we’ll still get a pretty decent language model! A quick way to downsample a dataset in 🤗 Datasets is via the <code>Dataset.train_test_split()</code> function that we saw in <a href="/course/chapter5">Chapter 5</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->train_size = <span class="hljs-number">10_000</span>
test_size = <span class="hljs-built_in">int</span>(<span class="hljs-number">0.1</span> * train_size)
downsampled_dataset = lm_datasets[<span class="hljs-string">&quot;train&quot;</span>].train_test_split(
train_size=train_size, test_size=test_size, seed=<span class="hljs-number">42</span>
)
downsampled_dataset<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->DatasetDict({
train: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;labels&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">10000</span>
})
test: Dataset({
features: [<span class="hljs-string">&#x27;attention_mask&#x27;</span>, <span class="hljs-string">&#x27;input_ids&#x27;</span>, <span class="hljs-string">&#x27;labels&#x27;</span>, <span class="hljs-string">&#x27;word_ids&#x27;</span>],
num_rows: <span class="hljs-number">1000</span>
})
})<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-su015l">This has automatically created new <code>train</code> and <code>test</code> splits, with the training set size set to 10,000 examples and the validation set to 10% of that — feel free to increase this if you have a beefy GPU! The next thing we need to do is log in to the Hugging Face Hub. If you’re running this code in a notebook, you can do so with the following utility function:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3fqg55">which will display a widget where you can enter your credentials. Alternatively, you can run:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->huggingface-<span class="hljs-keyword">cli</span> login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sc64uc">in your favorite terminal and log in there.</p> <p data-svelte-h="svelte-jgpv5a">Once we’re logged in, we can specify the arguments for the <code>Trainer</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
batch_size = <span class="hljs-number">64</span>
<span class="hljs-comment"># Show the training loss with every epoch</span>
logging_steps = <span class="hljs-built_in">len</span>(downsampled_dataset[<span class="hljs-string">&quot;train&quot;</span>]) // batch_size
model_name = model_checkpoint.split(<span class="hljs-string">&quot;/&quot;</span>)[-<span class="hljs-number">1</span>]
training_args = TrainingArguments(
output_dir=<span class="hljs-string">f&quot;<span class="hljs-subst">{model_name}</span>-finetuned-imdb&quot;</span>,
overwrite_output_dir=<span class="hljs-literal">True</span>,
evaluation_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
weight_decay=<span class="hljs-number">0.01</span>,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
push_to_hub=<span class="hljs-literal">True</span>,
fp16=<span class="hljs-literal">True</span>,
logging_steps=logging_steps,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zfa7rd">Here we tweaked a few of the default options, including <code>logging_steps</code> to ensure we track the training loss with each epoch. We’ve also used <code>fp16=True</code> to enable mixed-precision training, which gives us another boost in speed. By default, the <code>Trainer</code> will remove any columns that are not part of the model’s <code>forward()</code> method. This means that if you’re using the whole word masking collator, you’ll also need to set <code>remove_unused_columns=False</code> to ensure we don’t lose the <code>word_ids</code> column during training.</p> <p data-svelte-h="svelte-1ot1vnd">Note that you can specify the name of the repository you want to push to with the <code>hub_model_id</code> argument (in particular, you will have to use this argument to push to an organization). For instance, when we pushed the model to the <a href="https://huggingface.co/huggingface-course" rel="nofollow"><code>huggingface-course</code> organization</a>, we added <code>hub_model_id=&quot;huggingface-course/distilbert-finetuned-imdb&quot;</code> to <code>TrainingArguments</code>. By default, the repository used will be in your namespace and named after the output directory you set, so in our case it will be <code>&quot;lewtun/distilbert-finetuned-imdb&quot;</code>.</p> <p data-svelte-h="svelte-47mng">We now have all the ingredients to instantiate the <code>Trainer</code>. Here we just use the standard <code>data_collator</code>, but you can try the whole word masking collator and compare the results as an exercise:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=downsampled_dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=downsampled_dataset[<span class="hljs-string">&quot;test&quot;</span>],
data_collator=data_collator,
tokenizer=tokenizer,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-c9apqh">We’re now ready to run <code>trainer.train()</code> — but before doing so let’s briefly look at <em>perplexity</em>, which is a common metric to evaluate the performance of language models.</p> <h3 class="relative group"><a id="perplexity-for-language-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#perplexity-for-language-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Perplexity for language models</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/NURcDHhYe98" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-now2dv">Unlike other tasks like text classification or question answering where we’re given a labeled corpus to train on, with language modeling we don’t have any explicit labels. So how do we determine what makes a good language model? Like with the autocorrect feature in your phone, a good language model is one that assigns high probabilities to sentences that are grammatically correct, and low probabilities to nonsense sentences. To give you a better idea of what this looks like, you can find whole sets of “autocorrect fails” online, where the model in a person’s phone has produced some rather funny (and often inappropriate) completions!</p> <p data-svelte-h="svelte-1nciy2w">Assuming our test set consists mostly of sentences that are grammatically correct, then one way to measure the quality of our language model is to calculate the probabilities it assigns to the next word in all the sentences of the test set. High probabilities indicates that the model is not “surprised” or “perplexed” by the unseen examples, and suggests it has learned the basic patterns of grammar in the language. There are various mathematical definitions of perplexity, but the one we’ll use defines it as the exponential of the cross-entropy loss. Thus, we can calculate the perplexity of our pretrained model by using the <code>Trainer.evaluate()</code> function to compute the cross-entropy loss on the test set and then taking the exponential of the result:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> math
eval_results = trainer.evaluate()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&gt;&gt;&gt; Perplexity: <span class="hljs-subst">{math.exp(eval_results[<span class="hljs-string">&#x27;eval_loss&#x27;</span>]):<span class="hljs-number">.2</span>f}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>Perplexity: <span class="hljs-number">21.75</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-n4gsqr">A lower perplexity score means a better language model, and we can see here that our starting model has a somewhat large value. Let’s see if we can lower it by fine-tuning! To do that, we first run the training loop:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-e645ia">and then compute the resulting perplexity on the test set as before:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->eval_results = trainer.evaluate()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&gt;&gt;&gt; Perplexity: <span class="hljs-subst">{math.exp(eval_results[<span class="hljs-string">&#x27;eval_loss&#x27;</span>]):<span class="hljs-number">.2</span>f}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>Perplexity: <span class="hljs-number">11.32</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-14b9dr5">Nice — this is quite a reduction in perplexity, which tells us the model has learned something about the domain of movie reviews!</p> <p data-svelte-h="svelte-13hhldr">Once training is finished, we can push the model card with the training information to the Hub (the checkpoints are saved during training itself):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.push_to_hub()<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1jdek3j">✏️ <strong>Your turn!</strong> Run the training above after changing the data collator to the whole word masking collator. Do you get better results?</p></div> <p data-svelte-h="svelte-18f0f90">In our use case we didn’t need to do anything special with the training loop, but in some cases you might need to implement some custom logic. For these applications, you can use 🤗 Accelerate — let’s take a look!</p> <h2 class="relative group"><a id="fine-tuning-distilbert-with-accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-distilbert-with-accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning DistilBERT with 🤗 Accelerate</span></h2> <p data-svelte-h="svelte-17r4mi9">As we saw with the <code>Trainer</code>, fine-tuning a masked language model is very similar to the text classification example from <a href="/course/chapter3">Chapter 3</a>. In fact, the only subtlety is the use of a special data collator, and we’ve already covered that earlier in this section!</p> <p data-svelte-h="svelte-gxymov">However, we saw that <code>DataCollatorForLanguageModeling</code> also applies random masking with each evaluation, so we’ll see some fluctuations in our perplexity scores with each training run. One way to eliminate this source of randomness is to apply the masking <em>once</em> on the whole test set, and then use the default data collator in 🤗 Transformers to collect the batches during evaluation. To see how this works, let’s implement a simple function that applies the masking on a batch, similar to our first encounter with <code>DataCollatorForLanguageModeling</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">insert_random_mask</span>(<span class="hljs-params">batch</span>):
features = [<span class="hljs-built_in">dict</span>(<span class="hljs-built_in">zip</span>(batch, t)) <span class="hljs-keyword">for</span> t <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(*batch.values())]
masked_inputs = data_collator(features)
<span class="hljs-comment"># Create a new &quot;masked&quot; column for each column in the dataset</span>
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;masked_&quot;</span> + k: v.numpy() <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> masked_inputs.items()}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-14pqtt6">Next, we’ll apply this function to our test set and drop the unmasked columns so we can replace them with the masked ones. You can use whole word masking by replacing the <code>data_collator</code> above with the appropriate one, in which case you should remove the first line here:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->downsampled_dataset = downsampled_dataset.remove_columns([<span class="hljs-string">&quot;word_ids&quot;</span>])
eval_dataset = downsampled_dataset[<span class="hljs-string">&quot;test&quot;</span>].<span class="hljs-built_in">map</span>(
insert_random_mask,
batched=<span class="hljs-literal">True</span>,
remove_columns=downsampled_dataset[<span class="hljs-string">&quot;test&quot;</span>].column_names,
)
eval_dataset = eval_dataset.rename_columns(
{
<span class="hljs-string">&quot;masked_input_ids&quot;</span>: <span class="hljs-string">&quot;input_ids&quot;</span>,
<span class="hljs-string">&quot;masked_attention_mask&quot;</span>: <span class="hljs-string">&quot;attention_mask&quot;</span>,
<span class="hljs-string">&quot;masked_labels&quot;</span>: <span class="hljs-string">&quot;labels&quot;</span>,
}
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8xq44y">We can then set up the dataloaders as usual, but we’ll use the <code>default_data_collator</code> from 🤗 Transformers for the evaluation set:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> default_data_collator
batch_size = <span class="hljs-number">64</span>
train_dataloader = DataLoader(
downsampled_dataset[<span class="hljs-string">&quot;train&quot;</span>],
shuffle=<span class="hljs-literal">True</span>,
batch_size=batch_size,
collate_fn=data_collator,
)
eval_dataloader = DataLoader(
eval_dataset, batch_size=batch_size, collate_fn=default_data_collator
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ixsyix">Form here, we follow the standard steps with 🤗 Accelerate. The first order of business is to load a fresh version of the pretrained model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = <span class="hljs-module-access"><span class="hljs-module"><span class="hljs-identifier">AutoModelForMaskedLM</span>.</span></span>from<span class="hljs-constructor">_pretrained(<span class="hljs-params">model_checkpoint</span>)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1uukzr7">Then we need to specify the optimizer; we’ll use the standard <code>AdamW</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">5e-5</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xxic1e">With these objects, we can now prepare everything for training with the <code>Accelerator</code> object:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cqa387">Now that our model, optimizer, and dataloaders are configured, we can specify the learning rate scheduler as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler
num_train_epochs = <span class="hljs-number">3</span>
num_update_steps_per_epoch = <span class="hljs-built_in">len</span>(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1asoup2">There is just one last thing to do before training: create a model repository on the Hugging Face Hub! We can use the 🤗 Hub library to first generate the full name of our repo:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> get_full_repo_name
model_name = <span class="hljs-string">&quot;distilbert-base-uncased-finetuned-imdb-accelerate&quot;</span>
repo_name = get_full_repo_name(model_name)
repo_name<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;lewtun/distilbert-base-uncased-finetuned-imdb-accelerate&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-om9ony">then create and clone the repository using the <code>Repository</code> class from 🤗 Hub:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> Repository
output_dir = model_name
repo = Repository(output_dir, clone_from=repo_name)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ut1bsc">With that done, it’s just a simple matter of writing out the full training and evaluation loop:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> math
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_train_epochs):
<span class="hljs-comment"># Training</span>
model.train()
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)
<span class="hljs-comment"># Evaluation</span>
model.<span class="hljs-built_in">eval</span>()
losses = []
<span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(eval_dataloader):
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
losses.append(accelerator.gather(loss.repeat(batch_size)))
losses = torch.cat(losses)
losses = losses[: <span class="hljs-built_in">len</span>(eval_dataset)]
<span class="hljs-keyword">try</span>:
perplexity = math.exp(torch.mean(losses))
<span class="hljs-keyword">except</span> OverflowError:
perplexity = <span class="hljs-built_in">float</span>(<span class="hljs-string">&quot;inf&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&gt;&gt;&gt; Epoch <span class="hljs-subst">{epoch}</span>: Perplexity: <span class="hljs-subst">{perplexity}</span>&quot;</span>)
<span class="hljs-comment"># Save and upload</span>
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
<span class="hljs-keyword">if</span> accelerator.is_main_process:
tokenizer.save_pretrained(output_dir)
repo.push_to_hub(
commit_message=<span class="hljs-string">f&quot;Training in progress epoch <span class="hljs-subst">{epoch}</span>&quot;</span>, blocking=<span class="hljs-literal">False</span>
)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>Epoch <span class="hljs-number">0</span>: Perplexity: <span class="hljs-number">11.397545307900472</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>Epoch <span class="hljs-number">1</span>: Perplexity: <span class="hljs-number">10.904909330983092</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>Epoch <span class="hljs-number">2</span>: Perplexity: <span class="hljs-number">10.729503505340409</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-63yh5i">Cool, we’ve been able to evaluate perplexity with each epoch and ensure that multiple training runs are reproducible!</p> <h2 class="relative group"><a id="using-our-fine-tuned-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-our-fine-tuned-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using our fine-tuned model</span></h2> <p data-svelte-h="svelte-azcu8o">You can interact with your fine-tuned model either by using its widget on the Hub or locally with the <code>pipeline</code> from 🤗 Transformers. Let’s use the latter to download our model using the <code>fill-mask</code> pipeline:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> pipeline
mask_filler = pipeline(
<span class="hljs-string">&quot;fill-mask&quot;</span>, model=<span class="hljs-string">&quot;huggingface-course/distilbert-base-uncased-finetuned-imdb&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-12ftkgw">We can then feed the pipeline our sample text of “This is a great [MASK]” and see what the top 5 predictions are:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->preds = mask_filler(text)
<span class="hljs-keyword">for</span> pred <span class="hljs-keyword">in</span> preds:
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;&gt;&gt;&gt; <span class="hljs-subst">{pred[<span class="hljs-string">&#x27;sequence&#x27;</span>]}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&#x27;&gt;&gt;&gt; this is a great movie.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; this is a great film.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; this is a great story.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; this is a great movies.&#x27;</span>
<span class="hljs-string">&#x27;&gt;&gt;&gt; this is a great character.&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8cy1ke">Neat — our model has clearly adapted its weights to predict words that are more strongly associated with movies!</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/0Oxphw4Q9fo" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-w2yc62">This wraps up our first experiment with training a language model. In <a href="/course/en/chapter7/6">section 6</a> you’ll learn how to train an auto-regressive model like GPT-2 from scratch; head over there if you’d like to see how you can pretrain your very own Transformer model!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-13mk4ny">✏️ <strong>Try it out!</strong> To quantify the benefits of domain adaptation, fine-tune a classifier on the IMDb labels for both the pretrained and fine-tuned DistilBERT checkpoints. If you need a refresher on text classification, check out <a href="/course/chapter3">Chapter 3</a>.</p></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/en/chapter7/3.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1y0degu = {
assets: "/docs/course/pr_1069/en",
base: "/docs/course/pr_1069/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js"),
import("/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 79],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
161 kB
·
Xet hash:
6621acbfb0f49d3bddc10836cd3ec9f33366c0184c1e7acd2aff5c7d96978dfd

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.