Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"A full training loop","local":"a-full-training","sections":[{"title":"Prepare for training","local":"prepare-for-training","sections":[],"depth":3},{"title":"The training loop","local":"the-training-loop","sections":[],"depth":3},{"title":"The evaluation loop","local":"the-evaluation-loop","sections":[],"depth":3},{"title":"Supercharge your training loop with ๐ค Accelerate","local":"supercharge-your-training-loop-with-accelerate","sections":[],"depth":3},{"title":"Next Steps and Best Practices","local":"next-steps-and-best-practices","sections":[],"depth":3},{"title":"Section Quiz","local":"section-quiz","sections":[{"title":"1. What is the main difference between Adam and AdamW optimizers?","local":"1-what-is-the-main-difference-between-adam-and-adamw-optimizers","sections":[],"depth":3},{"title":"2. In a training loop, what is the correct order of operations?","local":"2-in-a-training-loop-what-is-the-correct-order-of-operations","sections":[],"depth":3},{"title":"3. What does the ๐ค Accelerate library primarily help with?","local":"3-what-does-the--accelerate-library-primarily-help-with","sections":[],"depth":3},{"title":"4. Why do we move batches to the device in a training loop?","local":"4-why-do-we-move-batches-to-the-device-in-a-training-loop","sections":[],"depth":3},{"title":"5. What does model.eval() do before evaluation?","local":"5-what-does-modeleval-do-before-evaluation","sections":[],"depth":3},{"title":"6. What is the purpose of torch.no_grad() during evaluation?","local":"6-what-is-the-purpose-of-torchnograd-during-evaluation","sections":[],"depth":3},{"title":"7. What changes when you use ๐ค Accelerate in your training loop?","local":"7-what-changes-when-you-use--accelerate-in-your-training-loop","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/entry/start.3d2e2978.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/singletons.7a9af56d.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/paths.692f56cf.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/entry/app.d60bb0c9.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/index.7cb9c9b8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/nodes/0.bbc29778.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/nodes/48.d54c6378.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/Tip.d10b3fc9.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/Youtube.8666c400.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/CodeBlock.abae2786.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/CourseFloatingBanner.df82c153.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/Question.7e41e492.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/stores.cb4752a8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1021/en/_app/immutable/chunks/getInferenceSnippets.a2135f3c.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"A full training loop","local":"a-full-training","sections":[{"title":"Prepare for training","local":"prepare-for-training","sections":[],"depth":3},{"title":"The training loop","local":"the-training-loop","sections":[],"depth":3},{"title":"The evaluation loop","local":"the-evaluation-loop","sections":[],"depth":3},{"title":"Supercharge your training loop with ๐ค Accelerate","local":"supercharge-your-training-loop-with-accelerate","sections":[],"depth":3},{"title":"Next Steps and Best Practices","local":"next-steps-and-best-practices","sections":[],"depth":3},{"title":"Section Quiz","local":"section-quiz","sections":[{"title":"1. What is the main difference between Adam and AdamW optimizers?","local":"1-what-is-the-main-difference-between-adam-and-adamw-optimizers","sections":[],"depth":3},{"title":"2. In a training loop, what is the correct order of operations?","local":"2-in-a-training-loop-what-is-the-correct-order-of-operations","sections":[],"depth":3},{"title":"3. What does the ๐ค Accelerate library primarily help with?","local":"3-what-does-the--accelerate-library-primarily-help-with","sections":[],"depth":3},{"title":"4. Why do we move batches to the device in a training loop?","local":"4-why-do-we-move-batches-to-the-device-in-a-training-loop","sections":[],"depth":3},{"title":"5. What does model.eval() do before evaluation?","local":"5-what-does-modeleval-do-before-evaluation","sections":[],"depth":3},{"title":"6. What is the purpose of torch.no_grad() during evaluation?","local":"6-what-is-the-purpose-of-torchnograd-during-evaluation","sections":[],"depth":3},{"title":"7. What changes when you use ๐ค Accelerate in your training loop?","local":"7-what-changes-when-you-use--accelerate-in-your-training-loop","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="a-full-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#a-full-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>A full training loop</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-3-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter3/section4.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter3/section4.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/Dh9CL8fyG80" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-57kt9p">Now weโll see how to achieve the same results as we did in the last section without using the <code>Trainer</code> class, implementing a training loop from scratch with modern PyTorch best practices. Again, we assume you have done the data processing in section 2. Here is a short summary covering everything you will need:</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-11sv2tl">๐๏ธ <strong>Training from Scratch</strong>: This section builds on the previous content. For comprehensive guidance on PyTorch training loops and best practices, check out the <a href="https://huggingface.co/docs/transformers/main/en/training#train-in-native-pytorch" rel="nofollow">๐ค Transformers training documentation</a> and the <a href="https://huggingface.co/learn/cookbook/en/fine_tuning_code_llm_on_single_gpu#model" rel="nofollow">custom training cookbook</a>.</p></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| checkpoint = <span class="hljs-string">"bert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">"sentence1"</span>], example[<span class="hljs-string">"sentence2"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="prepare-for-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#prepare-for-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Prepare for training</span></h3> <p data-svelte-h="svelte-ci9pg7">Before actually writing our training loop, we will need to define a few objects. The first ones are the dataloaders we will use to iterate over batches. But before we can define those dataloaders, we need to apply a bit of postprocessing to our <code>tokenized_datasets</code>, to take care of some things that the <code>Trainer</code> did for us automatically. Specifically, we need to:</p> <ul data-svelte-h="svelte-pk4vyx"><li>Remove the columns corresponding to values the model does not expect (like the <code>sentence1</code> and <code>sentence2</code> columns).</li> <li>Rename the column <code>label</code> to <code>labels</code> (because the model expects the argument to be named <code>labels</code>).</li> <li>Set the format of the datasets so they return PyTorch tensors instead of lists.</li></ul> <p data-svelte-h="svelte-ey27a0">Our <code>tokenized_datasets</code> has one method for each of those steps:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenized_datasets = tokenized_datasets.remove_columns([<span class="hljs-string">"sentence1"</span>, <span class="hljs-string">"sentence2"</span>, <span class="hljs-string">"idx"</span>]) | |
| tokenized_datasets = tokenized_datasets.rename_column(<span class="hljs-string">"label"</span>, <span class="hljs-string">"labels"</span>) | |
| tokenized_datasets.set_format(<span class="hljs-string">"torch"</span>) | |
| tokenized_datasets[<span class="hljs-string">"train"</span>].column_names<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-trxw7x">We can then check that the result only has columns that our model will accept:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-string">"attention_mask"</span>, <span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"labels"</span>, <span class="hljs-string">"token_type_ids"</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11c0vnt">Now that this is done, we can easily define our dataloaders:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader | |
| train_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"train"</span>], shuffle=<span class="hljs-literal">True</span>, batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator | |
| ) | |
| eval_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"validation"</span>], batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11wtmu1">To quickly check there is no mistake in the data processing, we can inspect a batch like this:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| <span class="hljs-keyword">break</span> | |
| {k: v.shape <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'attention_mask'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]), | |
| <span class="hljs-string">'input_ids'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]), | |
| <span class="hljs-string">'labels'</span>: torch.Size([<span class="hljs-number">8</span>]), | |
| <span class="hljs-string">'token_type_ids'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>])}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-q4m5oq">Note that the actual shapes will probably be slightly different for you since we set <code>shuffle=True</code> for the training dataloader and we are padding to the maximum length inside the batch.</p> <p data-svelte-h="svelte-xq9k1z">Now that weโre completely finished with data preprocessing (a satisfying yet elusive goal for any ML practitioner), letโs turn to the model. We instantiate it exactly as we did in the previous section:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1cfpt2v">To make sure that everything will go smoothly during training, we pass our batch to this model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->outputs = model(**batch) | |
| <span class="hljs-built_in">print</span>(outputs.loss, outputs.logits.shape)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tensor(<span class="hljs-number">0.5441</span>, grad_fn=<NllLossBackward>) torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">2</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ynwaiq">All ๐ค Transformers models will return the loss when <code>labels</code> are provided, and we also get the logits (two for each input in our batch, so a tensor of size 8 x 2).</p> <p data-svelte-h="svelte-1my9rg2">Weโre almost ready to write our training loop! Weโre just missing two things: an optimizer and a learning rate scheduler. Since we are trying to replicate what the <code>Trainer</code> was doing by hand, we will use the same defaults. The optimizer used by the <code>Trainer</code> is <code>AdamW</code>, which is the same as Adam, but with a twist for weight decay regularization (see <a href="https://arxiv.org/abs/1711.05101" rel="nofollow">โDecoupled Weight Decay Regularizationโ</a> by Ilya Loshchilov and Frank Hutter):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">5e-5</span>)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-o5kaes">๐ก <strong>Modern Optimization Tips</strong>: For even better performance, you can try:</p> <ul data-svelte-h="svelte-1qym9yd"><li><strong>AdamW with weight decay</strong>: <code>AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)</code></li> <li><strong>8-bit Adam</strong>: Use <code>bitsandbytes</code> for memory-efficient optimization</li> <li><strong>Different learning rates</strong>: Lower learning rates (1e-5 to 3e-5) often work better for large models</li></ul> <p data-svelte-h="svelte-x7tnyo">๐ <strong>Optimization Resources</strong>: Learn more about optimizers and training strategies in the <a href="https://huggingface.co/docs/transformers/main/en/performance#optimizer" rel="nofollow">๐ค Transformers optimization guide</a>.</p></div> <p data-svelte-h="svelte-12ciodg">Finally, the learning rate scheduler used by default is just a linear decay from the maximum value (5e-5) to 0. To properly define it, we need to know the number of training steps we will take, which is the number of epochs we want to run multiplied by the number of training batches (which is the length of our training dataloader). The <code>Trainer</code> uses three epochs by default, so we will follow that:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| <span class="hljs-built_in">print</span>(num_training_steps)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">1377</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="the-training-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-training-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The training loop</span></h3> <p data-svelte-h="svelte-1331u4o">One last thing: we will want to use the GPU if we have access to one (on a CPU, training might take several hours instead of a couple of minutes). To do this, we define a <code>device</code> we will put our model and our batches on:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| model.to(device) | |
| device<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->device(<span class="hljs-built_in">type</span>=<span class="hljs-string">'cuda'</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s26r6i">We are now ready to train! To get some sense of when training will be finished, we add a progress bar over our number of training steps, using the <code>tqdm</code> library:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1dozp7f">๐ก <strong>Modern Training Optimizations</strong>: To make your training loop even more efficient, consider:</p> <ul data-svelte-h="svelte-ujbxeo"><li><strong>Gradient Clipping</strong>: Add <code>torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)</code> before <code>optimizer.step()</code></li> <li><strong>Mixed Precision</strong>: Use <code>torch.cuda.amp.autocast()</code> and <code>GradScaler</code> for faster training</li> <li><strong>Gradient Accumulation</strong>: Accumulate gradients over multiple batches to simulate larger batch sizes</li> <li><strong>Checkpointing</strong>: Save model checkpoints periodically to resume training if interrupted</li></ul> <p data-svelte-h="svelte-elu395">๐ง <strong>Implementation Guide</strong>: For detailed examples of these optimizations, see the <a href="https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one" rel="nofollow">๐ค Transformers efficient training guide</a> and the <a href="https://huggingface.co/docs/transformers/main/en/optimizers" rel="nofollow">range of optimizers</a>.</p></div> <p data-svelte-h="svelte-mcninq">You can see that the core of the training loop looks a lot like the one in the introduction. We didnโt ask for any reporting, so this training loop will not tell us anything about how the model fares. We need to add an evaluation loop for that.</p> <h3 class="relative group"><a id="the-evaluation-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-evaluation-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The evaluation loop</span></h3> <p data-svelte-h="svelte-1sbu2u1">As we did earlier, we will use a metric provided by the ๐ค Evaluate library. Weโve already seen the <code>metric.compute()</code> method, but metrics can actually accumulate batches for us as we go over the prediction loop with the method <code>add_batch()</code>. Once we have accumulated all the batches, we can get the final result with <code>metric.compute()</code>. Hereโs how to implement all of this in an evaluation loop:</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-n1do1t">๐ <strong>Evaluation Best Practices</strong>: For more sophisticated evaluation strategies and metrics, explore the <a href="https://huggingface.co/docs/evaluate/" rel="nofollow">๐ค Evaluate documentation</a> and the <a href="https://github.com/huggingface/evaluation-guidebook" rel="nofollow">comprehensive evaluation cookbook</a>.</p></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| model.<span class="hljs-built_in">eval</span>() | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = model(**batch) | |
| logits = outputs.logits | |
| predictions = torch.argmax(logits, dim=-<span class="hljs-number">1</span>) | |
| metric.add_batch(predictions=predictions, references=batch[<span class="hljs-string">"labels"</span>]) | |
| metric.compute()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'accuracy'</span>: <span class="hljs-number">0.8431372549019608</span>, <span class="hljs-string">'f1'</span>: <span class="hljs-number">0.8907849829351535</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-14qz8qy">Again, your results will be slightly different because of the randomness in the model head initialization and the data shuffling, but they should be in the same ballpark.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-14jbd71">โ๏ธ <strong>Try it out!</strong> Modify the previous training loop to fine-tune your model on the SST-2 dataset.</p></div> <h3 class="relative group"><a id="supercharge-your-training-loop-with-accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#supercharge-your-training-loop-with-accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Supercharge your training loop with ๐ค Accelerate</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/s7dy8QRgjJ0" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-lpq30y">The training loop we defined earlier works fine on a single CPU or GPU. But using the <a href="https://github.com/huggingface/accelerate" rel="nofollow">๐ค Accelerate</a> library, with just a few adjustments we can enable distributed training on multiple GPUs or TPUs. ๐ค Accelerate handles the complexity of distributed training, mixed precision, and device placement automatically. Starting from the creation of the training and validation dataloaders, here is what our manual training loop looks like:</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-mw1zvv">โก <strong>Accelerate Deep Dive</strong>: Learn everything about distributed training, mixed precision, and hardware optimization in the <a href="https://huggingface.co/docs/accelerate/" rel="nofollow">๐ค Accelerate documentation</a> and explore practical examples in the <a href="https://huggingface.co/docs/transformers/main/en/accelerate" rel="nofollow">transformers documentation</a>.</p></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler | |
| accelerator = Accelerator() | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>) | |
| train_dl, eval_dl, model, optimizer = accelerator.prepare( | |
| train_dataloader, eval_dataloader, model, optimizer | |
| ) | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dl) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dl: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1q9wsp0">The first line to add is the import line. The second line instantiates an <code>Accelerator</code> object that will look at the environment and initialize the proper distributed setup. ๐ค Accelerate handles the device placement for you, so you can remove the lines that put the model on the device (or, if you prefer, change them to use <code>accelerator.device</code> instead of <code>device</code>).</p> <p data-svelte-h="svelte-15cl91a">Then the main bulk of the work is done in the line that sends the dataloaders, the model, and the optimizer to <code>accelerator.prepare()</code>. This will wrap those objects in the proper container to make sure your distributed training works as intended. The remaining changes to make are removing the line that puts the batch on the <code>device</code> (again, if you want to keep this you can just change it to use <code>accelerator.device</code>) and replacing <code>loss.backward()</code> with <code>accelerator.backward(loss)</code>.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400">โ ๏ธ In order to benefit from the speed-up offered by Cloud TPUs, we recommend padding your samples to a fixed length with the `padding="max_length"` and `max_length` arguments of the tokenizer.</div> <p data-svelte-h="svelte-1g09gg4">If youโd like to copy and paste it to play around, hereโs what the complete training loop looks like with ๐ค Accelerate:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler | |
| accelerator = Accelerator() | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>) | |
| train_dl, eval_dl, model, optimizer = accelerator.prepare( | |
| train_dataloader, eval_dataloader, model, optimizer | |
| ) | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dl) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dl: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1476nx">Putting this in a <code>train.py</code> script will make that script runnable on any kind of distributed setup. To try it out in your distributed setup, run the command:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qyc3dl">which will prompt you to answer a few questions and dump your answers in a configuration file used by this command:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate <span class="hljs-built_in">launch</span> train.py<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dc7y78">which will launch the distributed training.</p> <p data-svelte-h="svelte-529dmx">If you want to try this in a Notebook (for instance, to test it with TPUs on Colab), just paste the code in a <code>training_function()</code> and run a last cell with:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher | |
| notebook_launcher(training_function)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-py198">You can find more examples in the <a href="https://github.com/huggingface/accelerate/tree/main/examples" rel="nofollow">๐ค Accelerate repo</a>.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-lrep30">๐ <strong>Distributed Training</strong>: For comprehensive coverage of multi-GPU and multi-node training, check out the <a href="https://huggingface.co/docs/transformers/main/en/perf_train_gpu_many" rel="nofollow">๐ค Transformers distributed training guide</a> and the <a href="https://huggingface.co/docs/transformers/main/en/accelerate" rel="nofollow">scaling training cookbook</a>.</p></div> <h3 class="relative group"><a id="next-steps-and-best-practices" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#next-steps-and-best-practices"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Next Steps and Best Practices</span></h3> <p data-svelte-h="svelte-ww0xia">Now that youโve learned how to implement training from scratch, here are some additional considerations for production use:</p> <p data-svelte-h="svelte-91nu4a"><strong>Model Evaluation</strong>: Always evaluate your model on multiple metrics, not just accuracy. Use the ๐ค Evaluate library for comprehensive evaluation.</p> <p data-svelte-h="svelte-ljahbj"><strong>Hyperparameter Tuning</strong>: Consider using libraries like Optuna or Ray Tune for systematic hyperparameter optimization.</p> <p data-svelte-h="svelte-j1czxl"><strong>Model Monitoring</strong>: Track training metrics, learning curves, and validation performance throughout training.</p> <p data-svelte-h="svelte-1vozwss"><strong>Model Sharing</strong>: Once trained, share your model on the Hugging Face Hub to make it available to the community.</p> <p data-svelte-h="svelte-quu3pr"><strong>Efficiency</strong>: For large models, consider techniques like gradient checkpointing, parameter-efficient fine-tuning (LoRA, AdaLoRA), or quantization methods.</p> <p data-svelte-h="svelte-l3wmo4">This concludes our deep dive into fine-tuning with custom training loops. The skills youโve learned here will serve you well when you need full control over the training process or want to implement custom training logic that goes beyond what the <code>Trainer</code> API offers.</p> <h2 class="relative group"><a id="section-quiz" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#section-quiz"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Section Quiz</span></h2> <p data-svelte-h="svelte-ehgujc">Test your understanding of custom training loops and advanced training techniques:</p> <h3 class="relative group"><a id="1-what-is-the-main-difference-between-adam-and-adamw-optimizers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-what-is-the-main-difference-between-adam-and-adamw-optimizers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. What is the main difference between Adam and AdamW optimizers?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->AdamW uses a different learning rate schedule.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->AdamW includes decoupled weight decay regularization.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->AdamW only works with transformer models.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->AdamW requires less memory than Adam.<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <h3 class="relative group"><a id="2-in-a-training-loop-what-is-the-correct-order-of-operations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-in-a-training-loop-what-is-the-correct-order-of-operations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. In a training loop, what is the correct order of operations?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->Forward pass โ Backward pass โ Optimizer step โ Zero gradients<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->Forward pass โ Backward pass โ Optimizer step โ Scheduler step โ Zero gradients<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->Zero gradients โ Forward pass โ Optimizer step โ Backward pass<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->Forward pass โ Zero gradients โ Backward pass โ Optimizer step<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <h3 class="relative group"><a id="3-what-does-the--accelerate-library-primarily-help-with" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-what-does-the--accelerate-library-primarily-help-with"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. What does the ๐ค Accelerate library primarily help with?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->Making your models train faster by optimizing the forward pass.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->Automatically selecting the best hyperparameters.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->Enabling distributed training across multiple GPUs/TPUs with minimal code changes.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->Converting models to different frameworks like TensorFlow.<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <h3 class="relative group"><a id="4-why-do-we-move-batches-to-the-device-in-a-training-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#4-why-do-we-move-batches-to-the-device-in-a-training-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>4. Why do we move batches to the device in a training loop?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->To make the training faster.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->Because the model and data must be on the same device (CPU/GPU) for computation.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->To save memory.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->It's required by the DataLoader.<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <h3 class="relative group"><a id="5-what-does-modeleval-do-before-evaluation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#5-what-does-modeleval-do-before-evaluation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>5. What does model.eval() do before evaluation?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->It freezes the model parameters so they can't be updated.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->It changes the behavior of layers like dropout and batch normalization for inference.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->It enables gradient computation for evaluation metrics.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->It automatically calculates evaluation metrics.<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <h3 class="relative group"><a id="6-what-is-the-purpose-of-torchnograd-during-evaluation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#6-what-is-the-purpose-of-torchnograd-during-evaluation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>6. What is the purpose of torch.no_grad() during evaluation?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->To prevent the model from making predictions.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->To save memory and speed up computation by disabling gradient tracking.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->To enable evaluation mode for the model.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->To ensure consistent results across runs.<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <h3 class="relative group"><a id="7-what-changes-when-you-use--accelerate-in-your-training-loop" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#7-what-changes-when-you-use--accelerate-in-your-training-loop"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>7. What changes when you use ๐ค Accelerate in your training loop?</span></h3> <div><form><label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="0"> <!-- HTML_TAG_START -->You must rewrite your entire training loop from scratch.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="1"> <!-- HTML_TAG_START -->You wrap key objects with accelerator.prepare() and use accelerator.backward() instead of loss.backward().<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="2"> <!-- HTML_TAG_START -->You need to specify the number of GPUs in your code.<!-- HTML_TAG_END --></label> <label class="block"><input autocomplete="off" class="form-input -mt-1.5 mr-2" name="choice" type="checkbox" value="3"> <!-- HTML_TAG_START -->You must use a different optimizer and scheduler.<!-- HTML_TAG_END --></label> <div class="flex flex-row items-center mt-3"><button class="btn px-4 mr-4" type="submit" disabled>Submit</button> </div></form></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-l314m0">๐ก <strong>Key Takeaways:</strong></p> <ul data-svelte-h="svelte-w3ba9j"><li>Manual training loops give you complete control but require understanding of the proper sequence: forward โ backward โ optimizer step โ scheduler step โ zero gradients</li> <li>AdamW with weight decay is the recommended optimizer for transformer models</li> <li>Always use <code>model.eval()</code> and <code>torch.no_grad()</code> during evaluation for correct behavior and efficiency</li> <li>๐ค Accelerate makes distributed training accessible with minimal code changes</li> <li>Device management (moving tensors to GPU/CPU) is crucial for PyTorch operations</li> <li>Modern techniques like mixed precision, gradient accumulation, and gradient clipping can significantly improve training efficiency</li></ul></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/en/chapter3/4.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_engvle = { | |
| assets: "/docs/course/pr_1021/en", | |
| base: "/docs/course/pr_1021/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1021/en/_app/immutable/entry/start.3d2e2978.js"), | |
| import("/docs/course/pr_1021/en/_app/immutable/entry/app.d60bb0c9.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 48], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 87.5 kB
- Xet hash:
- 8f750a7c348975f314c596a087148ebc194ced84e06f2de02f58380c6dca7547
ยท
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.