Buckets:

rtrm's picture
download
raw
86.7 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Best Practices for Generation with Cache&quot;,&quot;local&quot;:&quot;best-practices-for-generation-with-cache&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;What is Cache and why we should care?&quot;,&quot;local&quot;:&quot;what-is-cache-and-why-we-should-care&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Under the Hood: How Cache Object Works in Attention Mechanism&quot;,&quot;local&quot;:&quot;under-the-hood-how-cache-object-works-in-attention-mechanism&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Generate with Cache&quot;,&quot;local&quot;:&quot;generate-with-cache&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Quantized Cache&quot;,&quot;local&quot;:&quot;quantized-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Offloaded Cache&quot;,&quot;local&quot;:&quot;offloaded-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Static Cache&quot;,&quot;local&quot;:&quot;static-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Offloaded Static Cache&quot;,&quot;local&quot;:&quot;offloaded-static-cache&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Sliding Window Cache&quot;,&quot;local&quot;:&quot;sliding-window-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Sink Cache&quot;,&quot;local&quot;:&quot;sink-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Encoder-Decoder Cache&quot;,&quot;local&quot;:&quot;encoder-decoder-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Model-specific Cache Classes&quot;,&quot;local&quot;:&quot;model-specific-cache-classes&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Iterative Generation with Cache&quot;,&quot;local&quot;:&quot;iterative-generation-with-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Re-use Cache to continue generation&quot;,&quot;local&quot;:&quot;re-use-cache-to-continue-generation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Legacy cache format&quot;,&quot;local&quot;:&quot;legacy-cache-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/38.456a4e55.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/Tip.baa67368.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Best Practices for Generation with Cache&quot;,&quot;local&quot;:&quot;best-practices-for-generation-with-cache&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;What is Cache and why we should care?&quot;,&quot;local&quot;:&quot;what-is-cache-and-why-we-should-care&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Under the Hood: How Cache Object Works in Attention Mechanism&quot;,&quot;local&quot;:&quot;under-the-hood-how-cache-object-works-in-attention-mechanism&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Generate with Cache&quot;,&quot;local&quot;:&quot;generate-with-cache&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Quantized Cache&quot;,&quot;local&quot;:&quot;quantized-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Offloaded Cache&quot;,&quot;local&quot;:&quot;offloaded-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Static Cache&quot;,&quot;local&quot;:&quot;static-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Offloaded Static Cache&quot;,&quot;local&quot;:&quot;offloaded-static-cache&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Sliding Window Cache&quot;,&quot;local&quot;:&quot;sliding-window-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Sink Cache&quot;,&quot;local&quot;:&quot;sink-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Encoder-Decoder Cache&quot;,&quot;local&quot;:&quot;encoder-decoder-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Model-specific Cache Classes&quot;,&quot;local&quot;:&quot;model-specific-cache-classes&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Iterative Generation with Cache&quot;,&quot;local&quot;:&quot;iterative-generation-with-cache&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Re-use Cache to continue generation&quot;,&quot;local&quot;:&quot;re-use-cache-to-continue-generation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Legacy cache format&quot;,&quot;local&quot;:&quot;legacy-cache-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="best-practices-for-generation-with-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#best-practices-for-generation-with-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Best Practices for Generation with Cache</span></h1> <p data-svelte-h="svelte-ivdsz6">Efficient caching is crucial for optimizing the performance of models in various generative tasks,
including text generation, translation, summarization and other transformer-based applications.
Effective caching helps reduce computation time and improve response rates, especially in real-time or resource-intensive applications.</p> <p data-svelte-h="svelte-1ru3n2s">Transformers support various caching methods, leveraging “Cache” classes to abstract and manage the caching logic.
This document outlines best practices for using these classes to maximize performance and efficiency.
Check out all the available <code>Cache</code> classes in the <a href="./internal/generation_utils">API documentation</a>.</p> <h2 class="relative group"><a id="what-is-cache-and-why-we-should-care" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#what-is-cache-and-why-we-should-care"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>What is Cache and why we should care?</span></h2> <p data-svelte-h="svelte-pgg2dh">Imagine you’re having a conversation with someone, and instead of remembering what was said previously, you have to start from scratch every time you respond. This would be slow and inefficient, right? In the world of Transformer models, a similar concept applies, and that’s where Caching keys and values come into play. From now on, I’ll refer to the concept as KV Cache.</p> <p data-svelte-h="svelte-1pcidqj">KV cache is needed to optimize the generation in autoregressive models, where the model predicts text token by token. This process can be slow since the model can generate only one token at a time, and each new prediction is dependent on the previous context. That means, to predict token number 1000 in the generation, you need information from the previous 999 tokens, which comes in the form of some matrix multiplications across the representations of those tokens. But to predict token number 1001, you also need the same information from the first 999 tokens, plus additional information from token number 1000. That is where key-value cache is used to optimize the sequential generation process by storing previous calculations to reuse in subsequent tokens, so they don’t need to be computed again.</p> <p data-svelte-h="svelte-19ueo1l">More concretely, key-value cache acts as a memory bank for these generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens. By storing this information, the model can avoid redundant computations and instead retrieve keys and values of previous tokens from the cache. Note that caching can be used only in inference and should be disabled when training, otherwise it might cause unexpected errors.</p> <details><summary data-svelte-h="svelte-8dl4q2"><em>For the Curious Minds Who Like to Dive Deep</em></summary> <h3 class="relative group"><a id="under-the-hood-how-cache-object-works-in-attention-mechanism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#under-the-hood-how-cache-object-works-in-attention-mechanism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Under the Hood: How Cache Object Works in Attention Mechanism</span></h3> <p data-svelte-h="svelte-1e25at2">When utilizing a cache object in the input, the Attention module performs several critical steps to integrate past and present information seamlessly.</p> <p data-svelte-h="svelte-1l5g8w7">The Attention module concatenates the current key-values with the past key-values stored in the cache. This results in attention weights of shape <code>(new_tokens_length, past_kv_length + new_tokens_length)</code>. Essentially, the past and current key-values are combined to compute attention scores, ensuring that the model considers both previous context and new input. The concatenated key-values are used to compute the attention scores resulting in attention weights of shape <code>(new_tokens_length, past_kv_length + new_tokens_length)</code>.</p> <p data-svelte-h="svelte-8mjlg7">Therefore, when iteratively calling <code>forward()</code> instead of the <code>generate()</code> method, it’s crucial to ensure that the attention mask shape matches the combined length of past and current key-values. The attention mask should have the shape <code>(batch_size, past_kv_length + new_tokens_length)</code>. This is usually handled internally when you call <code>generate()</code> method. If you want to implement your own generation loop with Cache classes, take this into consideration and prepare the attention mask to hold values to current and past tokens.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-fyl20k">One important concept you need to know when writing your own generation loop, is <code>cache_position</code>. In case you want to reuse an already filled Cache object by calling <code>forward()</code>, you have to pass in a valid <code>cache_position</code> which will indicate the positions of inputs in the sequence. Note that <code>cache_position</code> is not affected by padding, and always adds one more position for each token. For example, if key/value cache contains 10 tokens (no matter how many of it is a pad token), the cache position for the next token should be <code>torch.tensor([10])</code>.</p></div> <p data-svelte-h="svelte-181iwu3">See an example below for how to implement your own generation loop.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache
<span class="hljs-meta">&gt;&gt;&gt; </span>model_id = <span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(model_id)
<span class="hljs-meta">&gt;&gt;&gt; </span>past_key_values = DynamicCache()
<span class="hljs-meta">&gt;&gt;&gt; </span>messages = [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;Hello, what&#x27;s your name.&quot;</span>}]
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>, return_dict=<span class="hljs-literal">True</span>).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>generated_ids = inputs.input_ids
<span class="hljs-meta">&gt;&gt;&gt; </span>cache_position = torch.arange(inputs.input_ids.shape[<span class="hljs-number">1</span>], dtype=torch.int64, device=<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>max_new_tokens = <span class="hljs-number">10</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(max_new_tokens):
<span class="hljs-meta">... </span> outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=<span class="hljs-literal">True</span>)
<span class="hljs-meta">... </span> <span class="hljs-comment"># Greedily sample one next token</span>
<span class="hljs-meta">... </span> next_token_ids = outputs.logits[:, -<span class="hljs-number">1</span>:].argmax(-<span class="hljs-number">1</span>)
<span class="hljs-meta">... </span> generated_ids = torch.cat([generated_ids, next_token_ids], dim=-<span class="hljs-number">1</span>)
...
<span class="hljs-meta">... </span> <span class="hljs-comment"># Prepare inputs for the next generation step by leaaving unprocessed tokens, in our case we have only one new token</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># and expanding attn mask for the new token, as explained above</span>
<span class="hljs-meta">... </span> attention_mask = inputs[<span class="hljs-string">&quot;attention_mask&quot;</span>]
<span class="hljs-meta">... </span> attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[<span class="hljs-number">0</span>], <span class="hljs-number">1</span>))], dim=-<span class="hljs-number">1</span>)
<span class="hljs-meta">... </span> inputs = {<span class="hljs-string">&quot;input_ids&quot;</span>: next_token_ids, <span class="hljs-string">&quot;attention_mask&quot;</span>: attention_mask}
<span class="hljs-meta">... </span> cache_position = cache_position[-<span class="hljs-number">1</span>:] + <span class="hljs-number">1</span> <span class="hljs-comment"># add one more position for the next token</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(tokenizer.batch_decode(generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
<span class="hljs-string">&quot;[INST] Hello, what&#x27;s your name. [/INST] Hello! My name is LLaMA,&quot;</span><!-- HTML_TAG_END --></pre></div></details> <h2 class="relative group"><a id="generate-with-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#generate-with-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Generate with Cache</span></h2> <p data-svelte-h="svelte-9yvdkm">In 🤗 Transformers, we support various Cache types to optimize the performance across different models and tasks. By default, all models generate with caching,
with the <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.DynamicCache">~DynamicCache</a> class being the default cache for most models. It allows us to dynamically grow cache size, by saving more and more keys and values as we generate. If for some reason you don’t want to use caches, you can pass <code>use_cache=False</code> into the <code>generate()</code> method.</p> <p data-svelte-h="svelte-ez9zp4">Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case. Models for which initialization is recommended should be initialized before calling the model and passed to model as a kwarg. In all other cases you can simply define desired <code>cache_implementation</code> and we take care of the rest for you.</p> <table data-svelte-h="svelte-yhqrld"><thead><tr><th>Cache Type</th> <th>Memory Efficient</th> <th>Supports torch.compile()</th> <th>Initialization Recommended</th> <th>Latency</th> <th>Long Context Generation</th></tr></thead> <tbody><tr><td>Dynamic Cache</td> <td>No</td> <td>No</td> <td>No</td> <td>Mid</td> <td>No</td></tr> <tr><td>Static Cache</td> <td>No</td> <td>Yes</td> <td>Yes</td> <td>High</td> <td>No</td></tr> <tr><td>Offloaded Cache</td> <td>Yes</td> <td>No</td> <td>No</td> <td>Low</td> <td>Yes</td></tr> <tr><td>Offloaded Static Cache</td> <td>No</td> <td>Yes</td> <td>Yes</td> <td>High</td> <td>Yes</td></tr> <tr><td>Quantized Cache</td> <td>Yes</td> <td>No</td> <td>No</td> <td>Low</td> <td>Yes</td></tr> <tr><td>Sliding Window Cache</td> <td>No</td> <td>Yes</td> <td>Yes</td> <td>High</td> <td>No</td></tr> <tr><td>Sink Cache</td> <td>Yes</td> <td>No</td> <td>Yes</td> <td>Mid</td> <td>Yes</td></tr></tbody></table> <p data-svelte-h="svelte-b5pfm0">These cache classes can be set with a <code>cache_implementation</code> argument when generating. To learn about the available options for the cache_implementation flag, please refer to the <a href="./main_classes/text_generation#transformers.GenerationConfig">API Documentation</a>. Now, let’s explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support [“Model-Specific Cache”] classes for models such as Mamba or Jamba, keep reading for more details.</p> <h3 class="relative group"><a id="quantized-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quantized-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quantized Cache</span></h3> <p data-svelte-h="svelte-1yajrsk">The key and value cache can occupy a large portion of memory, becoming a <a href="https://huggingface.co/blog/llama31#inference-memory-requirements" rel="nofollow">bottleneck for long-context generation</a>, especially for Large Language Models.
Quantizing the cache when using <code>generate()</code> can significantly reduce memory requirements at the cost of speed.</p> <p data-svelte-h="svelte-nl9jpl">KV Cache quantization in <code>transformers</code> is largely inspired by the paper <a href="https://arxiv.org/abs/2402.02750" rel="nofollow">“KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache”</a> and currently supports <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.QuantoQuantizedCache">~QuantoQuantizedCache</a> and <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.HQQQuantizedCache">~HQQQuantizedCache</a> classes. For more information on the inner workings see the paper.</p> <p data-svelte-h="svelte-40r57g">To enable quantization of the key-value cache, one needs to indicate <code>cache_implementation=&quot;quantized&quot;</code> in the <code>generation_config</code>.
Quantization related arguments should be passed to the <code>generation_config</code> either as a <code>dict</code> or an instance of a <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.QuantizedCacheConfig">~QuantizedCacheConfig</a> class.
One has to indicate which quantization backend to use in the <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.QuantizedCacheConfig">~QuantizedCacheConfig</a>, the default is <code>quanto</code>.</p> <p data-svelte-h="svelte-1ke22x8">It is recommended to set <code>axis-key/axis-value</code> parameters in the cache config to <code>0</code> if you’re using the <code>quanto</code> backend and to <code>1</code> if you’re using the <code>HQQ</code> backend. For other config values, please use the defaults unless you’re running out of memory. In that case, you may consider decreasing the residual length.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-46etli">Cache quantization can be detrimental in terms of latency if the context length is short and there is enough GPU VRAM available to run without cache quantization. It is recommended to seek balance between memory efficiency and latency.</p></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;I like rock music because&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">&quot;quantized&quot;</span>, cache_config={<span class="hljs-string">&quot;nbits&quot;</span>: <span class="hljs-number">4</span>, <span class="hljs-string">&quot;backend&quot;</span>: <span class="hljs-string">&quot;quanto&quot;</span>})
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
I like rock music because it<span class="hljs-string">&#x27;s loud and energetic. It&#x27;</span>s a great way to express myself <span class="hljs-keyword">and</span> rel
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
I like rock music because it<span class="hljs-string">&#x27;s loud and energetic. I like to listen to it when I&#x27;</span>m feeling<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="offloaded-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#offloaded-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Offloaded Cache</span></h3> <p data-svelte-h="svelte-1qxqw8d">Similarly to KV cache quantization, <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.OffloadedCache">~OffloadedCache</a> strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
As the model’s <code>forward()</code> method iterates over the layers, this strategy maintains the current layer cache on the GPU.
At the same time it asynchronously prefetches the next layer cache as well as sending the previous layer cache back to the CPU.
Unlike KV cache quantization, this strategy always produces the same result as the default KV cache implementation.
Thus, it can serve as a drop-in replacement or a fallback for it.</p> <p data-svelte-h="svelte-1x4pr6r">Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
you may notice a small degradation in generation throughput compared to the default KV cache implementation.</p> <p data-svelte-h="svelte-6aujat">To enable KV cache offloading, pass <code>cache_implementation=&quot;offloaded&quot;</code> in the <code>generation_config</code> or directly to the <code>generate()</code> call.
Use <code>cache_implementation=&quot;offloaded_static&quot;</code> for an offloaded static cache (see also <a href="#offloaded-static-cache">Offloaded Static Cache</a> below).</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-meta">&gt;&gt;&gt; </span>ckpt = <span class="hljs-string">&quot;microsoft/Phi-3-mini-4k-instruct&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(ckpt)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;Fun fact: The shortest&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">23</span>, cache_implementation=<span class="hljs-string">&quot;offloaded&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
Fun fact: The shortest war <span class="hljs-keyword">in</span> history was between Britain <span class="hljs-keyword">and</span> Zanzibar on August <span class="hljs-number">27</span>, <span class="hljs-number">1896.</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">23</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
Fun fact: The shortest war <span class="hljs-keyword">in</span> history was between Britain <span class="hljs-keyword">and</span> Zanzibar on August <span class="hljs-number">27</span>, <span class="hljs-number">1896.</span><!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-16njdjl">Cache offloading requires a GPU and can be slower than dynamic KV cache. Use it if you are getting CUDA out of memory errors.</p></div> <p data-svelte-h="svelte-f3m4wg">The example below shows how KV cache offloading can be used as a fallback strategy.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">resilient_generate</span>(<span class="hljs-params">model, *args, **kwargs</span>):
<span class="hljs-meta">... </span> oom = <span class="hljs-literal">False</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">try</span>:
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> model.generate(*args, **kwargs)
<span class="hljs-meta">... </span> <span class="hljs-keyword">except</span> torch.cuda.OutOfMemoryError <span class="hljs-keyword">as</span> e:
<span class="hljs-meta">... </span> <span class="hljs-built_in">print</span>(e)
<span class="hljs-meta">... </span> <span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;retrying with cache_implementation=&#x27;offloaded&#x27;&quot;</span>)
<span class="hljs-meta">... </span> oom = <span class="hljs-literal">True</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> oom:
<span class="hljs-meta">... </span> torch.cuda.empty_cache()
<span class="hljs-meta">... </span> kwargs[<span class="hljs-string">&quot;cache_implementation&quot;</span>] = <span class="hljs-string">&quot;offloaded&quot;</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> model.generate(*args, **kwargs)
...
...
<span class="hljs-meta">&gt;&gt;&gt; </span>ckpt = <span class="hljs-string">&quot;microsoft/Phi-3-mini-4k-instruct&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(ckpt)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>prompt = [<span class="hljs-string">&quot;okay &quot;</span>*<span class="hljs-number">1000</span> + <span class="hljs-string">&quot;Fun fact: The most&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span>beams = { <span class="hljs-string">&quot;num_beams&quot;</span>: <span class="hljs-number">40</span>, <span class="hljs-string">&quot;num_beam_groups&quot;</span>: <span class="hljs-number">40</span>, <span class="hljs-string">&quot;num_return_sequences&quot;</span>: <span class="hljs-number">40</span>, <span class="hljs-string">&quot;diversity_penalty&quot;</span>: <span class="hljs-number">1.0</span>, <span class="hljs-string">&quot;max_new_tokens&quot;</span>: <span class="hljs-number">23</span>, <span class="hljs-string">&quot;early_stopping&quot;</span>: <span class="hljs-literal">True</span>, }
<span class="hljs-meta">&gt;&gt;&gt; </span>out = resilient_generate(model, **inputs, **beams)
<span class="hljs-meta">&gt;&gt;&gt; </span>responses = tokenizer.batch_decode(out[:,-<span class="hljs-number">28</span>:], skip_special_tokens=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-opu10i">On a GPU with 50 GB of RAM, running this code will print</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->CUDA <span class="hljs-keyword">out</span> <span class="hljs-keyword">of</span> memory. Tried <span class="hljs-keyword">to</span> allocate <span class="hljs-number">4.83</span> GiB. GPU
retrying <span class="hljs-keyword">with</span> cache_implementation=<span class="hljs-string">&#x27;offloaded&#x27;</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-83p2wm">before successfully generating 40 beams.</p> <h3 class="relative group"><a id="static-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#static-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Static Cache</span></h3> <p data-svelte-h="svelte-3lpkcr">Since the “DynamicCache” dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.StaticCache">~StaticCache</a> pre-allocates
a specific maximum size for the keys and values, allowing you to generate up to the maximum length without having to modify cache size. Check the below usage example.</p> <p data-svelte-h="svelte-hxs9qp">For more examples with Static Cache and JIT compilation, take a look at <a href="./llm_optims#static-kv-cache-and-torchcompile">StaticCache &amp; torchcompile</a></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>, torch_dtype=torch.float16, device_map=<span class="hljs-string">&quot;auto&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;Hello, my name is&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># simply pass the cache implementation=&quot;static&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">&quot;static&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]
<span class="hljs-string">&quot;Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of&quot;</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="offloaded-static-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#offloaded-static-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Offloaded Static Cache</span></h2> <p data-svelte-h="svelte-pwjxol">Like <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.OffloadedCache">~OffloadedCache</a> exists for offloading a “DynamicCache”, there is also an offloaded static cache. It fully supports
JIT optimizations. Just pass <code>cache_implementation=&quot;offloaded_static&quot;</code> in the <code>generation_config</code> or directly to the <code>generate()</code> call.
This will use the <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.OffloadedStaticCache">~OffloadedStaticCache</a> implementation instead.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>, torch_dtype=torch.float16, device_map=<span class="hljs-string">&quot;auto&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;Hello, my name is&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># simply pass the cache implementation=&quot;static&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">&quot;offloaded_static&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]
<span class="hljs-string">&quot;Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of&quot;</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="sliding-window-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sliding-window-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sliding Window Cache</span></h3> <p data-svelte-h="svelte-krlgtc">As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last <code>sliding_window</code> tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.</p> <p data-svelte-h="svelte-14gr42g">Note that you can use this cache only for models that support sliding window, e.g. Mistral models.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, SinkCache
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;mistralai/Mistral-7B-v0.1&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;mistralai/Mistral-7B-v0.1&quot;</span>, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;Yesterday I was on a rock concert and.&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># can be used by passing in cache implementation</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">30</span>, cache_implementation=<span class="hljs-string">&quot;sliding_window&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]
<span class="hljs-string">&quot;Yesterday I was on a rock concert and. I was so excited to see my favorite band. I was so excited that I was jumping up and down and screaming. I was so excited that I&quot;</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="sink-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sink-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sink Cache</span></h3> <p data-svelte-h="svelte-128onc7">Sink Cache was introduced in <a href="https://arxiv.org/abs/2309.17453" rel="nofollow">“Efficient Streaming Language Models with Attention Sinks”</a>. It allows you to generate long sequences of text (“infinite length” according to the paper) without any fine-tuning. That is achieved by smart handling of previous keys and values, specifically it retains a few initial tokens from the sequence, called “sink tokens”. This is based on the observation that these initial tokens attract a significant portion of attention scores during the generation process. Tokens that come after “sink tokens” are discarded on a sliding windowed basis, keeping only the latest <code>window_size</code> tokens. By keeping these initial tokens as “attention sinks,” the model maintains stable performance even when dealing with very long texts, thus discarding most of the previous knowledge.</p> <p data-svelte-h="svelte-1vy8lmv">Unlike other cache classes, this one can’t be used directly by indicating a <code>cache_implementation</code>. You have to initialize the Cache before calling on <code>generate()</code> as follows.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, SinkCache
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>, torch_dtype=torch.float16).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;This is a long story about unicorns, fairies and magic.&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># get our cache, specify number of sink tokens and window size</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># Note that window size already includes sink tokens, so has to be larger</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>past_key_values = SinkCache(window_length=<span class="hljs-number">256</span>, num_sink_tokens=<span class="hljs-number">4</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">30</span>, past_key_values=past_key_values)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]
<span class="hljs-string">&quot;This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily&quot;</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="encoder-decoder-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#encoder-decoder-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Encoder-Decoder Cache</span></h3> <p data-svelte-h="svelte-riglut">The <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.EncoderDecoderCache">~EncoderDecoderCache</a> is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in <a href="./model_doc/whisper">Whisper</a> models but we will be adding more models soon.</p> <p data-svelte-h="svelte-159fa3b">In terms of usage, there is nothing special to be done and calling <code>generate()</code> or <code>forward()</code> will handle everything for you.</p> <h3 class="relative group"><a id="model-specific-cache-classes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-specific-cache-classes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model-specific Cache Classes</span></h3> <p data-svelte-h="svelte-ks33vg">Some models require storing previous keys, values, or states in a specific way, and the above cache classes cannot be used. For such cases, we have several specialized cache classes that are designed for specific models. These models only accept their own dedicated cache classes and do not support using any other cache types. Some examples include <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.HybridCache">~HybridCache</a> for <a href="./model_doc/gemma2">Gemma2</a> series models or <a href="/docs/transformers/pr_33913/en/internal/generation_utils#transformers.MambaCache">~MambaCache</a> for <a href="./model_doc/mamba">Mamba</a> architecture models.</p> <h2 class="relative group"><a id="iterative-generation-with-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#iterative-generation-with-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Iterative Generation with Cache</span></h2> <p data-svelte-h="svelte-au9ho2">We have seen how to use each of the cache types when generating. What if you want to use cache in iterative generation setting, for example in applications like chatbots, where interactions involve multiple turns and continuous back-and-forth exchanges. Iterative generation with cache allows these systems to handle ongoing conversations effectively without reprocessing the entire context at each step. But there are some tips that you should know before you start implementing:</p> <p data-svelte-h="svelte-108v17k">The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in <a href="./chat_templating">chat_templating</a></p> <p data-svelte-h="svelte-1yn0icb">In case you are using Sink Cache, you have to crop your inputs to that maximum length because Sink Cache can generate text longer than its maximum window size, but it expects the first input to not exceed the maximum cache length.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer,AutoModelForCausalLM
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers.cache_utils <span class="hljs-keyword">import</span> (
<span class="hljs-meta">&gt;&gt;&gt; </span> DynamicCache,
<span class="hljs-meta">&gt;&gt;&gt; </span> SinkCache,
<span class="hljs-meta">&gt;&gt;&gt; </span> StaticCache,
<span class="hljs-meta">&gt;&gt;&gt; </span> SlidingWindowCache,
<span class="hljs-meta">&gt;&gt;&gt; </span> QuantoQuantizedCache,
<span class="hljs-meta">&gt;&gt;&gt; </span> QuantizedCacheConfig,
<span class="hljs-meta">&gt;&gt;&gt; </span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model_id = <span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">&#x27;auto&#x27;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(model_id)
<span class="hljs-meta">&gt;&gt;&gt; </span>user_prompts = [<span class="hljs-string">&quot;Hello, what&#x27;s your name?&quot;</span>, <span class="hljs-string">&quot;Btw, yesterday I was on a rock concert.&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>past_key_values = DynamicCache()
<span class="hljs-meta">&gt;&gt;&gt; </span>max_cache_length = past_key_values.get_max_length()
<span class="hljs-meta">&gt;&gt;&gt; </span>messages = []
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> prompt <span class="hljs-keyword">in</span> user_prompts:
<span class="hljs-meta">... </span> messages.append({<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: prompt})
<span class="hljs-meta">... </span> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>, return_dict=<span class="hljs-literal">True</span>).to(model.device)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> <span class="hljs-built_in">isinstance</span>(past_key_values, SinkCache):
<span class="hljs-meta">... </span> inputs = {k: v[:, -max_cache_length:] <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> inputs.items()}
...
<span class="hljs-meta">... </span> input_length = inputs[<span class="hljs-string">&quot;input_ids&quot;</span>].shape[<span class="hljs-number">1</span>]
...
<span class="hljs-meta">... </span> outputs = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">256</span>, past_key_values=past_key_values)
<span class="hljs-meta">... </span> completion = tokenizer.decode(outputs[<span class="hljs-number">0</span>, input_length: ], skip_special_tokens=<span class="hljs-literal">True</span>)
<span class="hljs-meta">... </span> messages.append({<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: completion})
<span class="hljs-built_in">print</span>(messages)
[{<span class="hljs-string">&#x27;role&#x27;</span>: <span class="hljs-string">&#x27;user&#x27;</span>, <span class="hljs-string">&#x27;content&#x27;</span>: <span class="hljs-string">&quot;Hello, what&#x27;s your name?&quot;</span>}, {<span class="hljs-string">&#x27;role&#x27;</span>: <span class="hljs-string">&#x27;assistant&#x27;</span>, <span class="hljs-string">&#x27;content&#x27;</span>: <span class="hljs-string">&quot; Hello! My name is LLaMA, I&#x27;m a large language model trained by a team of researcher at Meta AI. 😊&quot;</span>}, {<span class="hljs-string">&#x27;role&#x27;</span>: <span class="hljs-string">&#x27;user&#x27;</span>, <span class="hljs-string">&#x27;content&#x27;</span>: <span class="hljs-string">&#x27;Btw, yesterday I was on a rock concert.&#x27;</span>}, {<span class="hljs-string">&#x27;role&#x27;</span>: <span class="hljs-string">&#x27;assistant&#x27;</span>, <span class="hljs-string">&#x27;content&#x27;</span>: <span class="hljs-string">&#x27; Oh, cool! That sounds like a lot of fun! 🎉 Did you enjoy the concert? What was the band like? 🤔&#x27;</span>}]<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="re-use-cache-to-continue-generation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#re-use-cache-to-continue-generation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Re-use Cache to continue generation</span></h2> <p data-svelte-h="svelte-1v9i2j4">Sometimes you would want to first fill-in cache object with key/values for certain prefix prompt and re-use it several times to generate different sequences from it. In that case you can construct a <code>Cache</code> object that will hold the instruction prompt, and re-use it several times with different text sequences.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> copy
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
<span class="hljs-meta">&gt;&gt;&gt; </span>model_id = <span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(model_id)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># Init StaticCache with big enough max-length (1024 tokens for the below example)</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># You can also init a DynamicCache, if that suits you better</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>prompt_cache = StaticCache(config=model.config, max_batch_size=<span class="hljs-number">1</span>, max_cache_len=<span class="hljs-number">1024</span>, device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.bfloat16)
<span class="hljs-meta">&gt;&gt;&gt; </span>INITIAL_PROMPT = <span class="hljs-string">&quot;You are a helpful assistant. &quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># This is the common prompt cached, we need to run forward without grad to be abel to copy</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-meta">... </span> prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
<span class="hljs-meta">&gt;&gt;&gt; </span>prompts = [<span class="hljs-string">&quot;Help me to write a blogpost about travelling.&quot;</span>, <span class="hljs-string">&quot;What is the capital of France?&quot;</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>responses = []
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> prompt <span class="hljs-keyword">in</span> prompts:
<span class="hljs-meta">... </span> new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-meta">... </span> past_key_values = copy.deepcopy(prompt_cache)
<span class="hljs-meta">... </span> outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=<span class="hljs-number">20</span>)
<span class="hljs-meta">... </span> response = tokenizer.batch_decode(outputs)[<span class="hljs-number">0</span>]
<span class="hljs-meta">... </span> responses.append(response)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(responses)
[<span class="hljs-string">&#x27;&lt;s&gt; You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and&#x27;</span>, <span class="hljs-string">&#x27;&lt;s&gt; You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.&lt;/s&gt;&#x27;</span>]<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="legacy-cache-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#legacy-cache-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Legacy cache format</span></h2> <p data-svelte-h="svelte-wtnkn">Prior to the introduction of the <code>Cache</code> object, the cache of LLMs used to be a tuple of tuples of tensors. The legacy
format has a dynamic size, growing as we generate text — very similar to <code>DynamicCache</code>. If your project depend on
this legacy format, you can seamlessly convert it to a <code>DynamicCache</code> and back.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache
<span class="hljs-meta">&gt;&gt;&gt; </span>tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>, torch_dtype=torch.float16, device_map=<span class="hljs-string">&quot;auto&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>inputs = tokenizer(<span class="hljs-string">&quot;Hello, my name is&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># `return_dict_in_generate=True` is required to return the cache. `return_legacy_cache` forces the returned cache</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># to be of the legacy type</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>generation_outputs = model.generate(**inputs, return_dict_in_generate=<span class="hljs-literal">True</span>, return_legacy_cache=<span class="hljs-literal">True</span>, max_new_tokens=<span class="hljs-number">5</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># We can convert a legacy cache to a DynamicCache -- and the other way around. This is helpful if you have custom</span>
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-comment"># logic to manipulate a cache in a specific format.</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
<span class="hljs-meta">&gt;&gt;&gt; </span>legacy_format_cache = cache.to_legacy_cache()<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/kv_cache.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_z647wz = {
assets: "/docs/transformers/pr_33913/en",
base: "/docs/transformers/pr_33913/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"),
import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 38],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
86.7 kB
·
Xet hash:
a925bea0322b322d97376d6967a5e6c19bda8aeb584a903db863c7ee9cfa033b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.