Buckets:

hf-doc-build/doc-dev / datasets /main /en /use_with_pytorch.html
rtrm's picture
download
raw
59.3 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Use with PyTorch&quot;,&quot;local&quot;:&quot;use-with-pytorch&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Dataset format&quot;,&quot;local&quot;:&quot;dataset-format&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;N-dimensional arrays&quot;,&quot;local&quot;:&quot;n-dimensional-arrays&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Other feature types&quot;,&quot;local&quot;:&quot;other-feature-types&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Data loading&quot;,&quot;local&quot;:&quot;data-loading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Optimize data loading&quot;,&quot;local&quot;:&quot;optimize-data-loading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Use multiple Workers&quot;,&quot;local&quot;:&quot;use-multiple-workers&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Stream data&quot;,&quot;local&quot;:&quot;stream-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Checkpoint and resume&quot;,&quot;local&quot;:&quot;checkpoint-and-resume&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Distributed&quot;,&quot;local&quot;:&quot;distributed&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/datasets/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/entry/start.4d44eea4.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/scheduler.bdbef820.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/singletons.36b689ad.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/index.8a885b74.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/paths.27092e28.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/entry/app.d83067e8.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/index.c0aea24a.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/nodes/0.bfb01985.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/nodes/49.ecaba42f.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/Tip.31005f7d.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/CodeBlock.6ccca92e.js">
<link rel="modulepreload" href="/docs/datasets/main/en/_app/immutable/chunks/EditOnGithub.725ee0c1.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Use with PyTorch&quot;,&quot;local&quot;:&quot;use-with-pytorch&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Dataset format&quot;,&quot;local&quot;:&quot;dataset-format&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;N-dimensional arrays&quot;,&quot;local&quot;:&quot;n-dimensional-arrays&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Other feature types&quot;,&quot;local&quot;:&quot;other-feature-types&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Data loading&quot;,&quot;local&quot;:&quot;data-loading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Optimize data loading&quot;,&quot;local&quot;:&quot;optimize-data-loading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Use multiple Workers&quot;,&quot;local&quot;:&quot;use-multiple-workers&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Stream data&quot;,&quot;local&quot;:&quot;stream-data&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Checkpoint and resume&quot;,&quot;local&quot;:&quot;checkpoint-and-resume&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Distributed&quot;,&quot;local&quot;:&quot;distributed&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="use-with-pytorch" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-with-pytorch"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use with PyTorch</span></h1> <p data-svelte-h="svelte-1116k3w">This document is a quick introduction to using <code>datasets</code> with PyTorch, with a particular focus on how to get
<code>torch.Tensor</code> objects out of our datasets, and how to use a PyTorch <code>DataLoader</code> and a Hugging Face <code>Dataset</code>
with the best performance.</p> <h2 class="relative group"><a id="dataset-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dataset-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Dataset format</span></h2> <p data-svelte-h="svelte-j9f3ms">By default, datasets return regular python objects: integers, floats, strings, lists, etc.</p> <p data-svelte-h="svelte-wuauow">To get PyTorch tensors instead, you can set the format of the dataset to <code>pytorch</code> using <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.with_format">Dataset.with_format()</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>])}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])}<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1bbq9ig">A <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset">Dataset</a> object is a wrapper of an Arrow table, which allows fast zero-copy reads from arrays in the dataset to PyTorch tensors.</p></div> <p data-svelte-h="svelte-1ezbzoy">To load the data as tensors on a GPU, specify the <code>device</code> argument:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span>device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span> <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> <span class="hljs-string">&quot;cpu&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>, device=device)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], device=<span class="hljs-string">&#x27;cuda:0&#x27;</span>)}<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="n-dimensional-arrays" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#n-dimensional-arrays"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>N-dimensional arrays</span></h3> <p data-svelte-h="svelte-smjp9l">If your dataset consists of N-dimensional arrays, you will see that by default they are considered as the same tensor if the shape is fixed:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]] <span class="hljs-comment"># fixed shape</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])}<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>]],[[<span class="hljs-number">4</span>, <span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]] <span class="hljs-comment"># varying shape</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: [tensor([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>]), tensor([<span class="hljs-number">3</span>])]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gw41y9">However this logic often requires slow shape comparisons and data copies.
To avoid this, you must explicitly use the <code>Array</code> feature type and specify the shape of your tensors:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Array2D
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]]
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;data&quot;</span>: Array2D(shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=<span class="hljs-string">&#x27;int32&#x27;</span>)})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],
[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],
[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]])}<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="other-feature-types" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#other-feature-types"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Other feature types</span></h3> <p data-svelte-h="svelte-9al131"><a href="/docs/datasets/main/en/package_reference/main_classes#datasets.ClassLabel">ClassLabel</a> data are properly converted to tensors:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, ClassLabel
<span class="hljs-meta">&gt;&gt;&gt; </span>labels = [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;label&quot;</span>: ClassLabel(names=[<span class="hljs-string">&quot;negative&quot;</span>, <span class="hljs-string">&quot;positive&quot;</span>])})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;label&quot;</span>: labels}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">3</span>]
{<span class="hljs-string">&#x27;label&#x27;</span>: tensor([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>])}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hobffv">String and binary objects are unchanged, since PyTorch only supports numbers.</p> <p data-svelte-h="svelte-1g2r59q">The <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Image">Image</a> and <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Audio">Audio</a> feature types are also supported.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1go8nao">To use the <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Image">Image</a> feature type, you’ll need to install the <code>vision</code> extra as
<code>pip install datasets[vision]</code>.</p></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Audio, Image
<span class="hljs-meta">&gt;&gt;&gt; </span>images = [<span class="hljs-string">&quot;path/to/image.png&quot;</span>] * <span class="hljs-number">10</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;image&quot;</span>: Image()})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;image&quot;</span>: images}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;image&quot;</span>].shape
torch.Size([<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">4</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;image&#x27;</span>: tensor([[[<span class="hljs-number">255</span>, <span class="hljs-number">215</span>, <span class="hljs-number">106</span>, <span class="hljs-number">255</span>],
[<span class="hljs-number">255</span>, <span class="hljs-number">215</span>, <span class="hljs-number">106</span>, <span class="hljs-number">255</span>],
...,
[<span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
[<span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>]]], dtype=torch.uint8)}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>][<span class="hljs-string">&quot;image&quot;</span>].shape
torch.Size([<span class="hljs-number">2</span>, <span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">4</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;image&#x27;</span>: tensor([[[[<span class="hljs-number">255</span>, <span class="hljs-number">215</span>, <span class="hljs-number">106</span>, <span class="hljs-number">255</span>],
[<span class="hljs-number">255</span>, <span class="hljs-number">215</span>, <span class="hljs-number">106</span>, <span class="hljs-number">255</span>],
...,
[<span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
[<span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>]]]], dtype=torch.uint8)}<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-118qika">To use the <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Audio">Audio</a> feature type, you’ll need to install the <code>audio</code> extra as
<code>pip install datasets[audio]</code>.</p></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Audio, Image
<span class="hljs-meta">&gt;&gt;&gt; </span>audio = [<span class="hljs-string">&quot;path/to/audio.wav&quot;</span>] * <span class="hljs-number">10</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;audio&quot;</span>: Audio()})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;audio&quot;</span>: audio}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;audio&quot;</span>][<span class="hljs-string">&quot;array&quot;</span>]
tensor([ <span class="hljs-number">6.1035e-05</span>, <span class="hljs-number">1.5259e-05</span>, <span class="hljs-number">1.6785e-04</span>, ..., -<span class="hljs-number">1.5259e-05</span>,
-<span class="hljs-number">1.5259e-05</span>, <span class="hljs-number">1.5259e-05</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;audio&quot;</span>][<span class="hljs-string">&quot;sampling_rate&quot;</span>]
tensor(<span class="hljs-number">44100</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="data-loading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-loading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Data loading</span></h2> <p data-svelte-h="svelte-12en5kh">Like <code>torch.utils.data.Dataset</code> objects, a <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset">Dataset</a> can be passed directly to a PyTorch <code>DataLoader</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
<span class="hljs-meta">&gt;&gt;&gt; </span>data = np.random.rand(<span class="hljs-number">16</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>label = np.random.randint(<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, size=<span class="hljs-number">16</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data, <span class="hljs-string">&quot;label&quot;</span>: label}).with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dataloader = DataLoader(ds, batch_size=<span class="hljs-number">4</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> dataloader:
<span class="hljs-meta">... </span> <span class="hljs-built_in">print</span>(batch)
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([<span class="hljs-number">0.0047</span>, <span class="hljs-number">0.4979</span>, <span class="hljs-number">0.6726</span>, <span class="hljs-number">0.8105</span>]), <span class="hljs-string">&#x27;label&#x27;</span>: tensor([<span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>])}
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([<span class="hljs-number">0.4832</span>, <span class="hljs-number">0.2723</span>, <span class="hljs-number">0.4259</span>, <span class="hljs-number">0.2224</span>]), <span class="hljs-string">&#x27;label&#x27;</span>: tensor([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>])}
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([<span class="hljs-number">0.5837</span>, <span class="hljs-number">0.3444</span>, <span class="hljs-number">0.4658</span>, <span class="hljs-number">0.6417</span>]), <span class="hljs-string">&#x27;label&#x27;</span>: tensor([<span class="hljs-number">0</span>, <span class="hljs-number">1</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>])}
{<span class="hljs-string">&#x27;data&#x27;</span>: tensor([<span class="hljs-number">0.7022</span>, <span class="hljs-number">0.1225</span>, <span class="hljs-number">0.7228</span>, <span class="hljs-number">0.8259</span>]), <span class="hljs-string">&#x27;label&#x27;</span>: tensor([<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>])}<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="optimize-data-loading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#optimize-data-loading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Optimize data loading</span></h3> <p data-svelte-h="svelte-nw1d6c">There are several ways you can increase the speed your data is loaded which can save you time, especially if you are working with large datasets.
PyTorch offers parallelized data loading, retrieving batches of indices instead of individually, and streaming to iterate over the dataset without downloading it on disk.</p> <h4 class="relative group"><a id="use-multiple-workers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-multiple-workers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use multiple Workers</span></h4> <p data-svelte-h="svelte-1wyypch">You can parallelize data loading with the <code>num_workers</code> argument of a PyTorch <code>DataLoader</code> and get a higher throughput.</p> <p data-svelte-h="svelte-1u9vub">Under the hood, the <code>DataLoader</code> starts <code>num_workers</code> processes.
Each process reloads the dataset passed to the <code>DataLoader</code> and is used to query examples.
Reloading the dataset inside a worker doesn’t fill up your RAM, since it simply memory-maps the dataset again from your disk.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, load_from_disk
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
<span class="hljs-meta">&gt;&gt;&gt; </span>data = np.random.rand(<span class="hljs-number">10_000</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data}).save_to_disk(<span class="hljs-string">&quot;my_dataset&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = load_from_disk(<span class="hljs-string">&quot;my_dataset&quot;</span>).with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dataloader = DataLoader(ds, batch_size=<span class="hljs-number">32</span>, num_workers=<span class="hljs-number">4</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="stream-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#stream-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Stream data</span></h3> <p data-svelte-h="svelte-a5l4ib">Stream a dataset by loading it as an <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.IterableDataset">IterableDataset</a>. This allows you to progressively iterate over a remote dataset without downloading it on disk and or over local data files.
Learn more about which type of dataset is best for your use case in the <a href="./about_mapstyle_vs_iterable">choosing between a regular dataset or an iterable dataset</a> guide.</p> <p data-svelte-h="svelte-1t3p6zv">An iterable dataset from <code>datasets</code> inherits from <code>torch.utils.data.IterableDataset</code> so you can pass it to a <code>torch.utils.data.DataLoader</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, load_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
<span class="hljs-meta">&gt;&gt;&gt; </span>data = np.random.rand(<span class="hljs-number">10_000</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data}).push_to_hub(<span class="hljs-string">&quot;&lt;username&gt;/my_dataset&quot;</span>) <span class="hljs-comment"># Upload to the Hugging Face Hub</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>my_iterable_dataset = load_dataset(<span class="hljs-string">&quot;&lt;username&gt;/my_dataset&quot;</span>, streaming=<span class="hljs-literal">True</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dataloader = DataLoader(my_iterable_dataset, batch_size=<span class="hljs-number">32</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-gqi02o">If the dataset is split in several shards (i.e. if the dataset consists of multiple data files), then you can stream in parallel using <code>num_workers</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>my_iterable_dataset = load_dataset(<span class="hljs-string">&quot;deepmind/code_contests&quot;</span>, streaming=<span class="hljs-literal">True</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>my_iterable_dataset.n_shards
<span class="hljs-number">39</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>dataloader = DataLoader(my_iterable_dataset, batch_size=<span class="hljs-number">32</span>, num_workers=<span class="hljs-number">4</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-i7csup">In this case each worker is given a subset of the list of shards to stream from.</p> <h3 class="relative group"><a id="checkpoint-and-resume" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#checkpoint-and-resume"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Checkpoint and resume</span></h3> <p data-svelte-h="svelte-1x5nhsq">If you need a DataLoader that you can checkpoint and resume in the middle of training, you can use the <code>StatefulDataLoader</code> from <a href="https://github.com/pytorch/data" rel="nofollow">torchdata</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torchdata.stateful_dataloader <span class="hljs-keyword">import</span> StatefulDataLoader
<span class="hljs-meta">&gt;&gt;&gt; </span>my_iterable_dataset = load_dataset(<span class="hljs-string">&quot;deepmind/code_contests&quot;</span>, streaming=<span class="hljs-literal">True</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dataloader = StatefulDataLoader(my_iterable_dataset, batch_size=<span class="hljs-number">32</span>, num_workers=<span class="hljs-number">4</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># save in the middle of training</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>state_dict = dataloader.state_dict()
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># and resume later</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>dataloader.load_state_dict(state_dict)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1f8p0ie">This is possible thanks to <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.IterableDataset.state_dict">IterableDataset.state_dict()</a> and <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.IterableDataset.load_state_dict">IterableDataset.load_state_dict()</a>.</p> <h3 class="relative group"><a id="distributed" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#distributed"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Distributed</span></h3> <p data-svelte-h="svelte-1xpq9l0">To split your dataset across your training nodes, you can use <a href="/docs/datasets/main/en/package_reference/main_classes#datasets.distributed.split_dataset_by_node">datasets.distributed.split_dataset_by_node()</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">from</span> datasets.distributed <span class="hljs-keyword">import</span> split_dataset_by_node
ds = split_dataset_by_node(ds, rank=<span class="hljs-built_in">int</span>(os.environ[<span class="hljs-string">&quot;RANK&quot;</span>]), world_size=<span class="hljs-built_in">int</span>(os.environ[<span class="hljs-string">&quot;WORLD_SIZE&quot;</span>]))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-utgciv">This works for both map-style datasets and iterable datasets.
The dataset is split for the node at rank <code>rank</code> in a pool of nodes of size <code>world_size</code>.</p> <p data-svelte-h="svelte-1a3gkys">For map-style datasets:</p> <p data-svelte-h="svelte-41cx6v">Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.</p> <p data-svelte-h="svelte-1kujsme">For iterable datasets:</p> <p data-svelte-h="svelte-nsc411">If the dataset has a number of shards that is a factor of <code>world_size</code> (i.e. if <code>dataset.n_shards % world_size == 0</code>),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of <code>world_size</code>, skipping the other examples.</p> <p data-svelte-h="svelte-19jtkan">This can also be combined with a <code>torch.utils.data.DataLoader</code> if you want each node to use multiple workers to load the data.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/datasets/blob/main/docs/source/use_with_pytorch.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_w3org2 = {
assets: "/docs/datasets/main/en",
base: "/docs/datasets/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/datasets/main/en/_app/immutable/entry/start.4d44eea4.js"),
import("/docs/datasets/main/en/_app/immutable/entry/app.d83067e8.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 49],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
59.3 kB
·
Xet hash:
63600c321dcb5fe7833dddc335f5fbef8751569f90b7059024e2458e90e6c3e0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.