Buckets:
| <meta charset="utf-8" /><meta http-equiv="content-security-policy" content=""><meta name="hf:doc:metadata" content="{"local":"using-datasets-with-tensorflow","sections":[{"local":"dataset-format","title":"Dataset format"},{"local":"ndimensional-arrays","title":"N-dimensional arrays"},{"local":"other-feature-types","title":"Other feature types"},{"local":"data-loading","sections":[{"local":"using-totfdataset","title":"Using `to_tf_dataset()`"},{"local":"when-to-use-totfdataset","title":"When to use to_tf_dataset"},{"local":"caveats-and-limitations","title":"Caveats and limitations"}],"title":"Data loading"}],"title":"Using Datasets with TensorFlow"}" data-svelte="svelte-1phssyn"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/assets/pages/__layout.svelte-hf-doc-builder.css"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/start-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/vendor-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/paths-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/pages/__layout.svelte-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/pages/use_with_tensorflow.mdx-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/Tip-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/IconCopyLink-hf-doc-builder.js"> | |
| <link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/CodeBlock-hf-doc-builder.js"> | |
| <h1 class="relative group"><a id="using-datasets-with-tensorflow" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-datasets-with-tensorflow"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Using Datasets with TensorFlow | |
| </span></h1> | |
| <p>This document is a quick introduction to using <code>datasets</code> with TensorFlow, with a particular focus on how to get | |
| <code>tf.Tensor</code> objects out of our datasets, and how to stream data from Hugging Face <code>Dataset</code> objects to Keras methods | |
| like <code>model.fit()</code>.</p> | |
| <h2 class="relative group"><a id="dataset-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dataset-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Dataset format | |
| </span></h2> | |
| <p>By default, datasets return regular Python objects: integers, floats, strings, lists, etc.</p> | |
| <p>To get TensorFlow tensors instead, you can set the format of the dataset to <code>tf</code>:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset | |
| <span class="hljs-meta">>>> </span>data = [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]] | |
| <span class="hljs-meta">>>> </span>ds = Dataset.from_dict({<span class="hljs-string">"data"</span>: [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]}) | |
| <span class="hljs-meta">>>> </span>ds = ds.with_format(<span class="hljs-string">"tf"</span>) | |
| <span class="hljs-meta">>>> </span>ds[<span class="hljs-number">0</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=int64, numpy=array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>])>} | |
| <span class="hljs-meta">>>> </span>ds[:<span class="hljs-number">2</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy= | |
| array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], | |
| [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])>}<!-- HTML_TAG_END --></pre></div> | |
| <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p>A <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Dataset">Dataset</a> object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to TensorFlow tensors.</p></div> | |
| <p>This can be useful for converting your dataset to a dict of <code>Tensor</code> objects, or for writing a generator to load TF | |
| samples from it. If you wish to convert the entire dataset to <code>Tensor</code>, simply query the full dataset:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>ds[:] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy= | |
| array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], | |
| [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])>}<!-- HTML_TAG_END --></pre></div> | |
| <h2 class="relative group"><a id="ndimensional-arrays" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ndimensional-arrays"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>N-dimensional arrays | |
| </span></h2> | |
| <p>If your dataset consists of N-dimensional arrays, you will see that by default they are considered as nested lists. | |
| In particular, a TensorFlow formatted dataset outputs a <code>RaggedTensor</code> instead of a single tensor:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset | |
| <span class="hljs-meta">>>> </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]] | |
| <span class="hljs-meta">>>> </span>ds = Dataset.from_dict({<span class="hljs-string">"data"</span>: data}) | |
| <span class="hljs-meta">>>> </span>ds = ds.with_format(<span class="hljs-string">"tf"</span>) | |
| <span class="hljs-meta">>>> </span>ds[<span class="hljs-number">0</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.RaggedTensor [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]>}<!-- HTML_TAG_END --></pre></div> | |
| <p>To get a single tensor, you must explicitly use the Array feature type and specify the shape of your tensors:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Array2D | |
| <span class="hljs-meta">>>> </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]] | |
| <span class="hljs-meta">>>> </span>features = Features({<span class="hljs-string">"data"</span>: Array2D(shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=<span class="hljs-string">'int32'</span>)}) | |
| <span class="hljs-meta">>>> </span>ds = Dataset.from_dict({<span class="hljs-string">"data"</span>: data}, features=features) | |
| <span class="hljs-meta">>>> </span>ds = ds.with_format(<span class="hljs-string">"tf"</span>) | |
| <span class="hljs-meta">>>> </span>ds[<span class="hljs-number">0</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy= | |
| array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], | |
| [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])>} | |
| <span class="hljs-meta">>>> </span>ds[:<span class="hljs-number">2</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy= | |
| array([[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], | |
| [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], | |
| [[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>], | |
| [<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]])>}<!-- HTML_TAG_END --></pre></div> | |
| <h2 class="relative group"><a id="other-feature-types" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#other-feature-types"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Other feature types | |
| </span></h2> | |
| <p><a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.ClassLabel">ClassLabel</a> data are properly converted to tensors:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, ClassLabel | |
| <span class="hljs-meta">>>> </span>data = [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>] | |
| <span class="hljs-meta">>>> </span>features = Features({<span class="hljs-string">"data"</span>: ClassLabel(names=[<span class="hljs-string">"negative"</span>, <span class="hljs-string">"positive"</span>])}) | |
| <span class="hljs-meta">>>> </span>ds = Dataset.from_dict({<span class="hljs-string">"data"</span>: data}, features=features) | |
| <span class="hljs-meta">>>> </span>ds = ds.with_format(<span class="hljs-string">"tf"</span>) | |
| <span class="hljs-meta">>>> </span>ds[:<span class="hljs-number">3</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">3</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>])><!-- HTML_TAG_END --></pre></div> | |
| <p>Strings are also supported:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features | |
| <span class="hljs-meta">>>> </span>text = [<span class="hljs-string">"foo"</span>, <span class="hljs-string">"bar"</span>] | |
| <span class="hljs-meta">>>> </span>data = [<span class="hljs-number">0</span>, <span class="hljs-number">1</span>] | |
| <span class="hljs-meta">>>> </span>ds = Dataset.from_dict({<span class="hljs-string">"text"</span>: text, <span class="hljs-string">"data"</span>: data}) | |
| <span class="hljs-meta">>>> </span>ds = ds.with_format(<span class="hljs-string">"tf"</span>) | |
| <span class="hljs-meta">>>> </span>ds[:<span class="hljs-number">2</span>] | |
| {<span class="hljs-string">'text'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=string, numpy=array([<span class="hljs-string">b'foo'</span>, <span class="hljs-string">b'bar'</span>], dtype=<span class="hljs-built_in">object</span>)>, | |
| <span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">1</span>])>}<!-- HTML_TAG_END --></pre></div> | |
| <p>You can also explicitly format certain columns and leave the other columns unformatted:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>ds = ds.with_format(<span class="hljs-string">"tf"</span>, columns=[<span class="hljs-string">"data"</span>], output_all_columns=<span class="hljs-literal">True</span>) | |
| <span class="hljs-meta">>>> </span>ds[:<span class="hljs-number">2</span>] | |
| {<span class="hljs-string">'data'</span>: <tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">1</span>])>, | |
| <span class="hljs-string">'text'</span>: [<span class="hljs-string">'foo'</span>, <span class="hljs-string">'bar'</span>]}<!-- HTML_TAG_END --></pre></div> | |
| <p>The <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Image">Image</a> and <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Audio">Audio</a> feature types are not supported yet.</p> | |
| <h2 class="relative group"><a id="data-loading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-loading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Data loading | |
| </span></h2> | |
| <p>Although you can load individual samples and batches just by indexing into your dataset, this won’t work if you want | |
| to use Keras methods like <code>fit()</code> and <code>predict()</code>. You could write a generator function that shuffles and loads batches | |
| from your dataset and <code>fit()</code> on that, but that sounds like a lot of unnecessary work. Instead, if you want to stream | |
| data from your dataset on-the-fly, we recommend converting your dataset to a <code>tf.data.Dataset</code> using the | |
| <code>to_tf_dataset()</code> method.</p> | |
| <p>The <code>tf.data.Dataset</code> class covers a wide range of use-cases - it is often created from Tensors in memory, or using a load function to read files on disc | |
| or external storage. The dataset can be transformed arbitrarily with the <code>map()</code> method, or methods like <code>batch()</code> | |
| and <code>shuffle()</code> can be used to create a dataset that’s ready for training. These methods do not modify the stored data | |
| in any way - instead, the methods build a data pipeline graph that will be executed when the dataset is iterated over, | |
| usually during model training or inference. This is different from the <code>map()</code> method of Hugging Face <code>Dataset</code> objects, | |
| which runs the map function immediately and saves the new or changed columns.</p> | |
| <p>Since the entire data preprocessing pipeline can be compiled in a <code>tf.data.Dataset</code>, this approach allows for massively | |
| parallel, asynchronous data loading and training. However, the requirement for graph compilation can be a limitation, | |
| particularly for Hugging Face tokenizers, which are usually not (yet!) compilable as part of a TF graph. As a result, | |
| we usually advise pre-processing the dataset as a Hugging Face dataset, where arbitrary Python functions can be | |
| used, and then converting to <code>tf.data.Dataset</code> afterwards using <code>to_tf_dataset()</code> to get a batched dataset ready for | |
| training. To see examples of this approach, please see the <a href="https://github.com/huggingface/transformers/tree/main/examples" rel="nofollow">examples</a> or <a href="https://huggingface.co/docs/transformers/notebooks" rel="nofollow">notebooks</a> for <code>transformers</code>.</p> | |
| <h3 class="relative group"><a id="using-totfdataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-totfdataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Using <code>to_tf_dataset()</code></span></h3> | |
| <p>Using <code>to_tf_dataset()</code> is straightforward. Once your dataset is preprocessed and ready, simply call it like so:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset | |
| <span class="hljs-meta">>>> </span>data = {<span class="hljs-string">"inputs"</span>: [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], <span class="hljs-string">"labels"</span>: [<span class="hljs-number">0</span>, <span class="hljs-number">1</span>]} | |
| <span class="hljs-meta">>>> </span>ds = Dataset.from_dict(data) | |
| <span class="hljs-meta">>>> </span>tf_ds = ds.to_tf_dataset( | |
| columns=[<span class="hljs-string">"inputs"</span>], | |
| label_cols=[<span class="hljs-string">"labels"</span>], | |
| batch_size=<span class="hljs-number">2</span>, | |
| shuffle=<span class="hljs-literal">True</span> | |
| )<!-- HTML_TAG_END --></pre></div> | |
| <p>The returned <code>tf_ds</code> object here is now fully ready to train on, and can be passed directly to <code>model.fit()</code>! Note | |
| that you set the batch size when creating the dataset, and so you don’t need to specify it when calling <code>fit()</code>:</p> | |
| <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> | |
| <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> | |
| Copied</div></button></div> | |
| <pre><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>model.fit(tf_ds, epochs=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> | |
| <p>For a full description of the arguments, please see the <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset">to_tf_dataset()</a> documentation. In many cases, | |
| you will also need to add a <code>collate_fn</code> to your call. This is a function that takes multiple elements of the dataset | |
| and combines them into a single batch. When all elements have the same length, the built-in default collator will | |
| suffice, but for more complex tasks a custom collator may be necessary. In particular, many tasks have samples | |
| with varying sequence lengths which will require a <a href="https://huggingface.co/docs/transformers/main/en/main_classes/data_collator" rel="nofollow">data collator</a> that can pad batches correctly. You can see examples | |
| of this in the <code>transformers</code> NLP <a href="https://github.com/huggingface/transformers/tree/main/examples" rel="nofollow">examples</a> and | |
| <a href="https://huggingface.co/docs/transformers/notebooks" rel="nofollow">notebooks</a>, where variable sequence lengths are very common.</p> | |
| <h3 class="relative group"><a id="when-to-use-totfdataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#when-to-use-totfdataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>When to use to_tf_dataset | |
| </span></h3> | |
| <p>The astute reader may have noticed at this point that we have offered two approaches to achieve the same goal - if you | |
| want to pass your dataset to a TensorFlow model, you can either convert the dataset to a <code>Tensor</code> or <code>dict</code> of <code>Tensors</code> | |
| using <code>.with_format('tf')</code>, or you can convert the dataset to a <code>tf.data.Dataset</code> with <code>to_tf_dataset()</code>. Either of these | |
| can be passed to <code>model.fit()</code>, so which should you choose?</p> | |
| <p>The key thing to recognize is that when you convert the whole dataset to <code>Tensor</code>s, it is static and fully loaded into | |
| RAM. This is simple and convenient, but if any of the following apply, you should probably use <code>to_tf_dataset()</code> | |
| instead:</p> | |
| <ul><li>Your dataset is too large to fit in RAM. <code>to_tf_dataset()</code> streams only one batch at a time, so even very large | |
| datasets can be handled with this method.</li> | |
| <li>You want to apply random transformations using <code>dataset.with_transform()</code> or the <code>collate_fn</code>. This is | |
| common in several modalities, such as image augmentations when training vision models, or random masking when training | |
| masked language models. Using <code>to_tf_dataset()</code> will apply those transformations | |
| at the moment when a batch is loaded, which means the same samples will get different augmentations each time | |
| they are loaded. This is usually what you want.</li> | |
| <li>Your data has a variable dimension, such as input texts in NLP that consist of varying | |
| numbers of tokens. When you create a batch with samples with a variable dimension, the standard solution is to | |
| pad the shorter samples to the length of the longest one. When you stream samples from a dataset with <code>to_tf_dataset</code>, | |
| you can apply this padding to each batch via your <code>collate_fn</code>. However, if you want to convert | |
| such a dataset to dense <code>Tensor</code>s, then you will have to pad samples to the length of the longest sample in <em>the | |
| entire dataset!</em> This can result in huge amounts of padding, which wastes memory and reduces your model’s speed.</li></ul> | |
| <h3 class="relative group"><a id="caveats-and-limitations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#caveats-and-limitations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> | |
| <span>Caveats and limitations | |
| </span></h3> | |
| <p>Right now, <code>to_tf_dataset()</code> always return a batched dataset - we will add support for unbatched datasets soon!</p> | |
| <script type="module" data-hydrate="a5tgxd"> | |
| import { start } from "/docs/datasets/v2.3.2/en/_app/start-hf-doc-builder.js"; | |
| start({ | |
| target: document.querySelector('[data-hydrate="a5tgxd"]').parentNode, | |
| paths: {"base":"/docs/datasets/v2.3.2/en","assets":"/docs/datasets/v2.3.2/en"}, | |
| session: {}, | |
| route: false, | |
| spa: false, | |
| trailing_slash: "never", | |
| hydrate: { | |
| status: 200, | |
| error: null, | |
| nodes: [ | |
| import("/docs/datasets/v2.3.2/en/_app/pages/__layout.svelte-hf-doc-builder.js"), | |
| import("/docs/datasets/v2.3.2/en/_app/pages/use_with_tensorflow.mdx-hf-doc-builder.js") | |
| ], | |
| params: {} | |
| } | |
| }); | |
| </script> | |
Xet Storage Details
- Size:
- 39.5 kB
- Xet hash:
- f6a2a65e6177ce3eed8588a93f8b7ae65b55145aa07d20ae06b1bedc25f14bff
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.