File size: 10,969 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import unittest

import torch

from specforge.core.dflash import create_dflash_block_mask, create_dflash_sdpa_mask


def _reference_dflash_mask(anchor_positions, block_keep_mask, S, block_size, device):
    """Element-level reference mask mirroring the mask_mod inside create_dflash_block_mask.

    This uses plain Python loops so correctness is obvious by inspection.
    """
    B, N = anchor_positions.shape
    Q_LEN = N * block_size
    KV_LEN = S + N * block_size

    mask = torch.zeros(B, 1, Q_LEN, KV_LEN, dtype=torch.bool, device=device)
    for b in range(B):
        for q_idx in range(Q_LEN):
            q_block_id = q_idx // block_size
            anchor_pos = anchor_positions[b, q_block_id].item()
            is_valid = block_keep_mask[b, q_block_id].item()
            if not is_valid:
                continue
            for kv_idx in range(KV_LEN):
                is_context = kv_idx < S
                ctx_visible = is_context and (kv_idx < anchor_pos)

                is_draft = kv_idx >= S
                kv_block_id = (kv_idx - S) // block_size
                draft_visible = is_draft and (q_block_id == kv_block_id)

                if ctx_visible or draft_visible:
                    mask[b, 0, q_idx, kv_idx] = True
    return mask


