Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"KV cache strategies","local":"kv-cache-strategies","sections":[{"title":"Default cache","local":"default-cache","sections":[],"depth":2},{"title":"Fixed-size cache","local":"fixed-size-cache","sections":[],"depth":2},{"title":"Cache offloading","local":"cache-offloading","sections":[],"depth":2},{"title":"Quantized cache","local":"quantized-cache","sections":[],"depth":2},{"title":"Encoder-decoder cache","local":"encoder-decoder-cache","sections":[],"depth":2},{"title":"Model-specific caches","local":"model-specific-caches","sections":[],"depth":2},{"title":"Iterative generation","local":"iterative-generation","sections":[],"depth":2},{"title":"Prefill a cache (prefix caching)","local":"prefill-a-cache-prefix-caching","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/49.c89ff723.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"KV cache strategies","local":"kv-cache-strategies","sections":[{"title":"Default cache","local":"default-cache","sections":[],"depth":2},{"title":"Fixed-size cache","local":"fixed-size-cache","sections":[],"depth":2},{"title":"Cache offloading","local":"cache-offloading","sections":[],"depth":2},{"title":"Quantized cache","local":"quantized-cache","sections":[],"depth":2},{"title":"Encoder-decoder cache","local":"encoder-decoder-cache","sections":[],"depth":2},{"title":"Model-specific caches","local":"model-specific-caches","sections":[],"depth":2},{"title":"Iterative generation","local":"iterative-generation","sections":[],"depth":2},{"title":"Prefill a cache (prefix caching)","local":"prefill-a-cache-prefix-caching","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="kv-cache-strategies" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#kv-cache-strategies"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>KV cache strategies</span></h1> <p data-svelte-h="svelte-1hxt7ht">The key-value (KV) vectors are used to calculate attention scores. For autoregressive models, KV scores are calculated <em>every</em> time because the model predicts one token at a time. Each prediction depends on the previous tokens, which means the model performs the same computations each time.</p> <p data-svelte-h="svelte-og33ji">A KV <em>cache</em> stores these calculations so they can be reused without recomputing them. Efficient caching is crucial for optimizing model performance because it reduces computation time and improves response rates. Refer to the <a href="./cache_explanation">Caching</a> doc for a more detailed explanation about how a cache works.</p> <p data-svelte-h="svelte-1uwd5hm">Transformers offers several <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> classes that implement different caching mechanisms. Some of these <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> classes are optimized to save memory while others are designed to maximize generation speed. Refer to the table below to compare cache types and use it to help you select the best cache for your use case.</p> <table data-svelte-h="svelte-dxv7tp"><thead><tr><th>Cache Type</th> <th>Supports sliding layers</th> <th>Supports offloading</th> <th>Supports torch.compile()</th> <th>Expected memory usage</th></tr></thead> <tbody><tr><td>Dynamic Cache</td> <td>Yes</td> <td>Yes</td> <td>No</td> <td>Medium</td></tr> <tr><td>Static Cache</td> <td>Yes</td> <td>Yes</td> <td>Yes</td> <td>High</td></tr> <tr><td>Quantized Cache</td> <td>No</td> <td>No </td> <td>No</td> <td>Low</td></tr></tbody></table> <p data-svelte-h="svelte-abjss">This guide introduces you to the different <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> classes and shows you how to use them for generation.</p> <h2 class="relative group"><a id="default-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#default-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Default cache</span></h2> <p data-svelte-h="svelte-19997on">The <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> is the default cache class for all models. It allows the cache size to grow dynamically in order to store an increasing number of keys and values as generation progresses.</p> <p data-svelte-h="svelte-1wo1duj">Note that for models using sliding window attention (Mistral, Gemma2,…) or chunked attention (Llama4), the cache will stop growing when the layers using these types of attention have reached their maximum size (the sliding window or chunk size).</p> <p data-svelte-h="svelte-nkyv69">Disable the cache by configuring <code>use_cache=False</code> in <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, use_cache=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ryljih">Cache classes can also be initialized first before calling and passing it to the models <a href="https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values" rel="nofollow">past_key_values</a> parameter. This can be useful for more fine-grained control, or more advanced usage such as context caching.</p> <p data-svelte-h="svelte-1potdr8">In most cases, it’s easier to define the cache strategy in the <a href="https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation" rel="nofollow">cache_implementation</a> parameter.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| past_key_values = DynamicCache(config=model.config) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, past_key_values=past_key_values)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="fixed-size-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fixed-size-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fixed-size cache</span></h2> <p data-svelte-h="svelte-17n0w1w">The default <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> prevents you from taking advantage of most just-in-time (JIT) optimizations because the cache size isn’t fixed. JIT optimizations enable you to maximize latency at the expense of memory usage. All of the following cache types are compatible with JIT optimizations like <a href="./llm_optims#static-kv-cache-and-torchcompile">torch.compile</a> to accelerate generation.</p> <p data-svelte-h="svelte-j87w6s">A fixed-size cache (<a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a>) pre-allocates a specific maximum cache size for the kv pairs. You can generate up to the maximum cache size without needing to modify it. However, having a fixed (usually large) size for the key/value states means that while generating, a lot of tokens will actually be masked as they should not take part in the attention. So this trick allows to easily <code>compile</code> the decoding stage, but it incurs a waste of tokens in the attention computation. As all things, it’s then a trade-off which should be very good if you generate with several sequence of more or less the same lengths, but may be sub-optimal if you have for example 1 very large sequence, and then only short sequences (as the fix cache size would be large, a lot would be wasted for the short sequences). Make sure you understand the impact if you use it!</p> <p data-svelte-h="svelte-m0x91u">As for <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a>, note that for models using sliding window attention (Mistral, Gemma2,…) or chunked attention (Llama4), the cache will never be larger than the sliding window/chunk size on layers using these types of attention, even if the maximum length specified is larger.</p> <p data-svelte-h="svelte-15m2q0y">You can enable <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a> by configuring <code>cache_implementation="static"</code> in <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>. This will also turn on automatic <code>compilation</code> of the decoding stage for greedy and sample decoding strategies.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"Hello, my name is"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>] | |
| <span class="hljs-string">"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="cache-offloading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cache-offloading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Cache offloading</span></h2> <p data-svelte-h="svelte-15jbxrx">The KV cache can occupy a significant portion of memory and become a <a href="https://hf.co/blog/llama31#inference-memory-requirements" rel="nofollow">bottleneck</a> for long-context generation. Memory efficient caches focus on trading off speed for reduced memory usage. This is especially important for large language models (LLMs) and if your hardware is memory constrained.</p> <p data-svelte-h="svelte-7m6bkf">Offloading the cache saves GPU memory by moving the KV cache for model layers except one to the CPU. Only the current layer cache is maintained on the GPU during a models <code>forward</code> iteration over the layers. It will asynchronously prefetch the next layer’s cache, and send back the current layer’s cache back to the CPU after attention computation.</p> <p data-svelte-h="svelte-m6bxw0">You may want to consider offloading if you have a small GPU and you’re getting out-of-memory (OOM) errors.</p> <blockquote class="warning" data-svelte-h="svelte-1fbhhd3"><p>You may notice a small degradation in generation throughput compared to a full on-device cache, depending on your model and generation choices (context size, number of generated tokens, number of beams, etc.). This is because moving the key/value states back and forth requires some work.</p></blockquote> <p data-svelte-h="svelte-bhzvc1">Offloading is available for both <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> and <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a>. You can enable it by configuring <code>cache_implementation="offloaded"</code> for the dynamic version, or <code>cache_implementation="offloaded_static"</code> for the static version, in either <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationConfig">GenerationConfig</a> or <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>. | |
| Additionally, you can also instantiate your own <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> or <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a> with the <code>offloading=True</code> option, and pass this cache in <code>generate</code> or your model’s <code>forward</code> (for example, <code>past_key_values=DynamicCache(config=model.config, offloading=True)</code> for a dynamic cache).</p> <p data-svelte-h="svelte-1fef5ix">Note that the 2 <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> classes mentioned above have an additional option when instantiating them directly, <code>offload_only_non_sliding</code>. | |
| This additional argument decides if the layers using sliding window/chunk attention (if any), will be offloaded as well. Since | |
| these layers are usually short anyway, it may be better to avoid offloading them, as offloading may incur a speed penalty. By default, this option is <code>False</code> for <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a>, and <code>True</code> for <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| ckpt = <span class="hljs-string">"microsoft/Phi-3-mini-4k-instruct"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
| model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"Fun fact: The shortest"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">23</span>, cache_implementation=<span class="hljs-string">"offloaded"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]) | |
| Fun fact: The shortest war <span class="hljs-keyword">in</span> history was between Britain <span class="hljs-keyword">and</span> Zanzibar on August <span class="hljs-number">27</span>, <span class="hljs-number">1896.</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-pkg9iq">The example below shows how you can fallback to an offloaded cache if you run out of memory:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| <span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">resilient_generate</span>(<span class="hljs-params">model, *args, **kwargs</span>): | |
| oom = <span class="hljs-literal">False</span> | |
| device = Accelerator().device | |
| torch_device_module = <span class="hljs-built_in">getattr</span>(torch, device, torch.cuda) | |
| <span class="hljs-keyword">try</span>: | |
| <span class="hljs-keyword">return</span> model.generate(*args, **kwargs) | |
| <span class="hljs-keyword">except</span> torch.OutOfMemoryError <span class="hljs-keyword">as</span> e: | |
| <span class="hljs-built_in">print</span>(e) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"retrying with cache_implementation='offloaded'"</span>) | |
| oom = <span class="hljs-literal">True</span> | |
| <span class="hljs-keyword">if</span> oom: | |
| torch_device_module.empty_cache() | |
| kwargs[<span class="hljs-string">"cache_implementation"</span>] = <span class="hljs-string">"offloaded"</span> | |
| <span class="hljs-keyword">return</span> model.generate(*args, **kwargs) | |
| ckpt = <span class="hljs-string">"microsoft/Phi-3-mini-4k-instruct"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
| model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| prompt = [<span class="hljs-string">"okay "</span>*<span class="hljs-number">1000</span> + <span class="hljs-string">"Fun fact: The most"</span>] | |
| inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| beams = { <span class="hljs-string">"num_beams"</span>: <span class="hljs-number">40</span>, <span class="hljs-string">"num_return_sequences"</span>: <span class="hljs-number">20</span>, <span class="hljs-string">"max_new_tokens"</span>: <span class="hljs-number">23</span>, <span class="hljs-string">"early_stopping"</span>: <span class="hljs-literal">True</span>, } | |
| out = resilient_generate(model, **inputs, **beams) | |
| responses = tokenizer.batch_decode(out[:,-<span class="hljs-number">28</span>:], skip_special_tokens=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="quantized-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quantized-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quantized cache</span></h2> <p data-svelte-h="svelte-ac6m09">The <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.QuantizedCache">QuantizedCache</a> reduces memory requirements by quantizing the KV values to a lower precision. <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.QuantizedCache">QuantizedCache</a> currently supports two quantization backends:</p> <ul data-svelte-h="svelte-14p9huo"><li><code>hqq</code> supports int2, int4, and int8 datatypes.</li> <li><code>quanto</code> supports int2 and int4 datatypes. This is the default quantization backend.</li></ul> <blockquote class="warning" data-svelte-h="svelte-8cw27g"><p>Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency.</p></blockquote> <p data-svelte-h="svelte-1l0iias">Enable <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.QuantizedCache">QuantizedCache</a> by configuring <code>cache_implementation="quantized"</code> in <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationConfig">GenerationConfig</a>, and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you’re running out-of-memory. In that case, consider decreasing the residual length.</p> | |
| <hfoptions id="quantized-cache"> | |
| <p data-svelte-h="svelte-1ur1l6q">For the <code>hqq</code> backend, we recommend setting the <code>axis-key</code> and <code>axis-value</code> parameters to <code>1</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, QuantizedCache | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">"quantized"</span>, cache_config={<span class="hljs-string">"backend"</span>: <span class="hljs-string">"hqq"</span>}) | |
| <span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]) | |
| I like rock music because it<span class="hljs-string">'s loud and energetic. It'</span>s a great way to express myself <span class="hljs-keyword">and</span> rel<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-14czee2">For <code>quanto</code> backend, we recommend setting the <code>axis-key</code> and <code>axis-value</code> parameters to <code>0</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">"quantized"</span>, cache_config={<span class="hljs-string">"nbits"</span>: <span class="hljs-number">4</span>, <span class="hljs-string">"backend"</span>: <span class="hljs-string">"quanto"</span>}) | |
| <span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]) | |
| I like rock music because it<span class="hljs-string">'s loud and energetic. It'</span>s a great way to express myself <span class="hljs-keyword">and</span> rel<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="encoder-decoder-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#encoder-decoder-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Encoder-decoder cache</span></h2> <p data-svelte-h="svelte-1t6uudo"><a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.EncoderDecoderCache">EncoderDecoderCache</a> is designed for encoder-decoder models. It manages both the self-attention and cross-attention caches to ensure storage and retrieval of previous kv pairs. It is possible to individually set a different cache type for the encoder and decoder.</p> <p data-svelte-h="svelte-azkxhp">This cache type doesn’t require any setup. It is a simple wrapper around 2 <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a>s as described above, that will be used independently directly by the model.</p> <h2 class="relative group"><a id="model-specific-caches" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-specific-caches"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model-specific caches</span></h2> <p data-svelte-h="svelte-u1f4zv">Some models have a unique way of storing past kv pairs or states that is not compatible with any other cache classes.</p> <p data-svelte-h="svelte-w4fm68">Mamba models, such as <a href="./model_doc/mamba">Mamba</a>, require a specific cache because the model doesn’t have an attention mechanism or kv states. Thus, they are not compatible with the above <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> classes.</p> <h2 class="relative group"><a id="iterative-generation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#iterative-generation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Iterative generation</span></h2> <p data-svelte-h="svelte-67pkdu">A cache can also work in iterative generation settings where there is back-and-forth interaction with a model (chatbots). Like regular generation, iterative generation with a cache allows a model to efficiently handle ongoing conversations without recomputing the entire context at each step.</p> <p data-svelte-h="svelte-1ejbghf">For iterative generation with a cache, start by initializing an empty cache class and then you can feed in your new prompts. Keep track of dialogue history with a <a href="./chat_templating">chat template</a>.</p> <p data-svelte-h="svelte-1kmby8j">The following example demonstrates <a href="https://huggingface.co/meta-llama/Llama-2-7b-chat-hf" rel="nofollow">Llama-2-7b-chat-hf</a>. If you’re using a different chat-style model, <a href="/docs/transformers/pr_33892/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template">apply_chat_template()</a> may process messages differently. It might cut out important tokens depending on how the Jinja template is written.</p> <p data-svelte-h="svelte-fhggpa">For example, some models use special <code><think> ... </think></code> tokens during reasoning. These could get lost during re-encoding, causing indexing issues. You might need to manually remove or adjust extra tokens from the completions to keep things stable.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer,AutoModelForCausalLM, DynamicCache, StaticCache | |
| model_id = <span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span> | |
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=<span class="hljs-string">'auto'</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| user_prompts = [<span class="hljs-string">"Hello, what's your name?"</span>, <span class="hljs-string">"Btw, yesterday I was on a rock concert."</span>] | |
| past_key_values = DynamicCache(config=model.config) | |
| messages = [] | |
| <span class="hljs-keyword">for</span> prompt <span class="hljs-keyword">in</span> user_prompts: | |
| messages.append({<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: prompt}) | |
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"pt"</span>, return_dict=<span class="hljs-literal">True</span>).to(model.device) | |
| input_length = inputs[<span class="hljs-string">"input_ids"</span>].shape[<span class="hljs-number">1</span>] | |
| outputs = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">256</span>, past_key_values=past_key_values) | |
| completion = tokenizer.decode(outputs[<span class="hljs-number">0</span>, input_length: ], skip_special_tokens=<span class="hljs-literal">True</span>) | |
| messages.append({<span class="hljs-string">"role"</span>: <span class="hljs-string">"assistant"</span>, <span class="hljs-string">"content"</span>: completion})<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="prefill-a-cache-prefix-caching" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#prefill-a-cache-prefix-caching"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Prefill a cache (prefix caching)</span></h2> <p data-svelte-h="svelte-tg25qt">In some situations, you may want to fill a <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> with kv pairs for a certain prefix prompt and reuse it to generate different sequences.</p> <p data-svelte-h="svelte-1sp5dl9">The example below initializes a <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.StaticCache">StaticCache</a>, and then caches an initial prompt. Now you can generate several sequences from the prefilled prompt.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> copy | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache | |
| model_id = <span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span> | |
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map={<span class="hljs-string">""</span>: <span class="hljs-number">0</span>}) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-comment"># Init StaticCache with big enough max-length (1024 tokens for the below example)</span> | |
| <span class="hljs-comment"># You can also init a DynamicCache, if that suits you better</span> | |
| prompt_cache = StaticCache(config=model.config, max_cache_len=<span class="hljs-number">1024</span>) | |
| INITIAL_PROMPT = <span class="hljs-string">"You are a helpful assistant. "</span> | |
| inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device.<span class="hljs-built_in">type</span>) | |
| <span class="hljs-comment"># This is the common prompt cached, we need to run forward without grad to be able to copy</span> | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values | |
| prompts = [<span class="hljs-string">"Help me to write a blogpost about travelling."</span>, <span class="hljs-string">"What is the capital of France?"</span>] | |
| responses = [] | |
| <span class="hljs-keyword">for</span> prompt <span class="hljs-keyword">in</span> prompts: | |
| new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device.<span class="hljs-built_in">type</span>) | |
| past_key_values = copy.deepcopy(prompt_cache) | |
| outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=<span class="hljs-number">20</span>) | |
| response = tokenizer.batch_decode(outputs)[<span class="hljs-number">0</span>] | |
| responses.append(response) | |
| <span class="hljs-built_in">print</span>(responses)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/kv_cache.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_16tnnm8 = { | |
| assets: "/docs/transformers/pr_33892/en", | |
| base: "/docs/transformers/pr_33892/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"), | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 49], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 56.7 kB
- Xet hash:
- 28f768f415d883399ba57b0f697032685567cdbfb88c9ae7d0e46cab8569b057
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.