Buckets:

rtrm's picture
download
raw
46.8 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Optimizing inference&quot;,&quot;local&quot;:&quot;optimizing-inference&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Static kv-cache and torch.compile&quot;,&quot;local&quot;:&quot;static-kv-cache-and-torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Decoding strategies&quot;,&quot;local&quot;:&quot;decoding-strategies&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Speculative decoding&quot;,&quot;local&quot;:&quot;speculative-decoding&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Prompt lookup decoding&quot;,&quot;local&quot;:&quot;prompt-lookup-decoding&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Attention&quot;,&quot;local&quot;:&quot;attention&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;FlashAttention-2&quot;,&quot;local&quot;:&quot;flashattention-2&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;PyTorch scaled dot product attention&quot;,&quot;local&quot;:&quot;pytorch-scaled-dot-product-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Quantization&quot;,&quot;local&quot;:&quot;quantization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/50.5e3ee4b7.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/HfOption.fb051768.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Optimizing inference&quot;,&quot;local&quot;:&quot;optimizing-inference&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Static kv-cache and torch.compile&quot;,&quot;local&quot;:&quot;static-kv-cache-and-torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Decoding strategies&quot;,&quot;local&quot;:&quot;decoding-strategies&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Speculative decoding&quot;,&quot;local&quot;:&quot;speculative-decoding&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Prompt lookup decoding&quot;,&quot;local&quot;:&quot;prompt-lookup-decoding&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Attention&quot;,&quot;local&quot;:&quot;attention&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;FlashAttention-2&quot;,&quot;local&quot;:&quot;flashattention-2&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;PyTorch scaled dot product attention&quot;,&quot;local&quot;:&quot;pytorch-scaled-dot-product-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Quantization&quot;,&quot;local&quot;:&quot;quantization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="optimizing-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#optimizing-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Optimizing inference</span></h1> <p data-svelte-h="svelte-1f9oh7n">Inference with large language models (LLMs) can be challenging because they have to store and handle billions of parameters. To load a 70B parameter <a href="https://hf.co/meta-llama/Llama-2-70b-hf" rel="nofollow">Llama 2</a> model, it requires 256GB of memory for full precision weights and 128GB of memory for half-precision weights. The most powerful GPUs today - the A100 and H100 - only have 80GB of memory.</p> <p data-svelte-h="svelte-1olgvni">On top of the memory requirements, inference is slow because LLMs are called repeatedly to generate the next token. The input sequence increases as generation progresses, which takes longer and longer to process.</p> <p data-svelte-h="svelte-16ypsvh">This guide will show you how to optimize LLM inference to accelerate generation and reduce memory usage.</p> <blockquote class="tip" data-svelte-h="svelte-aag96e"><p>Try out <a href="https://hf.co/docs/text-generation-inference" rel="nofollow">Text Generation Inference (TGI)</a>, a Hugging Face library dedicated to deploying and serving highly optimized LLMs for inference.</p></blockquote> <h2 class="relative group"><a id="static-kv-cache-and-torchcompile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#static-kv-cache-and-torchcompile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Static kv-cache and torch.compile</span></h2> <p data-svelte-h="svelte-1a9ib4g">LLMs compute key-value (kv) values for each input token, and it performs the same kv computation each time because the generated output becomes part of the input. However, performing the same kv computation every time is not very efficient.</p> <p data-svelte-h="svelte-7nmn0g">A <em>kv-cache</em> stores the past keys and values instead of recomputing them each time. As a result, the kv-cache is dynamic and it grows with each generation step which prevents you from taking advantage of <a href="./perf_torch_compile">torch.compile</a>, a powerful optimization method that fuses PyTorch code into optimized kernels.</p> <p data-svelte-h="svelte-1abkxhz">The <em>static kv-cache</em> solves this issue by pre-allocating the kv-cache size to a maximum value, so you can combine it with <a href="./perf_torch_compile">torch.compile</a> for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.</p> <blockquote class="warning" data-svelte-h="svelte-ehqgyg"><p>Follow this <a href="https://github.com/huggingface/transformers/issues/28981" rel="nofollow">issue</a> to track which models (Llama, Gemma, Mistral, etc.) support a static kv-cache and torch.compile.</p></blockquote> <p data-svelte-h="svelte-13l87jg">Depending on your task, there are several ways you can use the static kv-cache.</p> <ol data-svelte-h="svelte-1b7xosv"><li>For basic use cases, set <a href="https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation" rel="nofollow">cache_implementation</a> to <code>&quot;static&quot;</code> (recommended).</li> <li>For multi-turn generation or a custom generation loop, initialize and handle <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a> directly.</li> <li>For more unique hardware or use cases, it may be better to compile the entire <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> function into a single graph.</li></ol> <blockquote class="tip" data-svelte-h="svelte-1984c1r"><p>Regardless of how you use the static kv-cache and torch.compile, left-pad your inputs with <a href="https://hf.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of" rel="nofollow">pad_to_multiple_of</a> to a limited set of values to avoid shape-related recompilations.</p></blockquote> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">1. cache_implementation </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">2. StaticCache </div></div> <div class="language-select"><ol data-svelte-h="svelte-k5lr94"><li>Set the <a href="https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation" rel="nofollow">cache_implementation</a> to <code>&quot;static&quot;</code> in a models <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationConfig">GenerationConfig</a>.</li> <li>Call <a href="./perf_torch_compile">torch.compile</a> to compile the forward pass with the static kv-cache.</li></ol> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> os
os.environ[<span class="hljs-string">&quot;TOKENIZERS_PARALLELISM&quot;</span>] = <span class="hljs-string">&quot;false&quot;</span> <span class="hljs-comment"># To prevent long warnings :)</span>
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>, dtype=<span class="hljs-string">&quot;auto&quot;</span>, device_map=<span class="hljs-string">&quot;auto&quot;</span>)
model.generation_config.cache_implementation = <span class="hljs-string">&quot;static&quot;</span>
model.forward = torch.<span class="hljs-built_in">compile</span>(model.forward, mode=<span class="hljs-string">&quot;reduce-overhead&quot;</span>, fullgraph=<span class="hljs-literal">True</span>)
input_text = <span class="hljs-string">&quot;The theory of special relativity states &quot;</span>
input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device.<span class="hljs-built_in">type</span>)
outputs = model.generate(**input_ids)
<span class="hljs-built_in">print</span>(tokenizer.batch_decode(outputs, skip_special_tokens=<span class="hljs-literal">True</span>))
[<span class="hljs-string">&#x27;The theory of special relativity states 1. The speed of light is constant in all inertial reference&#x27;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hf8rms">Under the hood, <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> attempts to reuse the same cache object to avoid recompilation at each call, which is critical to get the most out of <a href="./perf_torch_compile">torch.compile</a>. Be aware of the following to avoid triggering recompilation or if generation is slower than expected.</p> <ol data-svelte-h="svelte-13b9uyo"><li>If the batch size changes or the maximum output length increases between calls, the cache is reinitialized and recompiled.</li> <li>The first several calls of the compiled function are slower because it is being compiled.</li></ol> </div> <h2 class="relative group"><a id="decoding-strategies" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#decoding-strategies"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Decoding strategies</span></h2> <p data-svelte-h="svelte-a4gxio">Decoding can also be optimized to accelerate generation. You can use a lightweight assistant model to generate candidate tokens faster than the LLM itself or you can use a variant of this decoding strategy that works especially well for input-grounded tasks.</p> <h3 class="relative group"><a id="speculative-decoding" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#speculative-decoding"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Speculative decoding</span></h3> <blockquote class="tip" data-svelte-h="svelte-8lyoq0"><p>For a more in-depth explanation, take a look at the <a href="https://hf.co/blog/assisted-generation" rel="nofollow">Assisted Generation: a new direction toward low-latency text generation</a> blog post!</p></blockquote> <p data-svelte-h="svelte-nv0ghv">For each input token, the model weights are loaded each time during the forward pass, which is slow and cumbersome when a model has billions of parameters. Speculative decoding alleviates this slowdown by using a second smaller and faster assistant model to generate candidate tokens that are verified by the larger model in a single forward pass. If the verified tokens are correct, the LLM essentially gets them for “free” without having to generate them itself. There is no degradation in accuracy because the verification forward pass ensures the same outputs are generated as if the LLM had generated them on its own.</p> <p data-svelte-h="svelte-1nm1j4o">To get the largest speed up, the assistant model should be a lot smaller than the LLM so that it can generate tokens quickly. The assistant and LLM model must also share the same tokenizer to avoid re-encoding and decoding tokens.</p> <blockquote class="warning" data-svelte-h="svelte-1u3gmsf"><p>Speculative decoding is only supported for the greedy search and sampling decoding strategies, and it doesn’t support batched inputs.</p></blockquote> <p data-svelte-h="svelte-8p8rt0">Enable speculative decoding by loading an assistant model and passing it to <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">greedy search </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">sampling </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-keyword">import</span> torch
device = Accelerator().device
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>)
inputs = tokenizer(<span class="hljs-string">&quot;Einstein&#x27;s theory of relativity states&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(device)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>, dtype=<span class="hljs-string">&quot;auto&quot;</span>).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-125m&quot;</span>).to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=<span class="hljs-literal">True</span>)
[<span class="hljs-string">&quot;Einstein&#x27;s theory of relativity states that the speed of light is constant. &quot;</span>]<!-- HTML_TAG_END --></pre></div> </div> <h3 class="relative group"><a id="prompt-lookup-decoding" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#prompt-lookup-decoding"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Prompt lookup decoding</span></h3> <p data-svelte-h="svelte-123yl06">Prompt lookup decoding is a variant of speculative decoding that is also compatible with greedy search and sampling. Prompt lookup works especially well for input-grounded tasks - such as summarization - where there is often overlapping words between the prompt and output. These overlapping n-grams are used as the LLM candidate tokens.</p> <p data-svelte-h="svelte-q83pq8">To enable prompt lookup decoding, specify the number of tokens that should be overlapping in the <a href="https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.prompt_lookup_num_tokens" rel="nofollow">prompt_lookup_num_tokens</a> parameter. Then pass this parameter to <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">greedy decoding </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">sampling </div></div> <div class="language-select"><div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-keyword">import</span> torch
device = Accelerator().device
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>)
inputs = tokenizer(<span class="hljs-string">&quot;The second law of thermodynamics states&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(device)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>, dtype=<span class="hljs-string">&quot;auto&quot;</span>).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-125m&quot;</span>).to(device)
outputs = model.generate(**inputs, prompt_lookup_num_tokens=<span class="hljs-number">3</span>)
<span class="hljs-built_in">print</span>(tokenizer.batch_decode(outputs, skip_special_tokens=<span class="hljs-literal">True</span>))
[<span class="hljs-string">&#x27;The second law of thermodynamics states that entropy increases with temperature. &#x27;</span>]<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention</span></h2> <p data-svelte-h="svelte-119vxvi">A known issue with transformer models is that the self-attention mechanism grows quadratically in compute and memory with the number of input tokens. This limitation is only magnified in LLMs which handles much longer sequences. To address this, try FlashAttention2 or PyTorch’s scaled dot product attention (SDPA), which are more memory efficient attention implementations.</p> <h3 class="relative group"><a id="flashattention-2" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#flashattention-2"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>FlashAttention-2</span></h3> <p data-svelte-h="svelte-1noxdna">FlashAttention and <a href="./perf_infer_gpu_one#flashattention-2">FlashAttention-2</a> break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to the GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead.</p> <p data-svelte-h="svelte-zw82mr">To use FlashAttention-2, set <a href="https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation" rel="nofollow">attn_implementation</a> to <code>&quot;flash_attention_2&quot;</code> in <a href="/docs/transformers/pr_33892/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> or set with <code>model.set_attention_implementation(&quot;flash_attention_2&quot;)</code> to dynamically update the <a href="./attention_interface">attention interface</a> after the model is loaded.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, BitsAndBytesConfig
quant_config = BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>)
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;google/gemma-2b&quot;</span>,
quantization_config=quant_config,
dtype=torch.bfloat16,
attn_implementation=<span class="hljs-string">&quot;flash_attention_2&quot;</span>,
)
<span class="hljs-comment"># Change the model&#x27;s attention dynamically after loading</span>
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;google/gemma-2b&quot;</span>,
quantization_config=quant_config,
dtype=torch.bfloat16
)
model.set_attention_implementation(<span class="hljs-string">&quot;flash_attention_2&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="pytorch-scaled-dot-product-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch-scaled-dot-product-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch scaled dot product attention</span></h3> <p data-svelte-h="svelte-16n93px">Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch’s C++ implementation. SDPA chooses the most performant attention algorithm if you’re using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.</p> <blockquote class="tip" data-svelte-h="svelte-12rxand"><p>SDPA automatically supports FlashAttention-2 as long as you have the latest PyTorch version installed.</p></blockquote> <p data-svelte-h="svelte-xcpcjc">Use the <a href="https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html" rel="nofollow">torch.nn.attention.sdpa_kernel</a> context manager to explicitly enable or disable any of the four attention algorithms. For example, use <code>SDPBackend.FLASH_ATTENTION</code> to enable FlashAttention.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> torch.nn.attention <span class="hljs-keyword">import</span> SDPBackend, sdpa_kernel
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;google/gemma-2b&quot;</span>,
dtype=torch.bfloat16,
)
<span class="hljs-keyword">with</span> sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="quantization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quantization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quantization</span></h2> <p data-svelte-h="svelte-1h1vcbj">Quantization reduces the size of model weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you’re constrained by GPU memory.</p> <p data-svelte-h="svelte-bsjn3b">If you aren’t limited by your GPU, you don’t necessarily need to quantize your model because it can increase latency slightly (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.</p> <blockquote class="tip" data-svelte-h="svelte-1easuy2"><p>There are many quantization libraries (see the <a href="./quantization">Quantization</a> guide for more details) available, such as Quanto, AQLM, VPTQ, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the <a href="https://hf.co/blog/overview-quantization-transformers" rel="nofollow">Overview of natively supported quantization schemes in 🤗 Transformers</a> blog post which compares AutoGPTQ and bitsandbytes.</p></blockquote> <p data-svelte-h="svelte-14tgx8l">Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating the memory required to load <a href="https://hf.co/mistralai/Mistral-7B-v0.1" rel="nofollow">Mistral-7B-v0.1</a>.</p> <iframe src="https://hf-accelerate-model-memory-usage.hf.space" frameborder="0" width="850" height="450"></iframe> <p data-svelte-h="svelte-g7thb2">To load a model in half-precision, set the <a href="https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.dtype" rel="nofollow">dtype</a> parameter in <a href="/docs/transformers/pr_33892/en/model_doc/auto#transformers.AutoModel.from_pretrained">from_pretrained()</a> to <code>torch.bfloat16</code>. This requires 13.74GB of memory.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> torch
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;mistralai/Mistral-7B-v0.1&quot;</span>, dtype=torch.bfloat16, device_map=<span class="hljs-string">&quot;auto&quot;</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1f44nt3">To load a quantized model (8-bit or 4-bit), try <a href="https://hf.co/docs/bitsandbytes" rel="nofollow">bitsandbytes</a> and set the <a href="https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.BitsAndBytesConfig.load_in_4bit" rel="nofollow">load_in_4bit</a> or <a href="https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.BitsAndBytesConfig.load_in_8bit" rel="nofollow">load_in_8bit</a> parameters to <code>True</code>. Loading the model in 8-bits only requires 6.87 GB of memory.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
<span class="hljs-keyword">import</span> torch
quant_config = BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>)
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;mistralai/Mistral-7B-v0.1&quot;</span>, quantization_config=quant_config, device_map=<span class="hljs-string">&quot;auto&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/llm_optims.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_16tnnm8 = {
assets: "/docs/transformers/pr_33892/en",
base: "/docs/transformers/pr_33892/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"),
import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 50],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
46.8 kB
·
Xet hash:
55641508428251785bfdcccd81ab565fc092da1261cd5801c5ca019c82541cea

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.