Buckets:
| <meta charset="utf-8" /><meta http-equiv="content-security-policy" content=""><meta name="hf:doc:metadata" content="{"local":"quick-start","sections":[{"local":"load-the-dataset-and-model","title":"Load the dataset and model"},{"local":"tokenize-the-dataset","title":"Tokenize the dataset"},{"local":"format-the-dataset","title":"Format the dataset"},{"local":"train-the-model","title":"Train the model"},{"local":"whats-next","title":"What's next?"}],"title":"Quick Start"}" data-svelte="svelte-1phssyn"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/assets/pages/__layout.svelte-efc77dbd.css"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/start-0f8c1da7.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/vendor-8138ceec.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/paths-4b3c6e7e.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/pages/__layout.svelte-efb8e839.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/pages/quickstart.mdx-c9b63796.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/Tip-12722dfc.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/IconCopyLink-2dd3a6ac.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/CodeBlock-fc89709f.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/Markdown-7202589c.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/IconTensorflow-7f573d67.js"> | |
| <h1 class="relative group"><a id="quick-start" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quick-start"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Quick Start | |
| </span></h1> | |
| <p>The quick start is intended for developers who are ready to dive in to the code, and see an end-to-end example of how they can integrate 🤗 Datasets into their model training workflow. For beginners who are looking for a gentler introduction, we recommend you begin with the <a href="./tutorial">tutorials</a>.</p> | |
| <p>In the quick start, you will walkthrough all the steps to fine-tune <a href="https://huggingface.co/bert-base-cased" rel="nofollow">BERT</a> on a paraphrase classification task. Depending on the specific dataset you use, these steps may vary, but the general steps of how to load a dataset and process it are the same.</p> | |
| <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p>For more detailed information on loading and processing a dataset, take a look at <a href="https://huggingface.co/course/chapter3/1?fw=pt" rel="nofollow">Chapter 3</a> of the Hugging Face course! It covers additional important topics like dynamic padding, and fine-tuning with the Trainer API.</p></div> | |
| <p>Get started by installing 🤗 Datasets:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START -->pip <span class="hljs-keyword">install</span> datasets<!-- HTML_TAG_END --></pre></div> | |
| <h2 class="relative group"><a id="load-the-dataset-and-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-the-dataset-and-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Load the dataset and model | |
| </span></h2> | |
| <p>Begin by loading the <a href="https://huggingface.co/datasets/glue/viewer/mrpc" rel="nofollow">Microsoft Research Paraphrase Corpus (MRPC)</a> training dataset from the <a href="https://huggingface.co/datasets/glue" rel="nofollow">General Language Understanding Evaluation (GLUE) benchmark</a>. MRPC is a corpus of human annotated sentence pairs used to train a model to determine whether sentence pairs are semantically equivalent.</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-meta">>>> </span>dataset = load_dataset(<span class="hljs-string">'glue'</span>, <span class="hljs-string">'mrpc'</span>, split=<span class="hljs-string">'train'</span>)<!-- HTML_TAG_END --></pre></div> | |
| <p>Next, import the pre-trained BERT model and its tokenizer from the <a href="https://huggingface.co/transformers/" rel="nofollow">🤗 Transformers</a> library:</p> | |
| <div class="space-y-10 py-6 2xl:py-8 2xl:-mx-4"> | |
| <div class="border border-gray-200 rounded-xl px-4 relative"><div class="flex h-[22px] mt-[-12.5px] justify-between leading-none"><div class="flex px-1 items-center space-x-1 bg-white dark:bg-gray-950"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> | |
| <span>Pytorch</span></div> | |
| <div class="cursor-pointer flex items-center justify-center space-x-1 text-sm px-2 bg-white dark:bg-gray-950 hover:underline leading-none"><svg class="" width="0.9em" height="0.9em" viewBox="0 0 10 9" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M1.39125 1.9725L0.0883333 0.669997L0.677917 0.0804138L8.9275 8.33041L8.33792 8.91958L6.95875 7.54041C6.22592 8.00523 5.37572 8.25138 4.50792 8.25C2.26125 8.25 0.392083 6.63333 0 4.5C0.179179 3.52946 0.667345 2.64287 1.39167 1.9725H1.39125ZM5.65667 6.23833L5.04667 5.62833C4.81335 5.73996 4.55116 5.77647 4.29622 5.73282C4.04129 5.68918 3.80617 5.56752 3.62328 5.38463C3.44039 5.20175 3.31874 4.96663 3.27509 4.71169C3.23144 4.45676 3.26795 4.19456 3.37958 3.96125L2.76958 3.35125C2.50447 3.75187 2.38595 4.2318 2.4341 4.70978C2.48225 5.18777 2.6941 5.63442 3.0338 5.97411C3.37349 6.31381 3.82015 6.52567 4.29813 6.57382C4.77611 6.62197 5.25605 6.50345 5.65667 6.23833ZM2.83042 1.06666C3.35 0.862497 3.91625 0.749997 4.50792 0.749997C6.75458 0.749997 8.62375 2.36666 9.01583 4.5C8.88816 5.19404 8.60119 5.84899 8.1775 6.41333L6.56917 4.805C6.61694 4.48317 6.58868 4.15463 6.48664 3.84569C6.3846 3.53675 6.21162 3.256 5.98156 3.02594C5.7515 2.79588 5.47075 2.6229 5.16181 2.52086C4.85287 2.41882 4.52433 2.39056 4.2025 2.43833L2.83042 1.06708V1.06666Z" fill="currentColor"></path></svg> | |
| <span>Hide Pytorch content</span></div></div> | |
| <div class="framework-content"> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, AutoTokenizer | |
| <span class="hljs-meta">>>> </span>model = AutoModelForSequenceClassification.from_pretrained(<span class="hljs-string">'bert-base-cased'</span>) | |
| Some weights of the model checkpoint at bert-base-cased were <span class="hljs-keyword">not</span> used when initializing BertForSequenceClassification: [<span class="hljs-string">'cls.predictions.bias'</span>, <span class="hljs-string">'cls.predictions.transform.dense.weight'</span>, <span class="hljs-string">'cls.predictions.transform.dense.bias'</span>, <span class="hljs-string">'cls.predictions.decoder.weight'</span>, <span class="hljs-string">'cls.seq_relationship.weight'</span>, <span class="hljs-string">'cls.seq_relationship.bias'</span>, <span class="hljs-string">'cls.predictions.transform.LayerNorm.weight'</span>, <span class="hljs-string">'cls.predictions.transform.LayerNorm.bias'</span>] | |
| - This IS expected <span class="hljs-keyword">if</span> you are initializing BertForSequenceClassification <span class="hljs-keyword">from</span> the checkpoint of a model trained on another task <span class="hljs-keyword">or</span> <span class="hljs-keyword">with</span> another architecture (e.g. initializing a BertForSequenceClassification model <span class="hljs-keyword">from</span> a BertForPretraining model). | |
| - This IS NOT expected <span class="hljs-keyword">if</span> you are initializing BertForSequenceClassification <span class="hljs-keyword">from</span> the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model <span class="hljs-keyword">from</span> a BertForSequenceClassification model). | |
| Some weights of BertForSequenceClassification were <span class="hljs-keyword">not</span> initialized <span class="hljs-keyword">from</span> the model checkpoint at bert-base-cased <span class="hljs-keyword">and</span> are newly initialized: [<span class="hljs-string">'classifier.weight'</span>, <span class="hljs-string">'classifier.bias'</span>] | |
| You should probably TRAIN this model on a down-stream task to be able to use it <span class="hljs-keyword">for</span> predictions <span class="hljs-keyword">and</span> inference. | |
| <span class="hljs-meta">>>> </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">'bert-base-cased'</span>)<!-- HTML_TAG_END --></pre></div></div></div> | |
| <div class="border border-gray-200 rounded-xl px-4 relative"><div class="flex h-[22px] mt-[-12.5px] justify-between leading-none"><div class="flex px-1 items-center space-x-1 bg-white dark:bg-gray-950"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> | |
| <span>TensorFlow</span></div> | |
| <div class="cursor-pointer flex items-center justify-center space-x-1 text-sm px-2 bg-white dark:bg-gray-950 hover:underline leading-none"><svg class="" width="0.9em" height="0.9em" viewBox="0 0 10 9" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M1.39125 1.9725L0.0883333 0.669997L0.677917 0.0804138L8.9275 8.33041L8.33792 8.91958L6.95875 7.54041C6.22592 8.00523 5.37572 8.25138 4.50792 8.25C2.26125 8.25 0.392083 6.63333 0 4.5C0.179179 3.52946 0.667345 2.64287 1.39167 1.9725H1.39125ZM5.65667 6.23833L5.04667 5.62833C4.81335 5.73996 4.55116 5.77647 4.29622 5.73282C4.04129 5.68918 3.80617 5.56752 3.62328 5.38463C3.44039 5.20175 3.31874 4.96663 3.27509 4.71169C3.23144 4.45676 3.26795 4.19456 3.37958 3.96125L2.76958 3.35125C2.50447 3.75187 2.38595 4.2318 2.4341 4.70978C2.48225 5.18777 2.6941 5.63442 3.0338 5.97411C3.37349 6.31381 3.82015 6.52567 4.29813 6.57382C4.77611 6.62197 5.25605 6.50345 5.65667 6.23833ZM2.83042 1.06666C3.35 0.862497 3.91625 0.749997 4.50792 0.749997C6.75458 0.749997 8.62375 2.36666 9.01583 4.5C8.88816 5.19404 8.60119 5.84899 8.1775 6.41333L6.56917 4.805C6.61694 4.48317 6.58868 4.15463 6.48664 3.84569C6.3846 3.53675 6.21162 3.256 5.98156 3.02594C5.7515 2.79588 5.47075 2.6229 5.16181 2.52086C4.85287 2.41882 4.52433 2.39056 4.2025 2.43833L2.83042 1.06708V1.06666Z" fill="currentColor"></path></svg> | |
| <span>Hide TensorFlow content</span></div></div> | |
| <div class="framework-content"> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TFAutoModelForSequenceClassification, AutoTokenizer | |
| <span class="hljs-meta">>>> </span>model = TFAutoModelForSequenceClassification.from_pretrained(<span class="hljs-string">"bert-base-cased"</span>) | |
| Some weights of the model checkpoint at bert-base-cased were <span class="hljs-keyword">not</span> used when initializing TFBertForSequenceClassification: [<span class="hljs-string">'nsp___cls'</span>, <span class="hljs-string">'mlm___cls'</span>] | |
| - This IS expected <span class="hljs-keyword">if</span> you are initializing TFBertForSequenceClassification <span class="hljs-keyword">from</span> the checkpoint of a model trained on another task <span class="hljs-keyword">or</span> <span class="hljs-keyword">with</span> another architecture (e.g. initializing a BertForSequenceClassification model <span class="hljs-keyword">from</span> a BertForPretraining model). | |
| - This IS NOT expected <span class="hljs-keyword">if</span> you are initializing TFBertForSequenceClassification <span class="hljs-keyword">from</span> the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model <span class="hljs-keyword">from</span> a BertForSequenceClassification model). | |
| Some weights of TFBertForSequenceClassification were <span class="hljs-keyword">not</span> initialized <span class="hljs-keyword">from</span> the model checkpoint at bert-base-cased <span class="hljs-keyword">and</span> are newly initialized: [<span class="hljs-string">'dropout_37'</span>, <span class="hljs-string">'classifier'</span>] | |
| You should probably TRAIN this model on a down-stream task to be able to use it <span class="hljs-keyword">for</span> predictions <span class="hljs-keyword">and</span> inference. | |
| <span class="hljs-meta">>>> </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">'bert-base-cased'</span>)<!-- HTML_TAG_END --></pre></div> | |
| </div></div> | |
| </div> | |
| <h2 class="relative group"><a id="tokenize-the-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#tokenize-the-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Tokenize the dataset | |
| </span></h2> | |
| <p>The next step is to tokenize the text in order to build sequences of integers the model can understand. Encode the entire dataset with <a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.map">Dataset.map()</a>, and truncate and pad the inputs to the maximum length of the model. This ensures the appropriate tensor batches are built.</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">encode</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">'sentence1'</span>], examples[<span class="hljs-string">'sentence2'</span>], truncation=<span class="hljs-literal">True</span>, padding=<span class="hljs-string">'max_length'</span>) | |
| <span class="hljs-meta">>>> </span>dataset = dataset.<span class="hljs-built_in">map</span>(encode, batched=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">>>> </span>dataset[<span class="hljs-number">0</span>] | |
| {<span class="hljs-string">'sentence1'</span>: <span class="hljs-string">'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .'</span>, | |
| <span class="hljs-string">'sentence2'</span>: <span class="hljs-string">'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'</span>, | |
| <span class="hljs-string">'label'</span>: <span class="hljs-number">1</span>, | |
| <span class="hljs-string">'idx'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'input_ids'</span>: array([ <span class="hljs-number">101</span>, <span class="hljs-number">7277</span>, <span class="hljs-number">2180</span>, <span class="hljs-number">5303</span>, <span class="hljs-number">4806</span>, <span class="hljs-number">1117</span>, <span class="hljs-number">1711</span>, <span class="hljs-number">117</span>, <span class="hljs-number">2292</span>, <span class="hljs-number">1119</span>, <span class="hljs-number">1270</span>, <span class="hljs-number">107</span>, <span class="hljs-number">1103</span>, <span class="hljs-number">7737</span>, <span class="hljs-number">107</span>, <span class="hljs-number">117</span>, <span class="hljs-number">1104</span>, <span class="hljs-number">9938</span>, <span class="hljs-number">4267</span>, <span class="hljs-number">12223</span>, <span class="hljs-number">21811</span>, <span class="hljs-number">1117</span>, <span class="hljs-number">2554</span>, <span class="hljs-number">119</span>, <span class="hljs-number">102</span>, <span class="hljs-number">11336</span>, <span class="hljs-number">6732</span>, <span class="hljs-number">3384</span>, <span class="hljs-number">1106</span>, <span class="hljs-number">1140</span>, <span class="hljs-number">1112</span>, <span class="hljs-number">1178</span>, <span class="hljs-number">107</span>, <span class="hljs-number">1103</span>, <span class="hljs-number">7737</span>, <span class="hljs-number">107</span>, <span class="hljs-number">117</span>, <span class="hljs-number">7277</span>, <span class="hljs-number">2180</span>, <span class="hljs-number">5303</span>, <span class="hljs-number">4806</span>, <span class="hljs-number">1117</span>, <span class="hljs-number">1711</span>, <span class="hljs-number">1104</span>, <span class="hljs-number">9938</span>, <span class="hljs-number">4267</span>, <span class="hljs-number">12223</span>, <span class="hljs-number">21811</span>, <span class="hljs-number">1117</span>, <span class="hljs-number">2554</span>, <span class="hljs-number">119</span>, <span class="hljs-number">102</span>]), | |
| <span class="hljs-string">'token_type_ids'</span>: array([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>]), | |
| <span class="hljs-string">'attention_mask'</span>: array([<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>])}<!-- HTML_TAG_END --></pre></div> | |
| <p>Notice how there are three new columns in the dataset: <code>input_ids</code>, <code>token_type_ids</code>, and <code>attention_mask</code>. These columns are the inputs to the model.</p> | |
| <h2 class="relative group"><a id="format-the-dataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#format-the-dataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Format the dataset | |
| </span></h2> | |
| <p>Depending on whether you are using PyTorch, TensorFlow, or JAX, you will need to format the dataset accordingly. There are three changes you need to make to the dataset:</p> | |
| <div class="space-y-10 py-6 2xl:py-8 2xl:-mx-4"> | |
| <div class="border border-gray-200 rounded-xl px-4 relative"><div class="flex h-[22px] mt-[-12.5px] justify-between leading-none"><div class="flex px-1 items-center space-x-1 bg-white dark:bg-gray-950"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> | |
| <span>Pytorch</span></div> | |
| <div class="cursor-pointer flex items-center justify-center space-x-1 text-sm px-2 bg-white dark:bg-gray-950 hover:underline leading-none"><svg class="" width="0.9em" height="0.9em" viewBox="0 0 10 9" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M1.39125 1.9725L0.0883333 0.669997L0.677917 0.0804138L8.9275 8.33041L8.33792 8.91958L6.95875 7.54041C6.22592 8.00523 5.37572 8.25138 4.50792 8.25C2.26125 8.25 0.392083 6.63333 0 4.5C0.179179 3.52946 0.667345 2.64287 1.39167 1.9725H1.39125ZM5.65667 6.23833L5.04667 5.62833C4.81335 5.73996 4.55116 5.77647 4.29622 5.73282C4.04129 5.68918 3.80617 5.56752 3.62328 5.38463C3.44039 5.20175 3.31874 4.96663 3.27509 4.71169C3.23144 4.45676 3.26795 4.19456 3.37958 3.96125L2.76958 3.35125C2.50447 3.75187 2.38595 4.2318 2.4341 4.70978C2.48225 5.18777 2.6941 5.63442 3.0338 5.97411C3.37349 6.31381 3.82015 6.52567 4.29813 6.57382C4.77611 6.62197 5.25605 6.50345 5.65667 6.23833ZM2.83042 1.06666C3.35 0.862497 3.91625 0.749997 4.50792 0.749997C6.75458 0.749997 8.62375 2.36666 9.01583 4.5C8.88816 5.19404 8.60119 5.84899 8.1775 6.41333L6.56917 4.805C6.61694 4.48317 6.58868 4.15463 6.48664 3.84569C6.3846 3.53675 6.21162 3.256 5.98156 3.02594C5.7515 2.79588 5.47075 2.6229 5.16181 2.52086C4.85287 2.41882 4.52433 2.39056 4.2025 2.43833L2.83042 1.06708V1.06666Z" fill="currentColor"></path></svg> | |
| <span>Hide Pytorch content</span></div></div> | |
| <div class="framework-content"> | |
| <ol><li>Rename the <code>label</code> column to <code>labels</code>, the expected input name in <a href="https://huggingface.co/transformers/model_doc/bert#transformers.BertForSequenceClassification.forward" rel="nofollow">BertForSequenceClassification</a>:</li></ol> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>dataset = dataset.<span class="hljs-built_in">map</span>(<span class="hljs-keyword">lambda</span> examples: {<span class="hljs-string">'labels'</span>: examples[<span class="hljs-string">'label'</span>]}, batched=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> | |
| <ol start="2"><li>Retrieve the actual tensors from the Dataset object instead of using the current Python objects.</li> | |
| <li>Filter the dataset to only return the model inputs: <code>input_ids</code>, <code>token_type_ids</code>, and <code>attention_mask</code>.</li></ol> | |
| <p><a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.set_format">Dataset.set_format()</a> completes the last two steps on-the-fly. After you set the format, wrap the dataset in <code>torch.utils.data.DataLoader</code>:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-meta">>>> </span>dataset.set_format(<span class="hljs-built_in">type</span>=<span class="hljs-string">'torch'</span>, columns=[<span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'token_type_ids'</span>, <span class="hljs-string">'attention_mask'</span>, <span class="hljs-string">'labels'</span>]) | |
| <span class="hljs-meta">>>> </span>dataloader = torch.utils.data.DataLoader(dataset, batch_size=<span class="hljs-number">32</span>) | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(dataloader)) | |
| {<span class="hljs-string">'attention_mask'</span>: tensor([[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]]), | |
| <span class="hljs-string">'input_ids'</span>: tensor([[ <span class="hljs-number">101</span>, <span class="hljs-number">7277</span>, <span class="hljs-number">2180</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">10684</span>, <span class="hljs-number">2599</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1220</span>, <span class="hljs-number">1125</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">16944</span>, <span class="hljs-number">1107</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1109</span>, <span class="hljs-number">11896</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1109</span>, <span class="hljs-number">4173</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]]), | |
| <span class="hljs-string">'label'</span>: tensor([<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>]), | |
| <span class="hljs-string">'token_type_ids'</span>: tensor([[<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])}<!-- HTML_TAG_END --></pre></div></div></div> | |
| <div class="border border-gray-200 rounded-xl px-4 relative"><div class="flex h-[22px] mt-[-12.5px] justify-between leading-none"><div class="flex px-1 items-center space-x-1 bg-white dark:bg-gray-950"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> | |
| <span>TensorFlow</span></div> | |
| <div class="cursor-pointer flex items-center justify-center space-x-1 text-sm px-2 bg-white dark:bg-gray-950 hover:underline leading-none"><svg class="" width="0.9em" height="0.9em" viewBox="0 0 10 9" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M1.39125 1.9725L0.0883333 0.669997L0.677917 0.0804138L8.9275 8.33041L8.33792 8.91958L6.95875 7.54041C6.22592 8.00523 5.37572 8.25138 4.50792 8.25C2.26125 8.25 0.392083 6.63333 0 4.5C0.179179 3.52946 0.667345 2.64287 1.39167 1.9725H1.39125ZM5.65667 6.23833L5.04667 5.62833C4.81335 5.73996 4.55116 5.77647 4.29622 5.73282C4.04129 5.68918 3.80617 5.56752 3.62328 5.38463C3.44039 5.20175 3.31874 4.96663 3.27509 4.71169C3.23144 4.45676 3.26795 4.19456 3.37958 3.96125L2.76958 3.35125C2.50447 3.75187 2.38595 4.2318 2.4341 4.70978C2.48225 5.18777 2.6941 5.63442 3.0338 5.97411C3.37349 6.31381 3.82015 6.52567 4.29813 6.57382C4.77611 6.62197 5.25605 6.50345 5.65667 6.23833ZM2.83042 1.06666C3.35 0.862497 3.91625 0.749997 4.50792 0.749997C6.75458 0.749997 8.62375 2.36666 9.01583 4.5C8.88816 5.19404 8.60119 5.84899 8.1775 6.41333L6.56917 4.805C6.61694 4.48317 6.58868 4.15463 6.48664 3.84569C6.3846 3.53675 6.21162 3.256 5.98156 3.02594C5.7515 2.79588 5.47075 2.6229 5.16181 2.52086C4.85287 2.41882 4.52433 2.39056 4.2025 2.43833L2.83042 1.06708V1.06666Z" fill="currentColor"></path></svg> | |
| <span>Hide TensorFlow content</span></div></div> | |
| <div class="framework-content"> | |
| <ol><li>Rename the <code>label</code> column to <code>labels</code>, the expected input name in <a href="https://huggingface.co/transformers/model_doc/bert#tfbertforsequenceclassification" rel="nofollow">TFBertForSequenceClassification</a>:</li></ol> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>dataset = dataset.<span class="hljs-built_in">map</span>(<span class="hljs-keyword">lambda</span> examples: {<span class="hljs-string">'labels'</span>: examples[<span class="hljs-string">'label'</span>]}, batched=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> | |
| <ol start="2"><li>Retrieve the actual tensors from the Dataset object instead of using the current Python objects.</li> | |
| <li>Filter the dataset to only return the model inputs: <code>input_ids</code>, <code>token_type_ids</code>, and <code>attention_mask</code>.</li></ol> | |
| <p><a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.set_format">Dataset.set_format()</a> completes the last two steps on-the-fly. After you set the format, wrap the dataset in <code>tf.data.Dataset</code>:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-meta">>>> </span>dataset.set_format(<span class="hljs-built_in">type</span>=<span class="hljs-string">'tensorflow'</span>, columns=[<span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'token_type_ids'</span>, <span class="hljs-string">'attention_mask'</span>, <span class="hljs-string">'labels'</span>]) | |
| <span class="hljs-meta">>>> </span>features = {x: dataset[x].to_tensor(default_value=<span class="hljs-number">0</span>, shape=[<span class="hljs-literal">None</span>, tokenizer.model_max_length]) <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> [<span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'token_type_ids'</span>, <span class="hljs-string">'attention_mask'</span>]} | |
| <span class="hljs-meta">>>> </span>tfdataset = tf.data.Dataset.from_tensor_slices((features, dataset[<span class="hljs-string">"labels"</span>])).batch(<span class="hljs-number">32</span>) | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(tfdataset)) | |
| ({<span class="hljs-string">'input_ids'</span>: <tf.Tensor: shape=(<span class="hljs-number">32</span>, <span class="hljs-number">512</span>), dtype=int32, numpy= | |
| array([[ <span class="hljs-number">101</span>, <span class="hljs-number">7277</span>, <span class="hljs-number">2180</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">10684</span>, <span class="hljs-number">2599</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1220</span>, <span class="hljs-number">1125</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1109</span>, <span class="hljs-number">2026</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">22263</span>, <span class="hljs-number">1107</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">142</span>, <span class="hljs-number">1813</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]], dtype=int32)>, <span class="hljs-string">'token_type_ids'</span>: <tf.Tensor: shape=(<span class="hljs-number">32</span>, <span class="hljs-number">512</span>), dtype=int32, numpy= | |
| array([[<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]], dtype=int32)>, <span class="hljs-string">'attention_mask'</span>: <tf.Tensor: shape=(<span class="hljs-number">32</span>, <span class="hljs-number">512</span>), dtype=int32, numpy= | |
| array([[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]], dtype=int32)>}, <tf.Tensor: shape=(<span class="hljs-number">32</span>,), dtype=int64, numpy= | |
| array([<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, | |
| <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>])>)<!-- HTML_TAG_END --></pre></div> | |
| </div></div> | |
| </div> | |
| <h2 class="relative group"><a id="train-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#train-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Train the model | |
| </span></h2> | |
| <div class="space-y-10 py-6 2xl:py-8 2xl:-mx-4"> | |
| <div class="border border-gray-200 rounded-xl px-4 relative"><div class="flex h-[22px] mt-[-12.5px] justify-between leading-none"><div class="flex px-1 items-center space-x-1 bg-white dark:bg-gray-950"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> | |
| <span>Pytorch</span></div> | |
| <div class="cursor-pointer flex items-center justify-center space-x-1 text-sm px-2 bg-white dark:bg-gray-950 hover:underline leading-none"><svg class="" width="0.9em" height="0.9em" viewBox="0 0 10 9" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M1.39125 1.9725L0.0883333 0.669997L0.677917 0.0804138L8.9275 8.33041L8.33792 8.91958L6.95875 7.54041C6.22592 8.00523 5.37572 8.25138 4.50792 8.25C2.26125 8.25 0.392083 6.63333 0 4.5C0.179179 3.52946 0.667345 2.64287 1.39167 1.9725H1.39125ZM5.65667 6.23833L5.04667 5.62833C4.81335 5.73996 4.55116 5.77647 4.29622 5.73282C4.04129 5.68918 3.80617 5.56752 3.62328 5.38463C3.44039 5.20175 3.31874 4.96663 3.27509 4.71169C3.23144 4.45676 3.26795 4.19456 3.37958 3.96125L2.76958 3.35125C2.50447 3.75187 2.38595 4.2318 2.4341 4.70978C2.48225 5.18777 2.6941 5.63442 3.0338 5.97411C3.37349 6.31381 3.82015 6.52567 4.29813 6.57382C4.77611 6.62197 5.25605 6.50345 5.65667 6.23833ZM2.83042 1.06666C3.35 0.862497 3.91625 0.749997 4.50792 0.749997C6.75458 0.749997 8.62375 2.36666 9.01583 4.5C8.88816 5.19404 8.60119 5.84899 8.1775 6.41333L6.56917 4.805C6.61694 4.48317 6.58868 4.15463 6.48664 3.84569C6.3846 3.53675 6.21162 3.256 5.98156 3.02594C5.7515 2.79588 5.47075 2.6229 5.16181 2.52086C4.85287 2.41882 4.52433 2.39056 4.2025 2.43833L2.83042 1.06708V1.06666Z" fill="currentColor"></path></svg> | |
| <span>Hide Pytorch content</span></div></div> | |
| <div class="framework-content"> | |
| <p>Lastly, create a simple training loop and start training:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> tqdm <span class="hljs-keyword">import</span> tqdm | |
| <span class="hljs-meta">>>> </span>device = <span class="hljs-string">'cuda'</span> <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> <span class="hljs-string">'cpu'</span> | |
| <span class="hljs-meta">>>> </span>model.train().to(device) | |
| <span class="hljs-meta">>>> </span>optimizer = torch.optim.AdamW(params=model.parameters(), lr=<span class="hljs-number">1e-5</span>) | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">3</span>): | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> i, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(tqdm(dataloader)): | |
| <span class="hljs-meta">... </span> batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| <span class="hljs-meta">... </span> outputs = model(**batch) | |
| <span class="hljs-meta">... </span> loss = outputs[<span class="hljs-number">0</span>] | |
| <span class="hljs-meta">... </span> loss.backward() | |
| <span class="hljs-meta">... </span> optimizer.step() | |
| <span class="hljs-meta">... </span> optimizer.zero_grad() | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> i % <span class="hljs-number">10</span> == <span class="hljs-number">0</span>: | |
| <span class="hljs-meta">... </span> <span class="hljs-built_in">print</span>(<span class="hljs-string">f"loss: <span class="hljs-subst">{loss}</span>"</span>)<!-- HTML_TAG_END --></pre></div></div></div> | |
| <div class="border border-gray-200 rounded-xl px-4 relative"><div class="flex h-[22px] mt-[-12.5px] justify-between leading-none"><div class="flex px-1 items-center space-x-1 bg-white dark:bg-gray-950"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> | |
| <span>TensorFlow</span></div> | |
| <div class="cursor-pointer flex items-center justify-center space-x-1 text-sm px-2 bg-white dark:bg-gray-950 hover:underline leading-none"><svg class="" width="0.9em" height="0.9em" viewBox="0 0 10 9" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M1.39125 1.9725L0.0883333 0.669997L0.677917 0.0804138L8.9275 8.33041L8.33792 8.91958L6.95875 7.54041C6.22592 8.00523 5.37572 8.25138 4.50792 8.25C2.26125 8.25 0.392083 6.63333 0 4.5C0.179179 3.52946 0.667345 2.64287 1.39167 1.9725H1.39125ZM5.65667 6.23833L5.04667 5.62833C4.81335 5.73996 4.55116 5.77647 4.29622 5.73282C4.04129 5.68918 3.80617 5.56752 3.62328 5.38463C3.44039 5.20175 3.31874 4.96663 3.27509 4.71169C3.23144 4.45676 3.26795 4.19456 3.37958 3.96125L2.76958 3.35125C2.50447 3.75187 2.38595 4.2318 2.4341 4.70978C2.48225 5.18777 2.6941 5.63442 3.0338 5.97411C3.37349 6.31381 3.82015 6.52567 4.29813 6.57382C4.77611 6.62197 5.25605 6.50345 5.65667 6.23833ZM2.83042 1.06666C3.35 0.862497 3.91625 0.749997 4.50792 0.749997C6.75458 0.749997 8.62375 2.36666 9.01583 4.5C8.88816 5.19404 8.60119 5.84899 8.1775 6.41333L6.56917 4.805C6.61694 4.48317 6.58868 4.15463 6.48664 3.84569C6.3846 3.53675 6.21162 3.256 5.98156 3.02594C5.7515 2.79588 5.47075 2.6229 5.16181 2.52086C4.85287 2.41882 4.52433 2.39056 4.2025 2.43833L2.83042 1.06708V1.06666Z" fill="currentColor"></path></svg> | |
| <span>Hide TensorFlow content</span></div></div> | |
| <div class="framework-content"> | |
| <p>Lastly, compile the model and start training:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">>>> </span>opt = tf.keras.optimizers.Adam(learning_rate=<span class="hljs-number">3e-5</span>) | |
| <span class="hljs-meta">>>> </span>model.<span class="hljs-built_in">compile</span>(optimizer=opt, loss=loss_fn, metrics=[<span class="hljs-string">"accuracy"</span>]) | |
| <span class="hljs-meta">>>> </span>model.fit(tfdataset, epochs=<span class="hljs-number">3</span>)<!-- HTML_TAG_END --></pre></div> | |
| </div></div> | |
| </div> | |
| <h2 class="relative group"><a id="whats-next" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#whats-next"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>What's next? | |
| </span></h2> | |
| <p>This completes the basic steps of loading a dataset to train a model. You loaded and processed the MRPC dataset to fine-tune BERT to determine whether sentence pairs have the same meaning.</p> | |
| <p>For your next steps, take a look at our <a href="./how_to">How-to guides</a> and learn how to achieve a specific task (e.g. load a dataset offline, add a dataset to the Hub, change the name of a column). Or if you want to deepen your knowledge of 🤗 Datasets core concepts, read our <a href="./about_arrow">Conceptual Guides</a>.</p> | |
| <script type="module" data-hydrate="1jbcfhl"> | |
| import { start } from "/docs/datasets/v2.2.2/en/_app/start-0f8c1da7.js"; | |
| start({ | |
| target: document.querySelector('[data-hydrate="1jbcfhl"]').parentNode, | |
| paths: {"base":"/docs/datasets/v2.2.2/en","assets":"/docs/datasets/v2.2.2/en"}, | |
| session: {}, | |
| route: false, | |
| spa: false, | |
| trailing_slash: "never", | |
| hydrate: { | |
| status: 200, | |
| error: null, | |
| nodes: [ | |
| import("/docs/datasets/v2.2.2/en/_app/pages/__layout.svelte-efb8e839.js"), | |
| import("/docs/datasets/v2.2.2/en/_app/pages/quickstart.mdx-c9b63796.js") | |
| ], | |
| params: {} | |
| } | |
| }); | |
| </script> | |
Xet Storage Details
- Size:
- 70.9 kB
- Xet hash:
- 86c48090cb8e841e063340868992365b55f4f610e0a2df9774413facec818aa9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.