File size: 13,125 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 | from dataclasses import dataclass
from typing import List, Optional
import torch
from sgl_kernel.utils import is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
return out
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
return out
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
return out
if torch.version.hip is not None:
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
"""
Quick-GELU: y = x * sigmoid(1.702 * x)
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
so the last-dimension byte length must be a multiple of 16 bytes.
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError(
f"The last dimension ({input.shape[-1]}) x itemsize "
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
)
if out is not None:
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
else:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gelu_quick(out, input)
return out
@dataclass
class FusedSetKVBufferArg:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value: torch.Tensor
k_buffer: torch.Tensor
v_buffer: torch.Tensor
k_scale: Optional[float]
v_scale: Optional[float]
cache_loc: torch.Tensor
def _view_3d(x, head_size):
return x.view(x.shape[0], -1, head_size)
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
enable_pdl: Optional[bool] = None,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if enable_pdl is None:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
_view_3d(query, head_size),
_view_3d(key, head_size),
_view_3d(query, head_size),
_view_3d(key, head_size),
cos_sin_cache,
positions.long(),
(not is_neox),
enable_pdl,
(
_view_3d(fused_set_kv_buffer_arg.value, head_size)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.k_buffer, head_size)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.v_buffer, head_size)
if fused_set_kv_buffer_arg is not None
else None
),
(
fused_set_kv_buffer_arg.cache_loc
if fused_set_kv_buffer_arg is not None
else None
),
)
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
):
torch.ops.sgl_kernel.rotary_embedding.default(
positions, query, key, head_size, cos_sin_cache, is_neox
)
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset
)
def copy_to_gpu_no_ce(input: torch.Tensor, output: torch.Tensor):
torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)
def concat_mla_k(
k: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
):
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)
def concat_mla_absorb_q(
a: torch.Tensor,
b: torch.Tensor,
):
*batch_dims, _ = a.shape
out = torch.empty(
(*batch_dims, a.shape[-1] + b.shape[-1]), device=a.device, dtype=a.dtype
)
torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out)
return out
|