Buckets:
| <meta charset="utf-8" /><meta http-equiv="content-security-policy" content=""><meta name="hf:doc:metadata" content="{"local":"train-with-datasets","sections":[{"local":"tokenize","title":"Tokenize"},{"local":"use-in-pytorch-or-tensorflow","sections":[{"local":"pytorch","title":"PyTorch"},{"local":"tensorflow","title":"TensorFlow"}],"title":"Use in PyTorch or TensorFlow"}],"title":"Train with 🤗 Datasets"}" data-svelte="svelte-1phssyn"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/assets/pages/__layout.svelte-efc77dbd.css"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/start-0f8c1da7.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/vendor-8138ceec.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/paths-4b3c6e7e.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/pages/__layout.svelte-efb8e839.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/pages/use_dataset.mdx-a1d0c966.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/Tip-12722dfc.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/IconCopyLink-2dd3a6ac.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.2.2/en/_app/chunks/CodeBlock-fc89709f.js"> | |
| <h1 class="relative group"><a id="train-with-datasets" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#train-with-datasets"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Train with 🤗 Datasets | |
| </span></h1> | |
| <p>So far, you loaded a dataset from the Hugging Face Hub and learned how to access the information stored inside the dataset. Now you will tokenize and use your dataset with a framework such as PyTorch or TensorFlow. By default, all the dataset columns are returned as Python objects. But you can bridge the gap between a Python object and your machine learning framework by setting the format of a dataset. Formatting casts the columns into compatible PyTorch or TensorFlow types.</p> | |
| <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p>Often times you may want to modify the structure and content of your dataset before you use it to train a model. For example, you may want to remove a column or cast it as a different type. 🤗 Datasets provides the necessary tools to do this, but since each dataset is so different, the processing approach will vary individually. For more detailed information about preprocessing data, take a look at our <a href="https://huggingface.co/transformers/preprocessing#" rel="nofollow">guide</a> from the 🤗 Transformers library. Then come back and read our <a href="./process">How-to Process</a> guide to see all the different methods for processing your dataset.</p></div> | |
| <h2 class="relative group"><a id="tokenize" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#tokenize"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Tokenize | |
| </span></h2> | |
| <p>Tokenization divides text into individual words called tokens. Tokens are converted into numbers, which is what the model receives as its input. </p> | |
| <p>The first step is to install the 🤗 Transformers library:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START -->pip <span class="hljs-keyword">install</span> transformers<!-- HTML_TAG_END --></pre></div> | |
| <p>Next, import a tokenizer. It is important to use the tokenizer that is associated with the model you are using, so the text is split in the same way. In this example, load the <a href="https://huggingface.co/transformers/model_doc/bert#berttokenizerfast" rel="nofollow">BERT tokenizer</a> because you are using the <a href="https://huggingface.co/bert-base-cased" rel="nofollow">BERT</a> model:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> BertTokenizerFast | |
| <span class="hljs-meta">>>> </span>tokenizer = BertTokenizerFast.from_pretrained(<span class="hljs-string">'bert-base-cased'</span>)<!-- HTML_TAG_END --></pre></div> | |
| <p>Now you can tokenize <code>sentence1</code> field of the dataset:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>encoded_dataset = dataset.<span class="hljs-built_in">map</span>(<span class="hljs-keyword">lambda</span> examples: tokenizer(examples[<span class="hljs-string">'sentence1'</span>]), batched=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">>>> </span>encoded_dataset.column_names | |
| [<span class="hljs-string">'sentence1'</span>, <span class="hljs-string">'sentence2'</span>, <span class="hljs-string">'label'</span>, <span class="hljs-string">'idx'</span>, <span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'token_type_ids'</span>, <span class="hljs-string">'attention_mask'</span>] | |
| <span class="hljs-meta">>>> </span>encoded_dataset[<span class="hljs-number">0</span>] | |
| {<span class="hljs-string">'sentence1'</span>: <span class="hljs-string">'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .'</span>, | |
| <span class="hljs-string">'sentence2'</span>: <span class="hljs-string">'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'</span>, | |
| <span class="hljs-string">'label'</span>: <span class="hljs-number">1</span>, | |
| <span class="hljs-string">'idx'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'input_ids'</span>: [ <span class="hljs-number">101</span>, <span class="hljs-number">7277</span>, <span class="hljs-number">2180</span>, <span class="hljs-number">5303</span>, <span class="hljs-number">4806</span>, <span class="hljs-number">1117</span>, <span class="hljs-number">1711</span>, <span class="hljs-number">117</span>, <span class="hljs-number">2292</span>, <span class="hljs-number">1119</span>, <span class="hljs-number">1270</span>, <span class="hljs-number">107</span>, <span class="hljs-number">1103</span>, <span class="hljs-number">7737</span>, <span class="hljs-number">107</span>, <span class="hljs-number">117</span>, <span class="hljs-number">1104</span>, <span class="hljs-number">9938</span>, <span class="hljs-number">4267</span>, <span class="hljs-number">12223</span>, <span class="hljs-number">21811</span>, <span class="hljs-number">1117</span>, <span class="hljs-number">2554</span>, <span class="hljs-number">119</span>, <span class="hljs-number">102</span>], | |
| <span class="hljs-string">'token_type_ids'</span>: [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| <span class="hljs-string">'attention_mask'</span>: [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>] | |
| }<!-- HTML_TAG_END --></pre></div> | |
| <p>The tokenization process creates three new columns: <code>input_ids</code>, <code>token_type_ids</code>, and <code>attention_mask</code>. These are the inputs to the model.</p> | |
| <h2 class="relative group"><a id="use-in-pytorch-or-tensorflow" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-in-pytorch-or-tensorflow"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Use in PyTorch or TensorFlow | |
| </span></h2> | |
| <p>Next, format the dataset into compatible PyTorch or TensorFlow types.</p> | |
| <h3 class="relative group"><a id="pytorch" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>PyTorch | |
| </span></h3> | |
| <p>If you are using PyTorch, set the format with <a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.set_format">Dataset.set_format()</a>, which accepts two main arguments:</p> | |
| <ol><li><p><code>type</code> defines the type of column to cast to. For example, <code>torch</code> returns PyTorch tensors.</p></li> | |
| <li><p><code>columns</code> specify which columns should be formatted.</p></li></ol> | |
| <p>After you set the format, wrap the dataset with <code>torch.utils.data.DataLoader</code>. Your dataset is now ready for use in a training loop!</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| <span class="hljs-meta">>>> </span>dataset = load_dataset(<span class="hljs-string">'glue'</span>, <span class="hljs-string">'mrpc'</span>, split=<span class="hljs-string">'train'</span>) | |
| <span class="hljs-meta">>>> </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">'bert-base-cased'</span>) | |
| <span class="hljs-meta">>>> </span>dataset = dataset.<span class="hljs-built_in">map</span>(<span class="hljs-keyword">lambda</span> e: tokenizer(e[<span class="hljs-string">'sentence1'</span>], truncation=<span class="hljs-literal">True</span>, padding=<span class="hljs-string">'max_length'</span>), batched=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">>>> </span>dataset.set_format(<span class="hljs-built_in">type</span>=<span class="hljs-string">'torch'</span>, columns=[<span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'token_type_ids'</span>, <span class="hljs-string">'attention_mask'</span>, <span class="hljs-string">'label'</span>]) | |
| <span class="hljs-meta">>>> </span>dataloader = torch.utils.data.DataLoader(dataset, batch_size=<span class="hljs-number">32</span>) | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(dataloader)) | |
| {<span class="hljs-string">'attention_mask'</span>: tensor([[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]]), | |
| <span class="hljs-string">'input_ids'</span>: tensor([[ <span class="hljs-number">101</span>, <span class="hljs-number">7277</span>, <span class="hljs-number">2180</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">1109</span>, <span class="hljs-number">4173</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]]), | |
| <span class="hljs-string">'label'</span>: tensor([<span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>]), | |
| <span class="hljs-string">'token_type_ids'</span>: tensor([[<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])}<!-- HTML_TAG_END --></pre></div> | |
| <h3 class="relative group"><a id="tensorflow" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#tensorflow"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>TensorFlow | |
| </span></h3> | |
| <p>If you are using TensorFlow, you can use <a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset">Dataset.to_tf_dataset()</a> to wrap the dataset with a <strong>tf.data.Dataset</strong>, which is natively understood by Keras. | |
| This means a <strong>tf.data.Dataset</strong> object can be iterated over to yield batches of data, and can be passed directly to methods like <strong>model.fit()</strong>.</p> | |
| <p><a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset">Dataset.to_tf_dataset()</a> accepts several arguments:</p> | |
| <ol><li><p><code>columns</code> specify which columns should be formatted (includes the inputs and labels).</p></li> | |
| <li><p><code>shuffle</code> determines whether the dataset should be shuffled.</p></li> | |
| <li><p><code>batch_size</code> specifies the batch size.</p></li> | |
| <li><p><code>collate_fn</code> specifies a data collator that will batch each processed example and apply padding. If you are using a <code>DataCollator</code>, make sure you set <code>return_tensors="tf"</code> when you initialize it to return <code>tf.Tensor</code> outputs.</p></li></ol> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| <span class="hljs-meta">>>> </span>dataset = load_dataset(<span class="hljs-string">'glue'</span>, <span class="hljs-string">'mrpc'</span>, split=<span class="hljs-string">'train'</span>) | |
| <span class="hljs-meta">>>> </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">'bert-base-cased'</span>) | |
| <span class="hljs-meta">>>> </span>dataset = dataset.<span class="hljs-built_in">map</span>(<span class="hljs-keyword">lambda</span> e: tokenizer(e[<span class="hljs-string">'sentence1'</span>], truncation=<span class="hljs-literal">True</span>, padding=<span class="hljs-string">'max_length'</span>), batched=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">>>> </span>data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=<span class="hljs-string">"tf"</span>) | |
| <span class="hljs-meta">>>> </span>train_dataset = dataset[<span class="hljs-string">"train"</span>].to_tf_dataset( | |
| <span class="hljs-meta">... </span> columns=[<span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'token_type_ids'</span>, <span class="hljs-string">'attention_mask'</span>, <span class="hljs-string">'label'</span>], | |
| <span class="hljs-meta">... </span> shuffle=<span class="hljs-literal">True</span>, | |
| <span class="hljs-meta">... </span> batch_size=<span class="hljs-number">16</span>, | |
| <span class="hljs-meta">... </span> collate_fn=data_collator, | |
| <span class="hljs-meta">... </span>) | |
| <span class="hljs-meta">>>> </span>model.fit(train_dataset) <span class="hljs-comment"># The output tf.data.Dataset is ready for training immediately</span> | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(train_dataset)) <span class="hljs-comment"># You can also iterate over the dataset manually to get batches</span> | |
| {<span class="hljs-string">'attention_mask'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">512</span>), dtype=int64, numpy= | |
| array([[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])>, | |
| <span class="hljs-string">'input_ids'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">512</span>), dtype=int64, numpy= | |
| array([[ <span class="hljs-number">101</span>, <span class="hljs-number">11336</span>, <span class="hljs-number">11154</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [ <span class="hljs-number">101</span>, <span class="hljs-number">156</span>, <span class="hljs-number">22705</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])>, | |
| <span class="hljs-string">'labels'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>,), dtype=int64, numpy= | |
| array([<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>])>, | |
| <span class="hljs-string">'token_type_ids'</span>: <tf.Tensor: shape=(<span class="hljs-number">16</span>, <span class="hljs-number">512</span>), dtype=int64, numpy= | |
| array([[<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>], | |
| ..., | |
| [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, ..., <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>]])> | |
| }<!-- HTML_TAG_END --></pre></div> | |
| <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p><a href="/docs/datasets/v2.2.2/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset">Dataset.to_tf_dataset()</a> is the easiest way to create a TensorFlow compatible dataset. If you don’t want a <code>tf.data.Dataset</code> and would rather the dataset emit <code>tf.Tensor</code> objects, take a look at the <a href="./process#format">format</a> section instead!</p></div> | |
| <p>Your dataset is now ready for use in a training loop!</p> | |
| <script type="module" data-hydrate="1f9t50y"> | |
| import { start } from "/docs/datasets/v2.2.2/en/_app/start-0f8c1da7.js"; | |
| start({ | |
| target: document.querySelector('[data-hydrate="1f9t50y"]').parentNode, | |
| paths: {"base":"/docs/datasets/v2.2.2/en","assets":"/docs/datasets/v2.2.2/en"}, | |
| session: {}, | |
| route: false, | |
| spa: false, | |
| trailing_slash: "never", | |
| hydrate: { | |
| status: 200, | |
| error: null, | |
| nodes: [ | |
| import("/docs/datasets/v2.2.2/en/_app/pages/__layout.svelte-efb8e839.js"), | |
| import("/docs/datasets/v2.2.2/en/_app/pages/use_dataset.mdx-a1d0c966.js") | |
| ], | |
| params: {} | |
| } | |
| }); | |
| </script> | |
Xet Storage Details
- Size:
- 32.3 kB
- Xet hash:
- e64e734b68fcc1e17dc4b8c0dd13379fcf8d28b2b26f098491c253c94bc29811
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.