Buckets:

rtrm's picture
download
raw
44.2 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;LLM inference optimization&quot;,&quot;local&quot;:&quot;llm-inference-optimization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Static kv-cache and torch.compile&quot;,&quot;local&quot;:&quot;static-kv-cache-and-torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Speculative decoding&quot;,&quot;local&quot;:&quot;speculative-decoding&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Prompt lookup decoding&quot;,&quot;local&quot;:&quot;prompt-lookup-decoding&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Attention optimizations&quot;,&quot;local&quot;:&quot;attention-optimizations&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;FlashAttention-2&quot;,&quot;local&quot;:&quot;flashattention-2&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;PyTorch scaled dot product attention&quot;,&quot;local&quot;:&quot;pytorch-scaled-dot-product-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Quantization&quot;,&quot;local&quot;:&quot;quantization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/singletons.0f2b7d5f.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/paths.3d04d2c6.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/0.026d2fdd.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/38.02e8bac1.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/Tip.baa67368.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/EditOnGithub.91d95064.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/HfOption.1e589c90.js">
<link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/stores.c3f24f16.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;LLM inference optimization&quot;,&quot;local&quot;:&quot;llm-inference-optimization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Static kv-cache and torch.compile&quot;,&quot;local&quot;:&quot;static-kv-cache-and-torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Speculative decoding&quot;,&quot;local&quot;:&quot;speculative-decoding&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Prompt lookup decoding&quot;,&quot;local&quot;:&quot;prompt-lookup-decoding&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Attention optimizations&quot;,&quot;local&quot;:&quot;attention-optimizations&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;FlashAttention-2&quot;,&quot;local&quot;:&quot;flashattention-2&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;PyTorch scaled dot product attention&quot;,&quot;local&quot;:&quot;pytorch-scaled-dot-product-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Quantization&quot;,&quot;local&quot;:&quot;quantization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="llm-inference-optimization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#llm-inference-optimization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>LLM inference optimization</span></h1> <p data-svelte-h="svelte-1kjw8nj">Large language models (LLMs) have pushed text generation applications, such as chat and code completion models, to the next level by producing text that displays a high level of understanding and fluency. But what makes LLMs so powerful - namely their size - also presents challenges for inference.</p> <p data-svelte-h="svelte-1l3iyos">Basic inference is slow because LLMs have to be called repeatedly to generate the next token. The input sequence increases as generation progresses, which takes longer and longer for the LLM to process. LLMs also have billions of parameters, making it a challenge to store and handle all those weights in memory.</p> <p data-svelte-h="svelte-1wb7yan">This guide will show you how to use the optimization techniques available in Transformers to accelerate LLM inference.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-utq275">Hugging Face also provides <a href="https://hf.co/docs/text-generation-inference" rel="nofollow">Text Generation Inference (TGI)</a>, a library dedicated to deploying and serving highly optimized LLMs for inference. It includes deployment-oriented optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference.</p></div> <h2 class="relative group"><a id="static-kv-cache-and-torchcompile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#static-kv-cache-and-torchcompile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Static kv-cache and torch.compile</span></h2> <p data-svelte-h="svelte-bq37ce">During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you’re recomputing the same kv values each time.</p> <p data-svelte-h="svelte-w4nvh4">To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of <a href="./perf_torch_compile"><code>torch.compile</code></a>, a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. We have an entire guide dedicated to kv-caches <a href="./kv_cache">here</a>.</p> <p data-svelte-h="svelte-jt9td8">The <em>static kv-cache</em> solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with <code>torch.compile</code> for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1cpz1yf">Currently, only <a href="./model_doc/llama2">Llama</a> and a few other models support static kv-cache and <code>torch.compile</code>. Check <a href="https://github.com/huggingface/transformers/issues/28981" rel="nofollow">this issue</a> for a live model compatibility list.</p></div> <p data-svelte-h="svelte-1439wi2">There are three flavors of static kv-cache usage, depending on the complexity of your task:</p> <ol data-svelte-h="svelte-165n3j5"><li>Basic usage: simply set a flag in <code>generation_config</code> (recommended);</li> <li>Advanced usage: handle a cache object for multi-turn generation or a custom generation loop;</li> <li>Advanced usage: compile the entire <code>generate</code> function into a single graph, if having a single graph is relevant for you.</li></ol> <p data-svelte-h="svelte-etip6m">Select the correct tab below for further instructions on each of these flavors.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1m3bnto">Regardless of the strategy used with <code>torch.compile</code>, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The <a href="https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of" rel="nofollow"><code>pad_to_multiple_of</code> tokenizer flag</a> is your friend!</p></div> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">basic usage: generation_config </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">advanced usage: control Static Cache </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">advanced usage: end-to-end generate compilation </div></div> <div class="language-select"><p data-svelte-h="svelte-oys3sj">For this example, let’s use the <a href="https://hf.co/google/gemma-2b" rel="nofollow">Gemma</a> model. All we need to do is to:</p> <ol data-svelte-h="svelte-1fw1ivc"><li>Access the model’s <code>generation_config</code> attribute and set the <code>cache_implementation</code> to “static”;</li> <li>Call <code>torch.compile</code> on the model to compile the forward pass with the static kv-cache.</li></ol> <p data-svelte-h="svelte-9wooxy">And that’s it!</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> os
os.environ[<span class="hljs-string">&quot;TOKENIZERS_PARALLELISM&quot;</span>] = <span class="hljs-string">&quot;false&quot;</span> <span class="hljs-comment"># To prevent long warnings :)</span>
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>, device_map=<span class="hljs-string">&quot;auto&quot;</span>)
model.generation_config.cache_implementation = <span class="hljs-string">&quot;static&quot;</span>
model.forward = torch.<span class="hljs-built_in">compile</span>(model.forward, mode=<span class="hljs-string">&quot;reduce-overhead&quot;</span>, fullgraph=<span class="hljs-literal">True</span>)
input_text = <span class="hljs-string">&quot;The theory of special relativity states &quot;</span>
input_ids = tokenizer(input_text, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
outputs = model.generate(**input_ids)
<span class="hljs-built_in">print</span>(tokenizer.batch_decode(outputs, skip_special_tokens=<span class="hljs-literal">True</span>))
[<span class="hljs-string">&#x27;The theory of special relativity states 1. The speed of light is constant in all inertial reference&#x27;</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yth3sl">Under the hood, <code>generate</code> will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of <code>torch.compile</code>, and you should be aware of the following:</p> <ol data-svelte-h="svelte-u5kopl"><li>If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation;</li> <li>The first couple of calls of the compiled function are slower, as the function is being compiled.</li></ol> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1u2oe25">For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside <a href="/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>. See the advanced usage tab.</p></div> </div> <h2 class="relative group"><a id="speculative-decoding" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#speculative-decoding"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Speculative decoding</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-112w74b">For a more in-depth explanation, take a look at the <a href="https://hf.co/blog/assisted-generation" rel="nofollow">Assisted Generation: a new direction toward low-latency text generation</a> blog post!</p></div> <p data-svelte-h="svelte-18z9bzk">Another issue with autoregression is that for each input token you need to load the model weights each time during the forward pass. This is slow and cumbersome for LLMs which have billions of parameters. Speculative decoding alleviates this slowdown by using a second smaller and faster assistant model to generate candidate tokens that are verified by the larger LLM in a single forward pass. If the verified tokens are correct, the LLM essentially gets them for “free” without having to generate them itself. There is no degradation in accuracy because the verification forward pass ensures the same outputs are generated as if the LLM had generated them on its own.</p> <p data-svelte-h="svelte-1nm1j4o">To get the largest speed up, the assistant model should be a lot smaller than the LLM so that it can generate tokens quickly. The assistant and LLM model must also share the same tokenizer to avoid re-encoding and decoding tokens.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-eq929k">Speculative decoding is only supported for the greedy search and sampling decoding strategies, and it also doesn’t support batched inputs.</p></div> <p data-svelte-h="svelte-4l211j">Enable speculative decoding by loading an assistant model and passing it to the <a href="/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> method.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">greedy search </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">sampling </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer
<span class="hljs-keyword">import</span> torch
device = <span class="hljs-string">&quot;cuda&quot;</span> <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> <span class="hljs-string">&quot;cpu&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>)
inputs = tokenizer(<span class="hljs-string">&quot;Einstein&#x27;s theory of relativity states&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(device)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-125m&quot;</span>).to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=<span class="hljs-literal">True</span>)
[<span class="hljs-string">&quot;Einstein&#x27;s theory of relativity states that the speed of light is constant. &quot;</span>]<!-- HTML_TAG_END --></pre></div> </div> <h3 class="relative group"><a id="prompt-lookup-decoding" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#prompt-lookup-decoding"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Prompt lookup decoding</span></h3> <p data-svelte-h="svelte-123yl06">Prompt lookup decoding is a variant of speculative decoding that is also compatible with greedy search and sampling. Prompt lookup works especially well for input-grounded tasks - such as summarization - where there is often overlapping words between the prompt and output. These overlapping n-grams are used as the LLM candidate tokens.</p> <p data-svelte-h="svelte-1xnaq70">To enable prompt lookup decoding, specify the number of tokens that should be overlapping in the <code>prompt_lookup_num_tokens</code> parameter. Then you can pass this parameter to the <a href="/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> method.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">greedy decoding </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">sampling </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer
<span class="hljs-keyword">import</span> torch
device = <span class="hljs-string">&quot;cuda&quot;</span> <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> <span class="hljs-string">&quot;cpu&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>)
inputs = tokenizer(<span class="hljs-string">&quot;The second law of thermodynamics states&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(device)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-1.3b&quot;</span>).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;facebook/opt-125m&quot;</span>).to(device)
outputs = model.generate(**inputs, prompt_lookup_num_tokens=<span class="hljs-number">3</span>)
<span class="hljs-built_in">print</span>(tokenizer.batch_decode(outputs, skip_special_tokens=<span class="hljs-literal">True</span>))
[<span class="hljs-string">&#x27;The second law of thermodynamics states that entropy increases with temperature. &#x27;</span>]<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="attention-optimizations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention-optimizations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention optimizations</span></h2> <p data-svelte-h="svelte-zxt0p7">A known issue with transformer models is that the self-attention mechanism grows quadratically in compute and memory with the number of input tokens. This limitation is only magnified in LLMs which handles much longer sequences. To address this, try FlashAttention2 or PyTorch’s scaled dot product attention (SDPA), which are more memory efficient attention implementations and can accelerate inference.</p> <h3 class="relative group"><a id="flashattention-2" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#flashattention-2"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>FlashAttention-2</span></h3> <p data-svelte-h="svelte-pa091d">FlashAttention and <a href="./perf_infer_gpu_one#flashattention-2">FlashAttention-2</a> break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead.</p> <p data-svelte-h="svelte-1ym14uu">To use FlashAttention-2, set <code>attn_implementation=&quot;flash_attention_2&quot;</code> in the <a href="/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> method.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, BitsAndBytesConfig
quant_config = BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>)
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;google/gemma-2b&quot;</span>,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
attn_implementation=<span class="hljs-string">&quot;flash_attention_2&quot;</span>,
)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="pytorch-scaled-dot-product-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch-scaled-dot-product-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch scaled dot product attention</span></h3> <p data-svelte-h="svelte-16n93px">Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch’s C++ implementation. SDPA chooses the most performant attention algorithm if you’re using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-12yxz1f">SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed.</p></div> <p data-svelte-h="svelte-cbkj1o">Use the <a href="https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">torch.backends.cuda.sdp_kernel</a> context manager to explicitly enable or disable any of the three attention algorithms. For example, set <code>enable_flash=True</code> to enable FlashAttention.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;google/gemma-2b&quot;</span>,
torch_dtype=torch.bfloat16,
)
<span class="hljs-keyword">with</span> torch.backends.cuda.sdp_kernel(enable_flash=<span class="hljs-literal">True</span>, enable_math=<span class="hljs-literal">False</span>, enable_mem_efficient=<span class="hljs-literal">False</span>):
outputs = model.generate(**inputs)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="quantization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quantization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quantization</span></h2> <p data-svelte-h="svelte-x4c749">Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you’re constrained by your GPUs memory. If you aren’t limited by your GPU, you don’t necessarily need to quantize your model because it can incur a small latency cost (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1b7gbou">There are many quantization libraries (see the <a href="./quantization">Quantization</a> guide for more details) available, such as Quanto, AQLM, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the <a href="https://hf.co/blog/overview-quantization-transformers" rel="nofollow">Overview of natively supported quantization schemes in 🤗 Transformers</a> blog post which compares AutoGPTQ and bitsandbytes.</p></div> <p data-svelte-h="svelte-1i2aaqz">Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating how much memory it costs to load <a href="https://huggingface.co/mistralai/Mistral-7B-v0.1" rel="nofollow">Mistral-7B-v0.1</a>.</p> <iframe src="https://hf-accelerate-model-memory-usage.hf.space" frameborder="0" width="850" height="450"></iframe> <p data-svelte-h="svelte-l6iyvl">To load Mistral-7B-v0.1 in half-precision, set the <code>torch_dtype</code> parameter in the <a href="/docs/transformers/main/en/model_doc/auto#transformers.AutoModel.from_pretrained">from_pretrained()</a> method to <code>torch.bfloat16</code>. This requires 13.74GB of memory.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> torch
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;mistralai/Mistral-7B-v0.1&quot;</span>, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">&quot;auto&quot;</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-a32zwr">To load a quantized model (8-bit or 4-bit) for inference, try <a href="https://hf.co/docs/bitsandbytes" rel="nofollow">bitsandbytes</a> and set the <code>load_in_4bit</code> or <code>load_in_8bit</code> parameters to <code>True</code>. Loading the model in 8-bits only requires 6.87 GB of memory.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
<span class="hljs-keyword">import</span> torch
quant_config = BitsAndBytesConfig(load_in_8bit=<span class="hljs-literal">True</span>)
model = AutoModelForCausalLM.from_pretrained(
<span class="hljs-string">&quot;mistralai/Mistral-7B-v0.1&quot;</span>, quantization_config=quant_config, device_map=<span class="hljs-string">&quot;auto&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/llm_optims.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1xexzbk = {
assets: "/docs/transformers/main/en",
base: "/docs/transformers/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/main/en/_app/immutable/entry/start.2135b7e6.js"),
import("/docs/transformers/main/en/_app/immutable/entry/app.24372c84.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 38],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
44.2 kB
·
Xet hash:
a98e48712284499a691e6d2f31946684cbe1083a9dc56b5d1b7f6500e960c598

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.