File size: 13,705 Bytes
6f0b660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
"""
Partially inspired by torchtune's flex attention implementation
Citation:
@software{torchtune,
title = {torchtune: PyTorch's finetuning library},
author = {torchtune maintainers and contributors},
url = {https//github.com/pytorch/torchtune},
license = {BSD-3-Clause},
month = apr,
year = {2024}
}
"""
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from packaging import version
from ..utils import is_torch_flex_attn_available, logging
from ..utils.import_utils import _torch_version, is_torch_less_or_equal, is_torchdynamo_compiling
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
logger = logging.get_logger(__name__)
class WrappedFlexAttention:
"""
We are doing a singleton class so that flex attention is compiled once when it's first called.
"""
_instance = None
_is_flex_compiled = False
_compiled_flex_attention = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Create a new instance if one doesn't already exist
cls._instance = super().__new__(cls)
return cls._instance
@torch.compiler.disable(recursive=False)
def __init__(self, training):
"""
Initialize or update the singleton instance.
"""
if not self._is_flex_compiled or training != self.training:
self.training = training
if is_torch_less_or_equal("2.5.1"):
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
elif version.parse(_torch_version).base_version == "2.6.0" and training:
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
# Fallback, usually the most recent torch 2.7.x+ versions
else:
self._compiled_flex_attention = torch.compile(flex_attention)
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
def compile_friendly_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
training=False,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
# Do not use compiled version if already compiling forward (it raises issues)
flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention
return flex_attention_compiled(
query,
key,
value,
**kwargs,
)
Offset = Union[torch.Tensor, int]
# TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore
def make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[tuple[Offset, Offset]] = None,
is_causal: Optional[bool] = True,
) -> "BlockMask":
"""
IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
and will be removed in a future version without warnings. New code should not use it. It is only kept here
for BC for now, while models using it are being patched accordingly.
Create a block (causal) document mask for a batch of sequences, both packed and unpacked.
Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full (causal) block
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
# older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
pad_len = ((key_length // flex_default_block_size) + 1) * flex_default_block_size
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, pad_len - key_length))
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Utilizes default attention mask to enable encoder and encoder-decoder
attention masks.
"""
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
# kv indexing is crucial in order to work correctly
padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0
final_mask = padding_mask & document_mask
return final_mask
if not is_causal:
mask_mod_maybe_combined = default_mask_mod
else:
mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
if offsets is not None:
q_offset = offsets[0].to(device)
kv_offset = offsets[1].to(device)
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
return create_block_mask(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
# compiling the mask is not BC with older torch
_compile=not is_torch_less_or_equal("2.5.1"),
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"],
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
s_aux: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if head_mask is not None:
logger.warning_once(
"`flex_attention` does not support `head_mask`. Please set your attention to `eager` if you want this feature."
)
if kwargs.get("dropout", 0.0) > 0:
raise ValueError(
"`flex_attention` does not support `dropout`. Please use it with inference"
" only (`model.eval()`) or turn off the attention dropout in the respective config."
)
block_mask = None
score_mask = None
if isinstance(attention_mask, BlockMask):
block_mask = attention_mask
else:
score_mask = attention_mask
if score_mask is not None:
score_mask = score_mask[:, :, :, : key.shape[-2]]
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if score_mask is not None:
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[batch_idx][head_idx][0][0]
# Note: attention sinks cannot be correctly implemented in score_mod
# because it requires operating on the full attention matrix before softmax.
# ==> this is done after flex attention
return score
enable_gqa = True
num_local_query_heads = query.shape[1]
# When running TP this helps:
if (num_local_query_heads & (num_local_query_heads - 1)) != 0:
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
kernel_options = kwargs.get("kernel_options")
# On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
return_lse = query.device.type != "cpu"
if not return_lse and s_aux is not None:
raise ValueError(
"Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA"
)
flex_attention_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=return_lse,
training=module.training,
)
# lse is returned in float32
if return_lse:
attention_output, lse = flex_attention_output # type: ignore[misc]
lse = lse.to(value.dtype)
if s_aux is not None:
# Apply attention sinks by renormalizing using LSE
batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim
sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1)
# We need to compute the normalization that includes the sinks
# since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse)
# NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink))
lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True)
# Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply
renorm_factor = torch.exp(lse_expanded - combined_lse)
attention_output = attention_output * renorm_factor
else:
attention_output = flex_attention_output # type: ignore[assignment]
lse = None
attention_output = attention_output.transpose(1, 2).contiguous()
return attention_output, lse
|