| import pytest |
| import torch |
| import torch.nn.functional as F |
| from sgl_kernel import verify_tree_greedy |
|
|
|
|
| def test_verify_tree_greedy(): |
| candidates = torch.tensor( |
| [ |
| [0, 1, 2, 3, 4, 5], |
| [7, 8, 9, 10, 11, 12], |
| ], |
| dtype=torch.int64, |
| device="cuda", |
| ) |
| retrive_index = torch.tensor( |
| [ |
| [0, 1, 2, 3, 4, 5], |
| [6, 7, 8, 9, 10, 11], |
| ], |
| dtype=torch.int64, |
| device="cuda", |
| ) |
| retrive_next_token = torch.tensor( |
| [ |
| [1, 2, -1, 4, 5, -1], |
| [4, 2, 3, -1, 5, -1], |
| ], |
| dtype=torch.int64, |
| device="cuda", |
| ) |
| retrive_next_sibling = torch.tensor( |
| [ |
| [-1, 3, -1, -1, -1, -1], |
| [-1, -1, -1, -1, 1, -1], |
| ], |
| dtype=torch.int64, |
| device="cuda", |
| ) |
|
|
| target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda") |
| target_logits[0, 0, 3] = 10 |
| target_logits[0, 3, 4] = 10 |
| target_logits[0, 4, 5] = 10 |
| target_logits[1, 0, 11] = 10 |
| target_logits[1, 4, 12] = 10 |
| for i in range(target_logits.shape[0]): |
| for j in range(target_logits.shape[1]): |
| if torch.max(target_logits[i][j]) < 10: |
| target_logits[i][j][18] = 10 |
|
|
| target_predict = torch.argmax(target_logits, dim=-1) |
| predict_shape = (12,) |
|
|
| bs = candidates.shape[0] |
| num_spec_step = 4 |
|
|
| predicts = torch.full( |
| predict_shape, -1, dtype=torch.int32, device="cuda" |
| ) |
| accept_index = torch.full( |
| (bs, num_spec_step), -1, dtype=torch.int32, device="cuda" |
| ) |
| accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") |
|
|
| verify_tree_greedy( |
| predicts=predicts, |
| accept_index=accept_index, |
| accept_token_num=accept_token_num, |
| candidates=candidates, |
| retrive_index=retrive_index, |
| retrive_next_token=retrive_next_token, |
| retrive_next_sibling=retrive_next_sibling, |
| target_predict=target_predict, |
| ) |
|
|
| |
| assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] |
| assert accept_index.tolist() == [ |
| [0, 3, 4, 5], |
| [6, 10, 11, -1], |
| ] |
| assert accept_token_num.tolist() == [3, 2] |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|