Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"一個完整的訓練","local":"一個完整的訓練","sections":[{"title":"訓練前的準備","local":"訓練前的準備","sections":[],"depth":3},{"title":"訓練循環","local":"訓練循環","sections":[],"depth":3},{"title":"評估循環","local":"評估循環","sections":[],"depth":3},{"title":"S使用🤗 Accelerate加速您的訓練循環","local":"s使用-accelerate加速您的訓練循環","sections":[],"depth":3}],"depth":1}"> | |
| <link href="/docs/course/pr_1069/zh-TW/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/entry/start.01945326.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/singletons.40763523.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/paths.677dc4ce.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/entry/app.e8597ff2.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/index.2bf4358c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/nodes/0.decbb3b3.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/nodes/26.28dbe435.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/Tip.363c041f.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/Youtube.1e50a667.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/CodeBlock.4e987730.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/CourseFloatingBanner.9ff4c771.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/DocNotebookDropdown.efc1fb7c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/getInferenceSnippets.24b50994.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"一個完整的訓練","local":"一個完整的訓練","sections":[{"title":"訓練前的準備","local":"訓練前的準備","sections":[],"depth":3},{"title":"訓練循環","local":"訓練循環","sections":[],"depth":3},{"title":"評估循環","local":"評估循環","sections":[],"depth":3},{"title":"S使用🤗 Accelerate加速您的訓練循環","local":"s使用-accelerate加速您的訓練循環","sections":[],"depth":3}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="一個完整的訓練" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#一個完整的訓練"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>一個完整的訓練</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-3-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/zh-CN/chapter3/section4.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/zh-CN/chapter3/section4.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/Dh9CL8fyG80" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-45snkl">現在,我們將瞭解如何在不使用<code>Trainer</code>類的情況下獲得與上一節相同的結果。同樣,我們假設您已經學習了第 2 節中的數據處理。下面是一個簡短的總結,涵蓋了您需要的所有內容:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| checkpoint = <span class="hljs-string">"bert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">"sentence1"</span>], example[<span class="hljs-string">"sentence2"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="訓練前的準備" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#訓練前的準備"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>訓練前的準備</span></h3> <p data-svelte-h="svelte-1luc03v">在實際編寫我們的訓練循環之前,我們需要定義一些對象。第一個是我們將用於迭代批次的數據加載器。我們需要對我們的<code>tokenized_datasets</code>做一些處理,來處理<code>Trainer</code>自動為我們做的一些事情。具體來說,我們需要:</p> <ul data-svelte-h="svelte-16t4zf0"><li>刪除與模型不期望的值相對應的列(如<code>sentence1</code>和<code>sentence2</code>列)。</li> <li>將列名<code>label</code>重命名為<code>labels</code>(因為模型期望參數是<code>labels</code>)。</li> <li>設置數據集的格式,使其返回 PyTorch 張量而不是列表。</li></ul> <p data-svelte-h="svelte-uj5r7x">針對上面的每個步驟,我們的 <code>tokenized_datasets</code> 都有一個方法:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenized_datasets = tokenized_datasets.remove_columns([<span class="hljs-string">"sentence1"</span>, <span class="hljs-string">"sentence2"</span>, <span class="hljs-string">"idx"</span>]) | |
| tokenized_datasets = tokenized_datasets.rename_column(<span class="hljs-string">"label"</span>, <span class="hljs-string">"labels"</span>) | |
| tokenized_datasets.set_format(<span class="hljs-string">"torch"</span>) | |
| tokenized_datasets[<span class="hljs-string">"train"</span>].column_names<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ilinva">然後,我們可以檢查結果中是否只有模型能夠接受的列:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-string">"attention_mask"</span>, <span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"labels"</span>, <span class="hljs-string">"token_type_ids"</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-uwy8ea">至此,我們可以輕鬆定義數據加載器:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader | |
| train_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"train"</span>], shuffle=<span class="hljs-literal">True</span>, batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator | |
| ) | |
| eval_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"validation"</span>], batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ycgcns">為了快速檢驗數據處理中沒有錯誤,我們可以這樣檢驗其中的一個批次:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| <span class="hljs-keyword">break</span> | |
| {k: v.shape <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'attention_mask'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]), | |
| <span class="hljs-string">'input_ids'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]), | |
| <span class="hljs-string">'labels'</span>: torch.Size([<span class="hljs-number">8</span>]), | |
| <span class="hljs-string">'token_type_ids'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>])}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-c56v8i">請注意,實際的形狀可能與您略有不同,因為我們為訓練數據加載器設置了<code>shuffle=True</code>,並且模型會將句子填充到<code>batch</code>中的最大長度。</p> <p data-svelte-h="svelte-1thu4ij">現在我們已經完全完成了數據預處理(對於任何 ML 從業者來說都是一個令人滿意但難以實現的目標),讓我們將注意力轉向模型。我們完全像在上一節中所做的那樣實例化它:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k5txvx">為了確保訓練過程中一切順利,我們將<code>batch</code>傳遞給這個模型:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->outputs = model(**batch) | |
| <span class="hljs-built_in">print</span>(outputs.loss, outputs.logits.shape)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tensor(<span class="hljs-number">0.5441</span>, grad_fn=<NllLossBackward>) torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">2</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s0jxmf">當我們提供 <code>labels</code> 時, 🤗 Transformers 模型都將返回這個<code>batch</code>的<code>loss</code>,我們還得到了 <code>logits</code>(<code>batch</code>中的每個輸入有兩個,所以張量大小為 8 x 2)。</p> <p data-svelte-h="svelte-crabuu">我們幾乎準備好編寫我們的訓練循環了!我們只是缺少兩件事:優化器和學習率調度器。由於我們試圖自行實現 <code>Trainer</code>的功能,我們將使用相同的優化器和學習率調度器。<code>Trainer</code> 使用的優化器是 <code>AdamW</code> , 與 <code>Adam</code> 相同,但在權重衰減正則化方面有所不同(參見<a href="https://arxiv.org/abs/1711.05101" rel="nofollow">“Decoupled Weight Decay Regularization”</a>作者:Ilya Loshchilov 和 Frank Hutter):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">5e-5</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11cj6lf">最後,默認使用的學習率調度器只是從最大值 (5e-5) 到 0 的線性衰減。 為了定義它,我們需要知道我們訓練的次數,即所有數據訓練的次數(epochs)乘以的數據量(這是我們所有訓練數據的數量)。<code>Trainer</code>默認情況下使用三個<code>epochs</code>,因此我們定義訓練過程如下:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| <span class="hljs-built_in">print</span>(num_training_steps)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">1377</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="訓練循環" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#訓練循環"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>訓練循環</span></h3> <p data-svelte-h="svelte-1kip7jk">最後一件事:如果我們可以訪問 GPU,我們將希望使用 GPU(在 CPU 上,訓練可能需要幾個小時而不是幾分鐘)。為此,我們定義了一個 <code>device</code>,它在GPU可用的情況下指向GPU 我們將把我們的模型和<code>batche</code>放在<code>device</code>上:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| model.to(device) | |
| device<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->device(<span class="hljs-built_in">type</span>=<span class="hljs-string">'cuda'</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vg6486">我們現在準備好訓練了!為了瞭解訓練何時結束,我們使用 <code>tqdm</code> 庫,在訓練步驟數上添加了一個進度條:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-q2yxta">您可以看到訓練循環的核心與介紹中的非常相似。我們沒有要求任何檢驗,所以這個訓練循環不會告訴我們任何關於模型目前的狀態。我們需要為此添加一個評估循環。</p> <h3 class="relative group"><a id="評估循環" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#評估循環"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>評估循環</span></h3> <p data-svelte-h="svelte-12cb50l">正如我們之前所做的那樣,我們將使用 🤗 Evaluate 庫提供的指標。我們已經瞭解了 <code>metric.compute()</code> 方法,當我們使用 <code>add_batch()</code>方法進行預測循環時,實際上該指標可以為我們累積所有 <code>batch</code> 的結果。一旦我們累積了所有 <code>batch</code> ,我們就可以使用 <code>metric.compute()</code> 得到最終結果 .以下是在評估循環中實現所有這些的方法:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| model.<span class="hljs-built_in">eval</span>() | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = model(**batch) | |
| logits = outputs.logits | |
| predictions = torch.argmax(logits, dim=-<span class="hljs-number">1</span>) | |
| metric.add_batch(predictions=predictions, references=batch[<span class="hljs-string">"labels"</span>]) | |
| metric.compute()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'accuracy'</span>: <span class="hljs-number">0.8431372549019608</span>, <span class="hljs-string">'f1'</span>: <span class="hljs-number">0.8907849829351535</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xsc2k">同樣,由於模型頭部初始化和數據改組的隨機性,您的結果會略有不同,但它們應該在同一個範圍內。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-17q0fgc">✏️ <strong>試試看!</strong> 修改之前的訓練循環以在 SST-2 數據集上微調您的模型。</p></div> <h3 class="relative group"><a id="s使用-accelerate加速您的訓練循環" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#s使用-accelerate加速您的訓練循環"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>S使用🤗 Accelerate加速您的訓練循環</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/s7dy8QRgjJ0" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1735t3r">我們之前定義的訓練循環在單個 CPU 或 GPU 上運行良好。但是使用<a href="https://github.com/huggingface/accelerate" rel="nofollow">🤗 Accelerate</a>庫,只需進行一些調整,我們就可以在多個 GPU 或 TPU 上啟用分佈式訓練。從創建訓練和驗證數據加載器開始,我們的手動訓練循環如下所示:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>) | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| model.to(device) | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1br0vyr">以下是變化:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-addition">+ from accelerate import Accelerator</span> | |
| from torch.optim import AdamW | |
| from transformers import AutoModelForSequenceClassification, get_scheduler | |
| <span class="hljs-addition">+ accelerator = Accelerator()</span> | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2) | |
| optimizer = AdamW(model.parameters(), lr=3e-5) | |
| <span class="hljs-deletion">- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")</span> | |
| <span class="hljs-deletion">- model.to(device)</span> | |
| <span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(</span> | |
| <span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer</span> | |
| <span class="hljs-addition">+ )</span> | |
| num_epochs = 3 | |
| num_training_steps = num_epochs * len(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| "linear", | |
| optimizer=optimizer, | |
| num_warmup_steps=0, | |
| num_training_steps=num_training_steps | |
| ) | |
| progress_bar = tqdm(range(num_training_steps)) | |
| model.train() | |
| for epoch in range(num_epochs): | |
| for batch in train_dataloader: | |
| <span class="hljs-deletion">- batch = {k: v.to(device) for k, v in batch.items()}</span> | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| <span class="hljs-deletion">- loss.backward()</span> | |
| <span class="hljs-addition">+ accelerator.backward(loss)</span> | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8f6itc">要添加的第一行是導入<code>Accelerator</code>。第二行實例化一個 <code>Accelerator</code>對象 ,它將查看環境並初始化適當的分佈式設置。 🤗 Accelerate 為您處理數據在設備間的傳遞,因此您可以刪除將模型放在設備上的那行代碼(或者,如果您願意,可使用 <code>accelerator.device</code> 代替 <code>device</code> )。</p> <p data-svelte-h="svelte-1eoa0j4">然後大部分工作會在將數據加載器、模型和優化器發送到的<code>accelerator.prepare()</code>中完成。這將會把這些對象包裝在適當的容器中,以確保您的分佈式訓練按預期工作。要進行的其餘更改是刪除將<code>batch</code>放在 <code>device</code> 的那行代碼(同樣,如果您想保留它,您可以將其更改為使用 <code>accelerator.device</code> ) 並將 <code>loss.backward()</code> 替換為<code>accelerator.backward(loss)</code>。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400">⚠️ 為了使雲端 TPU 提供的加速發揮最大的效益,我們建議使用標記器(tokenizer)的 `padding=max_length` 和 `max_length` 參數將您的樣本填充到固定長度。</div> <p data-svelte-h="svelte-1uj9qtn">如果您想複製並粘貼來直接運行,以下是 🤗 Accelerate 的完整訓練循環:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler | |
| accelerator = Accelerator() | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>) | |
| train_dl, eval_dl, model, optimizer = accelerator.prepare( | |
| train_dataloader, eval_dataloader, model, optimizer | |
| ) | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dl) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dl: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18claa0">把這個放在 <code>train.py</code> 文件中,可以讓它在任何類型的分佈式設置上運行。要在分佈式設置中試用它,請運行以下命令:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qdvt6p">這將詢問您幾個配置的問題並將您的回答轉儲到此命令使用的配置文件中:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate <span class="hljs-built_in">launch</span> train.py<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ncn3iw">這將啟動分佈式訓練</p> <p data-svelte-h="svelte-k1igho">這將啟動分佈式訓練。如果您想在 Notebook 中嘗試此操作(例如,在 Colab 上使用 TPU 進行測試),只需將代碼粘貼到 <code>training_function()</code> 並使用以下命令運行最後一個單元格:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher | |
| notebook_launcher(training_function)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1pplyqx">您可以在<a href="https://github.com/huggingface/accelerate/tree/main/examples" rel="nofollow">🤗 Accelerate repo</a>找到更多的示例。</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/zh-TW/chapter3/4.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_14z83n9 = { | |
| assets: "/docs/course/pr_1069/zh-TW", | |
| base: "/docs/course/pr_1069/zh-TW", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1069/zh-TW/_app/immutable/entry/start.01945326.js"), | |
| import("/docs/course/pr_1069/zh-TW/_app/immutable/entry/app.e8597ff2.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 26], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 60.9 kB
- Xet hash:
- 7f59961fcf0e5227d6e2a05244dfc1d15470a9307d1a80a9ac0286bfc6dcf146
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.