| import torch |
|
|
|
|
| def tree_speculative_sampling_target_only( |
| predicts: torch.Tensor, |
| accept_index: torch.Tensor, |
| accept_token_num: torch.Tensor, |
| candidates: torch.Tensor, |
| retrive_index: torch.Tensor, |
| retrive_next_token: torch.Tensor, |
| retrive_next_sibling: torch.Tensor, |
| uniform_samples: torch.Tensor, |
| uniform_samples_for_final_sampling: torch.Tensor, |
| target_probs: torch.Tensor, |
| draft_probs: torch.Tensor, |
| threshold_single: float = 1.0, |
| threshold_acc: float = 1.0, |
| deterministic: bool = True, |
| ) -> None: |
| torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default( |
| predicts, |
| accept_index, |
| accept_token_num, |
| candidates, |
| retrive_index, |
| retrive_next_token, |
| retrive_next_sibling, |
| uniform_samples, |
| uniform_samples_for_final_sampling, |
| target_probs, |
| draft_probs, |
| threshold_single, |
| threshold_acc, |
| deterministic, |
| ) |
|
|
|
|
| def verify_tree_greedy( |
| predicts: torch.Tensor, |
| accept_index: torch.Tensor, |
| accept_token_num: torch.Tensor, |
| candidates: torch.Tensor, |
| retrive_index: torch.Tensor, |
| retrive_next_token: torch.Tensor, |
| retrive_next_sibling: torch.Tensor, |
| target_predict: torch.Tensor, |
| ) -> None: |
| torch.ops.sgl_kernel.verify_tree_greedy.default( |
| predicts, |
| accept_index, |
| accept_token_num, |
| candidates, |
| retrive_index, |
| retrive_next_token, |
| retrive_next_sibling, |
| target_predict, |
| ) |
|
|
|
|
| def build_tree_kernel_efficient( |
| parent_list: torch.Tensor, |
| selected_index: torch.Tensor, |
| verified_seq_len: torch.Tensor, |
| tree_mask: torch.Tensor, |
| positions: torch.Tensor, |
| retrive_index: torch.Tensor, |
| retrive_next_token: torch.Tensor, |
| retrive_next_sibling: torch.Tensor, |
| topk: int, |
| depth: int, |
| draft_token_num: int, |
| tree_mask_mode: int, |
| ) -> None: |
| torch.ops.sgl_kernel.build_tree_kernel_efficient.default( |
| parent_list, |
| selected_index, |
| verified_seq_len, |
| tree_mask, |
| positions, |
| retrive_index, |
| retrive_next_token, |
| retrive_next_sibling, |
| topk, |
| depth, |
| draft_token_num, |
| tree_mask_mode, |
| ) |
|
|
|
|
| def reconstruct_indices_from_tree_mask( |
| tree_mask: torch.Tensor, |
| verified_seq_len: torch.Tensor, |
| positions: torch.Tensor, |
| retrive_index: torch.Tensor, |
| retrive_next_token: torch.Tensor, |
| retrive_next_sibling: torch.Tensor, |
| batch_size: int, |
| draft_token_num: int, |
| ) -> None: |
| torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default( |
| tree_mask, |
| verified_seq_len, |
| positions, |
| retrive_index, |
| retrive_next_token, |
| retrive_next_sibling, |
| batch_size, |
| draft_token_num, |
| ) |
|
|
|
|
| def segment_packbits( |
| x: torch.Tensor, |
| input_indptr: torch.Tensor, |
| output_indptr: torch.Tensor, |
| y: torch.Tensor, |
| batch_size: int, |
| ) -> None: |
| torch.ops.sgl_kernel.segment_packbits.default( |
| x, |
| input_indptr, |
| output_indptr, |
| y, |
| batch_size, |
| torch.cuda.current_stream().cuda_stream, |
| ) |
|
|