Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Attention Interface","local":"attention-interface","sections":[{"title":"Customizing attention function","local":"customizing-attention-function","sections":[],"depth":2},{"title":"Dynamically switching attention function","local":"dynamically-switching-attention-function","sections":[],"depth":2},{"title":"Different attention per backbone in multimodal models","local":"different-attention-per-backbone-in-multimodal-models","sections":[],"depth":2},{"title":"What about new args needed in my custom attention function?","local":"what-about-new-args-needed-in-my-custom-attention-function","sections":[],"depth":2},{"title":"Accessing current available implementations","local":"accessing-current-available-implementations","sections":[],"depth":2},{"title":"Attention Mask Interface","local":"attention-mask-interface","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_33892/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/scheduler.31fdf58d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/singletons.9860629f.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.252883d5.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/paths.e85c0ec8.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/preload-helper.40847a0e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/index.2f76fdf0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/0.ca4aafa4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/nodes/7.c762cf5b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CopyLLMTxtMenu.ff482081.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.71f274cc.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/IconCopy.ac192424.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_33892/en/_app/immutable/chunks/CodeBlock.ab12f8e1.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Attention Interface","local":"attention-interface","sections":[{"title":"Customizing attention function","local":"customizing-attention-function","sections":[],"depth":2},{"title":"Dynamically switching attention function","local":"dynamically-switching-attention-function","sections":[],"depth":2},{"title":"Different attention per backbone in multimodal models","local":"different-attention-per-backbone-in-multimodal-models","sections":[],"depth":2},{"title":"What about new args needed in my custom attention function?","local":"what-about-new-args-needed-in-my-custom-attention-function","sections":[],"depth":2},{"title":"Accessing current available implementations","local":"accessing-current-available-implementations","sections":[],"depth":2},{"title":"Attention Mask Interface","local":"attention-mask-interface","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="attention-interface" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention-interface"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention Interface</span></h1> <p data-svelte-h="svelte-q2swd">This page describes how to use the <code>AttentionInterface</code> in order to register custom attention functions to use with | |
| supported models.</p> <h2 class="relative group"><a id="customizing-attention-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#customizing-attention-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Customizing attention function</span></h2> <p data-svelte-h="svelte-tb8yar">Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping. | |
| By default, we provide the implementation for <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow"><code>sdpa</code></a>, | |
| <a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow"><code>flash_attention_2</code></a> and <a href="https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention" rel="nofollow"><code>flex_attention</code></a> | |
| as well as <code>eager</code>, which is a simple matrix multiplication without any optimization on top.<br> | |
| This is the setting you can usually choose when instantiating a model:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model_id = <span class="hljs-string">"meta-llama/Llama-3.2-1B"</span> | |
| <span class="hljs-comment"># Here, using flash attention as an example</span> | |
| model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=<span class="hljs-string">"flash_attention_2"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6n784k">But what if you wanted to create your own attention function? Or simply play around with existing ones, adding | |
| a few statements here and there? You can now do so with the <code>AttentionInterface</code>! Here is an example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AttentionInterface | |
| <span class="hljs-keyword">from</span> transformers.integrations.sdpa_attention <span class="hljs-keyword">import</span> sdpa_attention_forward | |
| <span class="hljs-keyword">import</span> torch | |
| model_id = <span class="hljs-string">"meta-llama/Llama-3.2-1B"</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">my_new_sdpa</span>(<span class="hljs-params">*args, **kwargs</span>): | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"I just entered the attention computation"</span>) | |
| <span class="hljs-keyword">return</span> sdpa_attention_forward(*args, **kwargs) | |
| AttentionInterface.register(<span class="hljs-string">"my_new_sdpa"</span>, my_new_sdpa) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=<span class="hljs-string">"my_new_sdpa"</span>) | |
| <span class="hljs-comment"># Try running the forward with the new attention function</span> | |
| model(torch.ones(<span class="hljs-number">1</span>, <span class="hljs-number">5</span>, dtype=<span class="hljs-built_in">int</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-seluzu">You will see it prints “I just entered the attention computation” as many times as there are layers in the model (with this example, 16 times).</p> <h2 class="relative group"><a id="dynamically-switching-attention-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dynamically-switching-attention-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Dynamically switching attention function</span></h2> <p data-svelte-h="svelte-y4p893">You could dynamically change the model’s attention function as well:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Back to use original sdpa implementation</span> | |
| model.set_attn_implementation(<span class="hljs-string">"sdpa"</span>) | |
| model(torch.ones(<span class="hljs-number">1</span>, <span class="hljs-number">5</span>, dtype=<span class="hljs-built_in">int</span>))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fjssz0">and it will stop printing the statements, as it now uses the <code>sdpa</code> attention.<br> | |
| This allows to quickly change an attention function, without needing to reload the model!</p> <h2 class="relative group"><a id="different-attention-per-backbone-in-multimodal-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#different-attention-per-backbone-in-multimodal-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Different attention per backbone in multimodal models</span></h2> <p data-svelte-h="svelte-1ygkhpi">For multimodal models different attention functions may work better for each backbone module. For example, some vision backbones perform better in fp32, but are incompatible with FlashAttention. To continue using FlashAttention while keeping the vision encoder in fp32, create a dict and map each config to an attention implementation as shown below.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForImageTextToText | |
| model_id = <span class="hljs-string">"facebook/chameleon-7b"</span> | |
| attention_implementation_per_backbone = {<span class="hljs-string">"vision_config"</span>: <span class="hljs-string">"sdpa"</span>, <span class="hljs-string">"text_config"</span>: <span class="hljs-string">"flash_attention_2"</span>} | |
| model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=attention_implementation_per_backbone) | |
| <span class="hljs-comment"># <span class="hljs-doctag">NOTE:</span> keys in the attention implementation have to be the same as the sub-config names</span> | |
| <span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> attention_implementation_per_backbone: | |
| <span class="hljs-keyword">assert</span> key <span class="hljs-keyword">in</span> model.config.sub_configs, <span class="hljs-string">f"Invalid key in `attention_implementation`"</span> | |
| <span class="hljs-comment"># You can omit certain backbones - the default attention function (SDPA) will be used</span> | |
| <span class="hljs-comment"># This is equivalent to the previous example</span> | |
| model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={<span class="hljs-string">"text_config"</span>: <span class="hljs-string">"flash_attention_2"</span>}) | |
| <span class="hljs-comment"># Set the same attention implementation for all backbones with single string, same as in non-multimodal models</span> | |
| model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation=<span class="hljs-string">"eager"</span>) | |
| <span class="hljs-comment"># Alternatively use a dict with an empty key for global configuration</span> | |
| model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation={<span class="hljs-string">""</span>: <span class="hljs-string">"eager"</span>})<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="what-about-new-args-needed-in-my-custom-attention-function" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#what-about-new-args-needed-in-my-custom-attention-function"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>What about new args needed in my custom attention function?</span></h2> <p data-svelte-h="svelte-1mzkqc4">But indeed, what if the new function requires a new arg to be properly used? It’s no issue! Models supporting the | |
| <code>AttentionInterface</code> propagate kwargs all the way to the Attention layers, and to the used attention function. That way, | |
| you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model’s forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AttentionInterface | |
| <span class="hljs-keyword">from</span> transformers.integrations.sdpa_attention <span class="hljs-keyword">import</span> sdpa_attention_forward | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">custom_attention</span>(<span class="hljs-params"> | |
| module: torch.nn.Module, <span class="hljs-comment"># required arg</span> | |
| query: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| key: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| value: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor], <span class="hljs-comment"># required arg</span> | |
| a_new_kwargs = <span class="hljs-literal">None</span>, <span class="hljs-comment"># You can now add as many kwargs as you need</span> | |
| another_new_kwargs = <span class="hljs-literal">None</span>, <span class="hljs-comment"># You can now add as many kwargs as you need</span> | |
| **kwargs, <span class="hljs-comment"># You need to accept **kwargs as models will pass other args</span> | |
| </span>) -> <span class="hljs-built_in">tuple</span>[torch.Tensor, <span class="hljs-type">Optional</span>[torch.Tensor]] | |
| ... <span class="hljs-comment"># do your magic!</span> | |
| <span class="hljs-keyword">return</span> attn_output, attn_weights <span class="hljs-comment"># attn_weights are optional here</span> | |
| AttentionInterface.register(<span class="hljs-string">"custom"</span>, custom_attention) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=<span class="hljs-string">"custom"</span>) | |
| <span class="hljs-comment"># Forward pass with the new kwargs</span> | |
| model(torch.ones(<span class="hljs-number">1</span>, <span class="hljs-number">5</span>, dtype=<span class="hljs-built_in">int</span>), a_new_kwargs=..., another_new_kwargs=...)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-upb3ef">If in doubt about what args/kwargs a given model sends to the attention function, simply check that model’s modeling code on <a href="https://github.com/huggingface/transformers/tree/main/src/transformers/models" rel="nofollow">GitHub</a>!</p> <h2 class="relative group"><a id="accessing-current-available-implementations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accessing-current-available-implementations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accessing current available implementations</span></h2> <p data-svelte-h="svelte-7h0zcj">Most of the time, you will simply need to <code>register</code> a new function. If, however, you need to access an existing one, | |
| and/or perform a few checks, the preferred way is to use the global <code>ALL_ATTENTION_FUNCTIONS</code>. It behaves the same way you | |
| would expect from a usual Python dictionary:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">>>> </span><span class="hljs-keyword">from</span> transformers.modeling_utils <span class="hljs-keyword">import</span> ALL_ATTENTION_FUNCTIONS | |
| <span class="hljs-meta">>>> </span><span class="hljs-built_in">list</span>(ALL_ATTENTION_FUNCTIONS.keys()) | |
| <span class="hljs-meta">>>> </span>[<span class="hljs-string">'flash_attention_2'</span>, <span class="hljs-string">'flex_attention'</span>, <span class="hljs-string">'sdpa'</span>] | |
| <span class="hljs-meta">>>> </span>ALL_ATTENTION_FUNCTIONS[<span class="hljs-string">"sdpa"</span>] | |
| <span class="hljs-meta">>>> </span><function transformers.integrations.sdpa_attention.sdpa_attention_forward> | |
| <span class="hljs-meta">>>> </span>ALL_ATTENTION_FUNCTIONS.get(<span class="hljs-string">"sdpa"</span>, <span class="hljs-literal">None</span>) | |
| <span class="hljs-meta">>>> </span><function transformers.integrations.sdpa_attention.sdpa_attention_forward> | |
| <span class="hljs-comment"># You can also globally `register` a new function directly on it</span> | |
| <span class="hljs-meta">>>> </span>ALL_ATTENTION_FUNCTIONS.register(<span class="hljs-string">"new_func"</span>, new_func)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="attention-mask-interface" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#attention-mask-interface"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Attention Mask Interface</span></h2> <p data-svelte-h="svelte-15uniki">Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens | |
| the query tokens should attend to. This is now possible with the <code>AttentionMaskInterface</code>! It works in the same way as | |
| the <code>AttentionInterface</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AttentionMaskInterface | |
| <span class="hljs-keyword">from</span> transformers.masking_utils <span class="hljs-keyword">import</span> sdpa_mask | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">my_new_sdpa_mask</span>(<span class="hljs-params">*args, **kwargs</span>): | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"I just entered the attention mask computation"</span>) | |
| <span class="hljs-keyword">return</span> sdpa_mask(*args, **kwargs) | |
| AttentionMaskInterface.register(<span class="hljs-string">"my_new_sdpa_mask"</span>, my_new_sdpa_mask)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-7vvru9">The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor). | |
| By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped | |
| and <code>attention_mask=None</code> will be passed along to the Attention layers.</p> <p data-svelte-h="svelte-10po5zn">The default signature of the attention mask functions is the following:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">custom_attention_mask</span>(<span class="hljs-params"> | |
| batch_size: <span class="hljs-built_in">int</span>, <span class="hljs-comment"># required arg</span> | |
| cache_position: torch.Tensor, <span class="hljs-comment"># required arg</span> | |
| kv_length: <span class="hljs-built_in">int</span>, <span class="hljs-comment"># required arg</span> | |
| kv_offset: <span class="hljs-built_in">int</span> = <span class="hljs-number">0</span>, <span class="hljs-comment"># required arg</span> | |
| mask_function: <span class="hljs-type">Callable</span> = causal_mask_function, <span class="hljs-comment"># required arg</span> | |
| attention_mask: <span class="hljs-type">Optional</span>[torch.Tensor] = <span class="hljs-literal">None</span>, <span class="hljs-comment"># required arg</span> | |
| **kwargs, <span class="hljs-comment"># a few additional args may be passed as kwargs, especially the model's config is always passed</span> | |
| </span>) -> <span class="hljs-type">Optional</span>[torch.Tensor]:<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gqwmwq">It mostly works thanks to the <code>mask_function</code>, which is a <code>Callable</code> in the form of <a href="https://pytorch.org/blog/flexattention/" rel="nofollow">torch’s mask_mod functions</a>, taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.</p> <p data-svelte-h="svelte-ks3jto">If you cannot use the <code>mask_function</code> to create your mask for some reason, you can try to work around it by doing something similar to our <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py" rel="nofollow">torch export workaround</a>.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/attention_interface.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_16tnnm8 = { | |
| assets: "/docs/transformers/pr_33892/en", | |
| base: "/docs/transformers/pr_33892/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/start.b2c4257a.js"), | |
| import("/docs/transformers/pr_33892/en/_app/immutable/entry/app.05ef1f97.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 7], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 39.7 kB
- Xet hash:
- c71c79d9704940f0a199b356ed583e6871b1308edd96da1f5de720136987d18c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.