Transformers documentation
Attention backends
Attention backends
All attention implementations perform the same computation. Every token is compared to every other token. The difference is how the computation is performed. Basic attention scales poorly because it materializes the full attention matrix in memory, creating bottlenecks that slow down inference. Optimized implementations rearrange the math to reduce memory traffic for faster, more affordable inference.
The AttentionInterface provides optimized attention implementations. It decouples the attention implementation from the model implementation to simplify experimentation with different functions. Add new backends easily with this consistent interface.
| attention backend | description |
|---|---|
"flash_attention_3" | improves FlashAttention-2 by also overlapping operations and fusing forward and backward passes more tightly |
"flash_attention_2" | tiles computations into smaller blocks and uses fast on-chip memory |
"flex_attention" | framework for specifying custom attention patterns (sparse, block-local, sliding window) without writing low-level kernels by hand |
"sdpa" | built-in PyTorch implementation of scaled dot product attention |
“paged|flash_attention_3” | Paged version of FlashAttention-3 |
“paged|flash_attention_2” | Paged version of FlashAttention-2 |
“paged|sdpa” | Paged version of SDPA |
“paged|eager” | Paged version of eager |
Set an attention backend
Use the attn_implementation argument in from_pretrained() to instantiate a model with a specific attention function.
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_2"
)Switch between attention backends at runtime without reloading the model using set_attn_implementation().
model.set_attn_implementation("sdpa")Kernels
Download and load compiled compute kernels directly from the Hub at runtime with the Kernels library. This avoids packaging issues from mismatched PyTorch or CUDA versions.
Kernels automatically register to AttentionInterface upon detection. You don’t need to install the FlashAttention package explicitly.
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", attn_implementation="kernels-community/flash-attn2"
)SDPA context manager
PyTorch’s scaled dot product attention (SDPA) selects the fastest attention function for CUDA backends automatically. It defaults to the PyTorch C++ implementation for other backends.
Force SDPA to use a specific implementation with the torch.nn.attention.sdpa_kernel context manager.
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", attn_implementation="sdpa"
)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)Backbone-specific attention
Multimodal models use different backbones for each modality. Optimize performance by assigning specific attention functions to each backbone. Some vision backbones perform better in fp32, for example, which FlashAttention does not support.
Map vision backbones to different attention functions with a dict while the text backbone continues to use FlashAttention. Keys in the attention implementation must match sub-config names.
from transformers import AutoModelForImageTextToText
attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"}
for key in attention_implementation_per_backbone:
assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`"
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation=attention_implementation_per_backbone
)Omit certain backbones from the dict to use the default attention function (SDPA).
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation={"text_config": "flash_attention_2"}
)Set the same attention function for all backbones with a single string.
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation="eager"
)Set the attention function globally with an empty key.
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation={"": "eager"}
)Create a new attention function
Customize or create new attention functions by adding them to the attention registry with AttentionInterface.register(). Models use these functions through the attn_implementation argument.
Register a matching attention mask function when you register a custom attention function. If the customattn_implementationname is not registered in AttentionMaskInterface, Transformers skips mask creation and passesattention_mask=Noneto the attention layers. Your attention function must handle causal, padding, packing, or sliding-window constraints itself, or those constraints can be silently dropped.
This example customizes the attention function to print a statement for each layer. It keeps the mask in the original implementation by registering masking_utils.sdpa_mask as the attention mask function.
import torch
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from transformers.masking_utils import sdpa_mask
def my_new_sdpa(*args, **kwargs):
print("I just entered the attention computation")
return sdpa_attention_forward(*args, **kwargs)
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
AttentionMaskInterface.register("my_new_sdpa", sdpa_mask) # must have the same name as the registered attention function
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="my_new_sdpa")
model(torch.ones(1, 5, dtype=int))You can also add new arguments to the attention function. Models supporting AttentionInterface propagate kwargs to attention layers and the attention function. Pass arguments as kwargs in the model’s forward function. Custom attention functions must follow this signature and return format.
import torch
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from transformers.masking_utils import sdpa_mask
def custom_attention(
module: torch.nn.Module, # required arg
query: torch.Tensor, # required arg
key: torch.Tensor, # required arg
value: torch.Tensor, # required arg
attention_mask: Optional[torch.Tensor], # required arg
a_new_kwargs = None, # You can now add as many kwargs as you need
another_new_kwargs = None, # You can now add as many kwargs as you need
**kwargs, # You need to accept **kwargs as models will pass other args
) -> tuple[torch.Tensor, Optional[torch.Tensor]]
... # do your magic!
return attn_output, attn_weights # attn_weights are optional here
AttentionInterface.register("custom", custom_attention)
AttentionMaskInterface.register("custom", sdpa_mask) # to leave the existing mask untouched
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)Check a model’s modeling code to confirm what arguments and kwargs it sends to the attention function.
AttentionMaskInterface
AttentionMaskInterface is the registry the create_*_mask functions consult to convert a mask into the format the active attention backend expects. FlexAttention needs a BlockMask, SDPA needs a 4D tensor, and FlashAttention needs the base 2D padding mask. Register a custom backend, or override the formatter for an existing one, with AttentionMaskInterface.register().
import torch
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)Without a registered formatter for the active attn_implementation, mask creation is skipped and attention_mask=None passes to the attention layers.
Registered functions must match this signature.
def custom_attention_mask(
batch_size: int, # required arg
q_length: int, # required arg
kv_length: int, # required arg
q_offset: int = 0, # required arg
kv_offset: int = 0, # required arg
mask_function: Callable = causal_mask_function, # required arg
attention_mask: Optional[torch.Tensor] = None, # required arg
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:The mask_function argument is a Callable that mimics PyTorch’s mask_mod functions. It takes 4 indices (batch_idx, head_idx, q_idx, kv_idx) and returns a boolean indicating whether that position contributes to the attention computation. This is the same primitive shape used by or_mask_function and and_mask_function in Build an attention mask.
Use this workaround for torch.export if
mask_functionfails to create a mask.
Build an attention mask
Build attention masks with the create_*_mask functions in transformers.masking_utils. Each function reads the active attention backend from the model config, looks up the backend’s mask formatter in AttentionMaskInterface, and returns the format that backend expects. You don’t need to invert, expand, or cast the mask yourself.
Pick the function that matches the attention pattern.
| function | use case |
|---|---|
create_causal_mask | decoder-only models where each token attends to itself and earlier tokens |
create_bidirectional_mask | encoder models, or cross-attention from a decoder to encoder states |
create_sliding_window_causal_mask | decoder models with a sliding-window attention pattern |
create_chunked_causal_mask | decoder models that chunk the sequence into fixed-size blocks |
create_bidirectional_sliding_window_mask | encoder models with a sliding-window attention pattern |
The legacy callable mask helpers -
get_extended_attention_mask,create_extended_attention_mask_for_decoder,invert_attention_mask- emit a deprecation warning and will be removed in a future release. Use thecreate_*_maskfunctions instead.
Call create_causal_mask inside a decoder forward pass. Pass the config, the input embeddings, the user-provided 2D attention_mask, and the cache. The function uses the embeddings to read the batch size, query length, dtype, and device, and uses the cache to compute the key length.
from transformers.masking_utils import create_causal_mask
attention_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
)Add extra constraints on top of the base mask with the or_mask_function and and_mask_function arguments. Use or_mask_function to let additional positions attend, and and_mask_function to restrict the base pattern further. Both follow the 4-index mask_function signature described in AttentionMaskInterface. They take (batch_idx, head_idx, q_idx, kv_idx) and return a boolean.
or_mask_functionandand_mask_functioncan express any attention pattern, but they’re slower than the built-in patterns and are not compatible with ExecuTorch. The overhead is most noticeable on smaller models (~200M parameters), where mask creation takes a larger share of forward-pass time. Reach for them only when the standardcreate_*_maskfunctions can’t express what you need.
For example, overlay a function that returns True everywhere on a causal mask to turn it into a fully bidirectional one. The union with the causal pattern lets every token attend to every other token.
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
"or_mask_function": lambda *args: torch.tensor(True, dtype=torch.bool),
}
attention_mask = create_causal_mask(**mask_kwargs)During generation, generate() builds masks through create_masks_for_generate, which dispatches to the right create_*_mask based on the model config. Override it on a model class to plug in a custom masking strategy for generation.
Bidirectional attention
Decoder-only models use causal (unidirectional) attention by default, where each token only attends to itself and previous tokens. Set is_causal=False to switch to bidirectional attention, where every token attends to every other token. This lets you use decoder-only models as text encoders, for example, to generate embeddings.
This only works for causal (decoder) models. It does not turn encoder models into decoder models.
Set is_causal=False in the model config to make bidirectional attention the default for every forward pass.
from transformers import AutoModel, AutoConfig
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config.is_causal = False
model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", config=config)
# all forward passes now use bidirectional attention
outputs = model(**inputs)Pass is_causal in the forward call instead of the model config to switch between causal and bidirectional attention without loading the model twice. The kwarg temporarily overrides the config and is restored after the call.
from transformers import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B")
# run with bidirectional attention
outputs = model(**inputs, is_causal=False)
# run with default causal attention
outputs = model(**inputs)