| import pytest |
| import torch |
| import torch.nn.functional as F |
| from sgl_kernel import reconstruct_indices_from_tree_mask |
|
|
|
|
| def test_reconstruct_indices_from_tree_mask(): |
| bs = 1 |
| num_branch_token = 4 |
| seq_lens = torch.tensor([12], device="cuda", dtype=torch.int64) |
|
|
| retrive_index = torch.full( |
| (bs, num_branch_token), -1, device="cuda", dtype=torch.int64 |
| ) |
| retrive_next_token = torch.full( |
| (bs, num_branch_token), -1, device="cuda", dtype=torch.int64 |
| ) |
| retrive_next_sibling = torch.full( |
| (bs, num_branch_token), -1, device="cuda", dtype=torch.int64 |
| ) |
| positions = torch.empty((bs * num_branch_token), device="cuda", dtype=torch.int64) |
|
|
| tree_mask = torch.tensor( |
| [ |
| 1, |
| 0, |
| 0, |
| 0, |
| 1, |
| 1, |
| 0, |
| 0, |
| 1, |
| 0, |
| 1, |
| 0, |
| 1, |
| 0, |
| 1, |
| 1, |
| ], |
| device="cuda", |
| dtype=torch.int32, |
| ).to(torch.bool) |
|
|
| reconstruct_indices_from_tree_mask( |
| tree_mask, |
| seq_lens, |
| positions, |
| retrive_index, |
| retrive_next_token, |
| retrive_next_sibling, |
| bs, |
| num_branch_token, |
| ) |
| |
| assert retrive_index.tolist() == [ |
| [0, 1, 2, 3], |
| ], f"{retrive_index=}" |
| assert retrive_next_token.tolist() == [ |
| [1, -1, 3, -1], |
| ], f"{retrive_next_token=}" |
| assert retrive_next_sibling.tolist() == [ |
| [-1, 2, -1, -1], |
| ], f"{retrive_next_sibling=}" |
| assert positions.tolist() == [ |
| 12, |
| 13, |
| 13, |
| 14, |
| ], f"{positions=}" |
|
|
|
|
| if __name__ == "__main__": |
| test_reconstruct_indices_from_tree_mask() |
| pytest.main([__file__]) |
|
|