Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"KV cache strategies","local":"kv-cache-strategies","sections":[{"title":"Default cache","local":"default-cache","sections":[],"depth":2},{"title":"Memory efficient caches","local":"memory-efficient-caches","sections":[{"title":"Offloaded cache","local":"offloaded-cache","sections":[],"depth":3},{"title":"Quantized cache","local":"quantized-cache","sections":[],"depth":3},{"title":"Sink cache","local":"sink-cache","sections":[],"depth":3}],"depth":2},{"title":"Speed optimized caches","local":"speed-optimized-caches","sections":[{"title":"Static cache","local":"static-cache","sections":[],"depth":3},{"title":"Offloaded static cache","local":"offloaded-static-cache","sections":[],"depth":3},{"title":"Sliding window cache","local":"sliding-window-cache","sections":[],"depth":3}],"depth":2},{"title":"Model caches","local":"model-caches","sections":[{"title":"Encoder-decoder cache","local":"encoder-decoder-cache","sections":[],"depth":3},{"title":"Model-specific caches","local":"model-specific-caches","sections":[],"depth":3}],"depth":2},{"title":"Iterative generation","local":"iterative-generation","sections":[],"depth":2},{"title":"Prefill a cache","local":"prefill-a-cache","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_36839/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/scheduler.01eeda35.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/singletons.177df05e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.4862150a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/paths.517376d1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.6dd51b66.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/0.8897c14d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/42.86f6f5d0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/Tip.de9bae2b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/CodeBlock.864da1b0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/EditOnGithub.7faefd25.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/HfOption.f7f04550.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/stores.318eade7.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"KV cache strategies","local":"kv-cache-strategies","sections":[{"title":"Default cache","local":"default-cache","sections":[],"depth":2},{"title":"Memory efficient caches","local":"memory-efficient-caches","sections":[{"title":"Offloaded cache","local":"offloaded-cache","sections":[],"depth":3},{"title":"Quantized cache","local":"quantized-cache","sections":[],"depth":3},{"title":"Sink cache","local":"sink-cache","sections":[],"depth":3}],"depth":2},{"title":"Speed optimized caches","local":"speed-optimized-caches","sections":[{"title":"Static cache","local":"static-cache","sections":[],"depth":3},{"title":"Offloaded static cache","local":"offloaded-static-cache","sections":[],"depth":3},{"title":"Sliding window cache","local":"sliding-window-cache","sections":[],"depth":3}],"depth":2},{"title":"Model caches","local":"model-caches","sections":[{"title":"Encoder-decoder cache","local":"encoder-decoder-cache","sections":[],"depth":3},{"title":"Model-specific caches","local":"model-specific-caches","sections":[],"depth":3}],"depth":2},{"title":"Iterative generation","local":"iterative-generation","sections":[],"depth":2},{"title":"Prefill a cache","local":"prefill-a-cache","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="kv-cache-strategies" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#kv-cache-strategies"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>KV cache strategies</span></h1> <p data-svelte-h="svelte-1hxt7ht">The key-value (KV) vectors are used to calculate attention scores. For autoregressive models, KV scores are calculated <em>every</em> time because the model predicts one token at a time. Each prediction depends on the previous tokens, which means the model performs the same computations each time.</p> <p data-svelte-h="svelte-enffd7">A KV <em>cache</em> stores these calculations so they can be reused without recomputing them. Efficient caching is crucial for optimizing model performance because it reduces computation time and improves response rates. Refer to the <a href="./cache_explanation.md">Caching</a> doc for a more detailed explanation about how a cache works.</p> <p data-svelte-h="svelte-1ddilb6">Transformers offers several <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> classes that implement different caching mechanisms. Some of these <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> classes are optimized to save memory while others are designed to maximize generation speed. Refer to the table below to compare cache types and use it to help you select the best cache for your use case.</p> <table data-svelte-h="svelte-1w6mv1t"><thead><tr><th>Cache Type</th> <th>Memory Efficient </th> <th>Supports torch.compile()</th> <th>Initialization Recommended</th> <th>Latency</th> <th>Long Context Generation</th></tr></thead> <tbody><tr><td>Dynamic Cache</td> <td>No</td> <td>No</td> <td>No</td> <td>Mid</td> <td>No</td></tr> <tr><td>Static Cache</td> <td>No</td> <td>Yes</td> <td>Yes</td> <td>High</td> <td>No</td></tr> <tr><td>Offloaded Cache</td> <td>Yes</td> <td>No</td> <td>No</td> <td>Low</td> <td>Yes</td></tr> <tr><td>Offloaded Static Cache</td> <td>No</td> <td>Yes</td> <td>Yes</td> <td>High</td> <td>Yes</td></tr> <tr><td>Quantized Cache</td> <td>Yes</td> <td>No</td> <td>No</td> <td>Low</td> <td>Yes</td></tr> <tr><td>Sliding Window Cache</td> <td>No</td> <td>Yes</td> <td>Yes</td> <td>High</td> <td>No</td></tr> <tr><td>Sink Cache</td> <td>Yes</td> <td>No</td> <td>Yes</td> <td>Mid</td> <td>Yes</td></tr></tbody></table> <p data-svelte-h="svelte-1idz27a">This guide introduces you to the different <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> classes and shows you how to use them for generation.</p> <h2 class="relative group"><a id="default-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#default-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Default cache</span></h2> <p data-svelte-h="svelte-105fdwl">The <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> is the default cache class for most models. It allows the cache size to grow dynamically in order to store an increasing number of keys and values as generation progresses.</p> <p data-svelte-h="svelte-1wehsmj">Disable the cache by configuring <code>use_cache=False</code> in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, use_cache=<span class="hljs-literal">False</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-nbl90g">Cache classes can also be initialized first before calling and passing it to the models <a href="https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values" rel="nofollow">past_key_values</a> parameter. This cache initialization strategy is only recommended for some cache types.</p> <p data-svelte-h="svelte-109dkae">In most other cases, it’s easier to define the cache strategy in the <a href="https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation" rel="nofollow">cache_implementation</a> parameter.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| past_key_values = DynamicCache() | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, past_key_values=past_key_values)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="memory-efficient-caches" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#memory-efficient-caches"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Memory efficient caches</span></h2> <p data-svelte-h="svelte-15jbxrx">The KV cache can occupy a significant portion of memory and become a <a href="https://hf.co/blog/llama31#inference-memory-requirements" rel="nofollow">bottleneck</a> for long-context generation. Memory efficient caches focus on trading off speed for reduced memory usage. This is especially important for large language models (LLMs) and if your hardware is memory constrained.</p> <h3 class="relative group"><a id="offloaded-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#offloaded-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Offloaded cache</span></h3> <p data-svelte-h="svelte-fj09l4">The <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedCache">OffloadedCache</a> saves GPU memory by moving the KV cache for most model layers to the CPU. Only the current layer cache is maintained on the GPU during a models <code>forward</code> iteration over the layers. <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedCache">OffloadedCache</a> asynchronously prefetches the next layer cache and sends the previous layer cache back to the CPU.</p> <p data-svelte-h="svelte-1h7b3c0">This cache strategy always generates the same result as <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> and works as a drop-in replacement or fallback. You may want to use <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedCache">OffloadedCache</a> if you have a GPU and you’re getting out-of-memory (OOM) errors.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1i3749q">You may notice a small degradation in generation throughput compared to <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> depending on your model and generation choices (context size, number of generated tokens, number of beams, etc.).</p></div> <p data-svelte-h="svelte-1p7jbl8">Enable <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedCache">OffloadedCache</a> by configuring <code>cache_implementation="offloaded"</code> in either <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationConfig">GenerationConfig</a> or <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| ckpt = <span class="hljs-string">"microsoft/Phi-3-mini-4k-instruct"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
| model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"Fun fact: The shortest"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">23</span>, cache_implementation=<span class="hljs-string">"offloaded"</span>) | |
| <span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]) | |
| Fun fact: The shortest war <span class="hljs-keyword">in</span> history was between Britain <span class="hljs-keyword">and</span> Zanzibar on August <span class="hljs-number">27</span>, <span class="hljs-number">1896.</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jz2lpc">The example below shows how you can fallback on <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedCache">OffloadedCache</a> if you run out of memory.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">resilient_generate</span>(<span class="hljs-params">model, *args, **kwargs</span>): | |
| oom = <span class="hljs-literal">False</span> | |
| <span class="hljs-keyword">try</span>: | |
| <span class="hljs-keyword">return</span> model.generate(*args, **kwargs) | |
| <span class="hljs-keyword">except</span> torch.cuda.OutOfMemoryError <span class="hljs-keyword">as</span> e: | |
| <span class="hljs-built_in">print</span>(e) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"retrying with cache_implementation='offloaded'"</span>) | |
| oom = <span class="hljs-literal">True</span> | |
| <span class="hljs-keyword">if</span> oom: | |
| torch.cuda.empty_cache() | |
| kwargs[<span class="hljs-string">"cache_implementation"</span>] = <span class="hljs-string">"offloaded"</span> | |
| <span class="hljs-keyword">return</span> model.generate(*args, **kwargs) | |
| ckpt = <span class="hljs-string">"microsoft/Phi-3-mini-4k-instruct"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
| model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| prompt = [<span class="hljs-string">"okay "</span>*<span class="hljs-number">1000</span> + <span class="hljs-string">"Fun fact: The most"</span>] | |
| inputs = tokenizer(prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| beams = { <span class="hljs-string">"num_beams"</span>: <span class="hljs-number">40</span>, <span class="hljs-string">"num_beam_groups"</span>: <span class="hljs-number">40</span>, <span class="hljs-string">"num_return_sequences"</span>: <span class="hljs-number">40</span>, <span class="hljs-string">"diversity_penalty"</span>: <span class="hljs-number">1.0</span>, <span class="hljs-string">"max_new_tokens"</span>: <span class="hljs-number">23</span>, <span class="hljs-string">"early_stopping"</span>: <span class="hljs-literal">True</span>, } | |
| out = resilient_generate(model, **inputs, **beams) | |
| responses = tokenizer.batch_decode(out[:,-<span class="hljs-number">28</span>:], skip_special_tokens=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="quantized-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quantized-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quantized cache</span></h3> <p data-svelte-h="svelte-26oktx">The <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.QuantizedCache">QuantizedCache</a> reduces memory requirements by quantizing the KV values to a lower precision. <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.QuantizedCache">QuantizedCache</a> currently supports two quantization backends.</p> <ul data-svelte-h="svelte-d9n5a0"><li><a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.HQQQuantizedCache">HQQQuantizedCache</a> supports int2, int4, and int8 datatypes.</li> <li><a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.QuantoQuantizedCache">QuantoQuantizedCache</a> supports int2 and int4 datatypes. This is the default quantization backend.</li></ul> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1n6vxi">Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency.</p></div> <p data-svelte-h="svelte-8xkcq2">Enable <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.QuantizedCache">QuantizedCache</a> by configuring <code>cache_implementation="quantized"</code> in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationConfig">GenerationConfig</a>, and indicate the quantization backend in <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.QuantizedCacheConfig">QuantizedCacheConfig</a>. Any additional quantization related parameters should also be passed either as a dict or an instance of <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.QuantizedCacheConfig">QuantizedCacheConfig</a>. You should use the default values for these additional parameters unless you’re running out-of-memory. In that case, consider decreasing the residual length.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">HQQQuantizedCache </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">Quanto </div></div> <div class="language-select"><p data-svelte-h="svelte-1ia7ahf">For <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.HQQQuantizedCache">HQQQuantizedCache</a>, we recommend setting the <code>axis-key</code> and <code>axis-value</code> parameters to <code>1</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"I like rock music because"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">"quantized"</span>, cache_config={<span class="hljs-string">"axis-key"</span>: <span class="hljs-number">1</span>, <span class="hljs-string">"axis-value"</span>: <span class="hljs-number">1</span>, <span class="hljs-string">"backend"</span>: <span class="hljs-string">"hqq"</span>}) | |
| <span class="hljs-built_in">print</span>(tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]) | |
| I like rock music because it<span class="hljs-string">'s loud and energetic. It'</span>s a great way to express myself <span class="hljs-keyword">and</span> rel<!-- HTML_TAG_END --></pre></div> </div> <h3 class="relative group"><a id="sink-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sink-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sink cache</span></h3> <p data-svelte-h="svelte-1v80sjz"><a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SinkCache">SinkCache</a> is capable of generating very long sequences (“infinite length” according to the paper) by only retaining a few initial tokens from the sequence. These are called the <em>sink tokens</em> because they account for a significant portion of the attention scores during generation. Subsequent tokens are discarded on a sliding windowed basis, and only the latest <code>window_size</code> tokens are kept. This means most of the previous knowledge is discarded.</p> <p data-svelte-h="svelte-z2ocp4">The sink tokens allow a model to maintain stable performance even when it’s dealing with very long text sequences.</p> <p data-svelte-h="svelte-3gss7p">Enable <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SinkCache">SinkCache</a> by initializing it first with the <a href="https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.window_length" rel="nofollow">window_length</a> and <a href="https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.num_sink_tokens" rel="nofollow">num_sink_tokens</a> parameters before passing it to <a href="https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values" rel="nofollow">past_key_values</a> in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, SinkCache | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"This is a long story about unicorns, fairies and magic."</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| past_key_values = SinkCache(window_length=<span class="hljs-number">256</span>, num_sink_tokens=<span class="hljs-number">4</span>) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">30</span>, past_key_values=past_key_values) | |
| tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>] | |
| <span class="hljs-string">"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily"</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="speed-optimized-caches" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#speed-optimized-caches"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Speed optimized caches</span></h2> <p data-svelte-h="svelte-1uld3pp">The default <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> prevents you from taking advantage of just-in-time (JIT) optimizations because the cache size isn’t fixed. JIT optimizations enable you to maximize latency at the expense of memory usage. All of the following cache types are compatible with JIT optimizations like <a href="./llm_optims#static-kv-cache-and-torchcompile">torch.compile</a> to accelerate generation.</p> <h3 class="relative group"><a id="static-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#static-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Static cache</span></h3> <p data-svelte-h="svelte-8und9c">A <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.StaticCache">StaticCache</a> pre-allocates a specific maximum cache size for the kv pairs. You can generate up to the maximum cache size without needing to modify it.</p> <p data-svelte-h="svelte-1el4nli">Enable <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.StaticCache">StaticCache</a> by configuring <code>cache_implementation="static"</code> in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, torch_dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"Hello, my name is"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">"static"</span>) | |
| tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>] | |
| <span class="hljs-string">"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="offloaded-static-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#offloaded-static-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Offloaded static cache</span></h3> <p data-svelte-h="svelte-f4z67d">The <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedStaticCache">OffloadedStaticCache</a> is very similar to the <a href="#offloaded-cache">OffloadedCache</a> except the cache size is set to a maximum cache size. Otherwise, <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedStaticCache">OffloadedStaticCache</a> only keeps the current layer cache on the GPU and the rest are moved to the CPU.</p> <p data-svelte-h="svelte-qgwwb7">Enable <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.OffloadedStaticCache">OffloadedStaticCache</a> by configuring <code>cache_implementation="offloaded_static"</code> in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span>, torch_dtype=torch.float16, device_map=<span class="hljs-string">"auto"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"Hello, my name is"</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">20</span>, cache_implementation=<span class="hljs-string">"offloaded_static"</span>) | |
| tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>] | |
| <span class="hljs-string">"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zdzdtx">Cache offloading requires a CUDA GPU.</p> <h3 class="relative group"><a id="sliding-window-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sliding-window-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sliding window cache</span></h3> <p data-svelte-h="svelte-1jlffr1"><a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SlidingWindowCache">SlidingWindowCache</a> implements a sliding window over the previous kv pairs, and only keeps the last <code>sliding_window</code> tokens. This cache type is designed to only work with models that support <em>sliding window attention</em>, such as <a href="./model_doc/mistral">Mistral</a>. Older kv states are discarded and replaced by new kv states.</p> <p data-svelte-h="svelte-6z5f77">Enable <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SlidingWindowCache">SlidingWindowCache</a> by configuring <code>cache_implementation="sliding_window"</code> in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, SinkCache | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"mistralai/Mistral-7B-v0.1"</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"mistralai/Mistral-7B-v0.1"</span>, torch_dtype=torch.float16).to(<span class="hljs-string">"cuda:0"</span>) | |
| inputs = tokenizer(<span class="hljs-string">"Yesterday I was on a rock concert and."</span>, return_tensors=<span class="hljs-string">"pt"</span>).to(model.device) | |
| out = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">30</span>, cache_implementation=<span class="hljs-string">"sliding_window"</span>) | |
| tokenizer.batch_decode(out, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="model-caches" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-caches"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model caches</span></h2> <p data-svelte-h="svelte-kqp2g5">Some model types, like encoder-decoder models or <a href="./model_doc/gemma2">Gemma2</a> and <a href="./model_doc/mamba">Mamba</a>, have dedicated cache classes.</p> <h3 class="relative group"><a id="encoder-decoder-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#encoder-decoder-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Encoder-decoder cache</span></h3> <p data-svelte-h="svelte-vr36da"><a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.EncoderDecoderCache">EncoderDecoderCache</a> is designed for encoder-decoder models. It manages both the self-attention and cross-attention caches to ensure storage and retrieval of previous kv pairs. It is possible to individually set a different cache type for the encoder and decoder.</p> <p data-svelte-h="svelte-b180yg">This cache type doesn’t require any setup. It can be used when calling <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> or a models <code>forward</code> method.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1b9d0xx">The <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.EncoderDecoderCache">EncoderDecoderCache</a> currently only supports <a href="./model_doc/whisper">Whisper</a>.</p></div> <h3 class="relative group"><a id="model-specific-caches" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-specific-caches"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model-specific caches</span></h3> <p data-svelte-h="svelte-u1f4zv">Some models have a unique way of storing past kv pairs or states that is not compatible with any other cache classes.</p> <p data-svelte-h="svelte-5zl4jv"><a href="./model_doc/gemma2">Gemma2</a> requires <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.HybridCache">HybridCache</a>, which uses a combination of <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SlidingWindowCache">SlidingWindowCache</a> for sliding window attention and <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.StaticCache">StaticCache</a> for global attention under the hood.</p> <p data-svelte-h="svelte-6314k6"><a href="./model_doc/mamba">Mamba</a> requires <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.MambaCache">MambaCache</a> because the model doesn’t have an attention mechanism or kv states.</p> <h2 class="relative group"><a id="iterative-generation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#iterative-generation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Iterative generation</span></h2> <p data-svelte-h="svelte-67pkdu">A cache can also work in iterative generation settings where there is back-and-forth interaction with a model (chatbots). Like regular generation, iterative generation with a cache allows a model to efficiently handle ongoing conversations without recomputing the entire context at each step.</p> <p data-svelte-h="svelte-1ejbghf">For iterative generation with a cache, start by initializing an empty cache class and then you can feed in your new prompts. Keep track of dialogue history with a <a href="./chat_templating">chat template</a>.</p> <p data-svelte-h="svelte-1b914t7">If you’re using <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SinkCache">SinkCache</a>, the inputs need to be truncated to the maximum length because <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.SinkCache">SinkCache</a> can generate text that exceeds its maximum window size. However, the first input shouldn’t exceed the maximum cache length.</p> <p data-svelte-h="svelte-18xhg0c">The example below demonstrates how to use a cache for iterative generation.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer,AutoModelForCausalLM | |
| <span class="hljs-keyword">from</span> transformers.cache_utils <span class="hljs-keyword">import</span> ( | |
| DynamicCache, | |
| SinkCache, | |
| StaticCache, | |
| SlidingWindowCache, | |
| QuantoQuantizedCache, | |
| QuantizedCacheConfig, | |
| ) | |
| model_id = <span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span> | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">'auto'</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| user_prompts = [<span class="hljs-string">"Hello, what's your name?"</span>, <span class="hljs-string">"Btw, yesterday I was on a rock concert."</span>] | |
| past_key_values = DynamicCache() | |
| max_cache_length = past_key_values.get_max_length() | |
| messages = [] | |
| <span class="hljs-keyword">for</span> prompt <span class="hljs-keyword">in</span> user_prompts: | |
| messages.append({<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: prompt}) | |
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"pt"</span>, return_dict=<span class="hljs-literal">True</span>).to(model.device) | |
| <span class="hljs-keyword">if</span> <span class="hljs-built_in">isinstance</span>(past_key_values, SinkCache): | |
| inputs = {k: v[:, -max_cache_length:] <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> inputs.items()} | |
| input_length = inputs[<span class="hljs-string">"input_ids"</span>].shape[<span class="hljs-number">1</span>] | |
| outputs = model.generate(**inputs, do_sample=<span class="hljs-literal">False</span>, max_new_tokens=<span class="hljs-number">256</span>, past_key_values=past_key_values) | |
| completion = tokenizer.decode(outputs[<span class="hljs-number">0</span>, input_length: ], skip_special_tokens=<span class="hljs-literal">True</span>) | |
| messages.append({<span class="hljs-string">"role"</span>: <span class="hljs-string">"assistant"</span>, <span class="hljs-string">"content"</span>: completion})<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="prefill-a-cache" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#prefill-a-cache"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Prefill a cache</span></h2> <p data-svelte-h="svelte-brorkr">In some situations, you may want to fill a <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> with kv pairs for a certain prefix prompt and reuse it to generate different sequences.</p> <p data-svelte-h="svelte-1bie09z">The example below initializes a <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.StaticCache">StaticCache</a>, and then caches an initial prompt. Now you can generate several sequences from the prefilled prompt.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> copy | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache | |
| model_id = <span class="hljs-string">"meta-llama/Llama-2-7b-chat-hf"</span> | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">"cuda"</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| <span class="hljs-comment"># Init StaticCache with big enough max-length (1024 tokens for the below example) </span> | |
| <span class="hljs-comment"># You can also init a DynamicCache, if that suits you better</span> | |
| prompt_cache = StaticCache(config=model.config, max_batch_size=<span class="hljs-number">1</span>, max_cache_len=<span class="hljs-number">1024</span>, device=<span class="hljs-string">"cuda"</span>, dtype=torch.bfloat16) | |
| INITIAL_PROMPT = <span class="hljs-string">"You are a helpful assistant. "</span> | |
| inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cuda"</span>) | |
| <span class="hljs-comment"># This is the common prompt cached, we need to run forward without grad to be able to copy</span> | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values | |
| prompts = [<span class="hljs-string">"Help me to write a blogpost about travelling."</span>, <span class="hljs-string">"What is the capital of France?"</span>] | |
| responses = [] | |
| <span class="hljs-keyword">for</span> prompt <span class="hljs-keyword">in</span> prompts: | |
| new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors=<span class="hljs-string">"pt"</span>).to(<span class="hljs-string">"cuda"</span>) | |
| past_key_values = copy.deepcopy(prompt_cache) | |
| outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=<span class="hljs-number">20</span>) | |
| response = tokenizer.batch_decode(outputs)[<span class="hljs-number">0</span>] | |
| responses.append(response) | |
| <span class="hljs-built_in">print</span>(responses)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/kv_cache.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1bm5psi = { | |
| assets: "/docs/transformers/pr_36839/en", | |
| base: "/docs/transformers/pr_36839/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"), | |
| import("/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 42], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 71 kB
- Xet hash:
- 110efce2a7346f09d11377501a6dd15702bd1d5f664fdc9ac28db6041b13facf
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.