Buckets:

hf-doc-build/doc-dev / transformers /pr_36839 /en /cache_explanation.html
rtrm's picture
download
raw
18.6 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Caching&quot;,&quot;local&quot;:&quot;caching&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Cache class&quot;,&quot;local&quot;:&quot;cache-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Legacy cache format&quot;,&quot;local&quot;:&quot;legacy-cache-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_36839/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/scheduler.01eeda35.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/singletons.177df05e.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.4862150a.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/paths.517376d1.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.6dd51b66.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/0.8897c14d.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/8.46114246.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/Tip.de9bae2b.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/CodeBlock.864da1b0.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/EditOnGithub.7faefd25.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Caching&quot;,&quot;local&quot;:&quot;caching&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Cache class&quot;,&quot;local&quot;:&quot;cache-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Legacy cache format&quot;,&quot;local&quot;:&quot;legacy-cache-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="caching" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#caching"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Caching</span></h1> <p data-svelte-h="svelte-1vgfrh0">Imagine you’re having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right?</p> <p data-svelte-h="svelte-110unin">You can extend this analogy to transformer models. Autoregressive model generation can be slow because it makes a prediction one token at a time. Each new prediction is dependent on all the previous context.</p> <p data-svelte-h="svelte-1sc0ng3">To predict the 1000th token, the model requires information from the previous 999 tokens. The information is represented as matrix multiplications across the token representations.</p> <p data-svelte-h="svelte-stwnf4">To predict the 1001th token, you need the same information from the previous 999 tokens in addition to any information from the 1000th token. This is a lot of matrix multiplications a model has to compute over and over for each token!</p> <p data-svelte-h="svelte-k1lwaa">A key-value (KV) cache eliminates this inefficiency by storing kv pairs derived from the attention layers of previously processed tokens. The stored kv pairs are retrieved from the cache and reused for subsequent tokens, avoiding the need to recompute.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-16dnkl8">Caching should only be used for <strong>inference</strong>. It may cause unexpected errors if it’s enabled during training.</p></div> <h2 class="relative group"><a id="cache-class" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cache-class"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Cache class</span></h2> <p data-svelte-h="svelte-84etsa">When you use Transformers’ <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> class, the self-attention module performs several critical steps to integrate past and present information.</p> <ol data-svelte-h="svelte-qrsqjb"><li><p>The attention module concatenates current kv pairs with past kv pairs stored in the cache. This creates attentions weights with the shape <code>(new_tokens_length, past_kv_length + new_tokens_length)</code>. The current and past kv pairs are essentially combined to compute the attention scores, ensuring a model is aware of previous context and the current input.</p></li> <li><p>When the <code>forward</code> method is called iteratively, it’s crucial that the attention mask shape matches the combined length of the past and current kv pairs. The attention mask should have the shape <code>(batch_size, past_kv_length + new_tokens_length)</code>. This is typically handled internally in <a href="/docs/transformers/pr_36839/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>, but if you want to implement your own generation loop with <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a>, keep this in mind! The attention mask should hold the past and current token values.</p></li> <li><p>It is also important to be aware of the <code>cache_position</code>. This is important if you want to reuse a prefilled <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> with the <code>forward</code> method because you have to pass a valid <code>cache_position</code> value. This indicates the input positions in a sequence. <code>cache_position</code> is unaffected by padding, and it always adds one more position for each token. For example, if a kv cache contains 10 tokens - regardless of pad tokens - the cache position for the next token should be <code>torch.tensor([10])</code>.</p></li></ol> <p data-svelte-h="svelte-1mm88cx">The example below demonstrates how to create a generation loop with <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a>. As discussed, the attention mask is a concatenation of past and current token values and <code>1</code> is added to the cache position for the next token.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache
model_id = <span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=<span class="hljs-string">&quot;cuda:0&quot;</span>)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache()
messages = [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;Hello, what&#x27;s your name.&quot;</span>}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>, return_dict=<span class="hljs-literal">True</span>).to(<span class="hljs-string">&quot;cuda:0&quot;</span>)
generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[<span class="hljs-number">1</span>], dtype=torch.int64, device=<span class="hljs-string">&quot;cuda:0&quot;</span>)
max_new_tokens = <span class="hljs-number">10</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(max_new_tokens):
outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=<span class="hljs-literal">True</span>)
<span class="hljs-comment"># Greedily sample one next token</span>
next_token_ids = outputs.logits[:, -<span class="hljs-number">1</span>:].argmax(-<span class="hljs-number">1</span>)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-<span class="hljs-number">1</span>)
<span class="hljs-comment"># Prepare inputs for the next generation step by leaaving unprocessed tokens, in our case we have only one new token</span>
<span class="hljs-comment"># and expanding attn mask for the new token, as explained above</span>
attention_mask = inputs[<span class="hljs-string">&quot;attention_mask&quot;</span>]
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[<span class="hljs-number">0</span>], <span class="hljs-number">1</span>))], dim=-<span class="hljs-number">1</span>)
inputs = {<span class="hljs-string">&quot;input_ids&quot;</span>: next_token_ids, <span class="hljs-string">&quot;attention_mask&quot;</span>: attention_mask}
cache_position = cache_position[-<span class="hljs-number">1</span>:] + <span class="hljs-number">1</span> <span class="hljs-comment"># add one more position for the next token</span>
<span class="hljs-built_in">print</span>(tokenizer.batch_decode(generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
<span class="hljs-string">&quot;[INST] Hello, what&#x27;s your name. [/INST] Hello! My name is LLaMA,&quot;</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="legacy-cache-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#legacy-cache-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Legacy cache format</span></h2> <p data-svelte-h="svelte-frtlpc">Before the <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.Cache">Cache</a> class, the cache used to be stored as a tuple of tuples of tensors. This format has is dynamic because it grows as text is generated, similar to <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a>.</p> <p data-svelte-h="svelte-216z6t">If your project depends on this legacy format, you can convert between <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a> and a tuple of tuples as shown below with the <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache.from_legacy_cache">from_legacy_cache()</a> and <a href="/docs/transformers/pr_36839/en/internal/generation_utils#transformers.DynamicCache.to_legacy_cache">DynamicCache.to_legacy_cache()</a> functions. This is helpful if you have custom logic for manipulating a cache in a specific format.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>, torch_dtype=torch.float16, device_map=<span class="hljs-string">&quot;auto&quot;</span>)
inputs = tokenizer(<span class="hljs-string">&quot;Hello, my name is&quot;</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>).to(model.device)
<span class="hljs-comment"># `return_dict_in_generate=True` is required to return the cache and `return_legacy_cache` forces the returned cache</span>
<span class="hljs-comment"># in the the legacy format</span>
generation_outputs = model.generate(**inputs, return_dict_in_generate=<span class="hljs-literal">True</span>, return_legacy_cache=<span class="hljs-literal">True</span>, max_new_tokens=<span class="hljs-number">5</span>)
cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_format_cache = cache.to_legacy_cache()<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/cache_explanation.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1bm5psi = {
assets: "/docs/transformers/pr_36839/en",
base: "/docs/transformers/pr_36839/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"),
import("/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 8],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
18.6 kB
·
Xet hash:
43c365078873ce5f76683c7dd6b2e2fabc9178d3fc1eb9e5a3967e21ca91a7dd

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.