| import pytest |
| import torch |
| import torch.nn.functional as F |
| from sgl_kernel import tree_speculative_sampling_target_only |
|
|
| test_cases = [ |
| ( |
| 1, |
| 1, |
| [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18], |
| [[0, 3, 4, 5], [6, 10, 11, -1]], |
| [3, 2], |
| ), |
| ( |
| 0, |
| 0, |
| [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18], |
| [[0, 1, 2, -1], [6, 10, 11, -1]], |
| [2, 2], |
| ), |
| ] |
|
|
|
|
| @pytest.mark.parametrize( |
| "threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num", |
| test_cases, |
| ) |
| def test_tree_speculative_sampling_target_only( |
| threshold_single, |
| threshold_acc, |
| expected_predicts, |
| expected_accept_index, |
| expected_accept_token_num, |
| ): |
| """ |
| Tests the tree_speculative_sampling_target_only function using Pytest parameterization. |
| """ |
| device = "cuda" |
|
|
| candidates = torch.tensor( |
| [ |
| [0, 1, 2, 3, 4, 5], |
| [7, 8, 9, 10, 11, 12], |
| ], |
| dtype=torch.int64, |
| device=device, |
| ) |
| retrive_index = torch.tensor( |
| [ |
| [0, 1, 2, 3, 4, 5], |
| [6, 7, 8, 9, 10, 11], |
| ], |
| dtype=torch.int64, |
| device=device, |
| ) |
| retrive_next_token = torch.tensor( |
| [ |
| [1, 2, -1, 4, 5, -1], |
| [4, 2, 3, -1, 5, -1], |
| ], |
| dtype=torch.int64, |
| device=device, |
| ) |
| retrive_next_sibling = torch.tensor( |
| [ |
| [-1, 3, -1, -1, -1, -1], |
| [-1, -1, -1, -1, 1, -1], |
| ], |
| dtype=torch.int64, |
| device=device, |
| ) |
|
|
| target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device) |
| target_logits[0, 0, 3] = 10 |
| target_logits[0, 3, 4] = 10 |
| target_logits[0, 4, 5] = 10 |
| target_logits[1, 0, 11] = 10 |
| target_logits[1, 4, 12] = 10 |
|
|
| for i in range(target_logits.shape[0]): |
| for j in range(target_logits.shape[1]): |
| if torch.max(target_logits[i, j]) < 10: |
| target_logits[i, j, 18] = 10 |
|
|
| temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device) |
| bs, num_draft_tokens = candidates.shape |
| num_spec_step = len(expected_accept_index[0]) |
| predict_shape = (len(expected_predicts),) |
|
|
| predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device) |
| accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device) |
| accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device) |
|
|
| expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1) |
| target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) |
| draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device) |
| coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32) |
| coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32) |
|
|
| tree_speculative_sampling_target_only( |
| predicts=predicts, |
| accept_index=accept_index, |
| accept_token_num=accept_token_num, |
| candidates=candidates, |
| retrive_index=retrive_index, |
| retrive_next_token=retrive_next_token, |
| retrive_next_sibling=retrive_next_sibling, |
| uniform_samples=coins, |
| uniform_samples_for_final_sampling=coins_for_final_sampling, |
| target_probs=target_probs, |
| draft_probs=draft_probs, |
| threshold_single=threshold_single, |
| threshold_acc=threshold_acc, |
| deterministic=True, |
| ) |
|
|
| assert ( |
| predicts.tolist() == expected_predicts |
| ), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})" |
| assert ( |
| accept_index.tolist() == expected_accept_index |
| ), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})" |
| assert ( |
| accept_token_num.tolist() == expected_accept_token_num |
| ), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})" |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|