| from typing import List |
|
|
| import torch |
|
|
|
|
| def is_hip() -> bool: |
| return torch.version.hip is not None |
|
|
|
|
| _is_hip = is_hip() |
|
|
|
|
| def transfer_kv_per_layer( |
| src_k: torch.Tensor, |
| dst_k: torch.Tensor, |
| src_v: torch.Tensor, |
| dst_v: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_per_layer.default( |
| src_k, |
| dst_k, |
| src_v, |
| dst_v, |
| src_indices, |
| dst_indices, |
| item_size, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_per_layer_pf_lf( |
| src_k: torch.Tensor, |
| dst_k: torch.Tensor, |
| src_v: torch.Tensor, |
| dst_v: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| layer_id: int, |
| item_size: int, |
| src_layout_dim: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf.default( |
| src_k, |
| dst_k, |
| src_v, |
| dst_v, |
| src_indices, |
| dst_indices, |
| layer_id, |
| item_size, |
| src_layout_dim, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_per_layer_ph_lf( |
| src_k: torch.Tensor, |
| dst_k: torch.Tensor, |
| src_v: torch.Tensor, |
| dst_v: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| layer_id: int, |
| item_size: int, |
| src_layout_dim: int, |
| page_size: int, |
| head_num: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_per_layer_ph_lf.default( |
| src_k, |
| dst_k, |
| src_v, |
| dst_v, |
| src_indices, |
| dst_indices, |
| layer_id, |
| item_size, |
| src_layout_dim, |
| page_size, |
| head_num, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_all_layer( |
| src_k_layers: torch.Tensor, |
| dst_k_layers: torch.Tensor, |
| src_v_layers: torch.Tensor, |
| dst_v_layers: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| num_layers: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_all_layer.default( |
| src_k_layers, |
| dst_k_layers, |
| src_v_layers, |
| dst_v_layers, |
| src_indices, |
| dst_indices, |
| item_size, |
| num_layers, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_all_layer_lf_pf( |
| src_k_layers: torch.Tensor, |
| dst_k: torch.Tensor, |
| src_v_layers: torch.Tensor, |
| dst_v: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| dst_layout_dim: int, |
| num_layers: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf.default( |
| src_k_layers, |
| dst_k, |
| src_v_layers, |
| dst_v, |
| src_indices, |
| dst_indices, |
| item_size, |
| dst_layout_dim, |
| num_layers, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_all_layer_lf_ph( |
| src_k_layers: torch.Tensor, |
| dst_k: torch.Tensor, |
| src_v_layers: torch.Tensor, |
| dst_v: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| dst_layout_dim: int, |
| num_layers: int, |
| page_size: int, |
| head_num: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_all_layer_lf_ph.default( |
| src_k_layers, |
| dst_k, |
| src_v_layers, |
| dst_v, |
| src_indices, |
| dst_indices, |
| item_size, |
| dst_layout_dim, |
| num_layers, |
| page_size, |
| head_num, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_direct( |
| src_layers: List[torch.Tensor], |
| dst_layers: List[torch.Tensor], |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| page_size: int, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_direct.default( |
| src_layers, dst_layers, src_indices, dst_indices, page_size |
| ) |
|
|
|
|
| def transfer_kv_per_layer_direct_pf_lf( |
| src_ptrs: List[torch.Tensor], |
| dst_ptrs: List[torch.Tensor], |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| layer_id: int, |
| page_size: int, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf.default( |
| src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size |
| ) |
|
|
|
|
| def transfer_kv_all_layer_direct_lf_pf( |
| src_ptrs: List[torch.Tensor], |
| dst_ptrs: List[torch.Tensor], |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| page_size: int, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf.default( |
| src_ptrs, dst_ptrs, src_indices, dst_indices, page_size |
| ) |
|
|
|
|
| def transfer_kv_per_layer_mla( |
| src: torch.Tensor, |
| dst: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_per_layer_mla.default( |
| src, |
| dst, |
| src_indices, |
| dst_indices, |
| item_size, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_per_layer_mla_pf_lf( |
| src: torch.Tensor, |
| dst: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| layer_id: int, |
| item_size: int, |
| src_layout_dim: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf.default( |
| src, |
| dst, |
| src_indices, |
| dst_indices, |
| layer_id, |
| item_size, |
| src_layout_dim, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_all_layer_mla( |
| src_layers: torch.Tensor, |
| dst_layers: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| num_layers: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_all_layer_mla.default( |
| src_layers, |
| dst_layers, |
| src_indices, |
| dst_indices, |
| item_size, |
| num_layers, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|
|
|
| def transfer_kv_all_layer_mla_lf_pf( |
| src_layers: torch.Tensor, |
| dst: torch.Tensor, |
| src_indices: torch.Tensor, |
| dst_indices: torch.Tensor, |
| item_size: int, |
| dst_layout_dim: int, |
| num_layers: int, |
| block_quota: int = 2, |
| num_warps_per_block: int = 16 if _is_hip else 32, |
| ): |
| torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf.default( |
| src_layers, |
| dst, |
| src_indices, |
| dst_indices, |
| item_size, |
| dst_layout_dim, |
| num_layers, |
| block_quota, |
| num_warps_per_block, |
| ) |
|
|