| from typing import Optional, Union |
|
|
| import torch |
| from sgl_kernel.utils import _to_tensor_scalar_tuple |
|
|
| try: |
| import flashinfer.sampling as _flashinfer_sampling |
|
|
| _has_flashinfer = True |
| except ImportError: |
| _has_flashinfer = False |
|
|
|
|
| def _top_k_renorm_probs_internal( |
| probs: torch.Tensor, |
| maybe_top_k_arr: Optional[torch.Tensor], |
| top_k_val: int, |
| ) -> torch.Tensor: |
| probs = probs.float() |
| maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None |
| renorm_probs = torch.empty_like(probs) |
| torch.ops.sgl_kernel.top_k_renorm_probs.default( |
| probs, renorm_probs, maybe_top_k_arr, top_k_val |
| ) |
| return renorm_probs |
|
|
|
|
| def top_k_renorm_probs( |
| probs: torch.Tensor, |
| top_k: Union[torch.Tensor, int], |
| ) -> torch.Tensor: |
| r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py |
| Fused GPU kernel for renormalizing probabilities by top-k thresholding. |
| |
| Parameters |
| ---------- |
| probs: torch.Tensor |
| Probabilities, shape ``(batch_size, num_classes)``. |
| top_k: Union[torch.Tensor, int] |
| Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for |
| for re-normalizing probabilities, should be in ``(0, num_classes)``. |
| If a scalar, the same threshold is used for all requests. |
| If a tensor, each request has its own threshold. |
| We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. |
| |
| Returns |
| ------- |
| renorm_probs: torch.Tensor |
| Renormalized probabilities, shape ``(batch_size, num_classes)``. |
| |
| Note |
| ---- |
| This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to |
| ``top_k_sampling_from_probs``. |
| """ |
| if probs.device.type == "musa" or not _has_flashinfer: |
| return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k)) |
| else: |
| return _flashinfer_sampling.top_k_renorm_probs(probs, top_k) |
|
|
|
|
| top_k_renorm_prob = top_k_renorm_probs |
|
|
|
|
| def _top_p_renorm_probs_internal( |
| probs: torch.Tensor, |
| maybe_top_p_arr: Optional[torch.Tensor], |
| top_p_val: float, |
| ) -> torch.Tensor: |
| probs = probs.float() |
| maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None |
| renorm_probs = torch.empty_like(probs) |
| torch.ops.sgl_kernel.top_p_renorm_probs.default( |
| probs, renorm_probs, maybe_top_p_arr, top_p_val |
| ) |
| return renorm_probs |
|
|
|
|
| def top_p_renorm_probs( |
| probs: torch.Tensor, |
| top_p: Union[torch.Tensor, float], |
| ) -> torch.Tensor: |
| r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py |
| Fused GPU kernel for renormalizing probabilities by top-p thresholding. |
| |
| Parameters |
| ---------- |
| probs: torch.Tensor |
| Probabilities, shape ``(batch_size, num_classes)``. |
| top_p: Union[torch.Tensor, float] |
| Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for |
| re-normalizing probabilities, should be in ``(0, 1)``. |
| If a scalar, the same threshold is used for all requests. |
| If a tensor, each request has its own threshold. |
| We mask out the probabilities less than `threshold` where the cumulative sum |
| of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. |
| |
| Returns |
| ------- |
| renorm_probs: torch.Tensor |
| Renormalized probabilities, shape ``(batch_size, num_classes)``. |
| |
| Note |
| ---- |
| This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to |
| ``top_p_sampling_from_probs``. |
| |
| """ |
| if probs.device.type == "musa" or not _has_flashinfer: |
| return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p)) |
| else: |
| return _flashinfer_sampling.top_p_renorm_probs(probs, top_p) |
|
|
|
|
| top_p_renorm_prob = top_p_renorm_probs |
|
|
|
|
| def _top_k_mask_logits_internal( |
| logits: torch.Tensor, |
| maybe_top_k_arr: Optional[torch.Tensor], |
| top_k_val: int, |
| ) -> torch.Tensor: |
| logits = logits.float() |
| maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None |
| mask_logits = torch.empty_like(logits) |
| torch.ops.sgl_kernel.top_k_mask_logits.default( |
| logits, mask_logits, maybe_top_k_arr, top_k_val |
| ) |
| return mask_logits |
|
|
|
|
| def top_k_mask_logits( |
| logits: torch.Tensor, |
| top_k: Union[torch.Tensor, int], |
| ) -> torch.Tensor: |
| r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py |
| Fused GPU kernel for masking logits by top-k thresholding. |
| |
| Parameters |
| ---------- |
| logits: torch.Tensor |
| Logits before softmax, shape ``(batch_size, num_classes)``. |
| top_k: Union[torch.Tensor, int] |
| Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for |
| for masking logits, should be in ``(0, num_classes)``. |
| If a scalar, the same threshold is used for all requests. |
| If a tensor, each request has its own threshold. |
| We keep the top-k logits, set the rest to negative infinity. |
| |
| Returns |
| ------- |
| masked_logits: torch.Tensor |
| Masked logits, shape ``(batch_size, num_classes)``. |
| |
| Examples |
| -------- |
| |
| >>> import torch |
| >>> import flashinfer |
| >>> torch.manual_seed(42) |
| >>> batch_size = 4 |
| >>> vocab_size = 5 |
| >>> top_k = 3 |
| >>> logits = torch.randn(batch_size, vocab_size).to(0) |
| >>> logits |
| tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], |
| [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], |
| [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], |
| [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') |
| >>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k) |
| >>> masked_logits |
| tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf], |
| [ 1.0783, 0.8008, 1.6806, -inf, -inf], |
| [ -inf, 0.2415, -0.2316, 0.0418, -inf], |
| [ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0') |
| |
| Note |
| ---- |
| The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``. |
| |
| See Also |
| -------- |
| top_k_renorm_probs |
| """ |
| if logits.device.type == "musa" or not _has_flashinfer: |
| return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k)) |
| else: |
| return _flashinfer_sampling.top_k_mask_logits(logits, top_k) |
|
|