Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"使用 Keras 微調一個模型","local":"使用-keras-微調一個模型","sections":[{"title":"訓練模型","local":"訓練模型","sections":[],"depth":3},{"title":"提升訓練的效果","local":"提升訓練的效果","sections":[],"depth":3},{"title":"模型預測","local":"模型預測","sections":[],"depth":3}],"depth":1}"> | |
| <link href="/docs/course/pr_1069/zh-TW/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/entry/start.01945326.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/singletons.40763523.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/paths.677dc4ce.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/entry/app.e8597ff2.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/index.2bf4358c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/nodes/0.decbb3b3.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/nodes/25.6d717e1f.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/Tip.363c041f.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/Youtube.1e50a667.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/CodeBlock.4e987730.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/CourseFloatingBanner.9ff4c771.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/DocNotebookDropdown.efc1fb7c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/FrameworkSwitchCourse.8d4d4ab6.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/zh-TW/_app/immutable/chunks/getInferenceSnippets.24b50994.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"使用 Keras 微調一個模型","local":"使用-keras-微調一個模型","sections":[{"title":"訓練模型","local":"訓練模型","sections":[],"depth":3},{"title":"提升訓練的效果","local":"提升訓練的效果","sections":[],"depth":3},{"title":"模型預測","local":"模型預測","sections":[],"depth":3}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="使用-keras-微調一個模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#使用-keras-微調一個模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>使用 Keras 微調一個模型</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-3-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/zh-CN/chapter3/section3_tf.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/zh-CN/chapter3/section3_tf.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-qe99ht">完成上一節中的所有數據預處理工作後,您只剩下最後的幾個步驟來訓練模型。 但是請注意,<code>model.fit()</code> 命令在 CPU 上運行會非常緩慢。 如果您沒有GPU,則可以在 <a href="https://colab.research.google.com/" rel="nofollow">Google Colab</a> 上使用免費的 GPU 或 TPU(需要梯子)。</p> <p data-svelte-h="svelte-ecil1z">這一節的代碼示例假設您已經執行了上一節中的代碼示例。 下面一個簡短的摘要,包含了在開始學習這一節之前您需要的執行的代碼:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding | |
| <span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| checkpoint = <span class="hljs-string">"bert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">"sentence1"</span>], example[<span class="hljs-string">"sentence2"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=<span class="hljs-string">"tf"</span>) | |
| tf_train_dataset = tokenized_datasets[<span class="hljs-string">"train"</span>].to_tf_dataset( | |
| columns=[<span class="hljs-string">"attention_mask"</span>, <span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"token_type_ids"</span>], | |
| label_cols=[<span class="hljs-string">"labels"</span>], | |
| shuffle=<span class="hljs-literal">True</span>, | |
| collate_fn=data_collator, | |
| batch_size=<span class="hljs-number">8</span>, | |
| ) | |
| tf_validation_dataset = tokenized_datasets[<span class="hljs-string">"validation"</span>].to_tf_dataset( | |
| columns=[<span class="hljs-string">"attention_mask"</span>, <span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"token_type_ids"</span>], | |
| label_cols=[<span class="hljs-string">"labels"</span>], | |
| shuffle=<span class="hljs-literal">False</span>, | |
| collate_fn=data_collator, | |
| batch_size=<span class="hljs-number">8</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="訓練模型" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#訓練模型"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>訓練模型</span></h3> <p data-svelte-h="svelte-neyhrj">從🤗 Transformers 導入的 TensorFlow 模型已經是 Keras 模型。 下面的視頻是對 Keras 的簡短介紹。</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/rnTGBy2ax1c" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1v5o97q">這意味著,一旦我們有了數據,就需要很少的工作就可以開始對其進行訓練。</p> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/AUozVp78dhk" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-2b9muo">和<a href="/course/chapter2">第二章</a>使用的方法一樣, 我們將使用二分類的 <code>TFAutoModelForSequenceClassification</code>類:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TFAutoModelForSequenceClassification | |
| model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-twmyyw">您會注意到,與 <a href="/course/chapter2">第二章</a> 不同的是,您在實例化此預訓練模型後會收到警告。 這是因為 BERT 沒有對句子對進行分類進行預訓練,所以預訓練模型的 head 已經被丟棄,而是插入了一個適合序列分類的新 head。 警告表明一些權重沒有使用(對應於丟棄的預訓練頭),而其他一些權重是隨機初始化的(新頭的權重)。 最後鼓勵您訓練模型,這正是我們現在要做的。</p> <p data-svelte-h="svelte-6w9ps3">要在我們的數據集上微調模型,我們只需要在我們的模型上調用 <code>compile()</code> 方法,然後將我們的數據傳遞給 <code>fit()</code> 方法。 這將啟動微調過程(在 GPU 上應該需要幾分鐘)並輸出訓練loss,以及每個 epoch 結束時的驗證loss。</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1ywel4b">請注意🤗 Transformers 模型具有大多數 Keras 模型所沒有的特殊能力——它們可以自動使用內部計算的loss。 如果您沒有在 <code>compile()</code> 中設置損失函數,他們將默認使用內部計算的損失。 請注意,要使用內部損失,您需要將標籤作為輸入的一部分傳遞,而不是作為單獨的標籤(這是在 Keras 模型中使用標籤的正常方式)。 您將在課程的第 2 部分中看到這方面的示例,其中定義正確的損失函數可能很棘手。 然而,對於序列分類,標準的 Keras 損失函數可以正常工作,所以我們將在這裡使用它。</p></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tensorflow.keras.losses <span class="hljs-keyword">import</span> SparseCategoricalCrossentropy | |
| model.<span class="hljs-built_in">compile</span>( | |
| optimizer=<span class="hljs-string">"adam"</span>, | |
| loss=SparseCategoricalCrossentropy(from_logits=<span class="hljs-literal">True</span>), | |
| metrics=[<span class="hljs-string">"accuracy"</span>], | |
| ) | |
| model.fit( | |
| tf_train_dataset, | |
| validation_data=tf_validation_dataset, | |
| )<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-87xo3y">請注意這裡有一個非常常見的陷阱——你只是<em>可以</em>將損失的名稱作為字符串傳遞給 Keras,但默認情況下,Keras 會假設你已經對輸出應用了 softmax。 然而,許多模型在應用 softmax 之前就輸出,也稱為 <em>logits</em>。 我們需要告訴損失函數我們的模型是否經過了softmax,唯一的方法是直接調用它,而不是用字符串的名稱。</p></div> <h3 class="relative group"><a id="提升訓練的效果" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#提升訓練的效果"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>提升訓練的效果</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/cpzq6ESSM5c" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-xom0gg">如果您嘗試上面的代碼,它肯定會運行,但您會發現loss只是緩慢或零星地下降。 主要原因是<em>學習率</em>。 與loss一樣,當我們將優化器的名稱作為字符串傳遞給 Keras 時,Keras 會初始化該優化器具有所有參數的默認值,包括學習率。 但是,根據長期經驗,我們知道Transformer 模型更適合使用比 Adam 的默認值(1e-3)也寫成為 10 的 -3 次方,或 0.001,低得多的學習率。 5e-5 (0.00005) 比默認值大約低 20 倍,是一個更好的起點。</p> <p data-svelte-h="svelte-1xlrj49">除了降低學習率,我們還有第二個技巧:我們可以慢慢降低學習率。在訓練過程中。 在文獻中,您有時會看到這被稱為 <em>decaying</em> 或 <em>annealing</em>學習率。 在 Keras 中,最好的方法是使用 <em>learning rate scheduler</em>。 一個好用的是<code>PolynomialDecay</code>——儘管有這個名字,但在默認設置下,它只是簡單地從初始值線性衰減學習率值在訓練過程中的最終值,這正是我們想要的。但是, 為了正確使用調度程序,我們需要告訴它訓練的次數。 我們將在下面為其計算“num_train_steps”。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tensorflow.keras.optimizers.schedules <span class="hljs-keyword">import</span> PolynomialDecay | |
| batch_size = <span class="hljs-number">8</span> | |
| num_epochs = <span class="hljs-number">3</span> | |
| <span class="hljs-comment"># 訓練步數是數據集中的樣本數除以batch size再乘以 epoch。</span> | |
| <span class="hljs-comment"># 注意這裡的tf_train_dataset是一個轉化為batch後的 tf.data.Dataset,</span> | |
| <span class="hljs-comment"># 不是原來的 Hugging Face Dataset,所以它的 len() 已經是 num_samples // batch_size。</span> | |
| num_train_steps = <span class="hljs-built_in">len</span>(tf_train_dataset) * num_epochs | |
| lr_scheduler = PolynomialDecay( | |
| initial_learning_rate=<span class="hljs-number">5e-5</span>, end_learning_rate=<span class="hljs-number">0.0</span>, decay_steps=num_train_steps | |
| ) | |
| <span class="hljs-keyword">from</span> tensorflow.keras.optimizers <span class="hljs-keyword">import</span> Adam | |
| opt = Adam(learning_rate=lr_scheduler)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-k5flef">🤗 Transformers 庫還有一個 <code>create_optimizer()</code> 函數,它將創建一個具有學習率衰減的 <code>AdamW</code> 優化器。 這是一個便捷的方式,您將在本課程的後續部分中詳細瞭解。</p></div> <p data-svelte-h="svelte-j5gumi">現在我們有了全新的優化器,我們可以嘗試使用它進行訓練。 首先,讓我們重新加載模型,以重置我們剛剛進行的訓練運行對權重的更改,然後我們可以使用新的優化器對其進行編譯:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=<span class="hljs-literal">True</span>) | |
| model.<span class="hljs-built_in">compile</span>(optimizer=opt, loss=loss, metrics=[<span class="hljs-string">"accuracy"</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ur03qf">現在,我們再次進行fit:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=<span class="hljs-number">3</span>)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-15lz9ny">💡 如果您想在訓練期間自動將模型上傳到 Hub,您可以在 <code>model.fit()</code> 方法中傳遞 <code>PushToHubCallback</code>。 我們將在 <a href="/course/chapter4/3">第四章</a> 中進行介紹</p></div> <h3 class="relative group"><a id="模型預測" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#模型預測"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>模型預測</span></h3> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/nx10eh4CoOs" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-p7uafd">訓練和觀察的loss下降都非常好,但是如果我們想從訓練後的模型中獲得輸出,或者計算一些指標,或者在生產中使用模型呢? 為此,我們可以使用<code>predict()</code> 方法。 這將返回模型的輸出頭的<em>logits</em>數值,每個類一個。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->preds = model.predict(tf_validation_dataset)[<span class="hljs-string">"logits"</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-37mj27">我們可以將這些 logit 轉換為模型的類別預測,方法是使用 argmax 找到最高的 logit,它對應於最有可能的類別:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->class_preds = np.argmax(preds, axis=<span class="hljs-number">1</span>) | |
| <span class="hljs-built_in">print</span>(preds.shape, class_preds.shape)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->(<span class="hljs-number">408</span>, <span class="hljs-number">2</span>) (<span class="hljs-number">408</span>,)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vkuhsj">現在,讓我們使用這些 <code>preds</code> 來計算一些指標! 我們可以像加載數據集一樣輕鬆地加載與 MRPC 數據集相關的指標,這次使用的是 <code>evaluate.load()</code> 函數。 返回的對象有一個 <code>compute()</code> 方法,我們可以使用它來進行度量計算:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> evaluate | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| metric.compute(predictions=class_preds, references=raw_datasets[<span class="hljs-string">"validation"</span>][<span class="hljs-string">"label"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'accuracy'</span>: <span class="hljs-number">0.8578431372549019</span>, <span class="hljs-string">'f1'</span>: <span class="hljs-number">0.8996539792387542</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1r3acct">您獲得的確切結果可能會有所不同,因為模型頭的隨機初始化可能會改變它獲得的指標。 在這裡,我們可以看到我們的模型在驗證集上的準確率為 85.78%,F1 得分為 89.97。 這些是用於評估 GLUE 基準的 MRPC 數據集結果的兩個指標。 <a href="https://arxiv.org/pdf/1810.04805.pdf" rel="nofollow">BERT 論文</a> 中的表格報告了基本模型的 F1 分數為 88.9。 那是 <code>uncased</code> 模型,而我們目前使用的是 <code>cased</code> 模型,這解釋了為什麼我們會獲得更好的結果。</p> <p data-svelte-h="svelte-1xeqkwk">使用 Keras API 進行微調的介紹到此結束。 第 7 章將給出對大多數常見 NLP 任務執行此操作的示例。如果您想在 Keras API 上磨練自己的技能,請嘗試使第二節所使用的的數據處理在 GLUE SST-2 數據集上微調模型。</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/zh-TW/chapter3/3_tf.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_14z83n9 = { | |
| assets: "/docs/course/pr_1069/zh-TW", | |
| base: "/docs/course/pr_1069/zh-TW", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1069/zh-TW/_app/immutable/entry/start.01945326.js"), | |
| import("/docs/course/pr_1069/zh-TW/_app/immutable/entry/app.e8597ff2.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 25], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 40.8 kB
- Xet hash:
- e9ba8a7279e97a81df9c8e278b5326cb606d497b970fa3528f23d5dbc09249c5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.