Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Attention backends","local":"attention-backends","sections":[{"title":"Set an attention backend","local":"set-an-attention-backend","sections":[{"title":"Kernels","local":"kernels","sections":[],"depth":3},{"title":"SDPA context manager","local":"sdpa-context-manager","sections":[],"depth":3}],"depth":2},{"title":"Backbone-specific attention","local":"backbone-specific-attention","sections":[],"depth":2},{"title":"Create a new attention function","local":"create-a-new-attention-function","sections":[],"depth":2},{"title":"AttentionMaskInterface","local":"attentionmaskinterface","sections":[],"depth":2},{"title":"Build an attention mask","local":"build-an-attention-mask","sections":[],"depth":2},{"title":"Bidirectional attention","local":"bidirectional-attention","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/start.3d6cca8a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/scheduler.31fdf58d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/singletons.6af0ff6e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.252883d5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/paths.299a376b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/entry/app.b6ccab0d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/preload-helper.c438fa0a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/index.2f76fdf0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/0.f2629ed0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/nodes/9.d8bcdc60.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CopyLLMTxtMenu.ad38f6ea.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.fd2f7a8a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/IconCopy.ac192424.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/CodeBlock.e52df5d6.js"> | |
| <link rel="modulepreload" href="/docs/transformers/main/en/_app/immutable/chunks/HfOption.fb051768.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Attention backends","local":"attention-backends","sections":[{"title":"Set an attention backend","local":"set-an-attention-backend","sections":[{"title":"Kernels","local":"kernels","sections":[],"depth":3},{"title":"SDPA context manager","local":"sdpa-context-manager","sections":[],"depth":3}],"depth":2},{"title":"Backbone-specific attention","local":"backbone-specific-attention","sections":[],"depth":2},{"title":"Create a new attention function","local":"create-a-new-attention-function","sections":[],"depth":2},{"title":"AttentionMaskInterface","local":"attentionmaskinterface","sections":[],"depth":2},{"title":"Build an attention mask","local":"build-an-attention-mask","sections":[],"depth":2},{"title":"Bidirectional attention","local":"bidirectional-attention","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="attention-backends" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention-backends"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention backends</span></h1> <p data-svelte-h="svelte-blwksn">All attention implementations perform the same computation. Every token is compared to every other token. The difference is <em>how</em> the computation is performed. Basic attention scales poorly because it materializes the full attention matrix in memory, creating bottlenecks that slow down inference. Optimized implementations rearrange the math to reduce memory traffic for faster, more affordable inference.</p> <p data-svelte-h="svelte-r8qqxs">The <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionInterface">AttentionInterface</a> provides optimized attention implementations. It decouples the attention implementation from the model implementation to simplify experimentation with different functions. Add new backends easily with this consistent interface.</p> <table data-svelte-h="svelte-1p5os4e"><thead><tr><th>attention backend</th> <th>description</th></tr></thead> <tbody><tr><td><code>"flash_attention_3"</code></td> <td>improves FlashAttention-2 by also overlapping operations and fusing forward and backward passes more tightly</td></tr> <tr><td><code>"flash_attention_2"</code></td> <td>tiles computations into smaller blocks and uses fast on-chip memory</td></tr> <tr><td><code>"flex_attention"</code></td> <td>framework for specifying custom attention patterns (sparse, block-local, sliding window) without writing low-level kernels by hand</td></tr> <tr><td><code>"sdpa"</code></td> <td>built-in PyTorch implementation of <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">scaled dot product attention</a></td></tr> <tr><td><code>“paged|flash_attention_3”</code></td> <td>Paged version of FlashAttention-3</td></tr> <tr><td><code>“paged|flash_attention_2”</code></td> <td>Paged version of FlashAttention-2</td></tr> <tr><td><code>“paged|sdpa”</code></td> <td>Paged version of SDPA</td></tr> <tr><td><code>“paged|eager”</code></td> <td>Paged version of eager</td></tr></tbody></table> <h2 class="relative group"><a id="set-an-attention-backend" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#set-an-attention-backend"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Set an attention backend</span></h2> <p data-svelte-h="svelte-1vb89rb">Use the <code>attn_implementation</code> argument in <a href="/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a> to instantiate a model with a specific attention function.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>, attn_implementation=<span class="hljs-string">"flash_attention_2"</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sameut">Switch between attention backends at runtime without reloading the model using <a href="/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.set_attn_implementation">set_attn_implementation()</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START -->model.set_attn_implementation(<span class="hljs-string">"sdpa"</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="kernels" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#kernels"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Kernels</span></h3> <p data-svelte-h="svelte-c68q69">Download and load compiled compute kernels directly from the <a href="https://huggingface.co/models?other=kernels" rel="nofollow">Hub</a> at runtime with the <a href="https://huggingface.co/docs/kernels/index" rel="nofollow">Kernels</a> library. This avoids packaging issues from mismatched PyTorch or CUDA versions.</p> <p data-svelte-h="svelte-1t8tjbu">Kernels automatically register to <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionInterface">AttentionInterface</a> upon detection. You don’t need to install the FlashAttention package explicitly.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>, attn_implementation=<span class="hljs-string">"kernels-community/flash-attn2"</span> | |
| )<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="sdpa-context-manager" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sdpa-context-manager"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>SDPA context manager</span></h3> <p data-svelte-h="svelte-1bvapr3">PyTorch’s scaled dot product attention (SDPA) selects the fastest attention function for CUDA backends automatically. It defaults to the PyTorch C++ implementation for other backends.</p> <p data-svelte-h="svelte-141spqq">Force SDPA to use a specific implementation with the <a href="https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html" rel="nofollow">torch.nn.attention.sdpa_kernel</a> context manager.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> torch.nn.attention <span class="hljs-keyword">import</span> SDPBackend, sdpa_kernel | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained( | |
| <span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>, attn_implementation=<span class="hljs-string">"sdpa"</span> | |
| ) | |
| <span class="hljs-keyword">with</span> sdpa_kernel(SDPBackend.FLASH_ATTENTION): | |
| outputs = model.generate(**inputs)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="backbone-specific-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#backbone-specific-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Backbone-specific attention</span></h2> <p data-svelte-h="svelte-7clk6k">Multimodal models use different backbones for each modality. Optimize performance by assigning specific attention functions to each backbone. Some vision backbones perform better in fp32, for example, which FlashAttention does not support.</p> <p data-svelte-h="svelte-18si8a3">Map vision backbones to different attention functions with a dict while the text backbone continues to use FlashAttention. Keys in the attention implementation must match sub-config names.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForImageTextToText | |
| attention_implementation_per_backbone = {<span class="hljs-string">"vision_config"</span>: <span class="hljs-string">"sdpa"</span>, <span class="hljs-string">"text_config"</span>: <span class="hljs-string">"flash_attention_2"</span>} | |
| <span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> attention_implementation_per_backbone: | |
| <span class="hljs-keyword">assert</span> key <span class="hljs-keyword">in</span> model.config.sub_configs, <span class="hljs-string">f"Invalid key in `attention_implementation`"</span> | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| <span class="hljs-string">"facebook/chameleon-7b"</span>, attn_implementation=attention_implementation_per_backbone | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nnklnr">Omit certain backbones from the dict to use the default attention function (SDPA).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START -->model = AutoModelForImageTextToText.from_pretrained( | |
| <span class="hljs-string">"facebook/chameleon-7b"</span>, attn_implementation={<span class="hljs-string">"text_config"</span>: <span class="hljs-string">"flash_attention_2"</span>} | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1n9mqjt">Set the same attention function for all backbones with a single string.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START -->model = AutoModelForImageTextToText.from_pretrained( | |
| <span class="hljs-string">"facebook/chameleon-7b"</span>, attn_implementation=<span class="hljs-string">"eager"</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1tuj59m">Set the attention function globally with an empty key.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START -->model = AutoModelForImageTextToText.from_pretrained( | |
| <span class="hljs-string">"facebook/chameleon-7b"</span>, attn_implementation={<span class="hljs-string">""</span>: <span class="hljs-string">"eager"</span>} | |
| )<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="create-a-new-attention-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#create-a-new-attention-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Create a new attention function</span></h2> <p data-svelte-h="svelte-dr91k7">Customize or create new attention functions by adding them to the attention registry with <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionInterface.register">AttentionInterface.register()</a>. Models use these functions through the <code>attn_implementation</code> argument.</p> <blockquote class="warning" data-svelte-h="svelte-1m93d5q"><p><br> | |
| Register a matching attention mask function when you register a custom attention function. If the custom <code>attn_implementation</code> name is not registered in <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionMaskInterface">AttentionMaskInterface</a>, Transformers skips mask creation and passes <code>attention_mask=None</code> to the attention layers. Your attention function must handle causal, padding, packing, or sliding-window constraints itself, or those constraints can be silently dropped.</p></blockquote> <p data-svelte-h="svelte-uxj9xe">This example customizes the attention function to print a statement for each layer. It keeps the mask in the original implementation by registering <code>masking_utils.sdpa_mask</code> as the attention mask function.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-python "><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface | |
| <span class="hljs-keyword">from</span> transformers.integrations.sdpa_attention <span class="hljs-keyword">import</span> sdpa_attention_forward | |
| <span class="hljs-keyword">from</span> transformers.masking_utils <span class="hljs-keyword">import</span> sdpa_mask | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">my_new_sdpa</span>(<span class="hljs-params">*args, **kwargs</span>): | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"I just entered the attention computation"</span>) | |
| <span class="hljs-keyword">return</span> sdpa_attention_forward(*args, **kwargs) | |
| AttentionInterface.register(<span class="hljs-string">"my_new_sdpa"</span>, my_new_sdpa) | |
| AttentionMaskInterface.register(<span class="hljs-string">"my_new_sdpa"</span>, sdpa_mask) <span class="hljs-comment"># must have the same name as the registered attention function</span> | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>, attn_implementation=<span class="hljs-string">"my_new_sdpa"</span>) | |
| model(torch.ones(<span class="hljs-number">1</span>, <span class="hljs-number">5</span>, dtype=<span class="hljs-built_in">int</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-i19ndg">You can also add new arguments to the attention function. Models supporting <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionInterface">AttentionInterface</a> propagate kwargs to attention layers and the attention function. Pass arguments as kwargs in the model’s forward function. Custom attention functions must follow this signature and return format.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-python "><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface | |
| <span class="hljs-keyword">from</span> transformers.integrations.sdpa_attention <span class="hljs-keyword">import</span> sdpa_attention_forward | |
| <span class="hljs-keyword">from</span> transformers.masking_utils <span class="hljs-keyword">import</span> sdpa_mask | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">custom_attention</span>(<span class="hljs-params"> | |
| module: torch.nn.Module, <span class="hljs-comment"># required arg</span> | |
| query: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| key: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| value: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor], <span class="hljs-comment"># required arg</span> | |
| a_new_kwargs = <span class="hljs-literal">None</span>, <span class="hljs-comment"># You can now add as many kwargs as you need</span> | |
| another_new_kwargs = <span class="hljs-literal">None</span>, <span class="hljs-comment"># You can now add as many kwargs as you need</span> | |
| **kwargs, <span class="hljs-comment"># You need to accept **kwargs as models will pass other args</span> | |
| </span>) -> <span class="hljs-built_in">tuple</span>[torch.Tensor, <span class="hljs-type">Optional</span>[torch.Tensor]] | |
| ... <span class="hljs-comment"># do your magic!</span> | |
| <span class="hljs-keyword">return</span> attn_output, attn_weights <span class="hljs-comment"># attn_weights are optional here</span> | |
| AttentionInterface.register(<span class="hljs-string">"custom"</span>, custom_attention) | |
| AttentionMaskInterface.register(<span class="hljs-string">"custom"</span>, sdpa_mask) <span class="hljs-comment"># to leave the existing mask untouched</span> | |
| model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=<span class="hljs-string">"custom"</span>) | |
| model(torch.ones(<span class="hljs-number">1</span>, <span class="hljs-number">5</span>, dtype=<span class="hljs-built_in">int</span>), a_new_kwargs=..., another_new_kwargs=...)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1j9bhmb">Check a model’s <a href="https://github.com/huggingface/transformers/tree/main/src/transformers/models" rel="nofollow">modeling code</a> to confirm what arguments and kwargs it sends to the attention function.</p> <h2 class="relative group"><a id="attentionmaskinterface" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attentionmaskinterface"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>AttentionMaskInterface</span></h2> <p data-svelte-h="svelte-1bnoum4"><a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionMaskInterface">AttentionMaskInterface</a> is the registry the <a href="#build-an-attention-mask"><code>create_*_mask</code></a> functions consult to convert a mask into the format the active attention backend expects. FlexAttention needs a <a href="https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#torch.nn.attention.flex_attention.BlockMask" rel="nofollow">BlockMask</a>, SDPA needs a 4D tensor, and FlashAttention needs the base 2D padding mask. Register a custom backend, or override the formatter for an existing one, with <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionInterface.register">AttentionMaskInterface.register()</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-python "><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AttentionMaskInterface | |
| <span class="hljs-keyword">from</span> transformers.masking_utils <span class="hljs-keyword">import</span> sdpa_mask | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">my_new_sdpa_mask</span>(<span class="hljs-params">*args, **kwargs</span>): | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"I just entered the attention mask computation"</span>) | |
| <span class="hljs-keyword">return</span> sdpa_mask(*args, **kwargs) | |
| AttentionMaskInterface.register(<span class="hljs-string">"my_new_sdpa_mask"</span>, my_new_sdpa_mask)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lzrdh2">Without a registered formatter for the active <code>attn_implementation</code>, mask creation is skipped and <code>attention_mask=None</code> passes to the attention layers.</p> <p data-svelte-h="svelte-131bxrv">Registered functions must match this signature.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-python "><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">custom_attention_mask</span>(<span class="hljs-params"> | |
| batch_size: <span class="hljs-built_in">int</span>, <span class="hljs-comment"># required arg</span> | |
| q_length: <span class="hljs-built_in">int</span>, <span class="hljs-comment"># required arg</span> | |
| kv_length: <span class="hljs-built_in">int</span>, <span class="hljs-comment"># required arg</span> | |
| q_offset: <span class="hljs-built_in">int</span> = <span class="hljs-number">0</span>, <span class="hljs-comment"># required arg</span> | |
| kv_offset: <span class="hljs-built_in">int</span> = <span class="hljs-number">0</span>, <span class="hljs-comment"># required arg</span> | |
| mask_function: <span class="hljs-type">Callable</span> = causal_mask_function, <span class="hljs-comment"># required arg</span> | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor] = <span class="hljs-literal">None</span>, <span class="hljs-comment"># required arg</span> | |
| **kwargs, <span class="hljs-comment"># a few additional args may be passed as kwargs, especially the model's config is always passed</span> | |
| </span>) -> <span class="hljs-type">Optional</span>[torch.Tensor]:<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-116pzew">The <code>mask_function</code> argument is a <code>Callable</code> that mimics PyTorch’s <a href="https://pytorch.org/blog/flexattention/" rel="nofollow">mask_mod</a> functions. It takes 4 indices <code>(batch_idx, head_idx, q_idx, kv_idx)</code> and returns a boolean indicating whether that position contributes to the attention computation. This is the same primitive shape used by <code>or_mask_function</code> and <code>and_mask_function</code> in <a href="#build-an-attention-mask">Build an attention mask</a>.</p> <blockquote class="tip" data-svelte-h="svelte-mfw371"><p>Use this <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py" rel="nofollow">workaround</a> for torch.export if <code>mask_function</code> fails to create a mask.</p></blockquote> <h2 class="relative group"><a id="build-an-attention-mask" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#build-an-attention-mask"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Build an attention mask</span></h2> <p data-svelte-h="svelte-t1i881">Build attention masks with the <code>create_*_mask</code> functions in <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/masking_utils.py#L894" rel="nofollow">transformers.masking_utils</a>. Each function reads the active attention backend from the model config, looks up the backend’s mask formatter in <a href="/docs/transformers/main/en/internal/modeling_utils#transformers.AttentionMaskInterface">AttentionMaskInterface</a>, and returns the format that backend expects. You don’t need to invert, expand, or cast the mask yourself.</p> <p data-svelte-h="svelte-1cn7xq9">Pick the function that matches the attention pattern.</p> <table data-svelte-h="svelte-1vx10ti"><thead><tr><th>function</th> <th>use case</th></tr></thead> <tbody><tr><td><code>create_causal_mask</code></td> <td>decoder-only models where each token attends to itself and earlier tokens</td></tr> <tr><td><code>create_bidirectional_mask</code></td> <td>encoder models, or cross-attention from a decoder to encoder states</td></tr> <tr><td><code>create_sliding_window_causal_mask</code></td> <td>decoder models with a sliding-window attention pattern</td></tr> <tr><td><code>create_chunked_causal_mask</code></td> <td>decoder models that chunk the sequence into fixed-size blocks</td></tr> <tr><td><code>create_bidirectional_sliding_window_mask</code></td> <td>encoder models with a sliding-window attention pattern</td></tr></tbody></table> <blockquote class="warning" data-svelte-h="svelte-xquzt5"><p>The legacy callable mask helpers - <code>get_extended_attention_mask</code>, <code>create_extended_attention_mask_for_decoder</code>, <code>invert_attention_mask</code> - emit a deprecation warning and will be removed in a future release. Use the <code>create_*_mask</code> functions instead.</p></blockquote> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">causal attention </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">encoder self-attention </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">cross-attention </div></div> <div class="language-select"><p data-svelte-h="svelte-146le8o">Call <code>create_causal_mask</code> inside a decoder forward pass. Pass the config, the input embeddings, the user-provided 2D <code>attention_mask</code>, and the cache. The function uses the embeddings to read the batch size, query length, dtype, and device, and uses the cache to compute the key length.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers.masking_utils <span class="hljs-keyword">import</span> create_causal_mask | |
| attention_mask = create_causal_mask( | |
| config=self.config, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| )<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-3li0th">Add extra constraints on top of the base mask with the <code>or_mask_function</code> and <code>and_mask_function</code> arguments. Use <code>or_mask_function</code> to let additional positions attend, and <code>and_mask_function</code> to restrict the base pattern further. Both follow the 4-index <code>mask_function</code> signature described in <a href="#attentionmaskinterface">AttentionMaskInterface</a>. They take <code>(batch_idx, head_idx, q_idx, kv_idx)</code> and return a boolean.</p> <blockquote class="warning" data-svelte-h="svelte-1p927k4"><p><code>or_mask_function</code> and <code>and_mask_function</code> can express any attention pattern, but they’re slower than the built-in patterns and are not compatible with ExecuTorch. The overhead is most noticeable on smaller models (~200M parameters), where mask creation takes a larger share of forward-pass time. Reach for them only when the standard <code>create_*_mask</code> functions can’t express what you need.</p></blockquote> <p data-svelte-h="svelte-okyvje">For example, overlay a function that returns <code>True</code> everywhere on a causal mask to turn it into a fully bidirectional one. The union with the causal pattern lets every token attend to every other token.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START -->mask_kwargs = { | |
| <span class="hljs-string">"config"</span>: self.config, | |
| <span class="hljs-string">"inputs_embeds"</span>: inputs_embeds, | |
| <span class="hljs-string">"attention_mask"</span>: attention_mask, | |
| <span class="hljs-string">"past_key_values"</span>: past_key_values, | |
| <span class="hljs-string">"position_ids"</span>: position_ids, | |
| <span class="hljs-string">"or_mask_function"</span>: <span class="hljs-keyword">lambda</span> *args: torch.tensor(<span class="hljs-literal">True</span>, dtype=torch.<span class="hljs-built_in">bool</span>), | |
| } | |
| attention_mask = create_causal_mask(**mask_kwargs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18pok1y">During generation, <a href="/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate">generate()</a> builds masks through <code>create_masks_for_generate</code>, which dispatches to the right <code>create_*_mask</code> based on the model config. Override it on a model class to plug in a custom masking strategy for generation.</p> <h2 class="relative group"><a id="bidirectional-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bidirectional-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Bidirectional attention</span></h2> <p data-svelte-h="svelte-wl6asd">Decoder-only models use causal (unidirectional) attention by default, where each token only attends to itself and previous tokens. Set <code>is_causal=False</code> to switch to bidirectional attention, where every token attends to every other token. This lets you use decoder-only models as text encoders, for example, to generate embeddings.</p> <blockquote class="note" data-svelte-h="svelte-qdfb43"><p>This only works for causal (decoder) models. It does not turn encoder models into decoder models.</p></blockquote> <p data-svelte-h="svelte-1e3c7s6">Set <code>is_causal=False</code> in the model config to make bidirectional attention the default for every forward pass.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModel, AutoConfig | |
| config = AutoConfig.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>) | |
| config.is_causal = <span class="hljs-literal">False</span> | |
| model = AutoModel.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>, config=config) | |
| <span class="hljs-comment"># all forward passes now use bidirectional attention</span> | |
| outputs = model(**inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6n6eoj">Pass <code>is_causal</code> in the forward call instead of the model config to switch between causal and bidirectional attention without loading the model twice. The kwarg temporarily overrides the config and is restored after the call.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class="language-py "><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModel | |
| model = AutoModel.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.2-1B"</span>) | |
| <span class="hljs-comment"># run with bidirectional attention</span> | |
| outputs = model(**inputs, is_causal=<span class="hljs-literal">False</span>) | |
| <span class="hljs-comment"># run with default causal attention</span> | |
| outputs = model(**inputs)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/attention_interface.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_rv114u = { | |
| assets: "/docs/transformers/main/en", | |
| base: "/docs/transformers/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/main/en/_app/immutable/entry/start.3d6cca8a.js"), | |
| import("/docs/transformers/main/en/_app/immutable/entry/app.b6ccab0d.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 9], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 62.4 kB
- Xet hash:
- 4188e1fc116ca444697b18091961b943eafaae3a48e4f0dff7af425e873224bb
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.