| import unittest |
|
|
| import torch |
|
|
| from specforge.core.dflash import create_dflash_block_mask, create_dflash_sdpa_mask |
|
|
|
|
| def _reference_dflash_mask(anchor_positions, block_keep_mask, S, block_size, device): |
| """Element-level reference mask mirroring the mask_mod inside create_dflash_block_mask. |
| |
| This uses plain Python loops so correctness is obvious by inspection. |
| """ |
| B, N = anchor_positions.shape |
| Q_LEN = N * block_size |
| KV_LEN = S + N * block_size |
|
|
| mask = torch.zeros(B, 1, Q_LEN, KV_LEN, dtype=torch.bool, device=device) |
| for b in range(B): |
| for q_idx in range(Q_LEN): |
| q_block_id = q_idx // block_size |
| anchor_pos = anchor_positions[b, q_block_id].item() |
| is_valid = block_keep_mask[b, q_block_id].item() |
| if not is_valid: |
| continue |
| for kv_idx in range(KV_LEN): |
| is_context = kv_idx < S |
| ctx_visible = is_context and (kv_idx < anchor_pos) |
|
|
| is_draft = kv_idx >= S |
| kv_block_id = (kv_idx - S) // block_size |
| draft_visible = is_draft and (q_block_id == kv_block_id) |
|
|
| if ctx_visible or draft_visible: |
| mask[b, 0, q_idx, kv_idx] = True |
| return mask |
|
|
|
|
| class TestDFlashMask(unittest.TestCase): |
|
|
| def setUp(self): |
| torch.manual_seed(42) |
| self.device = torch.device("cuda") |
|
|
| def _compare_masks(self, anchor_positions, block_keep_mask, S, block_size): |
| """Compare create_dflash_sdpa_mask against element-level reference (ground truth).""" |
| anchor_positions = anchor_positions.to(self.device) |
| block_keep_mask = block_keep_mask.to(self.device) |
|
|
| sdpa_mask = create_dflash_sdpa_mask( |
| anchor_positions=anchor_positions, |
| block_keep_mask=block_keep_mask, |
| S=S, |
| block_size=block_size, |
| device=self.device, |
| ) |
|
|
| ref_mask = _reference_dflash_mask( |
| anchor_positions=anchor_positions, |
| block_keep_mask=block_keep_mask, |
| S=S, |
| block_size=block_size, |
| device=self.device, |
| ) |
|
|
| self.assertEqual( |
| sdpa_mask.shape, |
| ref_mask.shape, |
| f"Shape mismatch: sdpa {sdpa_mask.shape} vs ref {ref_mask.shape}", |
| ) |
| self.assertTrue( |
| torch.equal(sdpa_mask, ref_mask), |
| f"Mask mismatch with S={S}, block_size={block_size}, " |
| f"anchors={anchor_positions.tolist()}, keep={block_keep_mask.tolist()}\n" |
| f"Diff positions: {(sdpa_mask != ref_mask).nonzero(as_tuple=False).tolist()}", |
| ) |
|
|
| def _compare_block_mask_consistency( |
| self, anchor_positions, block_keep_mask, S, block_size |
| ): |
| """Verify create_dflash_block_mask block-level mask is consistent with reference.""" |
| anchor_positions = anchor_positions.to(self.device) |
| block_keep_mask = block_keep_mask.to(self.device) |
|
|
| block_mask = create_dflash_block_mask( |
| anchor_positions=anchor_positions, |
| block_keep_mask=block_keep_mask, |
| S=S, |
| block_size=block_size, |
| device=self.device, |
| ) |
|
|
| ref_mask = _reference_dflash_mask( |
| anchor_positions=anchor_positions, |
| block_keep_mask=block_keep_mask, |
| S=S, |
| block_size=block_size, |
| device=self.device, |
| ) |
|
|
| dense_blocks = block_mask.to_dense() |
| BM_BLOCK = 128 |
| B, N = anchor_positions.shape |
| Q_LEN = N * block_size |
| KV_LEN = S + N * block_size |
| n_q_blocks = (Q_LEN + BM_BLOCK - 1) // BM_BLOCK |
| n_kv_blocks = (KV_LEN + BM_BLOCK - 1) // BM_BLOCK |
|
|
| ref_int = ref_mask.squeeze(1).int() |
| for b in range(B): |
| for qi in range(n_q_blocks): |
| for ki in range(n_kv_blocks): |
| q_start = qi * BM_BLOCK |
| q_end = min(q_start + BM_BLOCK, Q_LEN) |
| k_start = ki * BM_BLOCK |
| k_end = min(k_start + BM_BLOCK, KV_LEN) |
| has_nonzero = ref_int[b, q_start:q_end, k_start:k_end].any().item() |
| block_val = dense_blocks[b, 0, qi, ki].item() |
| if has_nonzero: |
| self.assertEqual( |
| block_val, |
| 1, |
| f"Block ({qi},{ki}) for batch {b} should be 1 but got 0", |
| ) |
|
|
| def test_basic_single_batch_single_block(self): |
| """Single batch, single draft block.""" |
| anchor_positions = torch.tensor([[64]]) |
| block_keep_mask = torch.tensor([[True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_basic_single_batch_multi_block(self): |
| """Single batch, multiple draft blocks.""" |
| anchor_positions = torch.tensor([[32, 64, 96]]) |
| block_keep_mask = torch.tensor([[True, True, True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_multi_batch(self): |
| """Multiple batches with different anchors.""" |
| anchor_positions = torch.tensor([[16, 48, 80], [32, 64, 100]]) |
| block_keep_mask = torch.tensor([[True, True, True], [True, True, True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_invalid_blocks(self): |
| """Some blocks are masked out (block_keep_mask=False).""" |
| anchor_positions = torch.tensor([[20, 50, 80, 110]]) |
| block_keep_mask = torch.tensor([[True, False, True, False]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_all_blocks_invalid(self): |
| """All blocks invalid — mask should be all zeros.""" |
| anchor_positions = torch.tensor([[30, 60]]) |
| block_keep_mask = torch.tensor([[False, False]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_anchor_at_zero(self): |
| """Anchor at position 0 — no context tokens visible.""" |
| anchor_positions = torch.tensor([[0, 64]]) |
| block_keep_mask = torch.tensor([[True, True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_anchor_at_boundary(self): |
| """Anchor exactly at S — all context tokens visible.""" |
| anchor_positions = torch.tensor([[128]]) |
| block_keep_mask = torch.tensor([[True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=4) |
|
|
| def test_large_block_size(self): |
| """Larger draft block size.""" |
| anchor_positions = torch.tensor([[50, 150]]) |
| block_keep_mask = torch.tensor([[True, True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=256, block_size=16) |
|
|
| def test_block_size_1(self): |
| """Minimal block_size=1.""" |
| anchor_positions = torch.tensor([[10, 30, 50]]) |
| block_keep_mask = torch.tensor([[True, True, True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=64, block_size=1) |
|
|
| def test_mixed_validity_multi_batch(self): |
| """Multi-batch with mixed block validity patterns.""" |
| anchor_positions = torch.tensor([[10, 40, 70, 100], [20, 50, 80, 110]]) |
| block_keep_mask = torch.tensor( |
| [[True, False, True, True], [False, True, False, True]] |
| ) |
| self._compare_masks(anchor_positions, block_keep_mask, S=128, block_size=8) |
|
|
| def test_various_context_lengths(self): |
| """Sweep over various context lengths.""" |
| for S in [64, 128, 256, 512]: |
| with self.subTest(S=S): |
| anchor_positions = torch.tensor([[S // 4, S // 2, 3 * S // 4]]) |
| block_keep_mask = torch.tensor([[True, True, True]]) |
| self._compare_masks( |
| anchor_positions, block_keep_mask, S=S, block_size=4 |
| ) |
|
|
| def test_various_block_sizes(self): |
| """Sweep over various draft block sizes.""" |
| for block_size in [1, 2, 4, 8, 16]: |
| with self.subTest(block_size=block_size): |
| anchor_positions = torch.tensor([[32, 80]]) |
| block_keep_mask = torch.tensor([[True, True]]) |
| self._compare_masks( |
| anchor_positions, block_keep_mask, S=128, block_size=block_size |
| ) |
|
|
| def test_many_blocks(self): |
| """Large number of draft blocks.""" |
| N = 32 |
| anchors = torch.arange(10, 10 + N * 4, 4).unsqueeze(0) |
| keep = torch.ones(1, N, dtype=torch.bool) |
| keep[0, ::3] = False |
| self._compare_masks(anchors, keep, S=256, block_size=4) |
|
|
| def test_consecutive_anchors(self): |
| """Anchors placed consecutively.""" |
| anchor_positions = torch.tensor([[0, 1, 2, 3]]) |
| block_keep_mask = torch.tensor([[True, True, True, True]]) |
| self._compare_masks(anchor_positions, block_keep_mask, S=64, block_size=4) |
|
|
| def test_random_stress(self): |
| """Randomized stress test with multiple random configurations.""" |
| rng = torch.Generator().manual_seed(123) |
| for trial in range(5): |
| with self.subTest(trial=trial): |
| B = torch.randint(1, 4, (1,), generator=rng).item() |
| N = torch.randint(1, 8, (1,), generator=rng).item() |
| S = 64 * torch.randint(1, 5, (1,), generator=rng).item() |
| block_size = [1, 2, 4, 8][ |
| torch.randint(0, 4, (1,), generator=rng).item() |
| ] |
|
|
| anchor_positions = torch.stack( |
| [ |
| torch.randperm(S, generator=rng)[:N].sort().values |
| for _ in range(B) |
| ] |
| ) |
| block_keep_mask = torch.rand(B, N, generator=rng) > 0.3 |
|
|
| self._compare_masks( |
| anchor_positions, block_keep_mask, S=S, block_size=block_size |
| ) |
|
|
| def test_block_mask_consistency(self): |
| """Verify BlockMask block-level mask is consistent with element-level reference.""" |
| anchor_positions = torch.tensor([[32, 64, 96]]) |
| block_keep_mask = torch.tensor([[True, True, True]]) |
| self._compare_block_mask_consistency( |
| anchor_positions, block_keep_mask, S=128, block_size=4 |
| ) |
|
|
| def test_block_mask_consistency_mixed(self): |
| """Verify BlockMask consistency with mixed validity.""" |
| anchor_positions = torch.tensor([[10, 40, 70, 100], [20, 50, 80, 110]]) |
| block_keep_mask = torch.tensor( |
| [[True, False, True, True], [False, True, False, True]] |
| ) |
| self._compare_block_mask_consistency( |
| anchor_positions, block_keep_mask, S=128, block_size=8 |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main(verbosity=2) |
|
|