class TestDFlashMask(unittest.TestCase):

    def setUp(self):
        torch.manual_seed(42)
        self.device = torch.device("cuda")

    def _compare_masks(self, anchor_positions, block_keep_mask, S, block_size):
        """Compare create_dflash_sdpa_mask against element-level reference (ground truth)."""
        anchor_positions = anchor_positions.to(self.device)
        block_keep_mask = block_keep_mask.to(self.device)

        sdpa_mask = create_dflash_sdpa_mask(
            anchor_positions=anchor_positions,
            block_keep_mask=block_keep_mask,
            S=S,
            block_size=block_size,
            device=self.device,
        )

        ref_mask = _reference_dflash_mask(
            anchor_positions=anchor_positions,
            block_keep_mask=block_keep_mask,
            S=S,
            block_size=block_size,
            device=self.device,
        )

        self.assertEqual(
            sdpa_mask.shape,
            ref_mask.shape,
            f"Shape mismatch: sdpa {sdpa_mask.shape} vs ref {ref_mask.shape}",
        )
        self.assertTrue(
            torch.equal(sdpa_mask, ref_mask),
            f"Mask mismatch with S={S}, block_size={block_size}, "
            f"anchors={anchor_positions.tolist()}, keep={block_keep_mask.tolist()}\n"
            f"Diff positions: {(sdpa_mask != ref_mask).nonzero(as_tuple=False).tolist()}",
        )

    def _compare_block_mask_consistency(
        self, anchor_positions, block_keep_mask, S, block_size
    ):
        """Verify create_dflash_block_mask block-level mask is consistent with reference."""
        anchor_positions = anchor_positions.to(self.device)
        block_keep_mask = block_keep_mask.to(self.device)

        block_mask = create_dflash_block_mask(
            anchor_positions=anchor_positions,
            block_keep_mask=block_keep_mask,
            S=S,
            block_size=block_size,
            device=self.device,
        )

        ref_mask = _reference_dflash_mask(
            anchor_positions=anchor_positions,
            block_keep_mask=block_keep_mask,
            S=S,
            block_size=block_size,
            device=self.device,
        )

        dense_blocks = block_mask.to_dense()  # (B, H, Q_blocks, KV_blocks)
        BM_BLOCK = 128
        B, N = anchor_positions.shape
        Q_LEN = N * block_size
        KV_LEN = S + N * block_size
        n_q_blocks = (Q_LEN + BM_BLOCK - 1) // BM_BLOCK
        n_kv_blocks = (KV_LEN + BM_BLOCK - 1) // BM_BLOCK

        ref_int = ref_mask.squeeze(1).int()  # (B, Q_LEN, KV_LEN)
        for b in range(B):
            for qi in range(n_q_blocks):
                for ki in range(n_kv_blocks):
                    q_start = qi * BM_BLOCK
                    q_end = min(q_start + BM_BLOCK, Q_LEN)
                    k_start = ki * BM_BLOCK
                    k_end = min(k_start + BM_BLOCK, KV_LEN)
                    has_nonzero = ref_int[b, q_start:q_end, k_start:k_end].any().item()
                    block_val = dense_blocks[b, 0, qi, ki].item()
                    if has_nonzero:
                        self.assertEqual(
                            block_val,
                            1,
                            f"Block ({qi},{ki}) for batch {b} should be 1 but got 0",
                        )

    def test_basic_single_batch_single_block(self):
        """Single batch, single draft block."""
        anchor_positions = torch.tensor([[64]])
        block_keep_mask = torch.tensor([[True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_basic_single_batch_multi_block(self):
        """Single batch, multiple draft blocks."""
        anchor_positions = torch.tensor([[32, 64, 96]])
        block_keep_mask = torch.tensor([[True, True, True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_multi_batch(self):
        """Multiple batches with different anchors."""
        anchor_positions = torch.tensor([[16, 48, 80], [32, 64, 100]])
        block_keep_mask = torch.tensor([[True, True, True], [True, True, True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_invalid_blocks(self):
        """Some blocks are masked out (block_keep_mask=False)."""
        anchor_positions = torch.tensor([[20, 50, 80, 110]])
        block_keep_mask = torch.tensor([[True, False, True, False]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_all_blocks_invalid(self):
        """All blocks invalid — mask should be all zeros."""
        anchor_positions = torch.tensor([[30, 60]])
        block_keep_mask = torch.tensor([[False, False]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_anchor_at_zero(self):
        """Anchor at position 0 — no context tokens visible."""
        anchor_positions = torch.tensor([[0, 64]])
        block_keep_mask = torch.tensor([[True, True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_anchor_at_boundary(self):
        """Anchor exactly at S — all context tokens visible."""
        anchor_positions = torch.tensor([[128]])
        block_keep_mask = torch.tensor([[True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4)

    def test_large_block_size(self):
        """Larger draft block size."""
        anchor_positions = torch.tensor([[50, 150]])
        block_keep_mask = torch.tensor([[True, True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=256, block_size=16)

    def test_block_size_1(self):
        """Minimal block_size=1."""
        anchor_positions = torch.tensor([[10, 30, 50]])
        block_keep_mask = torch.tensor([[True, True, True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=64, block_size=1)

    def test_mixed_validity_multi_batch(self):
        """Multi-batch with mixed block validity patterns."""
        anchor_positions = torch.tensor([[10, 40, 70, 100], [20, 50, 80, 110]])
        block_keep_mask = torch.tensor(
            [[True, False, True, True], [False, True, False, True]]
        )
        self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=8)

    def test_various_context_lengths(self):
        """Sweep over various context lengths."""
        for S in [64, 128, 256, 512]:
            with self.subTest(S=S):
                anchor_positions = torch.tensor([[S // 4, S // 2, 3 * S // 4]])
                block_keep_mask = torch.tensor([[True, True, True]])
                self._compare_masks(
                    anchor_positions, block_keep_mask, S=S, block_size=4
                )

    def test_various_block_sizes(self):
        """Sweep over various draft block sizes."""
        for block_size in [1, 2, 4, 8, 16]:
            with self.subTest(block_size=block_size):
                anchor_positions = torch.tensor([[32, 80]])
                block_keep_mask = torch.tensor([[True, True]])
                self._compare_masks(
                    anchor_positions, block_keep_mask, S=128, block_size=block_size
                )

    def test_many_blocks(self):
        """Large number of draft blocks."""
        N = 32
        anchors = torch.arange(10, 10 + N * 4, 4).unsqueeze(0)
        keep = torch.ones(1, N, dtype=torch.bool)
        keep[0, ::3] = False
        self._compare_masks(anchors, keep, S=256, block_size=4)

    def test_consecutive_anchors(self):
        """Anchors placed consecutively."""
        anchor_positions = torch.tensor([[0, 1, 2, 3]])
        block_keep_mask = torch.tensor([[True, True, True, True]])
        self._compare_masks(anchor_positions, block_keep_mask, S=64, block_size=4)

    def test_random_stress(self):
        """Randomized stress test with multiple random configurations."""
        rng = torch.Generator().manual_seed(123)
        for trial in range(5):
            with self.subTest(trial=trial):
                B = torch.randint(1, 4, (1,), generator=rng).item()
                N = torch.randint(1, 8, (1,), generator=rng).item()
                S = 64 * torch.randint(1, 5, (1,), generator=rng).item()
                block_size = [1, 2, 4, 8][
                    torch.randint(0, 4, (1,), generator=rng).item()
                ]

                anchor_positions = torch.stack(
                    [
                        torch.randperm(S, generator=rng)[:N].sort().values
                        for _ in range(B)
                    ]
                )
                block_keep_mask = torch.rand(B, N, generator=rng) > 0.3

                self._compare_masks(
                    anchor_positions, block_keep_mask, S=S, block_size=block_size
                )

    def test_block_mask_consistency(self):
        """Verify BlockMask block-level mask is consistent with element-level reference."""
        anchor_positions = torch.tensor([[32, 64, 96]])
        block_keep_mask = torch.tensor([[True, True, True]])
        self._compare_block_mask_consistency(
            anchor_positions, block_keep_mask, S=128, block_size=4
        )

    def test_block_mask_consistency_mixed(self):
        """Verify BlockMask consistency with mixed validity."""
        anchor_positions = torch.tensor([[10, 40, 70, 100], [20, 50, 80, 110]])
        block_keep_mask = torch.tensor(
            [[True, False, True, True], [False, True, False, True]]
        )
        self._compare_block_mask_consistency(
            anchor_positions, block_keep_mask, S=128, block_size=8
        )


if __name__ == "__main__":
    unittest.main(verbosity=2)