File size: 5,271 Bytes
4f2517b
 
 
1f8827e
4f2517b
1f8827e
4f2517b
 
1f8827e
4f2517b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f8827e
4f2517b
 
 
 
 
1f8827e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from torch import Tensor as T
from torch.nn.attention.flex_attention import (
    BlockMask,
    _mask_mod_signature,
    and_masks,
    create_block_mask,
    flex_attention,
    or_masks,
)

# ---------------------------------------------------------------------------
# Two compiled variants of flex_attention
# ---------------------------------------------------------------------------
# _decode:  fullgraph=True, static shapes.
#           Used for decode steps (S_q == 1) where shapes are fixed and
#           the call will be captured inside a CUDA graph.  fullgraph=True
#           avoids graph breaks that would corrupt the capture.
#
# _prefill: dynamic=True, symbolic shapes.
#           Used for prefill steps (S_q > 1) where the sequence length
#           varies per image.  dynamic=True lets one compiled graph handle
#           all lengths without recompilation.  Prefill is never inside a
#           CUDA graph, so symbolic shape guards are fine.
compiled_flex_attn_decode = torch.compile(flex_attention, fullgraph=True)
compiled_flex_attn_prefill = torch.compile(flex_attention, dynamic=True)


def offset_mask_mod(mask_mod: _mask_mod_signature, offset: int):
    """Get a mask mod function with an offset applied to the query positions."""

    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + offset, kv)

    return _mask_mod


def get_causal_mask_mod() -> _mask_mod_signature:
    """Causal mask that prevents attention to future tokens."""

    def _causal_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
        return q_idx >= kv_idx

    return _causal_mask


def get_document_mask_mod(batch: T, eos_id: int) -> _mask_mod_signature:
    """Creates a document mask that prevents attention across document boundaries.

    Args:
        batch: Input batch tensor with shape [b, s, h, d]
        eos_id: End-of-sequence token ID that marks document boundaries

    Returns:
        A mask modifier function that implements document-level masking.
    """
    # batch is [b, s, h, d] shape
    eos_mask = batch == eos_id
    eos_mask[:, -1] = True
    cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1)
    sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32)
    sequence_indices[:, 1:] = cumulative_mask[:, :-1]

    def document_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
        return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx]

    return document_mask


def get_non_left_pad_mask_mod(batch: T, pad_id: int) -> _mask_mod_signature:
    """Prevent model from attending to the left-padded token required for correct batch inference."""

    non_pad_mask_id = torch.cumsum(batch != pad_id, dim=1)

    # Left-most pad tokens have cumulative id == 0.
    def mask_mod(b, h, q_idx, kv_idx):
        return non_pad_mask_id[b, kv_idx] > 0

    return mask_mod


def get_image_prefix_mask_mod(
    batch: T, soi_id: int, eoi_id: int
) -> _mask_mod_signature:
    # batch is [b, s, h, d] shape
    soi_mask = batch == soi_id
    eoi_mask = batch == eoi_id
    acc_soi_mask = torch.cumsum(soi_mask, dim=1)
    acc_eoi_mask = torch.cumsum(eoi_mask, dim=1)
    # Get every tokens between two soi_id and eoi_id exclusive of eoi_id
    img_mask = (acc_soi_mask - acc_eoi_mask) > 0

    # Create a tensor that assigns each token to its image number
    # Each image starts with SOI token, so we can use acc_soi_mask to track image numbers
    img_indices = acc_soi_mask * img_mask

    def image_prefix_mask_mod(b, h, q_idx, kv_idx):
        # Check if both tokens are image tokens and belong to the same image
        is_img_tokens = img_mask[b, q_idx] & img_mask[b, kv_idx]
        is_same_image = img_indices[b, q_idx] == img_indices[b, kv_idx]
        return is_img_tokens & is_same_image

    return image_prefix_mask_mod


_compiled_create_block_mask = torch.compile(
    create_block_mask, dynamic=True
) # Note: can't use mode = 'reduce-overhead' here because it uses internal CUDA graph trees on private streams, causing manual capture to record empty graphs


@torch.inference_mode()
def create_attention_mask(*args, **kwargs) -> BlockMask:
    """
    NOTE: We compile this for performance/memory reasons in large masks. To reduce
    recompiles due to grad_mode flips, we always run mask creation under inference_mode.
    """
    return _compiled_create_block_mask(*args, **kwargs)


def create_batch_attention_mask(
    input_batch: T,
    *,
    pad_token_id: int,
    eos_token_id: int,
    soi_token_id: int,
    eoi_token_id: int,
    max_len: int | None = None,
) -> BlockMask:
    """Build the combined FlexAttention mask for the batch engine.

    Composes causal + document + non-left-pad + image-prefix masks.
    """
    B, S = input_batch.size()
    block_causal_mask_mod = and_masks(
        get_causal_mask_mod(),
        get_document_mask_mod(input_batch, eos_token_id),
        get_non_left_pad_mask_mod(input_batch, pad_token_id),
    )
    image_prefix_mask_mod = get_image_prefix_mask_mod(
        batch=input_batch,
        soi_id=soi_token_id,
        eoi_id=eoi_token_id,
    )
    mask_mod = or_masks(image_prefix_mask_mod, block_causal_mask_mod)
    max_len = max_len or S
    return create_attention_mask(mask_mod, B, None, max_len, max_len)