Buckets:

hf-doc-build/doc-dev / transformers /pr_33892 /en /cache_explanation.html
rtrm's picture
download
raw
58.3 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Caching&quot;,&quot;local&quot;:&quot;caching&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Attention matrices&quot;,&quot;local&quot;:&quot;attention-matrices&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cache class&quot;,&quot;local&quot;:&quot;cache-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cache storage implementation&quot;,&quot;local&quot;:&quot;cache-storage-implementation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cache position&quot;,&quot;local&quot;:&quot;cache-position&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/10.9a7aa00b.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js">
<link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Caching&quot;,&quot;local&quot;:&quot;caching&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Attention matrices&quot;,&quot;local&quot;:&quot;attention-matrices&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cache class&quot;,&quot;local&quot;:&quot;cache-class&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cache storage implementation&quot;,&quot;local&quot;:&quot;cache-storage-implementation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cache position&quot;,&quot;local&quot;:&quot;cache-position&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="caching" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#caching"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Caching</span></h1> <p data-svelte-h="svelte-1vgfrh0">Imagine you’re having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right?</p> <p data-svelte-h="svelte-110unin">You can extend this analogy to transformer models. Autoregressive model generation can be slow because it makes a prediction one token at a time. Each new prediction is dependent on all the previous context.</p> <p data-svelte-h="svelte-1sc0ng3">To predict the 1000th token, the model requires information from the previous 999 tokens. The information is represented as matrix multiplications across the token representations.</p> <p data-svelte-h="svelte-stwnf4">To predict the 1001th token, you need the same information from the previous 999 tokens in addition to any information from the 1000th token. This is a lot of matrix multiplications a model has to compute over and over for each token!</p> <p data-svelte-h="svelte-k1lwaa">A key-value (KV) cache eliminates this inefficiency by storing kv pairs derived from the attention layers of previously processed tokens. The stored kv pairs are retrieved from the cache and reused for subsequent tokens, avoiding the need to recompute.</p> <blockquote class="warning" data-svelte-h="svelte-1f3w2hm"><p>Caching should only be used for <strong>inference</strong>. It may cause unexpected errors if it’s enabled during training.</p></blockquote> <p data-svelte-h="svelte-aj1f1y">To better understand how and why caching works, let’s take a closer look at the structure of the attention matrices.</p> <h2 class="relative group"><a id="attention-matrices" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention-matrices"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention matrices</span></h2> <p>The <strong data-svelte-h="svelte-1et3y0d">scaled dot-product attention</strong> is calculated as shown below for a batch of size <code data-svelte-h="svelte-1y90nls">b</code>, number of attention heads <code data-svelte-h="svelte-1blmdqe">h</code>, sequence length so far <code data-svelte-h="svelte-18tc35m">T</code>, and dimension per attention head <code data-svelte-h="svelte-krx0ij">d_head</code>.
<!-- HTML_TAG_START --><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mtext>Attention</mtext><mo stretchy="false">(</mo><mi>Q</mi><mo separator="true">,</mo><mi>K</mi><mo separator="true">,</mo><mi>V</mi><mo stretchy="false">)</mo><mo>=</mo><mtext>softmax</mtext><mrow><mo fence="true">(</mo><mfrac><mrow><mi>Q</mi><msup><mi>K</mi><mi mathvariant="normal"></mi></msup></mrow><msqrt><msub><mi>d</mi><mtext>head</mtext></msub></msqrt></mfrac><mo>×</mo><mtext>mask</mtext><mo fence="true">)</mo></mrow><mi>V</mi></mrow><annotation encoding="application/x-tex">
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V
</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">Attention</span></span><span class="mopen">(</span><span class="mord mathnormal">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:2.4761em;vertical-align:-0.95em;"></span><span class="mord text"><span class="mord">softmax</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size3">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.5261em;"><span style="top:-2.2528em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8572em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord"><span class="mord mathnormal">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">head</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><span style="top:-2.8172em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.1828em;"><span></span></span></span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">Q</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8491em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord text"><span class="mord">mask</span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size3">)</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span></span></span></span></span><!-- HTML_TAG_END --></p> <p data-svelte-h="svelte-11ldl4s">The query (<code>Q</code>), key (<code>K</code>), and value (<code>V</code>) matrices are projections from the input embeddings of shape <code>(b, h, T, d_head)</code>.</p> <p>For causal attention, the mask prevents the model from attending to future tokens. Once a token is processed, its representation never changes with respect to future tokens, which means <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>K</mi><mtext>past</mtext></msub></mrow><annotation encoding="application/x-tex"> K_{\text{past}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9694em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0715em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">past</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> and <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>V</mi><mtext>past</mtext></msub></mrow><annotation encoding="application/x-tex"> V_{\text{past}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9694em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.2222em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">past</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> can be cached and reused to compute the last token’s representation.
<!-- HTML_TAG_START --><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mtext>Attention</mtext><mo stretchy="false">(</mo><msub><mi>q</mi><mi>t</mi></msub><mo separator="true">,</mo><mo stretchy="false">[</mo><munder><munder><mrow><msub><mi>k</mi><mn>1</mn></msub><mo separator="true">,</mo><msub><mi>k</mi><mn>2</mn></msub><mo separator="true">,</mo><mo></mo><mo separator="true">,</mo><msub><mi>k</mi><mrow><mi>t</mi><mo></mo><mn>1</mn></mrow></msub></mrow><mo stretchy="true"></mo></munder><mtext>cached</mtext></munder><mo separator="true">,</mo><msub><mi>k</mi><mi>t</mi></msub><mo stretchy="false">]</mo><mo separator="true">,</mo><mo stretchy="false">[</mo><munder><munder><mrow><msub><mi>v</mi><mn>1</mn></msub><mo separator="true">,</mo><msub><mi>v</mi><mn>2</mn></msub><mo separator="true">,</mo><mo></mo><mo separator="true">,</mo><msub><mi>v</mi><mrow><mi>t</mi><mo></mo><mn>1</mn></mrow></msub></mrow><mo stretchy="true"></mo></munder><mtext>cached</mtext></munder><mo separator="true">,</mo><msub><mi>v</mi><mi>t</mi></msub><mo stretchy="false">]</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">
\text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])
</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:2.2924em;vertical-align:-1.5424em;"></span><span class="mord text"><span class="mord">Attention</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">[</span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944em;"><span style="top:-1.4576em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">cached</span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6944em;"><span class="svg-align" style="top:-2.1437em;"><span class="pstrut" style="height:3em;"></span><span class="stretchy" style="height:0.548em;min-width:1.6em;"><span class="brace-left" style="height:0.548em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"/></svg></span><span class="brace-center" style="height:0.548em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"/></svg></span><span class="brace-right" style="height:0.548em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"/></svg></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0315em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0315em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0315em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8563em;"><span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.5424em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0315em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">]</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mopen">[</span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-1.4576em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">cached</span></span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span class="svg-align" style="top:-2.1437em;"><span class="pstrut" style="height:3em;"></span><span class="stretchy" style="height:0.548em;min-width:1.6em;"><span class="brace-left" style="height:0.548em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"/></svg></span><span class="brace-center" style="height:0.548em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"/></svg></span><span class="brace-right" style="height:0.548em;"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"/></svg></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight"></span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2083em;"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.8563em;"><span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.5424em;"><span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">])</span></span></span></span></span><!-- HTML_TAG_END --></p> <p>At inference time, you only need the last token’s query to compute the representation <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>x</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex"> x_t </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> that predicts the next token $ t+1 $. At each step, the new key and value vectors are <strong data-svelte-h="svelte-ho7h3">stored</strong> in the cache and <strong data-svelte-h="svelte-ygjkh1">appended</strong> to the past keys and values.
<!-- HTML_TAG_START --><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>K</mi><mtext>cache</mtext></msub><mo></mo><mtext>concat</mtext><mo stretchy="false">(</mo><msub><mi>K</mi><mtext>past</mtext></msub><mo separator="true">,</mo><msub><mi>k</mi><mi>t</mi></msub><mo stretchy="false">)</mo><mo separator="true">,</mo><mspace width="1em"/><msub><mi>V</mi><mtext>cache</mtext></msub><mo></mo><mtext>concat</mtext><mo stretchy="false">(</mo><msub><mi>V</mi><mtext>past</mtext></msub><mo separator="true">,</mo><msub><mi>v</mi><mi>t</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">
K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)
</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0715em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">cache</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em;"></span><span class="mord text"><span class="mord">concat</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0715em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">past</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0315em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:1em;"></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.2222em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">cache</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em;"></span><span class="mord text"><span class="mord">concat</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.2222em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">past</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.2806em;"><span style="top:-2.55em;margin-left:-0.0359em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><!-- HTML_TAG_END --></p> <p data-svelte-h="svelte-fo5dj4">Attention is calculated independently in each layer of the model, and caching is done on a per-layer basis.</p> <p data-svelte-h="svelte-ebwicu">Refer to the table below to compare how caching improves efficiency.</p> <table data-svelte-h="svelte-1s32lyj"><thead><tr><th>without caching</th> <th>with caching</th></tr></thead> <tbody><tr><td>for each step, recompute all previous <code>K</code> and <code>V</code></td> <td>for each step, only compute current <code>K</code> and <code>V</code></td></tr> <tr><td>attention cost per step is <strong>quadratic</strong> with sequence length</td> <td>attention cost per step is <strong>linear</strong> with sequence length (memory grows linearly, but compute/token remains low)</td></tr></tbody></table> <h2 class="relative group"><a id="cache-class" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cache-class"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Cache class</span></h2> <p data-svelte-h="svelte-ranj33">A basic KV cache interface takes a key and value tensor for the current token and returns the updated <code>K</code> and <code>V</code> tensors. This is internally managed by a model’s <code>forward</code> method.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-wpprcc">When you use Transformers’ <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> class, the self-attention module performs several critical steps to integrate past and present information.</p> <ol data-svelte-h="svelte-l3xagl"><li><p>The attention module concatenates current kv pairs with past kv pairs stored in the cache. This creates attentions weights with the shape <code>(new_tokens_length, past_kv_length + new_tokens_length)</code>. The current and past kv pairs are essentially combined to compute the attention scores, ensuring a model is aware of previous context and the current input.</p></li> <li><p>When the <code>forward</code> method is called iteratively, it’s crucial that the attention mask shape matches the combined length of the past and current kv pairs. The attention mask should have the shape <code>(batch_size, past_kv_length + new_tokens_length)</code>. This is typically handled internally in <a href="/docs/transformers/pr_33892/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a>, but if you want to implement your own generation loop with <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a>, keep this in mind! The attention mask should hold the past and current token values.</p></li> <li><p>It is also important to be aware of the <code>cache_position</code>. This is important if you want to reuse a prefilled <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.Cache">Cache</a> with the <code>forward</code> method because you have to pass a valid <code>cache_position</code> value. This indicates the input positions in a sequence. <code>cache_position</code> is unaffected by padding, and it always adds one more position for each token. For example, if a kv cache contains 10 tokens - regardless of pad tokens - the cache position for the next token should be <code>torch.tensor([10])</code>.</p></li></ol> <h2 class="relative group"><a id="cache-storage-implementation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cache-storage-implementation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Cache storage implementation</span></h2> <p data-svelte-h="svelte-1su092y">Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape <code>[batch_size, num_heads, seq_len, head_dim]</code>.</p> <p data-svelte-h="svelte-159gg5y">Layers can be of different types (e.g. <code>DynamicLayer</code>, <code>StaticLayer</code>, <code>StaticSlidingWindowLayer</code>), which mostly changes how sequence length is handled and how the cache is updated.</p> <p data-svelte-h="svelte-1nihuku">The simplest is a <code>DynamicLayer</code> that grows as more tokens are processed. The sequence length dimension (<code>seq_len</code>) increases with each new token:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-<span class="hljs-number">2</span>)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sgb5p0">Other layer types like <code>StaticLayer</code> and <code>StaticSlidingWindowLayer</code> have a fixed sequence length that is set when the cache is created. This makes them compatible with <code>torch.compile</code>. In the case of <code>StaticSlidingWindowLayer</code>, existing tokens are shifted out of the cache when a new token is added.</p> <p data-svelte-h="svelte-zfvq3v">The example below demonstrates how to create a generation loop with <a href="/docs/transformers/pr_33892/en/internal/generation_utils#transformers.DynamicCache">DynamicCache</a>. As discussed, the attention mask is a concatenation of past and current token values and <code>1</code> is added to the cache position for the next token.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
device = Accelerator().device
model_id = <span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache(config=model.config)
messages = [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;Hello, what&#x27;s your name.&quot;</span>}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>, return_dict=<span class="hljs-literal">True</span>).to(model.device)
generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[<span class="hljs-number">1</span>], dtype=torch.int64, device=model.device)
max_new_tokens = <span class="hljs-number">10</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(max_new_tokens):
outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=<span class="hljs-literal">True</span>)
<span class="hljs-comment"># Greedily sample one next token</span>
next_token_ids = outputs.logits[:, -<span class="hljs-number">1</span>:].argmax(-<span class="hljs-number">1</span>)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-<span class="hljs-number">1</span>)
<span class="hljs-comment"># Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token</span>
<span class="hljs-comment"># and expanding attn mask for the new token, as explained above</span>
attention_mask = inputs[<span class="hljs-string">&quot;attention_mask&quot;</span>]
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[<span class="hljs-number">0</span>], <span class="hljs-number">1</span>))], dim=-<span class="hljs-number">1</span>)
inputs = {<span class="hljs-string">&quot;input_ids&quot;</span>: next_token_ids, <span class="hljs-string">&quot;attention_mask&quot;</span>: attention_mask}
cache_position = cache_position[-<span class="hljs-number">1</span>:] + <span class="hljs-number">1</span> <span class="hljs-comment"># add one more position for the next token</span>
<span class="hljs-built_in">print</span>(tokenizer.batch_decode(generated_ids, skip_special_tokens=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>])
<span class="hljs-string">&quot;[INST] Hello, what&#x27;s your name. [/INST] Hello! My name is LLaMA,&quot;</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="cache-position" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cache-position"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Cache position</span></h2> <p data-svelte-h="svelte-anb3re">The cache position tracks where to insert new tokens in the attention cache. It represents the <em>absolute</em> position of each token in the context, independent of padding or batch structure. Suppose you already cached <code>N</code> tokens and are now processing <code>K</code> new tokens. The cache position for the new tokens will range from <code>N</code> to <code>N + K - 1</code>. In other words, you’re processing tokens at positions - <code>[N, N + 1, N + 2, ..., N + K - 1]</code>.</p> <p data-svelte-h="svelte-ecafas">Cache position is used internally for two purposes:</p> <ol data-svelte-h="svelte-afdkur"><li>Selecting new tokens to process in the input sequence and ensuring only tokens that haven’t been cached yet are passed to the model’s <code>forward</code>.</li> <li>Storing key/value pairs at the correct positions in the cache. This is especially important for fixed-size caches, that pre-allocates a specific cache length.</li></ol> <p data-svelte-h="svelte-j9c3os">The generation loop usually takes care of the cache position, but if you’re writing a custom generation method, it is important that cache positions are accurate since they are used to write and read key/value states into fixed slots.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModelForCausalLM, DynamicCache
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
device = Accelerator().device
model_id = <span class="hljs-string">&quot;meta-llama/Llama-2-7b-chat-hf&quot;</span>
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;You are a helpful assistant.&quot;</span>}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">&quot;pt&quot;</span>, return_dict=<span class="hljs-literal">True</span>).to(model.device)
generated_ids = model.generate(**inputs, use_cache=<span class="hljs-literal">True</span>, max_new_tokens=<span class="hljs-number">10</span>)
<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/cache_explanation.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_16tnnm8 = {
assets: "/docs/transformers/pr_33892/en",
base: "/docs/transformers/pr_33892/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"),
import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 10],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
58.3 kB
·
Xet hash:
5ee884e233a8623de2130129c647b9780df8224ce22380e98b4ee363bec7d85a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.