Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Instantiate a big model","local":"instantiate-a-big-model","sections":[{"title":"Sharded checkpoints","local":"sharded-checkpoints","sections":[{"title":"Shard metadata","local":"shard-metadata","sections":[],"depth":3}],"depth":2},{"title":"Accelerate’s Big Model Inference","local":"accelerates-big-model-inference","sections":[],"depth":2},{"title":"Model data type","local":"model-data-type","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/11.c9ce39a7.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/Tip.baa67368.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/HfOption.1e589c90.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/stores.c3f24f16.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Instantiate a big model","local":"instantiate-a-big-model","sections":[{"title":"Sharded checkpoints","local":"sharded-checkpoints","sections":[{"title":"Shard metadata","local":"shard-metadata","sections":[],"depth":3}],"depth":2},{"title":"Accelerate’s Big Model Inference","local":"accelerates-big-model-inference","sections":[],"depth":2},{"title":"Model data type","local":"model-data-type","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="instantiate-a-big-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#instantiate-a-big-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Instantiate a big model</span></h1> <p data-svelte-h="svelte-1qkl6uj">A barrier to accessing very large pretrained models is the amount of memory required. When loading a pretrained PyTorch model, you usually:</p> <ol data-svelte-h="svelte-18g3fgx"><li>Create a model with random weights.</li> <li>Load your pretrained weights.</li> <li>Put those pretrained weights in the model.</li></ol> <p data-svelte-h="svelte-1m58d3z">The first two steps both require a full version of the model in memory and if the model weighs several GBs, you may not have enough memory for two copies of it. This problem is amplified in distributed training environments because each process loads a pretrained model and stores two copies in memory.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-14vsd0r">The randomly created model is initialized with “empty” tensors, which take space in memory without filling it. The random values are whatever was in this chunk of memory at the time. To improve loading speed, the <a href="https://github.com/huggingface/transformers/blob/c9f6e5e35156e068b227dd9b15521767f6afd4d2/src/transformers/modeling_utils.py#L2710" rel="nofollow"><code>_fast_init</code></a> parameter is set to <code>True</code> by default to skip the random initialization for all weights that are correctly loaded.</p></div> <p data-svelte-h="svelte-1vqz77w">This guide will show you how Transformers can help you load large pretrained models despite their memory requirements.</p> <h2 class="relative group"><a id="sharded-checkpoints" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sharded-checkpoints"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sharded checkpoints</span></h2> <p data-svelte-h="svelte-1a3giak">From Transformers v4.18.0, a checkpoint larger than 10GB is automatically sharded by the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.save_pretrained">save_pretrained()</a> method. It is split into several smaller partial checkpoints and creates an index file that maps parameter names to the files they’re stored in.</p> <p data-svelte-h="svelte-1cy3m1z">The maximum shard size is controlled with the <code>max_shard_size</code> parameter, but by default it is 5GB, because it is easier to run on free-tier GPU instances without running out of memory.</p> <p data-svelte-h="svelte-1jmx6m5">For example, let’s shard <a href="https://hf.co/BioMistral/BioMistral-7B" rel="nofollow">BioMistral/BioMistral-7B</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">with</span> tempfile.TemporaryDirectory() <span class="hljs-keyword">as</span> tmp_dir: | |
| <span class="hljs-meta">... </span> model.save_pretrained(tmp_dir, max_shard_size=<span class="hljs-string">"5GB"</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-built_in">print</span>(<span class="hljs-built_in">sorted</span>(os.listdir(tmp_dir))) | |
| [<span class="hljs-string">'config.json'</span>, <span class="hljs-string">'generation_config.json'</span>, <span class="hljs-string">'model-00001-of-00006.safetensors'</span>, <span class="hljs-string">'model-00002-of-00006.safetensors'</span>, <span class="hljs-string">'model-00003-of-00006.safetensors'</span>, <span class="hljs-string">'model-00004-of-00006.safetensors'</span>, <span class="hljs-string">'model-00005-of-00006.safetensors'</span>, <span class="hljs-string">'model-00006-of-00006.safetensors'</span>, <span class="hljs-string">'model.safetensors.index.json'</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-p2dgzp">The sharded checkpoint is reloaded with the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> method.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">with</span> tempfile.TemporaryDirectory() <span class="hljs-keyword">as</span> tmp_dir: | |
| <span class="hljs-meta">... </span> model.save_pretrained(tmp_dir, max_shard_size=<span class="hljs-string">"5GB"</span>) | |
| <span class="hljs-meta">... </span> new_model = AutoModel.from_pretrained(tmp_dir)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11n2c68">The main advantage of sharded checkpoints for big models is that each shard is loaded after the previous one, which caps the memory usage to only the model size and the largest shard size.</p> <p data-svelte-h="svelte-1ltya10">You could also directly load a sharded checkpoint inside a model without the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> method (similar to PyTorch’s <code>load_state_dict()</code> method for a full checkpoint). In this case, use the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.modeling_utils.load_sharded_checkpoint">load_sharded_checkpoint()</a> method.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers.modeling_utils <span class="hljs-keyword">import</span> load_sharded_checkpoint | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">with</span> tempfile.TemporaryDirectory() <span class="hljs-keyword">as</span> tmp_dir: | |
| <span class="hljs-meta">... </span> model.save_pretrained(tmp_dir, max_shard_size=<span class="hljs-string">"5GB"</span>) | |
| <span class="hljs-meta">... </span> load_sharded_checkpoint(model, tmp_dir)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="shard-metadata" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#shard-metadata"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Shard metadata</span></h3> <p data-svelte-h="svelte-16iznko">The index file determines which keys are in the checkpoint and where the corresponding weights are stored. This file is loaded like any other JSON file and you can get a dictionary from it.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">import</span> json | |
| <span class="hljs-meta">>>> </span><span class="hljs-keyword">with</span> tempfile.TemporaryDirectory() <span class="hljs-keyword">as</span> tmp_dir: | |
| <span class="hljs-meta">... </span> model.save_pretrained(tmp_dir, max_shard_size=<span class="hljs-string">"5GB"</span>) | |
| <span class="hljs-meta">... </span> <span class="hljs-keyword">with</span> <span class="hljs-built_in">open</span>(os.path.join(tmp_dir, <span class="hljs-string">"model.safetensors.index.json"</span>), <span class="hljs-string">"r"</span>) <span class="hljs-keyword">as</span> f: | |
| <span class="hljs-meta">... </span> index = json.load(f) | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">print</span>(index.keys()) | |
| dict_keys([<span class="hljs-string">'metadata'</span>, <span class="hljs-string">'weight_map'</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-n8xk5n">The <code>metadata</code> key provides the total model size.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>index[<span class="hljs-string">"metadata"</span>] | |
| {<span class="hljs-string">'total_size'</span>: <span class="hljs-number">28966928384</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1mrv1h2">The <code>weight_map</code> key maps each parameter name (typically <code>state_dict</code> in a PyTorch model) to the shard it’s stored in.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span>index[<span class="hljs-string">"weight_map"</span>] | |
| {<span class="hljs-string">'lm_head.weight'</span>: <span class="hljs-string">'model-00006-of-00006.safetensors'</span>, | |
| <span class="hljs-string">'model.embed_tokens.weight'</span>: <span class="hljs-string">'model-00001-of-00006.safetensors'</span>, | |
| <span class="hljs-string">'model.layers.0.input_layernorm.weight'</span>: <span class="hljs-string">'model-00001-of-00006.safetensors'</span>, | |
| <span class="hljs-string">'model.layers.0.mlp.down_proj.weight'</span>: <span class="hljs-string">'model-00001-of-00006.safetensors'</span>, | |
| ... | |
| }<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="accelerates-big-model-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerates-big-model-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate’s Big Model Inference</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1th3v5y">Make sure you have Accelerate v0.9.0 or later and PyTorch v1.9.0 or later installed.</p></div> <p data-svelte-h="svelte-1gkhyzm">From Transformers v4.20.0, the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> method is supercharged with Accelerate’s <a href="https://hf.co/docs/accelerate/usage_guides/big_modeling" rel="nofollow">Big Model Inference</a> feature to efficiently handle really big models! Big Model Inference creates a <em>model skeleton</em> on PyTorch’s <a href="https://pytorch.org/docs/main/meta.html" rel="nofollow"><strong>meta</strong></a> device. The randomly initialized parameters are only created when the pretrained weights are loaded. This way, you aren’t keeping two copies of the model in memory at the same time (one for the randomly initialized model and one for the pretrained weights), and the maximum memory consumed is only the full model size.</p> <p data-svelte-h="svelte-1jo4ysw">To enable Big Model Inference in Transformers, set <code>low_cpu_mem_usage=True</code> in the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> method.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| gemma = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"google/gemma-7b"</span>, low_cpu_mem_usage=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-e24bel">Accelerate automatically dispatches the model weights across all available devices, starting with the fastest device (GPU) first and then offloading to the slower devices (CPU and even hard drive). This is enabled by setting <code>device_map="auto"</code> in the <a href="/docs/transformers/pr_33913/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> method. When you pass the <code>device_map</code> parameter, <code>low_cpu_mem_usage</code> is automatically set to <code>True</code> so you don’t need to specify it.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| <span class="hljs-comment"># these loading methods are equivalent</span> | |
| gemma = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"google/gemma-7b"</span>, device_map=<span class="hljs-string">"auto"</span>) | |
| gemma = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"google/gemma-7b"</span>, device_map=<span class="hljs-string">"auto"</span>, low_cpu_mem_usage=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15c7psq">You can also write your own <code>device_map</code> by mapping each layer to a device. It should map all model parameters to a device, but you don’t have to detail where all the submodules of a layer go if the entire layer is on the same device.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->device_map = {<span class="hljs-string">"model.layers.1"</span>: <span class="hljs-number">0</span>, <span class="hljs-string">"model.layers.14"</span>: <span class="hljs-number">1</span>, <span class="hljs-string">"model.layers.31"</span>: <span class="hljs-string">"cpu"</span>, <span class="hljs-string">"lm_head"</span>: <span class="hljs-string">"disk"</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1b9v8ra">Access <code>hf_device_map</code> attribute to see how Accelerate split the model across devices.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->gemma.hf_device_map<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'model.embed_tokens'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.0'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.1'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.2'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.3'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.4'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.5'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.6'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.7'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.8'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.9'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.10'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.11'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.12'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.13'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'model.layers.14'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.15'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.16'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.17'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.18'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.19'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.20'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.21'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.22'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.23'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.24'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.25'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.26'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.27'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.28'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.29'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.30'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.layers.31'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'model.norm'</span>: <span class="hljs-string">'cpu'</span>, | |
| <span class="hljs-string">'lm_head'</span>: <span class="hljs-string">'cpu'</span>}<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="model-data-type" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-data-type"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model data type</span></h2> <p data-svelte-h="svelte-wrnzpy">PyTorch model weights are normally instantiated as torch.float32 and it can be an issue if you try to load a model as a different data type. For example, you’d need twice as much memory to load the weights in torch.float32 and then again to load them in your desired data type, like torch.float16.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-10yedw6">Due to how PyTorch is designed, the <code>torch_dtype</code> parameter only supports floating data types.</p></div> <p data-svelte-h="svelte-62b14u">To avoid wasting memory like this, explicitly set the <code>torch_dtype</code> parameter to the desired data type or set <code>torch_dtype="auto"</code> to load the weights with the most optimal memory pattern (the data type is automatically derived from the model weights).</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">specific dtype </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">auto dtype </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| gemma = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"google/gemma-7b"</span>, torch_dtype=torch.float16)<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-hi16m2">You can also set the data type to use for models instantiated from scratch.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoConfig, AutoModel | |
| my_config = AutoConfig.from_pretrained(<span class="hljs-string">"google/gemma-2b"</span>, torch_dtype=torch.float16) | |
| model = AutoModel.from_config(my_config)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/big_models.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_z647wz = { | |
| assets: "/docs/transformers/pr_33913/en", | |
| base: "/docs/transformers/pr_33913/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"), | |
| import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 11], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 43.6 kB
- Xet hash:
- 6a1bd26ca05557c864e25eb1caadb571d959e698f2f1a9cab4b59db28d2d24cd
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.