File size: 5,271 Bytes
4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import torch
from torch import Tensor as T
from torch.nn.attention.flex_attention import (
BlockMask,
_mask_mod_signature,
and_masks,
create_block_mask,
flex_attention,
or_masks,
)
# ---------------------------------------------------------------------------
# Two compiled variants of flex_attention
# ---------------------------------------------------------------------------
# _decode: fullgraph=True, static shapes.
# Used for decode steps (S_q == 1) where shapes are fixed and
# the call will be captured inside a CUDA graph. fullgraph=True
# avoids graph breaks that would corrupt the capture.
#
# _prefill: dynamic=True, symbolic shapes.
# Used for prefill steps (S_q > 1) where the sequence length
# varies per image. dynamic=True lets one compiled graph handle
# all lengths without recompilation. Prefill is never inside a
# CUDA graph, so symbolic shape guards are fine.
compiled_flex_attn_decode = torch.compile(flex_attention, fullgraph=True)
compiled_flex_attn_prefill = torch.compile(flex_attention, dynamic=True)
def offset_mask_mod(mask_mod: _mask_mod_signature, offset: int):
"""Get a mask mod function with an offset applied to the query positions."""
def _mask_mod(b, h, q, kv):
return mask_mod(b, h, q + offset, kv)
return _mask_mod
def get_causal_mask_mod() -> _mask_mod_signature:
"""Causal mask that prevents attention to future tokens."""
def _causal_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
return q_idx >= kv_idx
return _causal_mask
def get_document_mask_mod(batch: T, eos_id: int) -> _mask_mod_signature:
"""Creates a document mask that prevents attention across document boundaries.
Args:
batch: Input batch tensor with shape [b, s, h, d]
eos_id: End-of-sequence token ID that marks document boundaries
Returns:
A mask modifier function that implements document-level masking.
"""
# batch is [b, s, h, d] shape
eos_mask = batch == eos_id
eos_mask[:, -1] = True
cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1)
sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32)
sequence_indices[:, 1:] = cumulative_mask[:, :-1]
def document_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx]
return document_mask
def get_non_left_pad_mask_mod(batch: T, pad_id: int) -> _mask_mod_signature:
"""Prevent model from attending to the left-padded token required for correct batch inference."""
non_pad_mask_id = torch.cumsum(batch != pad_id, dim=1)
# Left-most pad tokens have cumulative id == 0.
def mask_mod(b, h, q_idx, kv_idx):
return non_pad_mask_id[b, kv_idx] > 0
return mask_mod
def get_image_prefix_mask_mod(
batch: T, soi_id: int, eoi_id: int
) -> _mask_mod_signature:
# batch is [b, s, h, d] shape
soi_mask = batch == soi_id
eoi_mask = batch == eoi_id
acc_soi_mask = torch.cumsum(soi_mask, dim=1)
acc_eoi_mask = torch.cumsum(eoi_mask, dim=1)
# Get every tokens between two soi_id and eoi_id exclusive of eoi_id
img_mask = (acc_soi_mask - acc_eoi_mask) > 0
# Create a tensor that assigns each token to its image number
# Each image starts with SOI token, so we can use acc_soi_mask to track image numbers
img_indices = acc_soi_mask * img_mask
def image_prefix_mask_mod(b, h, q_idx, kv_idx):
# Check if both tokens are image tokens and belong to the same image
is_img_tokens = img_mask[b, q_idx] & img_mask[b, kv_idx]
is_same_image = img_indices[b, q_idx] == img_indices[b, kv_idx]
return is_img_tokens & is_same_image
return image_prefix_mask_mod
_compiled_create_block_mask = torch.compile(
create_block_mask, dynamic=True
) # Note: can't use mode = 'reduce-overhead' here because it uses internal CUDA graph trees on private streams, causing manual capture to record empty graphs
@torch.inference_mode()
def create_attention_mask(*args, **kwargs) -> BlockMask:
"""
NOTE: We compile this for performance/memory reasons in large masks. To reduce
recompiles due to grad_mode flips, we always run mask creation under inference_mode.
"""
return _compiled_create_block_mask(*args, **kwargs)
def create_batch_attention_mask(
input_batch: T,
*,
pad_token_id: int,
eos_token_id: int,
soi_token_id: int,
eoi_token_id: int,
max_len: int | None = None,
) -> BlockMask:
"""Build the combined FlexAttention mask for the batch engine.
Composes causal + document + non-left-pad + image-prefix masks.
"""
B, S = input_batch.size()
block_causal_mask_mod = and_masks(
get_causal_mask_mod(),
get_document_mask_mod(input_batch, eos_token_id),
get_non_left_pad_mask_mod(input_batch, pad_token_id),
)
image_prefix_mask_mod = get_image_prefix_mask_mod(
batch=input_batch,
soi_id=soi_token_id,
eoi_id=eoi_token_id,
)
mask_mod = or_masks(image_prefix_mask_mod, block_causal_mask_mod)
max_len = max_len or S
return create_attention_mask(mask_mod, B, None, max_len, max_len)
|