File size: 6,527 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | from typing import Optional, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple
try:
import flashinfer.sampling as _flashinfer_sampling
_has_flashinfer = True
except ImportError:
_has_flashinfer = False
def _top_k_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs.default(
probs, renorm_probs, maybe_top_k_arr, top_k_val
)
return renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for re-normalizing probabilities, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
"""
if probs.device.type == "musa" or not _has_flashinfer:
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
else:
return _flashinfer_sampling.top_k_renorm_probs(probs, top_k)
top_k_renorm_prob = top_k_renorm_probs
def _top_p_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs.default(
probs, renorm_probs, maybe_top_p_arr, top_p_val
)
return renorm_probs
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
re-normalizing probabilities, should be in ``(0, 1)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
"""
if probs.device.type == "musa" or not _has_flashinfer:
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
else:
return _flashinfer_sampling.top_p_renorm_probs(probs, top_p)
top_p_renorm_prob = top_p_renorm_probs
def _top_k_mask_logits_internal(
logits: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
logits = logits.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
mask_logits = torch.empty_like(logits)
torch.ops.sgl_kernel.top_k_mask_logits.default(
logits, mask_logits, maybe_top_k_arr, top_k_val
)
return mask_logits
def top_k_mask_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for masking logits by top-k thresholding.
Parameters
----------
logits: torch.Tensor
Logits before softmax, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for masking logits, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k logits, set the rest to negative infinity.
Returns
-------
masked_logits: torch.Tensor
Masked logits, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> logits = torch.randn(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
>>> masked_logits
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
Note
----
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
See Also
--------
top_k_renorm_probs
"""
if logits.device.type == "musa" or not _has_flashinfer:
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
else:
return _flashinfer_sampling.top_k_mask_logits(logits, top_k)
|