Buckets:

hf-doc-build/doc / datasets /v2.3.2 /en /use_with_tensorflow.html
rtrm's picture
download
raw
39.5 kB
<meta charset="utf-8" /><meta http-equiv="content-security-policy" content=""><meta name="hf:doc:metadata" content="{&quot;local&quot;:&quot;using-datasets-with-tensorflow&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;dataset-format&quot;,&quot;title&quot;:&quot;Dataset format&quot;},{&quot;local&quot;:&quot;ndimensional-arrays&quot;,&quot;title&quot;:&quot;N-dimensional arrays&quot;},{&quot;local&quot;:&quot;other-feature-types&quot;,&quot;title&quot;:&quot;Other feature types&quot;},{&quot;local&quot;:&quot;data-loading&quot;,&quot;sections&quot;:[{&quot;local&quot;:&quot;using-totfdataset&quot;,&quot;title&quot;:&quot;Using `to_tf_dataset()`&quot;},{&quot;local&quot;:&quot;when-to-use-totfdataset&quot;,&quot;title&quot;:&quot;When to use to_tf_dataset&quot;},{&quot;local&quot;:&quot;caveats-and-limitations&quot;,&quot;title&quot;:&quot;Caveats and limitations&quot;}],&quot;title&quot;:&quot;Data loading&quot;}],&quot;title&quot;:&quot;Using Datasets with TensorFlow&quot;}" data-svelte="svelte-1phssyn">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/assets/pages/__layout.svelte-hf-doc-builder.css">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/start-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/vendor-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/paths-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/pages/__layout.svelte-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/pages/use_with_tensorflow.mdx-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/Tip-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/IconCopyLink-hf-doc-builder.js">
<link rel="modulepreload" href="/docs/datasets/v2.3.2/en/_app/chunks/CodeBlock-hf-doc-builder.js">
<h1 class="relative group"><a id="using-datasets-with-tensorflow" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-datasets-with-tensorflow"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Using Datasets with TensorFlow
</span></h1>
<p>This document is a quick introduction to using <code>datasets</code> with TensorFlow, with a particular focus on how to get
<code>tf.Tensor</code> objects out of our datasets, and how to stream data from Hugging Face <code>Dataset</code> objects to Keras methods
like <code>model.fit()</code>.</p>
<h2 class="relative group"><a id="dataset-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dataset-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Dataset format
</span></h2>
<p>By default, datasets return regular Python objects: integers, floats, strings, lists, etc.</p>
<p>To get TensorFlow tensors instead, you can set the format of the dataset to <code>tf</code>:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;tf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=int64, numpy=array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>])&gt;}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy=
array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])&gt;}<!-- HTML_TAG_END --></pre></div>
<div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p>A <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Dataset">Dataset</a> object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to TensorFlow tensors.</p></div>
<p>This can be useful for converting your dataset to a dict of <code>Tensor</code> objects, or for writing a generator to load TF
samples from it. If you wish to convert the entire dataset to <code>Tensor</code>, simply query the full dataset:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>ds[:]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy=
array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])&gt;}<!-- HTML_TAG_END --></pre></div>
<h2 class="relative group"><a id="ndimensional-arrays" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ndimensional-arrays"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>N-dimensional arrays
</span></h2>
<p>If your dataset consists of N-dimensional arrays, you will see that by default they are considered as nested lists.
In particular, a TensorFlow formatted dataset outputs a <code>RaggedTensor</code> instead of a single tensor:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]]
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;tf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.RaggedTensor [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]&gt;}<!-- HTML_TAG_END --></pre></div>
<p>To get a single tensor, you must explicitly use the Array feature type and specify the shape of your tensors:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Array2D
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]]
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;data&quot;</span>: Array2D(shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=<span class="hljs-string">&#x27;int32&#x27;</span>)})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;tf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy=
array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])&gt;}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=int64, numpy=
array([[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],
[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],
[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]])&gt;}<!-- HTML_TAG_END --></pre></div>
<h2 class="relative group"><a id="other-feature-types" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#other-feature-types"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Other feature types
</span></h2>
<p><a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.ClassLabel">ClassLabel</a> data are properly converted to tensors:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, ClassLabel
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;data&quot;</span>: ClassLabel(names=[<span class="hljs-string">&quot;negative&quot;</span>, <span class="hljs-string">&quot;positive&quot;</span>])})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;tf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">3</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">3</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>])&gt;<!-- HTML_TAG_END --></pre></div>
<p>Strings are also supported:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features
<span class="hljs-meta">&gt;&gt;&gt; </span>text = [<span class="hljs-string">&quot;foo&quot;</span>, <span class="hljs-string">&quot;bar&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [<span class="hljs-number">0</span>, <span class="hljs-number">1</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;text&quot;</span>: text, <span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;tf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;text&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=string, numpy=array([<span class="hljs-string">b&#x27;foo&#x27;</span>, <span class="hljs-string">b&#x27;bar&#x27;</span>], dtype=<span class="hljs-built_in">object</span>)&gt;,
<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">1</span>])&gt;}<!-- HTML_TAG_END --></pre></div>
<p>You can also explicitly format certain columns and leave the other columns unformatted:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;tf&quot;</span>, columns=[<span class="hljs-string">&quot;data&quot;</span>], output_all_columns=<span class="hljs-literal">True</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: &lt;tf.Tensor: shape=(<span class="hljs-number">2</span>,), dtype=int64, numpy=array([<span class="hljs-number">0</span>, <span class="hljs-number">1</span>])&gt;,
<span class="hljs-string">&#x27;text&#x27;</span>: [<span class="hljs-string">&#x27;foo&#x27;</span>, <span class="hljs-string">&#x27;bar&#x27;</span>]}<!-- HTML_TAG_END --></pre></div>
<p>The <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Image">Image</a> and <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Audio">Audio</a> feature types are not supported yet.</p>
<h2 class="relative group"><a id="data-loading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-loading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Data loading
</span></h2>
<p>Although you can load individual samples and batches just by indexing into your dataset, this won’t work if you want
to use Keras methods like <code>fit()</code> and <code>predict()</code>. You could write a generator function that shuffles and loads batches
from your dataset and <code>fit()</code> on that, but that sounds like a lot of unnecessary work. Instead, if you want to stream
data from your dataset on-the-fly, we recommend converting your dataset to a <code>tf.data.Dataset</code> using the
<code>to_tf_dataset()</code> method.</p>
<p>The <code>tf.data.Dataset</code> class covers a wide range of use-cases - it is often created from Tensors in memory, or using a load function to read files on disc
or external storage. The dataset can be transformed arbitrarily with the <code>map()</code> method, or methods like <code>batch()</code>
and <code>shuffle()</code> can be used to create a dataset that’s ready for training. These methods do not modify the stored data
in any way - instead, the methods build a data pipeline graph that will be executed when the dataset is iterated over,
usually during model training or inference. This is different from the <code>map()</code> method of Hugging Face <code>Dataset</code> objects,
which runs the map function immediately and saves the new or changed columns.</p>
<p>Since the entire data preprocessing pipeline can be compiled in a <code>tf.data.Dataset</code>, this approach allows for massively
parallel, asynchronous data loading and training. However, the requirement for graph compilation can be a limitation,
particularly for Hugging Face tokenizers, which are usually not (yet!) compilable as part of a TF graph. As a result,
we usually advise pre-processing the dataset as a Hugging Face dataset, where arbitrary Python functions can be
used, and then converting to <code>tf.data.Dataset</code> afterwards using <code>to_tf_dataset()</code> to get a batched dataset ready for
training. To see examples of this approach, please see the <a href="https://github.com/huggingface/transformers/tree/main/examples" rel="nofollow">examples</a> or <a href="https://huggingface.co/docs/transformers/notebooks" rel="nofollow">notebooks</a> for <code>transformers</code>.</p>
<h3 class="relative group"><a id="using-totfdataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-totfdataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Using <code>to_tf_dataset()</code></span></h3>
<p>Using <code>to_tf_dataset()</code> is straightforward. Once your dataset is preprocessed and ready, simply call it like so:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = {<span class="hljs-string">&quot;inputs&quot;</span>: [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], <span class="hljs-string">&quot;labels&quot;</span>: [<span class="hljs-number">0</span>, <span class="hljs-number">1</span>]}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict(data)
<span class="hljs-meta">&gt;&gt;&gt; </span>tf_ds = ds.to_tf_dataset(
columns=[<span class="hljs-string">&quot;inputs&quot;</span>],
label_cols=[<span class="hljs-string">&quot;labels&quot;</span>],
batch_size=<span class="hljs-number">2</span>,
shuffle=<span class="hljs-literal">True</span>
)<!-- HTML_TAG_END --></pre></div>
<p>The returned <code>tf_ds</code> object here is now fully ready to train on, and can be passed directly to <code>model.fit()</code>! Note
that you set the batch size when creating the dataset, and so you don’t need to specify it when calling <code>fit()</code>:</p>
<div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg>
<div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div>
Copied</div></button></div>
<pre><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>model.fit(tf_ds, epochs=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div>
<p>For a full description of the arguments, please see the <a href="/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset">to_tf_dataset()</a> documentation. In many cases,
you will also need to add a <code>collate_fn</code> to your call. This is a function that takes multiple elements of the dataset
and combines them into a single batch. When all elements have the same length, the built-in default collator will
suffice, but for more complex tasks a custom collator may be necessary. In particular, many tasks have samples
with varying sequence lengths which will require a <a href="https://huggingface.co/docs/transformers/main/en/main_classes/data_collator" rel="nofollow">data collator</a> that can pad batches correctly. You can see examples
of this in the <code>transformers</code> NLP <a href="https://github.com/huggingface/transformers/tree/main/examples" rel="nofollow">examples</a> and
<a href="https://huggingface.co/docs/transformers/notebooks" rel="nofollow">notebooks</a>, where variable sequence lengths are very common.</p>
<h3 class="relative group"><a id="when-to-use-totfdataset" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#when-to-use-totfdataset"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>When to use to_tf_dataset
</span></h3>
<p>The astute reader may have noticed at this point that we have offered two approaches to achieve the same goal - if you
want to pass your dataset to a TensorFlow model, you can either convert the dataset to a <code>Tensor</code> or <code>dict</code> of <code>Tensors</code>
using <code>.with_format(&#39;tf&#39;)</code>, or you can convert the dataset to a <code>tf.data.Dataset</code> with <code>to_tf_dataset()</code>. Either of these
can be passed to <code>model.fit()</code>, so which should you choose?</p>
<p>The key thing to recognize is that when you convert the whole dataset to <code>Tensor</code>s, it is static and fully loaded into
RAM. This is simple and convenient, but if any of the following apply, you should probably use <code>to_tf_dataset()</code>
instead:</p>
<ul><li>Your dataset is too large to fit in RAM. <code>to_tf_dataset()</code> streams only one batch at a time, so even very large
datasets can be handled with this method.</li>
<li>You want to apply random transformations using <code>dataset.with_transform()</code> or the <code>collate_fn</code>. This is
common in several modalities, such as image augmentations when training vision models, or random masking when training
masked language models. Using <code>to_tf_dataset()</code> will apply those transformations
at the moment when a batch is loaded, which means the same samples will get different augmentations each time
they are loaded. This is usually what you want.</li>
<li>Your data has a variable dimension, such as input texts in NLP that consist of varying
numbers of tokens. When you create a batch with samples with a variable dimension, the standard solution is to
pad the shorter samples to the length of the longest one. When you stream samples from a dataset with <code>to_tf_dataset</code>,
you can apply this padding to each batch via your <code>collate_fn</code>. However, if you want to convert
such a dataset to dense <code>Tensor</code>s, then you will have to pad samples to the length of the longest sample in <em>the
entire dataset!</em> This can result in huge amounts of padding, which wastes memory and reduces your model’s speed.</li></ul>
<h3 class="relative group"><a id="caveats-and-limitations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#caveats-and-limitations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a>
<span>Caveats and limitations
</span></h3>
<p>Right now, <code>to_tf_dataset()</code> always return a batched dataset - we will add support for unbatched datasets soon!</p>
<script type="module" data-hydrate="a5tgxd">
import { start } from "/docs/datasets/v2.3.2/en/_app/start-hf-doc-builder.js";
start({
target: document.querySelector('[data-hydrate="a5tgxd"]').parentNode,
paths: {"base":"/docs/datasets/v2.3.2/en","assets":"/docs/datasets/v2.3.2/en"},
session: {},
route: false,
spa: false,
trailing_slash: "never",
hydrate: {
status: 200,
error: null,
nodes: [
import("/docs/datasets/v2.3.2/en/_app/pages/__layout.svelte-hf-doc-builder.js"),
import("/docs/datasets/v2.3.2/en/_app/pages/use_with_tensorflow.mdx-hf-doc-builder.js")
],
params: {}
}
});
</script>

Xet Storage Details

Size:
39.5 kB
·
Xet hash:
f6a2a65e6177ce3eed8588a93f8b7ae65b55145aa07d20ae06b1bedc25f14bff

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.