Buckets:

download
raw
84.9 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Launching distributed training from Jupyter Notebooks&quot;,&quot;local&quot;:&quot;launching-distributed-training-from-jupyter-notebooks&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Configuring the Environment&quot;,&quot;local&quot;:&quot;configuring-the-environment&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preparing the Dataset and Model&quot;,&quot;local&quot;:&quot;preparing-the-dataset-and-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Writing the Training Function&quot;,&quot;local&quot;:&quot;writing-the-training-function&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using the notebook_launcher&quot;,&quot;local&quot;:&quot;using-the-notebooklauncher&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Debugging&quot;,&quot;local&quot;:&quot;debugging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Conclusion&quot;,&quot;local&quot;:&quot;conclusion&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/accelerate/pr_4021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/scheduler.b9285784.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/singletons.7547c222.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.6d423e5c.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/paths.d42c9205.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/preload-helper.b0bd19d1.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.26bc89a1.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/0.0e7c56e8.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/6.afe78a43.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/Tip.e4eba3d6.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a0ae628.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/CodeBlock.844ff9c3.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Launching distributed training from Jupyter Notebooks&quot;,&quot;local&quot;:&quot;launching-distributed-training-from-jupyter-notebooks&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Configuring the Environment&quot;,&quot;local&quot;:&quot;configuring-the-environment&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Preparing the Dataset and Model&quot;,&quot;local&quot;:&quot;preparing-the-dataset-and-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Writing the Training Function&quot;,&quot;local&quot;:&quot;writing-the-training-function&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using the notebook_launcher&quot;,&quot;local&quot;:&quot;using-the-notebooklauncher&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Debugging&quot;,&quot;local&quot;:&quot;debugging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Conclusion&quot;,&quot;local&quot;:&quot;conclusion&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="launching-distributed-training-from-jupyter-notebooks" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#launching-distributed-training-from-jupyter-notebooks"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Launching distributed training from Jupyter Notebooks</span></h1> <p data-svelte-h="svelte-1j74re8">This tutorial teaches you how to fine-tune a computer vision model with 🤗 Accelerate from a Jupyter Notebook on a distributed system.
You will also learn how to set up a few requirements needed for ensuring your environment is configured properly, your data has been prepared properly, and finally how to launch training.</p> <blockquote class="tip"><p data-svelte-h="svelte-1tjynav">This tutorial is also available as a Jupyter Notebook <a href="https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_cv_example.ipynb" rel="nofollow">here</a></p></blockquote> <h2 class="relative group"><a id="configuring-the-environment" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#configuring-the-environment"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Configuring the Environment</span></h2> <p data-svelte-h="svelte-1wr7aay">Before any training can be performed, an Accelerate config file must exist in the system. Usually this can be done by running the following in a terminal and answering the prompts:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jph9ta">However, if general defaults are fine and you are <em>not</em> running on a TPU, Accelerate has a utility to quickly write your device configuration into a config file via <a href="/docs/accelerate/pr_4021/en/package_reference/utilities#accelerate.commands.config.default.write_basic_config">utils.write_basic_config()</a>.</p> <p data-svelte-h="svelte-1gywp69">The following code will restart Jupyter after writing the configuration, as CUDA runtime or XPU runtime was called to perform this.</p> <blockquote class="warning"><p data-svelte-h="svelte-kppmtb">CUDA and XPU can’t be initialized more than once on a multi-device system. It’s fine to debug in the notebook and have calls to CUDA/XPU, but in order to finally train a full cleanup and restart will need to be performed.</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> write_basic_config
write_basic_config() <span class="hljs-comment"># Write a config file</span>
os._exit(<span class="hljs-number">00</span>) <span class="hljs-comment"># Restart the notebook</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="preparing-the-dataset-and-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preparing-the-dataset-and-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preparing the Dataset and Model</span></h2> <p data-svelte-h="svelte-1mtlp3l">Next you should prepare your dataset. As mentioned earlier, great care should be taken when preparing the <code>DataLoaders</code> and model to make sure that <strong>nothing</strong> is put on <em>any</em> GPU.</p> <p data-svelte-h="svelte-1gkh70d">If you do, it is recommended to put that specific code into a function and call that from within the notebook launcher interface, which will be shown later.</p> <p data-svelte-h="svelte-myu7hv">Make sure the dataset is downloaded based on the directions <a href="https://github.com/huggingface/accelerate/tree/main/examples#simple-vision-example" rel="nofollow">here</a></p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> os, re, torch, PIL
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">from</span> torch.optim.lr_scheduler <span class="hljs-keyword">import</span> OneCycleLR
<span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader, Dataset
<span class="hljs-keyword">from</span> torchvision.transforms <span class="hljs-keyword">import</span> Compose, RandomResizedCrop, Resize, ToTensor
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> set_seed
<span class="hljs-keyword">from</span> timm <span class="hljs-keyword">import</span> create_model<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-h07c4l">First you need to create a function to extract the class name based on a filename:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> os
data_dir = <span class="hljs-string">&quot;../../images&quot;</span>
fnames = os.listdir(data_dir)
fname = fnames[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(fname)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->beagle_32.jpg<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-60ulbw">In the case here, the label is <code>beagle</code>. Using regex you can extract the label from the filename:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> re
<span class="hljs-keyword">def</span> <span class="hljs-title function_">extract_label</span>(<span class="hljs-params">fname</span>):
stem = fname.split(os.path.sep)[-<span class="hljs-number">1</span>]
<span class="hljs-keyword">return</span> re.search(<span class="hljs-string">r&quot;^(.*)_\d+\.jpg$&quot;</span>, stem).groups()[<span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->extract_label(fname)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1kzkru5">And you can see it properly returned the right name for our file:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">&quot;beagle&quot;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-eeriqt">Next a <code>Dataset</code> class should be made to handle grabbing the image and the label:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">class</span> <span class="hljs-title class_">PetsDataset</span>(<span class="hljs-title class_ inherited__">Dataset</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, file_names, image_transform=<span class="hljs-literal">None</span>, label_to_id=<span class="hljs-literal">None</span></span>):
self.file_names = file_names
self.image_transform = image_transform
self.label_to_id = label_to_id
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__len__</span>(<span class="hljs-params">self</span>):
<span class="hljs-keyword">return</span> <span class="hljs-built_in">len</span>(self.file_names)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__getitem__</span>(<span class="hljs-params">self, idx</span>):
fname = self.file_names[idx]
raw_image = PIL.Image.<span class="hljs-built_in">open</span>(fname)
image = raw_image.convert(<span class="hljs-string">&quot;RGB&quot;</span>)
<span class="hljs-keyword">if</span> self.image_transform <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
image = self.image_transform(image)
label = extract_label(fname)
<span class="hljs-keyword">if</span> self.label_to_id <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
label = self.label_to_id[label]
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;image&quot;</span>: image, <span class="hljs-string">&quot;label&quot;</span>: label}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1erq0yi">Now to build the dataset. Outside the training function you can find and declare all the filenames and labels and use them as references inside the
launched function:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->fnames = [os.path.join(<span class="hljs-string">&quot;../../images&quot;</span>, fname) <span class="hljs-keyword">for</span> fname <span class="hljs-keyword">in</span> fnames <span class="hljs-keyword">if</span> fname.endswith(<span class="hljs-string">&quot;.jpg&quot;</span>)]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17fsfbj">Next gather all the labels:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->all_labels = [extract_label(fname) <span class="hljs-keyword">for</span> fname <span class="hljs-keyword">in</span> fnames]
id_to_label = <span class="hljs-built_in">list</span>(<span class="hljs-built_in">set</span>(all_labels))
id_to_label.sort()
label_to_id = {lbl: i <span class="hljs-keyword">for</span> i, lbl <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(id_to_label)}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1wyd1v3">Next, you should make a <code>get_dataloaders</code> function that will return your built dataloaders for you. As mentioned earlier, if data is automatically
sent to the GPU or a TPU device when building your <code>DataLoaders</code>, they must be built using this method.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">get_dataloaders</span>(<span class="hljs-params">batch_size: <span class="hljs-built_in">int</span> = <span class="hljs-number">64</span></span>):
<span class="hljs-string">&quot;Builds a set of dataloaders with a batch_size&quot;</span>
random_perm = np.random.permutation(<span class="hljs-built_in">len</span>(fnames))
cut = <span class="hljs-built_in">int</span>(<span class="hljs-number">0.8</span> * <span class="hljs-built_in">len</span>(fnames))
train_split = random_perm[:cut]
eval_split = random_perm[cut:]
<span class="hljs-comment"># For training a simple RandomResizedCrop will be used</span>
train_tfm = Compose([RandomResizedCrop((<span class="hljs-number">224</span>, <span class="hljs-number">224</span>), scale=(<span class="hljs-number">0.5</span>, <span class="hljs-number">1.0</span>)), ToTensor()])
train_dataset = PetsDataset([fnames[i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> train_split], image_transform=train_tfm, label_to_id=label_to_id)
<span class="hljs-comment"># For evaluation a deterministic Resize will be used</span>
eval_tfm = Compose([Resize((<span class="hljs-number">224</span>, <span class="hljs-number">224</span>)), ToTensor()])
eval_dataset = PetsDataset([fnames[i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> eval_split], image_transform=eval_tfm, label_to_id=label_to_id)
<span class="hljs-comment"># Instantiate dataloaders</span>
train_dataloader = DataLoader(train_dataset, shuffle=<span class="hljs-literal">True</span>, batch_size=batch_size, num_workers=<span class="hljs-number">4</span>)
eval_dataloader = DataLoader(eval_dataset, shuffle=<span class="hljs-literal">False</span>, batch_size=batch_size * <span class="hljs-number">2</span>, num_workers=<span class="hljs-number">4</span>)
<span class="hljs-keyword">return</span> train_dataloader, eval_dataloader<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s9102n">Finally, you should import the scheduler to be used later:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch.optim.lr_scheduler <span class="hljs-keyword">import</span> CosineAnnealingLR<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="writing-the-training-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#writing-the-training-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Writing the Training Function</span></h2> <p data-svelte-h="svelte-k7ji6m">Now you can build the training loop. <a href="/docs/accelerate/pr_4021/en/package_reference/launchers#accelerate.notebook_launcher">notebook_launcher()</a> works by passing in a function to call that will be ran across the distributed system.</p> <p data-svelte-h="svelte-oa85ci">Here is a basic training loop for the animal classification problem:</p> <blockquote class="tip"><p data-svelte-h="svelte-1eykicx">The code has been split up to allow for explanations on each section. A full version that can be copy and pasted will be available at the end</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">training_loop</span>(<span class="hljs-params">mixed_precision=<span class="hljs-string">&quot;fp16&quot;</span>, seed: <span class="hljs-built_in">int</span> = <span class="hljs-number">42</span>, batch_size: <span class="hljs-built_in">int</span> = <span class="hljs-number">64</span></span>):
set_seed(seed)
accelerator = Accelerator(mixed_precision=mixed_precision)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-b1peea">First you should set the seed and create an <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> object as early in the training loop as possible.</p> <blockquote class="warning"><p data-svelte-h="svelte-1aol2se">If training on the TPU, your training loop should take in the model as a parameter and it should be instantiated
outside of the training loop function. See the <a href="../concept_guides/training_tpu">TPU best practices</a>
to learn why</p></blockquote> <p data-svelte-h="svelte-1ojop43">Next you should build your dataloaders and create your model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> train_dataloader, eval_dataloader = get_dataloaders(batch_size)
model = create_model(<span class="hljs-string">&quot;resnet50d&quot;</span>, pretrained=<span class="hljs-literal">True</span>, num_classes=<span class="hljs-built_in">len</span>(label_to_id))<!-- HTML_TAG_END --></pre></div> <blockquote class="tip"><p data-svelte-h="svelte-j432sg">You build the model here so that the seed also controls the new weight initialization</p></blockquote> <p data-svelte-h="svelte-2qfnag">As you are performing transfer learning in this example, the encoder of the model starts out frozen so the head of the model can be
trained only initially:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> <span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> model.parameters():
param.requires_grad = <span class="hljs-literal">False</span>
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> model.get_classifier().parameters():
param.requires_grad = <span class="hljs-literal">True</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16kq69y">Normalizing the batches of images will make training a little faster:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> mean = torch.tensor(model.default_cfg[<span class="hljs-string">&quot;mean&quot;</span>])[<span class="hljs-literal">None</span>, :, <span class="hljs-literal">None</span>, <span class="hljs-literal">None</span>]
std = torch.tensor(model.default_cfg[<span class="hljs-string">&quot;std&quot;</span>])[<span class="hljs-literal">None</span>, :, <span class="hljs-literal">None</span>, <span class="hljs-literal">None</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1x7hf7y">To make these constants available on the active device, you should set it to the Accelerator’s device:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> mean = mean.to(accelerator.device)
std = std.to(accelerator.device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1drdkyl">Next instantiate the rest of the PyTorch classes used for training:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> optimizer = torch.optim.Adam(params=model.parameters(), lr=<span class="hljs-number">3e-2</span> / <span class="hljs-number">25</span>)
lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=<span class="hljs-number">3e-2</span>, epochs=<span class="hljs-number">5</span>, steps_per_epoch=<span class="hljs-built_in">len</span>(train_dataloader))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ohcs5p">Before passing everything to <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.prepare">prepare()</a>.</p> <blockquote class="tip"><p data-svelte-h="svelte-1y658j6">There is no specific order to remember, you just need to unpack the objects in the same order you gave them to the prepare method.</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-8iqamy">Now train the model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">5</span>):
model.train()
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
inputs = (batch[<span class="hljs-string">&quot;image&quot;</span>] - mean) / std
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, batch[<span class="hljs-string">&quot;label&quot;</span>])
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-uq3liu">The evaluation loop will look slightly different compared to the training loop. The number of elements passed as well as the overall
total accuracy of each batch will be added to two constants:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> model.<span class="hljs-built_in">eval</span>()
accurate = <span class="hljs-number">0</span>
num_elems = <span class="hljs-number">0</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vssgg7">Next you have the rest of your standard PyTorch loop:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader:
inputs = (batch[<span class="hljs-string">&quot;image&quot;</span>] - mean) / std
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-o67l3v">Before finally the last major difference.</p> <p data-svelte-h="svelte-1unx2og">When performing distributed evaluation, the predictions and labels need to be passed through
<a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.gather">gather()</a> so that all of the data is available on the current device and a properly calculated metric can be achieved:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch[<span class="hljs-string">&quot;label&quot;</span>])
num_elems += accurate_preds.shape[<span class="hljs-number">0</span>]
accurate += accurate_preds.long().<span class="hljs-built_in">sum</span>()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fzdj38">Now you just need to calculate the actual metric for this problem, and you can print it on the main process using <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.print">print()</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> eval_metric = accurate.item() / num_elems
accelerator.<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;epoch <span class="hljs-subst">{epoch}</span>: <span class="hljs-subst">{<span class="hljs-number">100</span> * eval_metric:<span class="hljs-number">.2</span>f}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1kobhd">A full version of this training loop is available below:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">training_loop</span>(<span class="hljs-params">mixed_precision=<span class="hljs-string">&quot;fp16&quot;</span>, seed: <span class="hljs-built_in">int</span> = <span class="hljs-number">42</span>, batch_size: <span class="hljs-built_in">int</span> = <span class="hljs-number">64</span></span>):
set_seed(seed)
<span class="hljs-comment"># Initialize accelerator</span>
accelerator = Accelerator(mixed_precision=mixed_precision)
<span class="hljs-comment"># Build dataloaders</span>
train_dataloader, eval_dataloader = get_dataloaders(batch_size)
<span class="hljs-comment"># Instantiate the model (you build the model here so that the seed also controls new weight initializations)</span>
model = create_model(<span class="hljs-string">&quot;resnet50d&quot;</span>, pretrained=<span class="hljs-literal">True</span>, num_classes=<span class="hljs-built_in">len</span>(label_to_id))
<span class="hljs-comment"># Freeze the base model</span>
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> model.parameters():
param.requires_grad = <span class="hljs-literal">False</span>
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> model.get_classifier().parameters():
param.requires_grad = <span class="hljs-literal">True</span>
<span class="hljs-comment"># You can normalize the batches of images to be a bit faster</span>
mean = torch.tensor(model.default_cfg[<span class="hljs-string">&quot;mean&quot;</span>])[<span class="hljs-literal">None</span>, :, <span class="hljs-literal">None</span>, <span class="hljs-literal">None</span>]
std = torch.tensor(model.default_cfg[<span class="hljs-string">&quot;std&quot;</span>])[<span class="hljs-literal">None</span>, :, <span class="hljs-literal">None</span>, <span class="hljs-literal">None</span>]
<span class="hljs-comment"># To make these constants available on the active device, set it to the accelerator device</span>
mean = mean.to(accelerator.device)
std = std.to(accelerator.device)
<span class="hljs-comment"># Instantiate the optimizer</span>
optimizer = torch.optim.Adam(params=model.parameters(), lr=<span class="hljs-number">3e-2</span> / <span class="hljs-number">25</span>)
<span class="hljs-comment"># Instantiate the learning rate scheduler</span>
lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=<span class="hljs-number">3e-2</span>, epochs=<span class="hljs-number">5</span>, steps_per_epoch=<span class="hljs-built_in">len</span>(train_dataloader))
<span class="hljs-comment"># Prepare everything</span>
<span class="hljs-comment"># There is no specific order to remember, you just need to unpack the objects in the same order you gave them to the</span>
<span class="hljs-comment"># prepare method.</span>
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
<span class="hljs-comment"># Now you train the model</span>
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">5</span>):
model.train()
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
inputs = (batch[<span class="hljs-string">&quot;image&quot;</span>] - mean) / std
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, batch[<span class="hljs-string">&quot;label&quot;</span>])
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
model.<span class="hljs-built_in">eval</span>()
accurate = <span class="hljs-number">0</span>
num_elems = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader:
inputs = (batch[<span class="hljs-string">&quot;image&quot;</span>] - mean) / std
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-<span class="hljs-number">1</span>)
accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch[<span class="hljs-string">&quot;label&quot;</span>])
num_elems += accurate_preds.shape[<span class="hljs-number">0</span>]
accurate += accurate_preds.long().<span class="hljs-built_in">sum</span>()
eval_metric = accurate.item() / num_elems
<span class="hljs-comment"># Use accelerator.print to print only on the main process.</span>
accelerator.<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;epoch <span class="hljs-subst">{epoch}</span>: <span class="hljs-subst">{<span class="hljs-number">100</span> * eval_metric:<span class="hljs-number">.2</span>f}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="using-the-notebooklauncher" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-the-notebooklauncher"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using the notebook_launcher</span></h2> <p data-svelte-h="svelte-1pnm5tu">All that’s left is to use the <a href="/docs/accelerate/pr_4021/en/package_reference/launchers#accelerate.notebook_launcher">notebook_launcher()</a>.</p> <p data-svelte-h="svelte-7o5jv1">You pass in the function, the arguments (as a tuple), and the number of processes to train on. (See the <a href="../package_reference/launchers">documentation</a> for more information)</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->args = (<span class="hljs-string">&quot;fp16&quot;</span>, <span class="hljs-number">42</span>, <span class="hljs-number">64</span>)
notebook_launcher(training_loop, args, num_processes=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-rkokvk">In the case of running on multiple nodes, you need to set up a Jupyter session at each node and run the launching cell at the same time.</p> <p data-svelte-h="svelte-192rw96">For an environment containing 2 nodes (computers) with 8 GPUs each and the main computer with an IP address of “172.31.43.8”, it would look like so:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->notebook_launcher(training_loop, args, master_addr=<span class="hljs-string">&quot;172.31.43.8&quot;</span>, node_rank=<span class="hljs-number">0</span>, num_nodes=<span class="hljs-number">2</span>, num_processes=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3ig3ul">And in the second Jupyter session on the other machine:</p> <blockquote class="tip"><p data-svelte-h="svelte-1dj1i7x">Notice how the <code>node_rank</code> has changed</p></blockquote> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->notebook_launcher(training_loop, args, master_addr=<span class="hljs-string">&quot;172.31.43.8&quot;</span>, node_rank=<span class="hljs-number">1</span>, num_nodes=<span class="hljs-number">2</span>, num_processes=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10flsnd">In the case of running on the TPU, it would look like so:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = create_model(<span class="hljs-string">&quot;resnet50d&quot;</span>, pretrained=<span class="hljs-literal">True</span>, num_classes=<span class="hljs-built_in">len</span>(label_to_id))
args = (model, <span class="hljs-string">&quot;fp16&quot;</span>, <span class="hljs-number">42</span>, <span class="hljs-number">64</span>)
notebook_launcher(training_loop, args, num_processes=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lkd19f">To launch the training process with elasticity, enabling fault tolerance, you can use the <code>elastic_launch</code> feature provided by PyTorch. This requires setting additional parameters such as <code>rdzv_backend</code> and <code>max_restarts</code>. Here is an example of how to use <code>notebook_launcher</code> with elastic capabilities:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->notebook_launcher(
training_loop,
args,
num_processes=<span class="hljs-number">2</span>,
max_restarts=<span class="hljs-number">3</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-svq1hp">As it’s running it will print the progress as well as state how many devices you ran on. This tutorial was ran with two GPUs:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->Launching training on <span class="hljs-number">2</span> GPUs.
epoch <span class="hljs-number">0</span>: <span class="hljs-number">88.12</span>
epoch <span class="hljs-number">1</span>: <span class="hljs-number">91.73</span>
epoch <span class="hljs-number">2</span>: <span class="hljs-number">92.58</span>
epoch <span class="hljs-number">3</span>: <span class="hljs-number">93.90</span>
epoch <span class="hljs-number">4</span>: <span class="hljs-number">94.71</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9wooxy">And that’s it!</p> <p data-svelte-h="svelte-9dwmdu">Please note that <a href="/docs/accelerate/pr_4021/en/package_reference/launchers#accelerate.notebook_launcher">notebook_launcher()</a> ignores the Accelerate config file, to launch based on the config use:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="debugging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#debugging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Debugging</span></h2> <p data-svelte-h="svelte-g8b06q">A common issue when running the <code>notebook_launcher</code> is receiving a CUDA/XPU has already been initialized issue. This usually stems
from an import or prior code in the notebook that makes a call to the PyTorch <code>torch.cuda</code> or <code>torch.xpu</code> sublibrary. To help narrow down what went wrong,
you can launch the <code>notebook_launcher</code> with <code>ACCELERATE_DEBUG_MODE=yes</code> in your environment and an additional check
will be made when spawning that a regular process can be created and utilize CUDA/XPU without issue. (Your CUDA/XPU code can still be ran afterwards).</p> <h2 class="relative group"><a id="conclusion" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#conclusion"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Conclusion</span></h2> <p data-svelte-h="svelte-ixf1ne">This notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:</p> <ul data-svelte-h="svelte-1aesmj1"><li>Make sure to save any code that use CUDA/XPU (or CUDA/XPU imports) for the function passed to <a href="/docs/accelerate/pr_4021/en/package_reference/launchers#accelerate.notebook_launcher">notebook_launcher()</a></li> <li>Set the <code>num_processes</code> to be the number of devices used for training (such as number of GPUs, XPUs, CPUs, TPUs, etc)</li> <li>If using the TPU, declare your model outside the training loop function</li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/basic_tutorials/notebook.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1q7nz6m = {
assets: "/docs/accelerate/pr_4021/en",
base: "/docs/accelerate/pr_4021/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"),
import("/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 6],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
84.9 kB
·
Xet hash:
f99de48dc747cd0cf3bc64c1d71ad7afe70a7e5dacfa667529513c5a872f7793

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.