Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Debugging the training pipeline","local":"debugging-the-training-pipeline","sections":[{"title":"Debugging the training pipeline","local":"debugging-the-training-pipeline","sections":[{"title":"Check your data","local":"check-your-data","sections":[],"depth":3},{"title":"Check your model","local":"check-your-model","sections":[],"depth":3},{"title":"Check your hyperparameters","local":"check-your-hyperparameters","sections":[],"depth":3}],"depth":2},{"title":"Other potential issues","local":"other-potential-issues","sections":[{"title":"Dealing with out-of-memory errors","local":"dealing-with-out-of-memory-errors","sections":[],"depth":3},{"title":"Hungry Hungry TensorFlow 🦛","local":"hungry-hungry-tensorflow","sections":[],"depth":3},{"title":"Check your data (again!)","local":"check-your-data-again","sections":[],"depth":3},{"title":"Overfit your model on one batch","local":"overfit-your-model-on-one-batch","sections":[],"depth":3},{"title":"Don’t tune anything until you have a first baseline","local":"dont-tune-anything-until-you-have-a-first-baseline","sections":[],"depth":3},{"title":"Ask for help","local":"ask-for-help","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1069/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/singletons.bc78d867.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/paths.76894643.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/index.7cb9c9b8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/0.f5347c47.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/nodes/90.2808d34b.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Tip.d10b3fc9.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/Youtube.8666c400.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CodeBlock.abae2786.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/CourseFloatingBanner.df82c153.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/FrameworkSwitchCourse.97630871.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/en/_app/immutable/chunks/getInferenceSnippets.f9350a3f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Debugging the training pipeline","local":"debugging-the-training-pipeline","sections":[{"title":"Debugging the training pipeline","local":"debugging-the-training-pipeline","sections":[{"title":"Check your data","local":"check-your-data","sections":[],"depth":3},{"title":"Check your model","local":"check-your-model","sections":[],"depth":3},{"title":"Check your hyperparameters","local":"check-your-hyperparameters","sections":[],"depth":3}],"depth":2},{"title":"Other potential issues","local":"other-potential-issues","sections":[{"title":"Dealing with out-of-memory errors","local":"dealing-with-out-of-memory-errors","sections":[],"depth":3},{"title":"Hungry Hungry TensorFlow 🦛","local":"hungry-hungry-tensorflow","sections":[],"depth":3},{"title":"Check your data (again!)","local":"check-your-data-again","sections":[],"depth":3},{"title":"Overfit your model on one batch","local":"overfit-your-model-on-one-batch","sections":[],"depth":3},{"title":"Don’t tune anything until you have a first baseline","local":"dont-tune-anything-until-you-have-a-first-baseline","sections":[],"depth":3},{"title":"Ask for help","local":"ask-for-help","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="debugging-the-training-pipeline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#debugging-the-training-pipeline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Debugging the training pipeline</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-8-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter8/section4_tf.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter8/section4_tf.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-nwhezp">You’ve written a beautiful script to train or fine-tune a model on a given task, dutifully following the advice from <a href="/course/chapter7">Chapter 7</a>. But when you launch the command <code>model.fit()</code>, something horrible happens: you get an error 😱! Or worse, everything seems to be fine and the training runs without error, but the resulting model is crappy. In this section, we will show you what you can do to debug these kinds of issues.</p> <h2 class="relative group"><a id="debugging-the-training-pipeline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#debugging-the-training-pipeline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Debugging the training pipeline</span></h2> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/N9kO52itd0Q" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-1ioi9wy">The problem when you encounter an error in <code>model.fit()</code> is that it could come from multiple sources, as training usually brings together a lot of things that you’ve been working on up until that point. The problem could be something wrong in your dataset, or some issue when trying to batch elements of the datasets together. Or it could be something wrong in the model code, or your loss function or optimizer. And even if everything goes well for training, something could still go wrong during the evaluation if there is a problem with your metric.</p> <p data-svelte-h="svelte-1hhfuy3">The best way to debug an error that arises in <code>model.fit()</code> is to manually go through this whole pipeline to see where things went awry. The error is then often very easy to solve.</p> <p data-svelte-h="svelte-q32aak">To demonstrate this, we will use the following script that (tries to) fine-tune a DistilBERT model on the <a href="https://huggingface.co/datasets/glue" rel="nofollow">MNLI dataset</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">import</span> evaluate | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ( | |
| AutoTokenizer, | |
| TFAutoModelForSequenceClassification, | |
| ) | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| model_checkpoint = <span class="hljs-string">"distilbert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">"premise"</span>], examples[<span class="hljs-string">"hypothesis"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>) | |
| train_dataset = tokenized_datasets[<span class="hljs-string">"train"</span>].to_tf_dataset( | |
| columns=[<span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"labels"</span>], batch_size=<span class="hljs-number">16</span>, shuffle=<span class="hljs-literal">True</span> | |
| ) | |
| validation_dataset = tokenized_datasets[<span class="hljs-string">"validation_matched"</span>].to_tf_dataset( | |
| columns=[<span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"labels"</span>], batch_size=<span class="hljs-number">16</span>, shuffle=<span class="hljs-literal">True</span> | |
| ) | |
| model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
| model.<span class="hljs-built_in">compile</span>(loss=<span class="hljs-string">"sparse_categorical_crossentropy"</span>, optimizer=<span class="hljs-string">"adam"</span>) | |
| model.fit(train_dataset)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1cpnctt">If you try to execute it, you might get some <code>VisibleDeprecationWarning</code>s when doing the dataset conversion — this is a known UX issue we have, so please ignore it. If you’re reading the course after, say, November 2021 and it’s still happening, then send rage tweets at @carrigmat until he fixes it.</p> <p data-svelte-h="svelte-kfr7z5">What’s a more serious problem, though, is that we get an outright error. And it’s really, terrifyingly long:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ValueError: No gradients provided <span class="hljs-keyword">for</span> <span class="hljs-built_in">any</span> variable: [<span class="hljs-string">'tf_distil_bert_for_sequence_classification/distilbert/embeddings/word_embeddings/weight:0'</span>, <span class="hljs-string">'...'</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ptmr0s">What does that mean? We tried to train on our data, but we got no gradient? This is pretty perplexing; how do we even begin to debug something like that? When the error you get doesn’t immediately suggest where the problem is, the best solution is often to walk through things in sequence, making sure at each stage that everything looks right. And of course, the place to start is always to…</p> <h3 class="relative group"><a id="check-your-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#check-your-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Check your data</span></h3> <p data-svelte-h="svelte-17kaoxp">This goes without saying, but if your data is corrupted, Keras is not going to be able to fix it for you. So first things first, you need to have a look at what is inside your training set.</p> <p data-svelte-h="svelte-1j2vnd8">Although it’s tempting to look inside <code>raw_datasets</code> and <code>tokenized_datasets</code>, we highly recommend you go to the data right at the point where it’s going to enter the model. That means reading an output from the <code>tf.data.Dataset</code> you created with the <code>to_tf_dataset()</code> function! So how do we do that? <code>tf.data.Dataset</code> objects give us whole batches at a time and don’t support indexing, so we can’t just ask for <code>train_dataset[0]</code>. We can, however, ask it politely for a batch:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataset: | |
| <span class="hljs-keyword">break</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1r7b0mn"><code>break</code> ends the loop after one iteration, so this grabs the first batch that comes out of <code>train_dataset</code> and saves it as <code>batch</code>. Now, let’s take a look at what’s inside:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'attention_mask'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">76</span>), dtype=int64, numpy= | |
| array([[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])>, | |
| <span class="hljs-string">'label'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">1</span>])>, | |
| <span class="hljs-string">'input_ids'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">76</span>), dtype=int64, numpy= | |
| array([[ <span class="hljs-number">101</span>, <span class="hljs-number">2174</span>, <span class="hljs-number">1010</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">3174</span>, <span class="hljs-number">2420</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">2044</span>, <span class="hljs-number">2048</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">3398</span>, <span class="hljs-number">3398</span>, ..., <span class="hljs-number">2051</span>, <span class="hljs-number">2894</span>, <span class="hljs-number">102</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">4124</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1999</span>, <span class="hljs-number">2070</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cjvsar">This looks right, doesn’t it? We’re passing the <code>labels</code>, <code>attention_mask</code>, and <code>input_ids</code> to the model, which should be everything it needs to compute outputs and calculate the loss. So why don’t we have a gradient? Look closer: we’re passing a single dictionary as input, but a training batch is usually an input tensor or dictionary, plus a labels tensor. Our labels are just a key in our input dictionary.</p> <p data-svelte-h="svelte-advhgk">Is this a problem? Not always, actually! But it’s one of the most common issues you’ll encounter when training Transformer models with TensorFlow. Our models can all compute loss internally, but to do that the labels need to be passed in the input dictionary. This is the loss that is used when we don’t specify a loss value to <code>compile()</code>. Keras, on the other hand, usually expects labels to be passed separately from the input dictionary, and loss computations will usually fail if you don’t do that.</p> <p data-svelte-h="svelte-gjcw8h">The problem has now become clearer: we passed a <code>loss</code> argument, which means we’re asking Keras to compute losses for us, but we passed our labels as inputs to the model, not as labels in the place Keras expects them! We need to choose one or the other: either we use the model’s internal loss and keep the labels where they are, or we keep using Keras losses, but we move the labels to the place Keras expects them. For simplicity, let’s take the first approach. Change the call to <code>compile()</code> to read:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.<span class="hljs-built_in">compile</span>(optimizer=<span class="hljs-string">"adam"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1070po5">Now we’ll use the model’s internal loss, and this problem should be resolved!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-eoh1kk">✏️ <strong>Your turn!</strong> As an optional challenge after we’ve resolved the other issues, you can try coming back to this step and getting the model to work with the original Keras-computed loss instead of the internal loss. You’ll need to add <code>"labels"</code> to the <code>label_cols</code> argument of <code>to_tf_dataset()</code> to ensure that the labels are correctly outputted, which will get you gradients — but there’s one more problem with the loss that we specified. Training will still run with this problem, but learning will be very slow and will plateau at a high training loss. Can you figure out what it is?</p> <p data-svelte-h="svelte-8egii9">A ROT13-encoded hint, if you’re stuck: Vs lbh ybbx ng gur bhgchgf bs FrdhraprPynffvsvpngvba zbqryf va Genafsbezref, gurve svefg bhgchg vf <code>ybtvgf</code>. Jung ner ybtvgf?</p> <p data-svelte-h="svelte-izrj8u">And a second hint: Jura lbh fcrpvsl bcgvzvmref, npgvingvbaf be ybffrf jvgu fgevatf, Xrenf frgf nyy gur nethzrag inyhrf gb gurve qrsnhygf. Jung nethzragf qbrf FcnefrPngrtbevpnyPebffragebcl unir, naq jung ner gurve qrsnhygf?</p></div> <p data-svelte-h="svelte-10wlyhd">Now, let’s try training. We should get gradients now, so hopefully (ominous music plays here) we can just call <code>model.fit()</code> and everything will work fine!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> <span class="hljs-number">246</span>/<span class="hljs-number">24543</span> [..............................] - ETA: <span class="hljs-number">15</span>:<span class="hljs-number">52</span> - loss: nan<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ksxchw">Oh no.</p> <p data-svelte-h="svelte-ygaffu"><code>nan</code> is not a very encouraging loss value. Still, we’ve checked our data, and it looks pretty good. If that’s not the problem, where can we go next? The obvious next step is to…</p> <h3 class="relative group"><a id="check-your-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#check-your-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Check your model</span></h3> <p data-svelte-h="svelte-1uzyoya"><code>model.fit()</code> is a really great convenience function in Keras, but it does a lot of things for you, and that can make it trickier to find exactly where a problem has occurred. If you’re debugging your model, one strategy that can really help is to pass just a single batch to the model, and look at the outputs for that one batch in detail. Another really helpful tip if the model is throwing errors is to <code>compile()</code> the model with <code>run_eagerly=True</code>. This will make it a lot slower, but it will make the error messages much more comprehensible, because they’ll indicate exactly where in your model’s code the problem occurred.</p> <p data-svelte-h="svelte-zwnuy3">For now, though, we don’t need <code>run_eagerly</code> just yet. Let’s run the <code>batch</code> we got before through the model and see what the outputs look like:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model(batch)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->TFSequenceClassifierOutput(loss=<tf.Tensor: shape=(<span class="hljs-number">16</span>,), dtype=float32, numpy= | |
| array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, | |
| nan, nan, nan], dtype=float32)>, logits=<tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">2</span>), dtype=float32, numpy= | |
| array([[nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan], | |
| [nan, nan]], dtype=float32)>, hidden_states=<span class="hljs-literal">None</span>, attentions=<span class="hljs-literal">None</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vdbbfv">Well, this is tricky. Everything is <code>nan</code>! But that’s strange, isn’t it? How would all our logits become <code>nan</code>? <code>nan</code> means “not a number.” <code>nan</code> values often occur when you perform a forbidden operation, such as division by zero. But one thing that’s very important to know about <code>nan</code> in machine learning is that this value tends to <em>propagate</em>. If you multiply a number by <code>nan</code>, the output is also <code>nan</code>. And if you get a <code>nan</code> anywhere in your output, your loss, or your gradient, then it will rapidly spread throughout your whole model — because when that <code>nan</code> value is propagated back through your network, you’ll get <code>nan</code> gradients, and when weight updates are computed with those gradients, you’ll get <code>nan</code> weights, and those weights will compute even more <code>nan</code> outputs! Soon enough the whole network will just be one big block of <code>nan</code>s. Once that happens, it’s pretty hard to see where the problem started. How can we isolate where <code>nan</code> first crept in?</p> <p data-svelte-h="svelte-1ng71q0">The answer is to try <em>reinitializing</em> our model. Once we started training, we got a <code>nan</code> somewhere and it quickly propagated through the whole model. So, let’s load the model from a checkpoint and not do any weight updates, and see where we get a <code>nan</code> value:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
| model(batch)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mgm7a4">When we run that, we get:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->TFSequenceClassifierOutput(loss=<tf.Tensor: shape=(<span class="hljs-number">16</span>,), dtype=float32, numpy= | |
| array([<span class="hljs-number">0.6844486</span> , nan, nan, <span class="hljs-number">0.67127866</span>, <span class="hljs-number">0.7068601</span> , | |
| nan, <span class="hljs-number">0.69309855</span>, nan, <span class="hljs-number">0.65531296</span>, nan, | |
| nan, nan, <span class="hljs-number">0.675402</span> , nan, nan, | |
| <span class="hljs-number">0.69831556</span>], dtype=float32)>, logits=<tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">2</span>), dtype=float32, numpy= | |
| array([[-<span class="hljs-number">0.04761693</span>, -<span class="hljs-number">0.06509043</span>], | |
| [-<span class="hljs-number">0.0481936</span> , -<span class="hljs-number">0.04556257</span>], | |
| [-<span class="hljs-number">0.0040929</span> , -<span class="hljs-number">0.05848458</span>], | |
| [-<span class="hljs-number">0.02417453</span>, -<span class="hljs-number">0.0684005</span> ], | |
| [-<span class="hljs-number">0.02517801</span>, -<span class="hljs-number">0.05241832</span>], | |
| [-<span class="hljs-number">0.04514256</span>, -<span class="hljs-number">0.0757378</span> ], | |
| [-<span class="hljs-number">0.02656011</span>, -<span class="hljs-number">0.02646275</span>], | |
| [ <span class="hljs-number">0.00766164</span>, -<span class="hljs-number">0.04350497</span>], | |
| [ <span class="hljs-number">0.02060014</span>, -<span class="hljs-number">0.05655622</span>], | |
| [-<span class="hljs-number">0.02615328</span>, -<span class="hljs-number">0.0447021</span> ], | |
| [-<span class="hljs-number">0.05119278</span>, -<span class="hljs-number">0.06928903</span>], | |
| [-<span class="hljs-number">0.02859691</span>, -<span class="hljs-number">0.04879177</span>], | |
| [-<span class="hljs-number">0.02210129</span>, -<span class="hljs-number">0.05791225</span>], | |
| [-<span class="hljs-number">0.02363213</span>, -<span class="hljs-number">0.05962167</span>], | |
| [-<span class="hljs-number">0.05352269</span>, -<span class="hljs-number">0.0481673</span> ], | |
| [-<span class="hljs-number">0.08141848</span>, -<span class="hljs-number">0.07110836</span>]], dtype=float32)>, hidden_states=<span class="hljs-literal">None</span>, attentions=<span class="hljs-literal">None</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-n2u44p"><em>Now</em> we’re getting somewhere! There are no <code>nan</code> values in our logits, which is reassuring. But we do see a few <code>nan</code> values in our loss! Is there something about those samples in particular that’s causing this problem? Let’s see which ones they are (note that if you run this code yourself, you may get different indices because the dataset has been shuffled):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| loss = model(batch).loss.numpy() | |
| indices = np.flatnonzero(np.isnan(loss)) | |
| indices<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->array([ <span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">5</span>, <span class="hljs-number">7</span>, <span class="hljs-number">9</span>, <span class="hljs-number">10</span>, <span class="hljs-number">11</span>, <span class="hljs-number">13</span>, <span class="hljs-number">14</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-kc1sp3">Let’s look at the samples these indices came from:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->input_ids = batch[<span class="hljs-string">"input_ids"</span>].numpy() | |
| input_ids[indices]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->array([[ <span class="hljs-number">101</span>, <span class="hljs-number">2007</span>, <span class="hljs-number">2032</span>, <span class="hljs-number">2001</span>, <span class="hljs-number">1037</span>, <span class="hljs-number">16480</span>, <span class="hljs-number">3917</span>, <span class="hljs-number">2594</span>, <span class="hljs-number">4135</span>, | |
| <span class="hljs-number">23212</span>, <span class="hljs-number">3070</span>, <span class="hljs-number">2214</span>, <span class="hljs-number">10170</span>, <span class="hljs-number">1010</span>, <span class="hljs-number">2012</span>, <span class="hljs-number">4356</span>, <span class="hljs-number">1997</span>, <span class="hljs-number">3183</span>, | |
| <span class="hljs-number">6838</span>, <span class="hljs-number">12953</span>, <span class="hljs-number">2039</span>, <span class="hljs-number">2000</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">6147</span>, <span class="hljs-number">1997</span>, <span class="hljs-number">2010</span>, <span class="hljs-number">2606</span>, | |
| <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">6838</span>, <span class="hljs-number">2001</span>, <span class="hljs-number">3294</span>, <span class="hljs-number">6625</span>, <span class="hljs-number">3773</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">2214</span>, | |
| <span class="hljs-number">2158</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1998</span>, <span class="hljs-number">6814</span>, <span class="hljs-number">2016</span>, <span class="hljs-number">2234</span>, <span class="hljs-number">2461</span>, <span class="hljs-number">2153</span>, <span class="hljs-number">1998</span>, <span class="hljs-number">13322</span>, | |
| <span class="hljs-number">2009</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2045</span>, <span class="hljs-number">1005</span>, <span class="hljs-number">1055</span>, <span class="hljs-number">2053</span>, <span class="hljs-number">3382</span>, <span class="hljs-number">2008</span>, | |
| <span class="hljs-number">2016</span>, <span class="hljs-number">1005</span>, <span class="hljs-number">2222</span>, <span class="hljs-number">3046</span>, <span class="hljs-number">8103</span>, <span class="hljs-number">2075</span>, <span class="hljs-number">2009</span>, <span class="hljs-number">2153</span>, <span class="hljs-number">1012</span>, | |
| <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1998</span>, <span class="hljs-number">2007</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">3712</span>, <span class="hljs-number">4634</span>, <span class="hljs-number">1010</span>, <span class="hljs-number">2057</span>, <span class="hljs-number">8108</span>, | |
| <span class="hljs-number">2025</span>, <span class="hljs-number">3404</span>, <span class="hljs-number">2028</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">2616</span>, <span class="hljs-number">18449</span>, <span class="hljs-number">2125</span>, <span class="hljs-number">1999</span>, | |
| <span class="hljs-number">1037</span>, <span class="hljs-number">9666</span>, <span class="hljs-number">1997</span>, <span class="hljs-number">4100</span>, <span class="hljs-number">8663</span>, <span class="hljs-number">11020</span>, <span class="hljs-number">6313</span>, <span class="hljs-number">2791</span>, <span class="hljs-number">1998</span>, | |
| <span class="hljs-number">2431</span>, <span class="hljs-number">1011</span>, <span class="hljs-number">4301</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2028</span>, <span class="hljs-number">1005</span>, <span class="hljs-number">1055</span>, <span class="hljs-number">5177</span>, | |
| <span class="hljs-number">2110</span>, <span class="hljs-number">1998</span>, <span class="hljs-number">3977</span>, <span class="hljs-number">2000</span>, <span class="hljs-number">2832</span>, <span class="hljs-number">2106</span>, <span class="hljs-number">2025</span>, <span class="hljs-number">2689</span>, <span class="hljs-number">2104</span>, | |
| <span class="hljs-number">2122</span>, <span class="hljs-number">6214</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1045</span>, <span class="hljs-number">2001</span>, <span class="hljs-number">1999</span>, <span class="hljs-number">1037</span>, <span class="hljs-number">13090</span>, <span class="hljs-number">5948</span>, <span class="hljs-number">2007</span>, <span class="hljs-number">2048</span>, | |
| <span class="hljs-number">2308</span>, <span class="hljs-number">2006</span>, <span class="hljs-number">2026</span>, <span class="hljs-number">5001</span>, <span class="hljs-number">2043</span>, <span class="hljs-number">2026</span>, <span class="hljs-number">2171</span>, <span class="hljs-number">2001</span>, <span class="hljs-number">2170</span>, | |
| <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">1045</span>, <span class="hljs-number">2001</span>, <span class="hljs-number">3564</span>, <span class="hljs-number">1999</span>, <span class="hljs-number">2277</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">2195</span>, <span class="hljs-number">4279</span>, <span class="hljs-number">2191</span>, <span class="hljs-number">2039</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">2181</span>, <span class="hljs-number">2124</span>, <span class="hljs-number">2004</span>, | |
| <span class="hljs-number">1996</span>, <span class="hljs-number">2225</span>, <span class="hljs-number">7363</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2045</span>, <span class="hljs-number">2003</span>, <span class="hljs-number">2069</span>, <span class="hljs-number">2028</span>, | |
| <span class="hljs-number">2451</span>, <span class="hljs-number">1999</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">2225</span>, <span class="hljs-number">7363</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">2061</span>, <span class="hljs-number">2008</span>, <span class="hljs-number">1045</span>, <span class="hljs-number">2123</span>, <span class="hljs-number">1005</span>, <span class="hljs-number">1056</span>, <span class="hljs-number">2113</span>, <span class="hljs-number">2065</span>, | |
| <span class="hljs-number">2009</span>, <span class="hljs-number">2428</span>, <span class="hljs-number">10654</span>, <span class="hljs-number">7347</span>, <span class="hljs-number">2030</span>, <span class="hljs-number">2009</span>, <span class="hljs-number">7126</span>, <span class="hljs-number">2256</span>, <span class="hljs-number">2495</span>, | |
| <span class="hljs-number">2291</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2009</span>, <span class="hljs-number">2003</span>, <span class="hljs-number">5094</span>, <span class="hljs-number">2256</span>, <span class="hljs-number">2495</span>, <span class="hljs-number">2291</span>, <span class="hljs-number">2035</span>, | |
| <span class="hljs-number">2105</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">2051</span>, <span class="hljs-number">1010</span>, <span class="hljs-number">2029</span>, <span class="hljs-number">3216</span>, <span class="hljs-number">2019</span>, <span class="hljs-number">2503</span>, <span class="hljs-number">3444</span>, <span class="hljs-number">1010</span>, | |
| <span class="hljs-number">6732</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">2265</span>, <span class="hljs-number">2038</span>, <span class="hljs-number">19840</span>, <span class="hljs-number">2098</span>, <span class="hljs-number">2125</span>, <span class="hljs-number">9906</span>, <span class="hljs-number">1998</span>, | |
| <span class="hljs-number">2003</span>, <span class="hljs-number">2770</span>, <span class="hljs-number">2041</span>, <span class="hljs-number">1997</span>, <span class="hljs-number">4784</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2051</span>, <span class="hljs-number">6732</span>, | |
| <span class="hljs-number">1996</span>, <span class="hljs-number">2265</span>, <span class="hljs-number">2003</span>, <span class="hljs-number">9525</span>, <span class="hljs-number">1998</span>, <span class="hljs-number">4569</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1996</span>, <span class="hljs-number">10556</span>, <span class="hljs-number">2140</span>, <span class="hljs-number">11515</span>, <span class="hljs-number">2058</span>, <span class="hljs-number">1010</span>, <span class="hljs-number">2010</span>, <span class="hljs-number">2162</span>, | |
| <span class="hljs-number">2252</span>, <span class="hljs-number">5689</span>, <span class="hljs-number">2013</span>, <span class="hljs-number">2010</span>, <span class="hljs-number">7223</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2043</span>, <span class="hljs-number">1996</span>, | |
| <span class="hljs-number">10556</span>, <span class="hljs-number">2140</span>, <span class="hljs-number">11515</span>, <span class="hljs-number">2058</span>, <span class="hljs-number">1010</span>, <span class="hljs-number">2010</span>, <span class="hljs-number">2252</span>, <span class="hljs-number">3062</span>, <span class="hljs-number">2000</span>, | |
| <span class="hljs-number">1996</span>, <span class="hljs-number">2598</span>, <span class="hljs-number">1012</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">13543</span>, <span class="hljs-number">1999</span>, <span class="hljs-number">2049</span>, <span class="hljs-number">6143</span>, <span class="hljs-number">2933</span>, <span class="hljs-number">2443</span>, <span class="hljs-number">102</span>, <span class="hljs-number">2025</span>, | |
| <span class="hljs-number">13543</span>, <span class="hljs-number">1999</span>, <span class="hljs-number">6143</span>, <span class="hljs-number">2933</span>, <span class="hljs-number">2003</span>, <span class="hljs-number">2443</span>, <span class="hljs-number">102</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-nb4d2c">Well, there’s a lot in here, but nothing stands out as unusual. Let’s look at the labels:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->labels = batch[<span class="hljs-string">'labels'</span>].numpy() | |
| labels[indices]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->array([<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1mjpex2">Ah! The <code>nan</code> samples all have the same label, and it’s label 2. This is a very strong hint. The fact that we’re only getting a loss of <code>nan</code> when our label is 2 suggests that this is a very good time to check the number of labels in our model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.config.num_labels<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">2</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8tvrnv">Now we see the problem: the model thinks there are only two classes, but the labels go up to 2, which means there are in fact three classes (because 0 is also a class). This is how we got a <code>nan</code> — by trying to compute the loss for a nonexistent class! Let’s try changing that and fitting the model again:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, <span class="hljs-attribute">num_labels</span>=3) | |
| model.compile(<span class="hljs-attribute">optimizer</span>=<span class="hljs-string">'adam'</span>) | |
| model.fit(train_dataset)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> <span class="hljs-number">869</span>/<span class="hljs-number">24543</span> [>.............................] - ETA: <span class="hljs-number">15</span>:<span class="hljs-number">29</span> - loss: <span class="hljs-number">1.1032</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1lpz9k6">We’re training! No more <code>nan</code>s, and our loss is declining… sort of. If you watch it for a while, you might start to get a bit impatient, because the loss value stays stubbornly high. Let’s stop training here and try to think about what could be causing this problem. At this point, we’re pretty sure both the data and the model are okay, but our model isn’t learning well. What else is left? It’s time to…</p> <h3 class="relative group"><a id="check-your-hyperparameters" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#check-your-hyperparameters"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Check your hyperparameters</span></h3> <p data-svelte-h="svelte-1erdz5n">If you look back at the code above, you might not be able to see any hyperparameters at all, except perhaps the <code>batch_size</code>, and that doesn’t seem like a likely culprit. Don’t be fooled, though; there are always hyperparameters, and if you can’t see them, it just means that you don’t know what they’re set to. In particular, remember a critical thing about Keras: if you set a loss, optimizer, or activation function with a string, <em>all of its arguments will be set to their default values</em>. This means that even though using strings for this is very convenient, you should be very careful when doing so, as it can easily hide critical things from you. (Anyone trying the optional challenge above should take careful note of this fact.)</p> <p data-svelte-h="svelte-16qubtb">In this case, where have we set an argument with a string? We were setting the loss with a string initially, but we’re not doing that anymore. We are, however, setting the optimizer with a string. Could that be hiding anything from us? Let’s take a look at <a href="https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam" rel="nofollow">its arguments</a>.</p> <p data-svelte-h="svelte-ev3jp2">Does anything stand out here? That’s right — the learning rate! When we just use the string <code>'adam'</code>, we’re going to get the default learning rate, which is 0.001, or 1e-3. This is way too high for a Transformer model! In general, we recommend trying learning rates between 1e-5 and 1e-4 for your models; that’s somewhere between 10X and 100X smaller than the value we’re actually using here. That sounds like it might be a major problem, so let’s try reducing it. To do that, we need to import the actual <code>optimizer</code> object. While we’re at it, let’s reinitialize the model from the checkpoint, in case training with the high learning rate damaged its weights:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> tensorflow.keras.optimizers <span class="hljs-keyword">import</span> Adam | |
| model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
| model.<span class="hljs-built_in">compile</span>(optimizer=Adam(<span class="hljs-number">5e-5</span>))<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-or1qvk">💡 You can also import the <code>create_optimizer()</code> function from 🤗 Transformers, which will give you an AdamW optimizer with correct weight decay as well as learning rate warmup and decay. This optimizer will often produce slightly better results than the ones you get with the default Adam optimizer.</p></div> <p data-svelte-h="svelte-1rbb4n7">Now, we can try fitting the model with the new, improved learning rate:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model.fit(train_dataset)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">319</span>/<span class="hljs-number">24543</span> [..............................] - ETA: <span class="hljs-number">16</span>:07 - loss: <span class="hljs-number">0.9718</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ortff">Now our loss is really going somewhere! Training finally looks like it’s working. There’s a lesson here: when your model is running but loss isn’t declining, and you’re sure your data is okay, it’s a good idea to check hyperparameters like the learning rate and weight decay. Setting either of those too high is very likely to cause training to “stall” at a high loss value.</p> <h2 class="relative group"><a id="other-potential-issues" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#other-potential-issues"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Other potential issues</span></h2> <p data-svelte-h="svelte-1kstazi">We’ve covered the issues in the script above, but there are several other common errors you might face. Let’s take a look at a (very incomplete) list.</p> <h3 class="relative group"><a id="dealing-with-out-of-memory-errors" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dealing-with-out-of-memory-errors"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Dealing with out-of-memory errors</span></h3> <p data-svelte-h="svelte-roog6f">The telltale sign of running out of memory is an error like “OOM when allocating tensor” — OOM is short for “out of memory.” This is a very common hazard when dealing with large language models. If you encounter this, a good strategy is to halve your batch size and try again. Bear in mind, though, that some models are <em>very</em> large. For example, the full-size GPT-2 has 1.5B parameters, which means you’ll need 6 GB of memory just to store the model, and another 6 GB for its gradients! Training the full GPT-2 model will usually require over 20 GB of VRAM no matter what batch size you use, which only a few GPUs have. More lightweight models like <code>distilbert-base-cased</code> are much easier to run, and train much more quickly too.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-9j5678">In the next part of the course, we’ll look at more advanced techniques that can help you reduce your memory footprint and let you fine-tune the biggest models.</p></div> <h3 class="relative group"><a id="hungry-hungry-tensorflow" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#hungry-hungry-tensorflow"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Hungry Hungry TensorFlow 🦛</span></h3> <p data-svelte-h="svelte-a90tcv">One particular quirk of TensorFlow that you should be aware of is that it allocates <em>all</em> of your GPU memory to itself as soon as you load a model or do any training, and then it divides up that memory as required. This is different from the behavior of other frameworks, like PyTorch, which allocate memory as required with CUDA rather than doing it internally. One advantage of the TensorFlow approach is that it can often give useful errors when you run out of memory, and it can recover from that state without crashing the whole CUDA kernel. But there’s also an important downside: if you run two TensorFlow processes at once, then <strong>you’re going to have a bad time</strong>.</p> <p data-svelte-h="svelte-11jlq39">If you’re running on Colab you don’t need to worry about this, but if you’re running locally this is definitely something you should be careful about. In particular, be aware that closing a notebook tab does not necessarily shut that notebook down! You may need to select running notebooks (the ones with a green icon) and manually shut them down in the directory listing. Any running notebook that was using TensorFlow could still be holding on to a bunch of your GPU memory, and that means any new notebook you start may encounter some very odd issues.</p> <p data-svelte-h="svelte-1omyoz1">If you start getting errors about CUDA, BLAS, or cuBLAS in code that worked before, this is very often the culprit. You can use a command like <code>nvidia-smi</code> to check — when you shut down or restart your current notebook, is most of your memory free, or is it still in use? If it’s still in use, something else is holding on to it!</p> <h3 class="relative group"><a id="check-your-data-again" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#check-your-data-again"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Check your data (again!)</span></h3> <p data-svelte-h="svelte-yl2m3g">Your model will only learn something if it’s actually possible to learn anything from your data. If there is a bug that corrupts the data or the labels are attributed randomly, it’s very likely you won’t get any model training on your dataset. One helpful tool here is <code>tokenizer.decode()</code>. This will turn <code>input_ids</code> back into strings, so you can view the data and see if your training data is teaching what you want it to teach. For example, after you get a <code>batch</code> from your <code>tf.data.Dataset</code> like we did above, you can decode the first element like so:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->input_ids = batch[<span class="hljs-string">"input_ids"</span>].numpy() | |
| tokenizer.decode(input_ids[<span class="hljs-number">0</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1bj28iy">Then you can compare it with the first label, like so:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->labels = batch[<span class="hljs-string">"labels"</span>].numpy() | |
| label = labels[<span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ixs1c3">Once you can view your data like this, you can ask yourself the following questions:</p> <ul data-svelte-h="svelte-1m9el5s"><li>Is the decoded data understandable?</li> <li>Do you agree with the labels?</li> <li>Is there one label that’s more common than the others?</li> <li>What should the loss/metric be if the model predicted a random answer/always the same answer?</li></ul> <p data-svelte-h="svelte-1u7fsc">After looking at your data, go through a few of the model’s predictions — if your model outputs tokens, try decoding them too! If the model is always predicting the same thing it might be because your dataset is biased toward one category (for classification problems), so techniques like oversampling rare classes might help. Alternatively, this can also be caused by training issues like bad hyperparameter settings.</p> <p data-svelte-h="svelte-1uctb5u">If the loss/metric you get on your initial model before any training is very different from the loss/metric you would expect for random predictions, double-check the way your loss or metric is computed, as there is probably a bug there. If you are using several losses that you add at the end, make sure they are of the same scale.</p> <p data-svelte-h="svelte-8qgi72">When you are sure your data is perfect, you can see if the model is capable of training on it with one simple test.</p> <h3 class="relative group"><a id="overfit-your-model-on-one-batch" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#overfit-your-model-on-one-batch"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Overfit your model on one batch</span></h3> <p data-svelte-h="svelte-ud8t14">Overfitting is usually something we try to avoid when training, as it means the model is not learning to recognize the general features we want it to but is instead just memorizing the training samples. However, trying to train your model on one batch over and over again is a good test to check if the problem as you framed it can be solved by the model you are attempting to train. It will also help you see if your initial learning rate is too high.</p> <p data-svelte-h="svelte-10eu86u">Doing this once you have defined your <code>model</code> is really easy; just grab a batch of training data, then treat that <code>batch</code> as your entire dataset, fitting on it for a large number of epochs:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataset: | |
| <span class="hljs-keyword">break</span> | |
| <span class="hljs-comment"># Make sure you have run model.compile() and set your optimizer,</span> | |
| <span class="hljs-comment"># and your loss/metrics if you're using them</span> | |
| model.fit(batch, epochs=<span class="hljs-number">20</span>)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-ye7yyo">💡 If your training data is unbalanced, make sure to build a batch of training data containing all the labels.</p></div> <p data-svelte-h="svelte-9qhcul">The resulting model should have close-to-perfect results on the <code>batch</code>, with a loss declining quickly toward 0 (or the minimum value for the loss you’re using).</p> <p data-svelte-h="svelte-1appsmg">If you don’t manage to have your model obtain perfect results like this, it means there is something wrong with the way you framed the problem or your data, so you should fix that. Only when you manage to pass the overfitting test can you be sure that your model can actually learn something.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-62ftx4">⚠️ You will have to recreate your model and recompile after this overfitting test, as the model obtained probably won’t be able to recover and learn something useful on your full dataset.</p></div> <h3 class="relative group"><a id="dont-tune-anything-until-you-have-a-first-baseline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dont-tune-anything-until-you-have-a-first-baseline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Don’t tune anything until you have a first baseline</span></h3> <p data-svelte-h="svelte-1btfwlq">Intense hyperparameter tuning is always emphasized as being the hardest part of machine learning, but it’s just the last step to help you gain a little bit on the metric. <em>Very</em> bad values for your hyperparameters, like using the default Adam learning rate of 1e-3 with a Transformer model, will make learning proceed very slowly or completely stall, of course, but most of the time “reasonable” hyperparameters, like a learning rate from 1e-5 to 5e-5, will work just fine to give you good results. So, don’t launch into a time-consuming and costly hyperparameter search until you have something that beats the baseline you have on your dataset.</p> <p data-svelte-h="svelte-ngq5bm">Once you have a good enough model, you can start tweaking a bit. Don’t try launching a thousand runs with different hyperparameters, but compare a couple of runs with different values for one hyperparameter to get an idea of which has the greatest impact.</p> <p data-svelte-h="svelte-17qlrbw">If you are tweaking the model itself, keep it simple and don’t try anything you can’t reasonably justify. Always make sure you go back to the overfitting test to verify that your change hasn’t had any unintended consequences.</p> <h3 class="relative group"><a id="ask-for-help" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ask-for-help"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ask for help</span></h3> <p data-svelte-h="svelte-1kejrvw">Hopefully you will have found some advice in this section that helped you solve your issue, but if that’s not the case, remember you can always ask the community on the <a href="https://discuss.huggingface.co/" rel="nofollow">forums</a>.</p> <p data-svelte-h="svelte-1wt0l4q">Here are some additional resources that may prove helpful:</p> <ul data-svelte-h="svelte-vkgmp2"><li><a href="https://docs.google.com/presentation/d/1yHLPvPhUs2KGI5ZWo0sU-PKU3GimAk3iTsI38Z-B5Gw/edit#slide=id.p" rel="nofollow">“Reproducibility as a vehicle for engineering best practices”</a> by Joel Grus</li> <li><a href="https://towardsdatascience.com/checklist-for-debugging-neural-networks-d8b2a9434f21" rel="nofollow">“Checklist for debugging neural networks”</a> by Cecelia Shao</li> <li><a href="https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765" rel="nofollow">“How to unit test machine learning code”</a> by Chase Roberts</li> <li><a href="http://karpathy.github.io/2019/04/25/recipe/" rel="nofollow">“A Recipe for Training Neural Networks”</a> by Andrej Karpathy</li></ul> <p data-svelte-h="svelte-vjjrmr">Of course, not every problem you encounter when training neural nets is your own fault! If you encounter something in the 🤗 Transformers or 🤗 Datasets library that does not seem right, you may have encountered a bug. You should definitely tell us all about it, and in the next section we’ll explain exactly how to do that.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/en/chapter8/4_tf.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1y0degu = { | |
| assets: "/docs/course/pr_1069/en", | |
| base: "/docs/course/pr_1069/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1069/en/_app/immutable/entry/start.c5306bb2.js"), | |
| import("/docs/course/pr_1069/en/_app/immutable/entry/app.4264f5f8.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 90], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 125 kB
- Xet hash:
- 69068de5b80b4ccc64b3f38ee4b665b0cf045c878c571ffd25c841635bfdbdac
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.